From 813dd897ec41c4d3d518e170c52415571d6adacc Mon Sep 17 00:00:00 2001 From: moonbings Date: Fri, 11 Nov 2022 00:31:55 +0900 Subject: [PATCH 1/3] Update seed feature --- synthtiger/__init__.py | 12 ++++++++++- synthtiger/gen.py | 36 ++++++++++++++++++++++---------- synthtiger/main.py | 1 + synthtiger/templates/template.py | 2 ++ 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/synthtiger/__init__.py b/synthtiger/__init__.py index c71fc86..2a382e4 100644 --- a/synthtiger/__init__.py +++ b/synthtiger/__init__.py @@ -6,7 +6,14 @@ from synthtiger import components, layers, templates, utils from synthtiger._version import __version__ -from synthtiger.gen import generator, read_config, read_template +from synthtiger.gen import ( + generator, + get_global_random_states, + read_config, + read_template, + set_global_random_seed, + set_global_random_states, +) __all__ = [ "components", @@ -14,6 +21,9 @@ "templates", "utils", "generator", + "get_global_random_states", "read_config", "read_template", + "set_global_random_seed", + "set_global_random_states", ] diff --git a/synthtiger/gen.py b/synthtiger/gen.py index 65baa37..1f7660e 100644 --- a/synthtiger/gen.py +++ b/synthtiger/gen.py @@ -62,6 +62,27 @@ def generator(path, name, config=None, count=None, worker=0, seed=None, verbose= yield task_idx, data +def get_global_random_states(): + states = { + "random": random.getstate(), + "numpy": np.random.get_state(), + "imgaug": imgaug.random.get_global_rng().state, + } + return states + + +def set_global_random_states(states): + random.setstate(states["random"]) + np.random.set_state(states["numpy"]) + imgaug.random.get_global_rng().state = states["imgaug"] + + +def set_global_random_seed(seed): + random.seed(seed) + np.random.set_state(np.random.RandomState(np.random.MT19937(seed)).get_state()) + imgaug.seed(seed) + + def _run(func, args): proc = Process(target=func, args=args) proc.daemon = True @@ -89,13 +110,9 @@ def _worker(path, name, config, task_queue, data_queue, verbose): def _generate(template, seed, verbose): - temp_state = random.getstate() - temp_np_state = np.random.get_state() - temp_imgaug_state = imgaug.random.get_global_rng().state - - random.seed(seed) - np.random.set_state(np.random.RandomState(np.random.MT19937(seed)).get_state()) - imgaug.seed(seed) + states = get_global_random_states() + set_global_random_seed(seed) + template.seed = seed while True: try: @@ -106,8 +123,5 @@ def _generate(template, seed, verbose): continue break - random.setstate(temp_state) - np.random.set_state(temp_np_state) - imgaug.random.get_global_rng().state = temp_imgaug_state - + set_global_random_states(states) return data diff --git a/synthtiger/main.py b/synthtiger/main.py index d752d72..d1620ce 100644 --- a/synthtiger/main.py +++ b/synthtiger/main.py @@ -17,6 +17,7 @@ def run(args): pprint.pprint(config) + synthtiger.set_global_random_seed(args.seed) template = synthtiger.read_template(args.script, args.name, config) generator = synthtiger.generator( args.script, diff --git a/synthtiger/templates/template.py b/synthtiger/templates/template.py index 276773a..2983dff 100644 --- a/synthtiger/templates/template.py +++ b/synthtiger/templates/template.py @@ -8,6 +8,8 @@ class Template(ABC): + seed = None + def __init__(self, config=None): pass From bc688d7587b8d4242d395fd9a7c437896c704ed1 Mon Sep 17 00:00:00 2001 From: moonbings Date: Fri, 11 Nov 2022 01:13:24 +0900 Subject: [PATCH 2/3] Fix minor --- synthtiger/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synthtiger/main.py b/synthtiger/main.py index d1620ce..ade65d8 100644 --- a/synthtiger/main.py +++ b/synthtiger/main.py @@ -17,7 +17,6 @@ def run(args): pprint.pprint(config) - synthtiger.set_global_random_seed(args.seed) template = synthtiger.read_template(args.script, args.name, config) generator = synthtiger.generator( args.script, @@ -29,6 +28,9 @@ def run(args): verbose=args.verbose, ) + synthtiger.set_global_random_seed(args.seed) + template.seed = args.seed + if args.output is not None: template.init_save(args.output) From e9fe40b23fdd24c81627244a353f843f82dac575 Mon Sep 17 00:00:00 2001 From: moonbings Date: Fri, 11 Nov 2022 01:39:28 +0900 Subject: [PATCH 3/3] Fix minor --- synthtiger/gen.py | 1 - synthtiger/main.py | 4 +--- synthtiger/templates/template.py | 2 -- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/synthtiger/gen.py b/synthtiger/gen.py index 1f7660e..aa849e8 100644 --- a/synthtiger/gen.py +++ b/synthtiger/gen.py @@ -112,7 +112,6 @@ def _worker(path, name, config, task_queue, data_queue, verbose): def _generate(template, seed, verbose): states = get_global_random_states() set_global_random_seed(seed) - template.seed = seed while True: try: diff --git a/synthtiger/main.py b/synthtiger/main.py index ade65d8..d1620ce 100644 --- a/synthtiger/main.py +++ b/synthtiger/main.py @@ -17,6 +17,7 @@ def run(args): pprint.pprint(config) + synthtiger.set_global_random_seed(args.seed) template = synthtiger.read_template(args.script, args.name, config) generator = synthtiger.generator( args.script, @@ -28,9 +29,6 @@ def run(args): verbose=args.verbose, ) - synthtiger.set_global_random_seed(args.seed) - template.seed = args.seed - if args.output is not None: template.init_save(args.output) diff --git a/synthtiger/templates/template.py b/synthtiger/templates/template.py index 2983dff..276773a 100644 --- a/synthtiger/templates/template.py +++ b/synthtiger/templates/template.py @@ -8,8 +8,6 @@ class Template(ABC): - seed = None - def __init__(self, config=None): pass