Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add image callback in sampling #140

Merged
merged 1 commit into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions discoart/nn/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import torch
import random


def set_seed(seed: int) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True


def detach_gpu(val):
if isinstance(val, (int, float)):
return val
else:
return val.detach().cpu().item()
6 changes: 6 additions & 0 deletions discoart/nn/transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import torch
import torchvision.transforms as T

inv_normalize = T.Normalize(
mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711],
std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711],
)


def symmetry_transformation_fn(x, use_horizontal_symmetry, use_vertical_symmetry):
Expand Down
13 changes: 7 additions & 6 deletions discoart/persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _sample(
is_save_gif,
is_image_output,
is_display_step,
image_callback,
):
with threading.Lock():
is_sampling_done.clear()
Expand All @@ -54,13 +55,13 @@ def _sample(
if is_save_step:
if is_image_output:
if cur_t == -1:
c.save_uri_to_file(
os.path.join(output_dir, f'{_nb}-done-{k}.png')
)
f_name = os.path.join(output_dir, f'{_nb}-done-{k}.png')
else:
c.save_uri_to_file(
os.path.join(output_dir, f'{_nb}-step-{j}-{k}.png')
)
f_name = os.path.join(output_dir, f'{_nb}-step-{j}-{k}.png')
c.save_uri_to_file(f_name)

if callable(image_callback):
image_callback(f_name)

da[k].chunks.append(c)

Expand Down
48 changes: 15 additions & 33 deletions discoart/runner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import copy
import os.path
import random
import tempfile
import threading
from typing import Callable, Optional

import clip
import lpips
import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import wandb
from docarray import DocumentArray, Document
Expand All @@ -26,20 +24,18 @@
get_output_dir,
is_jupyter,
)
from .nn.helper import set_seed, detach_gpu
from .nn.losses import spherical_dist_loss, tv_loss, range_loss
from .nn.make_cutouts import MakeCutouts
from .nn.sec_diff import alpha_sigma_to_t
from .nn.transform import symmetry_transformation_fn
from .nn.transform import symmetry_transformation_fn, inv_normalize
from .persist import _sample_thread, _persist_thread, _save_progress_thread
from .prompt import PromptPlanner

inv_normalize = T.Normalize(
mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711],
std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711],
)


def do_run(args, models, device, events) -> 'DocumentArray':
def do_run(
args, models, device, events, image_callback: Optional[Callable[[str], None]] = None
) -> 'DocumentArray':
skip_event, stop_event = events

_is_jupyter = is_jupyter()
Expand Down Expand Up @@ -114,7 +110,7 @@ def do_run(args, models, device, events) -> 'DocumentArray':

init = None

_set_seed(args.seed)
set_seed(args.seed)
if args.init_image:
d = Document(uri=args.init_image).load_uri_to_image_tensor(side_x, side_y)
init = (
Expand Down Expand Up @@ -293,12 +289,12 @@ def cond_fn(x, t, **kwargs):
) # min=-0.02, min=-clamp_max,

traced_info = {
'losses/total': _detach(loss) + cut_losses,
'losses/tv': _detach(tv_losses),
'losses/range': _detach(range_losses),
'losses/sat': _detach(sat_losses),
'losses/init': _detach(init_losses),
'losses/cuts': _detach(cut_losses),
'losses/total': detach_gpu(loss) + cut_losses,
'losses/tv': detach_gpu(tv_losses),
'losses/range': detach_gpu(range_losses),
'losses/sat': detach_gpu(sat_losses),
'losses/init': detach_gpu(init_losses),
'losses/cuts': detach_gpu(cut_losses),
}

traced_info.update(
Expand Down Expand Up @@ -351,7 +347,7 @@ def cond_fn(x, t, **kwargs):

# set seed for each image in the batch
new_seed = org_seed + _nb
_set_seed(new_seed)
set_seed(new_seed)
args.seed = new_seed
if _is_jupyter:
redraw_widget(
Expand Down Expand Up @@ -441,6 +437,7 @@ def cond_fn(x, t, **kwargs):
args.gif_fps > 0,
args.image_output,
is_display_step,
image_callback,
)
)

Expand Down Expand Up @@ -497,18 +494,3 @@ def redraw_widget(_handlers, _redraw_fn, args, _nb):

_handlers.code.value = export_python(args)
_redraw_fn()


def _set_seed(seed: int) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True


def _detach(val):
if isinstance(val, (int, float)):
return val
else:
return val.detach().cpu().item()