Skip to content
This repository has been archived by the owner on Nov 9, 2023. It is now read-only.

Commit

Permalink
fixes and optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
iperov committed Jan 7, 2020
1 parent 842a489 commit d3e6b43
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 57 deletions.
16 changes: 11 additions & 5 deletions models/Model_AVATAR/Model.py
Expand Up @@ -58,13 +58,12 @@ def onInitialize(self, batch_size=-1, **in_options):
self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape))
self.C = modelify(AVATARModel.ResNet (9, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape))

if self.is_first_run():
conv_weights_list = []
self.CA_conv_weights_list = []
if self.is_first_run():
for model, _ in self.get_model_filename_list():
for layer in model.layers:
if type(layer) == keras.layers.Conv2D:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
CAInitializerMP ( conv_weights_list )
self.CA_conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights

if not self.is_first_run():
self.load_weights_safe( self.get_model_filename_list() )
Expand Down Expand Up @@ -247,7 +246,14 @@ def get_model_filename_list(self):
#override
def onSave(self):
self.save_weights_safe( self.get_model_filename_list() )


#override
def on_success_train_one_iter(self):
if len(self.CA_conv_weights_list) != 0:
exec(nnlib.import_all(), locals(), globals())
CAInitializerMP ( self.CA_conv_weights_list )
self.CA_conv_weights_list = []

#override
def onTrainOneIter(self, generators_samples, generators_list):
warped_src64, src64, src64m = generators_samples[0]
Expand Down
42 changes: 26 additions & 16 deletions samplelib/SampleGeneratorFace.py
@@ -1,9 +1,9 @@
import multiprocessing
import traceback

import pickle
import cv2
import numpy as np

import time
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType)
Expand All @@ -23,6 +23,7 @@ def __init__ (self, samples_path, debug=False, batch_size=1,
sample_process_options=SampleProcessor.Options(),
output_sample_types=[],
add_sample_idx=False,
generators_count=4,
**kwargs):

super().__init__(samples_path, debug, batch_size)
Expand All @@ -33,27 +34,30 @@ def __init__ (self, samples_path, debug=False, batch_size=1,
if self.debug:
self.generators_count = 1
else:
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 6)
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, generators_count)

samples_clis = SampleHost.host (SampleType.FACE, self.samples_path, number_of_clis=self.generators_count)
self.samples_len = len(samples_clis[0])
samples = SampleHost.load (SampleType.FACE, self.samples_path)
self.samples_len = len(samples)

if self.samples_len == 0:
raise ValueError('No training data provided.')

index_host = mp_utils.IndexHost(self.samples_len)

if random_ct_samples_path is not None:
ct_samples_clis = SampleHost.host (SampleType.FACE, random_ct_samples_path, number_of_clis=self.generators_count)
ct_index_host = mp_utils.IndexHost( len(ct_samples_clis[0]) )
ct_samples = SampleHost.load (SampleType.FACE, random_ct_samples_path)
ct_index_host = mp_utils.IndexHost( len(ct_samples) )
else:
ct_samples_clis = None
ct_samples = None
ct_index_host = None

pickled_samples = pickle.dumps(samples, 4)
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None

if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_clis[0], index_host.create_cli(), ct_samples_clis[0] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_clis[i], index_host.create_cli(), ct_samples_clis[i] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]

self.generator_counter = -1

Expand All @@ -70,21 +74,26 @@ def __next__(self):
return next(generator)

def batch_func(self, param ):
samples, index_host, ct_samples, ct_index_host = param
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param

samples = pickle.loads(pickled_samples)
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None

bs = self.batch_size
while True:
batches = None

indexes = index_host.multi_get(bs)
ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None

batch_samples = samples.multi_get (indexes)
batch_ct_samples = ct_samples.multi_get (ct_indexes) if ct_samples is not None else None

t = time.time()
for n_batch in range(bs):
sample_idx = indexes[n_batch]
sample = batch_samples[n_batch]
ct_sample = batch_ct_samples[n_batch] if ct_samples is not None else None
sample = samples[sample_idx]

ct_sample = None
if ct_samples is not None:
ct_sample = ct_samples[ct_indexes[n_batch]]

try:
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
Expand All @@ -102,4 +111,5 @@ def batch_func(self, param ):

if self.add_sample_idx:
batches[i_sample_idx].append (sample_idx)

yield [ np.array(batch) for batch in batches]
37 changes: 21 additions & 16 deletions samplelib/SampleGeneratorFaceTemporal.py
@@ -1,10 +1,13 @@
import pickle
import traceback
import numpy as np

import cv2
import numpy as np

from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType)
from utils import iter_utils

from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase

'''
output_sample_types = [
Expand All @@ -24,14 +27,18 @@ def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sampl
self.generators_count = 1
else:
self.generators_count = generators_count

samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)

samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')

pickled_samples = pickle.dumps(samples, 4)
if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, pickled_samples) )]
else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ]
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, pickled_samples) ) for i in range(self.generators_count) ]

self.generator_counter = -1

def __iter__(self):
Expand All @@ -43,22 +50,20 @@ def __next__(self):
return next(generator)

def batch_func(self, param):
generator_id, samples = param

generator_id, pickled_samples = param
samples = pickle.loads(pickled_samples)
samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')


mult_max = 1
l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )

samples_idxs = [ *range(l+1) ] [generator_id::self.generators_count]
samples_idxs = [ *range(l+1) ]

if len(samples_idxs) - self.temporal_image_count < 0:
raise ValueError('Not enough samples to fit temporal line.')

shuffle_idxs = []

while True:
batches = None
for n_batch in range(self.batch_size):
Expand Down
29 changes: 11 additions & 18 deletions samplelib/SampleHost.py
Expand Up @@ -2,7 +2,7 @@
import operator
import traceback
from pathlib import Path

import pickle
import samplelib.PackedFaceset
from DFLIMG import *
from facelib import FaceType, LandmarksProcessor
Expand Down Expand Up @@ -35,7 +35,7 @@ def get_person_id_max_count(samples_path):
return len(list(persons_name_idxs.keys()))

@staticmethod
def host(sample_type, samples_path, number_of_clis):
def load(sample_type, samples_path):
samples_cache = SampleHost.samples_cache

if str(samples_path) not in samples_cache.keys():
Expand All @@ -46,10 +46,8 @@ def host(sample_type, samples_path, number_of_clis):
if sample_type == SampleType.IMAGE:
if samples[sample_type] is None:
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
elif sample_type == SampleType.FACE or \
sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = None


elif sample_type == SampleType.FACE:
if samples[sample_type] is None:
try:
result = samplelib.PackedFaceset.load(samples_path)
Expand All @@ -61,18 +59,13 @@ def host(sample_type, samples_path, number_of_clis):

if result is None:
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )

if sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleHost.upgradeToFaceTemporalSortedSamples(result)

samples[sample_type] = mp_utils.ListHost(result)

list_host = samples[sample_type]

clis = [ list_host.create_cli() for _ in range(number_of_clis) ]

return clis

samples[sample_type] = result

elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleHost.load (SampleType.FACE, samples_path)
result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
samples[sample_type] = result

return samples[sample_type]

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion utils/iter_utils.py
Expand Up @@ -22,7 +22,7 @@ def __next__(self):
return next(self.generator_func)

class SubprocessGenerator(object):
def __init__(self, generator_func, user_param=None, prefetch=3, start_now=False):
def __init__(self, generator_func, user_param=None, prefetch=2, start_now=False):
super().__init__()
self.prefetch = prefetch
self.generator_func = generator_func
Expand Down
2 changes: 1 addition & 1 deletion utils/mp_utils.py
Expand Up @@ -125,7 +125,7 @@ def host_thread(self, indexes_count):
result.append(shuffle_idxs.pop())
self.cqs[cq_id].put (result)

time.sleep(0.005)
time.sleep(0.001)

def create_cli(self):
cq = multiprocessing.Queue()
Expand Down

0 comments on commit d3e6b43

Please sign in to comment.