Permalink
Browse files

SAE: remove rare sample booster. DSSIM->MSE transition now between 5-…

…20k epochs.
  • Loading branch information...
iperov committed Feb 11, 2019
1 parent f8e6397 commit 470fb9287a1f84a06b513492315e72c7d8a48870
Showing with 24 additions and 21 deletions.
  1. +23 −21 models/Model_SAE/Model.py
  2. +1 −0 samples/SampleGeneratorFace.py
@@ -302,32 +302,34 @@ def onTrainOneEpoch(self, generators_samples, generators_list):
warped_src, target_src, target_src_mask, src_sample_idxs = generators_samples[0]
warped_dst, target_dst, target_dst_mask, dst_sample_idxs = generators_samples[1]

dssim_pixel_alpha = np.clip ( self.epoch / 15000.0, 0.0, 1.0 ) #smooth transition between DSSIM and MSE in 15k epochs
dssim_pixel_alpha = np.clip ( (self.epoch - 5000) / 15000.0, 0.0, 1.0 ) #smooth transition between DSSIM and MSE in 5-20k epochs
dssim_pixel_alpha = np.repeat( dssim_pixel_alpha, (self.batch_size,) )
dssim_pixel_alpha = np.expand_dims(dssim_pixel_alpha,-1)

src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])

#gathering array of sample_losses
self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]

if len(self.src_sample_losses) >= 48: #array is big enough
#fetching idxs which losses are bigger than average
x = np.array (self.src_sample_losses)
self.src_sample_losses = []
b = x[:,1]
idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs


if len(self.dst_sample_losses) >= 48: #array is big enough
#fetching idxs which losses are bigger than average
x = np.array (self.dst_sample_losses)
self.dst_sample_losses = []
b = x[:,1]
idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
# 'worst' sample booster gives no good result, or I dont know how to filter worst samples properly.
#
##gathering array of sample_losses
#self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
#self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]
#
#if len(self.src_sample_losses) >= 48: #array is big enough
# #fetching idxs which losses are bigger than average
# x = np.array (self.src_sample_losses)
# self.src_sample_losses = []
# b = x[:,1]
# idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
# generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
#
#
#if len(self.dst_sample_losses) >= 48: #array is big enough
# #fetching idxs which losses are bigger than average
# x = np.array (self.dst_sample_losses)
# self.dst_sample_losses = []
# b = x[:,1]
# idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
# generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs

if self.options['learn_mask']:
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask])
@@ -54,6 +54,7 @@ def __next__(self):
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)

#forces to repeat these sample idxs as fast as possible
def repeat_sample_idxs(self, idxs): # [ idx, ... ]
#send idxs list to all sub generators.
for gen_sq in self.generators_sq:

0 comments on commit 470fb92

Please sign in to comment.