From d0e6e1891a36130b8b0186f7ef374b4b47da68e4 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Fri, 5 Aug 2022 23:31:50 +0200 Subject: [PATCH] feat: add image callback in sampling (#140) --- discoart/nn/helper.py | 18 +++++++++++++++ discoart/nn/transform.py | 6 +++++ discoart/persist.py | 13 ++++++----- discoart/runner.py | 48 +++++++++++++--------------------------- 4 files changed, 46 insertions(+), 39 deletions(-) create mode 100644 discoart/nn/helper.py diff --git a/discoart/nn/helper.py b/discoart/nn/helper.py new file mode 100644 index 0000000..365cdca --- /dev/null +++ b/discoart/nn/helper.py @@ -0,0 +1,18 @@ +import numpy as np +import torch +import random + + +def set_seed(seed: int) -> None: + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +def detach_gpu(val): + if isinstance(val, (int, float)): + return val + else: + return val.detach().cpu().item() diff --git a/discoart/nn/transform.py b/discoart/nn/transform.py index 09536c5..09fd3bc 100644 --- a/discoart/nn/transform.py +++ b/discoart/nn/transform.py @@ -1,4 +1,10 @@ import torch +import torchvision.transforms as T + +inv_normalize = T.Normalize( + mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711], + std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711], +) def symmetry_transformation_fn(x, use_horizontal_symmetry, use_vertical_symmetry): diff --git a/discoart/persist.py b/discoart/persist.py index 5000a4b..3a49c01 100644 --- a/discoart/persist.py +++ b/discoart/persist.py @@ -32,6 +32,7 @@ def _sample( is_save_gif, is_image_output, is_display_step, + image_callback, ): with threading.Lock(): is_sampling_done.clear() @@ -54,13 +55,13 @@ def _sample( if is_save_step: if is_image_output: if cur_t == -1: - c.save_uri_to_file( - os.path.join(output_dir, f'{_nb}-done-{k}.png') - ) + f_name = os.path.join(output_dir, f'{_nb}-done-{k}.png') else: - c.save_uri_to_file( - os.path.join(output_dir, f'{_nb}-step-{j}-{k}.png') - ) + f_name = os.path.join(output_dir, f'{_nb}-step-{j}-{k}.png') + c.save_uri_to_file(f_name) + + if callable(image_callback): + image_callback(f_name) da[k].chunks.append(c) diff --git a/discoart/runner.py b/discoart/runner.py index 84756a2..0c71d2c 100644 --- a/discoart/runner.py +++ b/discoart/runner.py @@ -1,14 +1,12 @@ import copy import os.path -import random import tempfile import threading +from typing import Callable, Optional import clip import lpips -import numpy as np import torch -import torchvision.transforms as T import torchvision.transforms.functional as TF import wandb from docarray import DocumentArray, Document @@ -26,20 +24,18 @@ get_output_dir, is_jupyter, ) +from .nn.helper import set_seed, detach_gpu from .nn.losses import spherical_dist_loss, tv_loss, range_loss from .nn.make_cutouts import MakeCutouts from .nn.sec_diff import alpha_sigma_to_t -from .nn.transform import symmetry_transformation_fn +from .nn.transform import symmetry_transformation_fn, inv_normalize from .persist import _sample_thread, _persist_thread, _save_progress_thread from .prompt import PromptPlanner -inv_normalize = T.Normalize( - mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711], - std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711], -) - -def do_run(args, models, device, events) -> 'DocumentArray': +def do_run( + args, models, device, events, image_callback: Optional[Callable[[str], None]] = None +) -> 'DocumentArray': skip_event, stop_event = events _is_jupyter = is_jupyter() @@ -114,7 +110,7 @@ def do_run(args, models, device, events) -> 'DocumentArray': init = None - _set_seed(args.seed) + set_seed(args.seed) if args.init_image: d = Document(uri=args.init_image).load_uri_to_image_tensor(side_x, side_y) init = ( @@ -293,12 +289,12 @@ def cond_fn(x, t, **kwargs): ) # min=-0.02, min=-clamp_max, traced_info = { - 'losses/total': _detach(loss) + cut_losses, - 'losses/tv': _detach(tv_losses), - 'losses/range': _detach(range_losses), - 'losses/sat': _detach(sat_losses), - 'losses/init': _detach(init_losses), - 'losses/cuts': _detach(cut_losses), + 'losses/total': detach_gpu(loss) + cut_losses, + 'losses/tv': detach_gpu(tv_losses), + 'losses/range': detach_gpu(range_losses), + 'losses/sat': detach_gpu(sat_losses), + 'losses/init': detach_gpu(init_losses), + 'losses/cuts': detach_gpu(cut_losses), } traced_info.update( @@ -351,7 +347,7 @@ def cond_fn(x, t, **kwargs): # set seed for each image in the batch new_seed = org_seed + _nb - _set_seed(new_seed) + set_seed(new_seed) args.seed = new_seed if _is_jupyter: redraw_widget( @@ -441,6 +437,7 @@ def cond_fn(x, t, **kwargs): args.gif_fps > 0, args.image_output, is_display_step, + image_callback, ) ) @@ -497,18 +494,3 @@ def redraw_widget(_handlers, _redraw_fn, args, _nb): _handlers.code.value = export_python(args) _redraw_fn() - - -def _set_seed(seed: int) -> None: - np.random.seed(seed) - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - - -def _detach(val): - if isinstance(val, (int, float)): - return val - else: - return val.detach().cpu().item()