Permalink
Browse files

fix DFLJPG,

SAE: added "rare sample booster"
SAE: pixel loss replaced to smooth transition from DSSIM to PixelLoss in 15k epochs by default
  • Loading branch information...
iperov committed Feb 9, 2019
1 parent f93b471 commit 4d37fd62cd21757ff6fda654f7b38d4d5fa641b6
@@ -199,7 +199,7 @@ def onSave(self):
pass

#overridable
def onTrainOneEpoch(self, sample):
def onTrainOneEpoch(self, sample, generator_list):
#train your keras models here

#return array of losses
@@ -293,7 +293,8 @@ def debug_one_epoch(self):
images = []
for generator in self.generator_list:
for i,batch in enumerate(next(generator)):
images.append( batch[0] )
if len(batch.shape) == 4:
images.append( batch[0] )

return image_utils.equalize_and_stack_square (images)

@@ -305,14 +306,12 @@ def train_one_epoch(self):
supressor = std_utils.suppress_stdout_stderr()
supressor.__enter__()

self.last_sample = self.generate_next_sample()

epoch_time = time.time()

losses = self.onTrainOneEpoch(self.last_sample)

sample = self.generate_next_sample()
epoch_time = time.time()
losses = self.onTrainOneEpoch(sample, self.generator_list)
epoch_time = time.time() - epoch_time

self.last_sample = sample

self.loss_history.append ( [float(loss[1]) for loss in losses] )

if self.supress_std_once:
@@ -55,7 +55,7 @@ def onSave(self):
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] )

#override
def onTrainOneEpoch(self, sample):
def onTrainOneEpoch(self, sample, generators_list):
warped_src, target_src, target_src_mask = sample[0]
warped_dst, target_dst, target_dst_mask = sample[1]

@@ -73,7 +73,7 @@ def onSave(self):
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]])

#override
def onTrainOneEpoch(self, sample):
def onTrainOneEpoch(self, sample, generators_list):
warped_src, target_src, target_src_mask = sample[0]
warped_dst, target_dst, target_dst_mask = sample[1]

@@ -75,7 +75,7 @@ def onSave(self):
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] )

#override
def onTrainOneEpoch(self, sample):
def onTrainOneEpoch(self, sample, generators_list):
warped_src, target_src, target_src_full_mask = sample[0]
warped_dst, target_dst, target_dst_full_mask = sample[1]

@@ -64,7 +64,7 @@ def onSave(self):
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)]] )

#override
def onTrainOneEpoch(self, sample):
def onTrainOneEpoch(self, sample, generators_list):
warped_src, target_src, target_src_mask = sample[0]
warped_dst, target_dst, target_dst_mask = sample[1]

@@ -52,13 +52,7 @@ def onInitializeOptions(self, is_first_run, ask_override):
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces. If style is learned good enough, set this value to 0.1-0.3 to prevent artifacts appearing."), 0.0, 100.0 )
else:
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)

if is_first_run or ask_override:
default_pixel_loss = False if is_first_run else self.options.get('pixel_loss', False)
self.options['pixel_loss'] = input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", default_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 30-40k epochs to enhance fine details.")
else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False)


default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
default_ed_ch_dims = 42
if is_first_run:
@@ -83,6 +77,7 @@ def onInitialize(self, **in_options):
bgr_shape = (resolution, resolution, 3)
mask_shape = (resolution, resolution, 1)

dssim_pixel_alpha = Input( (1,) )
warped_src = Input(bgr_shape)
target_src = Input(bgr_shape)
target_srcm = Input(mask_shape)
@@ -199,6 +194,7 @@ def onInitialize(self, **in_options):
def optimizer():
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)

dssim_pixel_alpha_value = dssim_pixel_alpha[0][0]

if self.options['archi'] == 'liae':
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
@@ -208,29 +204,29 @@ def optimizer():
src_dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
if self.options['learn_mask']:
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights

if self.options['pixel_loss']:
src_loss = sum([ K.mean( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] )) for i in range(len(target_src_masked_ar)) ])
else:
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])

src_dssim_loss_batch = sum([ ( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
src_pixel_loss_batch = sum([ tf_reduce_mean ( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar)) ])

src_loss_batch = src_dssim_loss_batch*(1.0-dssim_pixel_alpha_value) + src_pixel_loss_batch*dssim_pixel_alpha_value
src_loss = K.mean(src_loss_batch)

if self.options['face_style_power'] != 0:
face_style_power = self.options['face_style_power'] / 100.0
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )

if self.options['bg_style_power'] != 0:
bg_style_power = self.options['bg_style_power'] / 100.0
if self.options['pixel_loss']:
src_loss += K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
else:
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))

if self.options['pixel_loss']:
dst_loss = sum([ K.mean( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] )) for i in range(len(target_dst_masked_ar)) ])
else:
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])

self.src_dst_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss,dst_loss], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
bg_style_power = self.options['bg_style_power'] / 100.0
bg_dssim_loss = K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
bg_pixel_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
src_loss += bg_dssim_loss*(1.0-dssim_pixel_alpha_value) + bg_pixel_loss*dssim_pixel_alpha_value

dst_dssim_loss_batch = sum([ ( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
dst_pixel_loss_batch = sum([ tf_reduce_mean ( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar)) ])
dst_loss_batch = dst_dssim_loss_batch*(1.0-dssim_pixel_alpha_value) + dst_pixel_loss_batch*dssim_pixel_alpha_value
dst_loss = K.mean(dst_loss_batch)

self.src_dst_train = K.function ([dssim_pixel_alpha, warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss,dst_loss,src_loss_batch,dst_loss_batch], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )


if self.options['learn_mask']:
@@ -250,6 +246,9 @@ def optimizer():
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])

if self.is_training_mode:
self.src_sample_losses = []
self.dst_sample_losses = []

f = SampleProcessor.TypeFlags
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
self.set_training_data_generators ([
@@ -259,14 +258,14 @@ def optimizer():
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
] ),
], add_sample_idx=True ),

SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
] )
], add_sample_idx=True )
])
#override
def onSave(self):
@@ -289,13 +288,39 @@ def onSave(self):

self.save_weights_safe(ar)


#override
def onTrainOneEpoch(self, sample):
warped_src, target_src, target_src_mask = sample[0]
warped_dst, target_dst, target_dst_mask = sample[1]
def onTrainOneEpoch(self, generators_samples, generators_list):
warped_src, target_src, target_src_mask, src_sample_idxs = generators_samples[0]
warped_dst, target_dst, target_dst_mask, dst_sample_idxs = generators_samples[1]

dssim_pixel_alpha = np.clip ( self.epoch / 15000.0, 0.0, 1.0 ) #smooth transition between DSSIM and MSE in 15k epochs
dssim_pixel_alpha = np.repeat( dssim_pixel_alpha, (self.batch_size,) )
dssim_pixel_alpha = np.expand_dims(dssim_pixel_alpha,-1)

src_loss, dst_loss = self.src_dst_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])

#gathering array of sample_losses
self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]

if len(self.src_sample_losses) >= 48: #array is big enough
#fetching idxs which losses are bigger than average
x = np.array (self.src_sample_losses)
self.src_sample_losses = []
b = x[:,1]
idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs


if len(self.dst_sample_losses) >= 48: #array is big enough
#fetching idxs which losses are bigger than average
x = np.array (self.dst_sample_losses)
self.dst_sample_losses = []
b = x[:,1]
idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs

if self.options['learn_mask']:
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask])

@@ -453,6 +478,9 @@ def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42):
strides = resolution // 32 if adapt_k_size else 2
lowest_dense_res = resolution // 16

def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )

def downscale (dim):
def func(x):
return LeakyReLU(0.1)(Conv2D(dim, k_size, strides=strides, padding='same')(x))
@@ -496,6 +524,10 @@ def DFDecFlow(output_nc, ed_ch_dims=21):
exec (nnlib.import_all(), locals(), globals())
ed_dims = output_nc * ed_ch_dims


def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )

def upscale (dim):
def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
@@ -68,7 +68,7 @@ def __init__ (self, force_gpu_idx = -1,

@staticmethod
def getDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb):
if not hasNVML and totalmemsize_gb <= 2:
if not hasNVML:
return [0]

result = []
@@ -52,6 +52,7 @@ class nnlib(object):
tf = nnlib.tf
tf_sess = nnlib.tf_sess
tf_reduce_mean = tf.reduce_mean # todo tf 12+ = tf.math.reduce_mean
tf_total_variation = tf.image.total_variation
tf_dssim = nnlib.tf_dssim
tf_ssim = nnlib.tf_ssim
Oops, something went wrong.

0 comments on commit 4d37fd6

Please sign in to comment.