Skip to content
This repository has been archived by the owner on Nov 9, 2023. It is now read-only.

Commit

Permalink
Trainer: added option for all models
Browse files Browse the repository at this point in the history
Enable autobackup? (y/n ?:help skip:%s) :
Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01

SAE: added option only for CUDA builds:
Enable gradient clipping? (y/n, ?:help skip:%s) :
Gradient clipping reduces chance of model collapse, sacrificing speed of training.
  • Loading branch information
iperov committed Jun 20, 2019
1 parent ea1d59f commit 8484060
Show file tree
Hide file tree
Showing 14 changed files with 211 additions and 81 deletions.
Binary file modified doc/manual_en_google_translated.docx
Binary file not shown.
Binary file modified doc/manual_en_google_translated.pdf
Binary file not shown.
Binary file modified doc/manual_ru.pdf
Binary file not shown.
Binary file modified doc/manual_ru_source.docx
Binary file not shown.
171 changes: 121 additions & 50 deletions models/ModelBase.py
@@ -1,26 +1,30 @@
import os
import json
import time
import colorsys
import inspect
import json
import os
import pickle
import colorsys
import imagelib
import shutil
import time
from pathlib import Path
from utils import Path_utils
from utils import std_utils
from utils.cv2_utils import *
import numpy as np

import cv2
from samplelib import SampleGeneratorBase
from nnlib import nnlib
import numpy as np

import imagelib
from interact import interact as io
from nnlib import nnlib
from samplelib import SampleGeneratorBase
from utils import Path_utils, std_utils
from utils.cv2_utils import *

'''
You can implement your own model. Check examples.
'''
class ModelBase(object):


def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
ask_enable_autobackup=True,
ask_write_preview_history=True,
ask_target_iter=True,
ask_batch_size=True,
Expand Down Expand Up @@ -84,6 +88,12 @@ def __init__(self, model_path, training_data_src_path=None, training_data_dst_pa
if self.iter == 0:
io.log_info ("\nModel first run. Enter model options as default for each run.")

if ask_enable_autobackup and (self.iter == 0 or ask_override):
default_autobackup = False if self.iter == 0 else self.options.get('autobackup',False)
self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]) , default_autobackup, help_message="Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01")
else:
self.options['autobackup'] = self.options.get('autobackup', False)

if ask_write_preview_history and (self.iter == 0 or ask_override):
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
Expand Down Expand Up @@ -127,6 +137,10 @@ def __init__(self, model_path, training_data_src_path=None, training_data_dst_pa
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
else:
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)

self.autobackup = self.options.get('autobackup', False)
if not self.autobackup and 'autobackup' in self.options:
self.options.pop('autobackup')

self.write_preview_history = self.options.get('write_preview_history', False)
if not self.write_preview_history and 'write_preview_history' in self.options:
Expand Down Expand Up @@ -160,8 +174,16 @@ def __init__(self, model_path, training_data_src_path=None, training_data_dst_pa
if self.is_training_mode:
if self.device_args['force_gpu_idx'] == -1:
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
self.autobackups_path = self.model_path / ( '%s_autobackups' % (self.get_model_name()) )
else:
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
self.autobackups_path = self.model_path / ( '%d_%s_autobackups' % (self.device_args['force_gpu_idx'], self.get_model_name()) )

if self.autobackup:
self.autobackup_current_hour = time.localtime().tm_hour

if not self.autobackups_path.exists():
self.autobackups_path.mkdir(exist_ok=True)

if self.write_preview_history or io.is_colab():
if not self.preview_history_path.exists():
Expand Down Expand Up @@ -205,8 +227,8 @@ def __init__(self, model_path, training_data_src_path=None, training_data_dst_pa

io.destroy_window(wnd_name)
else:
self.sample_for_preview = self.generate_next_sample()

self.sample_for_preview = self.generate_next_sample()
self.last_sample = self.sample_for_preview
model_summary_text = []

model_summary_text += ["===== Model summary ====="]
Expand Down Expand Up @@ -277,6 +299,10 @@ def onGetPreview(self, sample):
def get_model_name(self):
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]

#overridable , return [ [model, filename],... ] list
def get_model_filename_list(self):
return []

#overridable
def get_converter(self):
raise NotImplementedError
Expand Down Expand Up @@ -314,7 +340,8 @@ def get_static_preview(self):
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr

def save(self):
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
summary_path = self.get_strpath_storage_for_file('summary.txt')
Path( summary_path ).write_text(self.model_summary_text)
self.onSave()

model_data = {
Expand All @@ -325,6 +352,44 @@ def save(self):
}
self.model_data_path.write_bytes( pickle.dumps(model_data) )

bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ]
bckp_filename_list += [ str(summary_path), str(self.model_data_path) ]

if self.autobackup:
current_hour = time.localtime().tm_hour
if self.autobackup_current_hour != current_hour:
self.autobackup_current_hour = current_hour

for i in range(15,0,-1):
idx_str = '%.2d' % i
next_idx_str = '%.2d' % (i+1)

idx_backup_path = self.autobackups_path / idx_str
next_idx_packup_path = self.autobackups_path / next_idx_str

if idx_backup_path.exists():
if i == 15:
Path_utils.delete_all_files(idx_backup_path)
else:
next_idx_packup_path.mkdir(exist_ok=True)
Path_utils.move_all_files (idx_backup_path, next_idx_packup_path)

if i == 1:
idx_backup_path.mkdir(exist_ok=True)
for filename in bckp_filename_list:
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )

previews = self.get_previews()
plist = []
for i in range(len(previews)):
name, bgr = previews[i]
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]

for preview, filepath in plist:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )

def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
for model, filename in model_filename_list:
filename = self.get_strpath_storage_for_file(filename)
Expand All @@ -349,12 +414,16 @@ def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
print ("Unable to load ", opt_filename)


def save_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
def save_weights_safe(self, model_filename_list):
for model, filename in model_filename_list:
filename = self.get_strpath_storage_for_file(filename)
model.save_weights( filename + '.tmp' )

rename_list = model_filename_list

"""
#unused
, optimizer_filename_list=[]
if len(optimizer_filename_list) != 0:
opt_filename = self.get_strpath_storage_for_file('opt.h5')
Expand All @@ -374,7 +443,8 @@ def save_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
rename_list += [('', 'opt.h5')]
except Exception as e:
print ("Unable to save ", opt_filename)

"""

for _, filename in rename_list:
filename = self.get_strpath_storage_for_file(filename)
source_filename = Path(filename+'.tmp')
Expand All @@ -383,8 +453,7 @@ def save_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
if target_filename.exists():
target_filename.unlink()
source_filename.rename ( str(target_filename) )



def debug_one_iter(self):
images = []
for generator in self.generator_list:
Expand Down Expand Up @@ -490,45 +559,47 @@ def get_loss_history_preview(loss_history, iter, w, c):

lh_height = 100
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
loss_count = len(loss_history[0])
lh_len = len(loss_history)

l_per_col = lh_len / w
plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p],
*[ loss_history[i_ab][p]
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
]
)
for p in range(loss_count)

if len(loss_history) != 0:
loss_count = len(loss_history[0])
lh_len = len(loss_history)

l_per_col = lh_len / w
plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p],
*[ loss_history[i_ab][p]
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
]
)
for p in range(loss_count)
]
for col in range(w)
]
for col in range(w)
]

plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p],
*[ loss_history[i_ab][p]
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
]
)
for p in range(loss_count)

plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p],
*[ loss_history[i_ab][p]
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
]
)
for p in range(loss_count)
]
for col in range(w)
]
for col in range(w)
]

plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2
plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2

for col in range(0, w):
for p in range(0,loss_count):
point_color = [1.0]*c
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )
for col in range(0, w):
for p in range(0,loss_count):
point_color = [1.0]*c
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )

ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
ph_max = np.clip( ph_max, 0, lh_height-1 )
ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
ph_max = np.clip( ph_max, 0, lh_height-1 )

ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
ph_min = np.clip( ph_min, 0, lh_height-1 )
ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
ph_min = np.clip( ph_min, 0, lh_height-1 )

for ph in range(ph_min, ph_max+1):
lh_img[ (lh_height-ph-1), col ] = point_color
for ph in range(ph_min, ph_max+1):
lh_img[ (lh_height-ph-1), col ] = point_color

lh_lines = 5
lh_line_height = (lh_height-1)/lh_lines
Expand Down
1 change: 1 addition & 0 deletions models/Model_DEV_FANSEG/Model.py
Expand Up @@ -11,6 +11,7 @@ class Model(ModelBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs,
ask_enable_autobackup=False,
ask_write_preview_history=False,
ask_target_iter=False,
ask_sort_by_yaw=False,
Expand Down
1 change: 1 addition & 0 deletions models/Model_DEV_POSEEST/Model.py
Expand Up @@ -12,6 +12,7 @@ class Model(ModelBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs,
ask_enable_autobackup=False,
ask_write_preview_history=False,
ask_target_iter=False,
ask_sort_by_yaw=False,
Expand Down
11 changes: 8 additions & 3 deletions models/Model_DF/Model.py
Expand Up @@ -59,11 +59,16 @@ def onInitialize(self):
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types=output_sample_types)
])

#override
def get_model_filename_list(self):
return [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']]

#override
def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']] )
self.save_weights_safe( self.get_model_filename_list() )

#override
def onTrainOneIter(self, sample, generators_list):
Expand Down
10 changes: 7 additions & 3 deletions models/Model_H128/Model.py
Expand Up @@ -70,11 +70,15 @@ def onInitialize(self):
output_sample_types=output_sample_types )
])

#override
def get_model_filename_list(self):
return [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']]

#override
def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']] )
self.save_weights_safe( self.get_model_filename_list() )

#override
def onTrainOneIter(self, sample, generators_list):
Expand Down
10 changes: 7 additions & 3 deletions models/Model_H64/Model.py
Expand Up @@ -71,11 +71,15 @@ def onInitialize(self):
output_sample_types=output_sample_types)
])

#override
def get_model_filename_list(self):
return [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']]

#override
def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'],
[self.decoder_dst, 'decoder_dst.h5']] )
self.save_weights_safe( self.get_model_filename_list() )

#override
def onTrainOneIter(self, sample, generators_list):
Expand Down
12 changes: 8 additions & 4 deletions models/Model_LIAEF128/Model.py
Expand Up @@ -65,12 +65,16 @@ def onInitialize(self):
output_sample_types=output_sample_types)
])

#override
def get_model_filename_list(self):
return [[self.encoder, 'encoder.h5'],
[self.decoder, 'decoder.h5'],
[self.inter_B, 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5']]

#override
def onSave(self):
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
[self.decoder, 'decoder.h5'],
[self.inter_B, 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5']] )
self.save_weights_safe( self.get_model_filename_list() )

#override
def onTrainOneIter(self, sample, generators_list):
Expand Down
16 changes: 10 additions & 6 deletions models/Model_RecycleGAN/Model.py
Expand Up @@ -201,14 +201,18 @@ def opt():
else:
self.G_convert = K.function([real_B0],[fake_A0])

#override
def get_model_filename_list(self):
return [ [self.GA, 'GA.h5'],
[self.GB, 'GB.h5'],
[self.DA, 'DA.h5'],
[self.DB, 'DB.h5'],
[self.PA, 'PA.h5'],
[self.PB, 'PB.h5'] ]

#override
def onSave(self):
self.save_weights_safe( [[self.GA, 'GA.h5'],
[self.GB, 'GB.h5'],
[self.DA, 'DA.h5'],
[self.DB, 'DB.h5'],
[self.PA, 'PA.h5'],
[self.PB, 'PB.h5'] ])
self.save_weights_safe( self.get_model_filename_list() )

#override
def onTrainOneIter(self, generators_samples, generators_list):
Expand Down

0 comments on commit 8484060

Please sign in to comment.