In [None]:
from argparse import Namespace
from PIL import Image
from pathlib import Path
from utils import common
import os

from IPython.display import HTML, Javascript, display
import ipywidgets as wi
import ipycanvas as ca
from PIL import Image
from ipyfilechooser import FileChooser
from nokogiri.spylus import spylus
import numpy as np


from options.test_options import TestOptions
from configs import data_configs
from datasets.images_dataset import ImagesDataset
from scripts.inference import run_on_batch

import torch
import torchvision
from models.psp import pSp

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
root = Path("/data/natsuki/danbooru2020/psp/encavgsim_1632393929")
# encavg_1631706221 エンコード
epoch = "iteration_495000.pt"
#epoch = "best_model.pt"
ckpt = torch.load(root/f"checkpoints/{epoch}", map_location='cpu')
opts = ckpt['opts']
test_opts = TestOptions().parser.parse_args(
f"""
--exp_dir={root} \
--checkpoint_path={root}/checkpoints/{epoch} \
--data_path=/data/natsuki/whitechest_sim_val \
--test_batch_size=1 \
--test_workers=1 \
--latent_mask=10,11,12,13,14,15
""".split())
#  --latent_mask=14,15
# 15 -> 14,15 で色がガラッと変わる
opts.update(vars(test_opts))
if 'learn_in_w' not in opts:
    opts['learn_in_w'] = False
if 'output_size' not in opts:
    opts['output_size'] = 1024
opts = Namespace(**opts)

In [None]:
net = pSp(opts)
net.eval()
net.cuda()

In [None]:
def check_run_on_dataset(i=0):
    dataset_type = "whitechest_sim"
    dataset_args = data_configs.DATASETS[dataset_type]
    transforms_dict = dataset_args['transforms'](None).get_transforms() # self.opts = None で無問題
    test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'],
                                  target_root=dataset_args['test_target_root'],
                                  source_transform=transforms_dict['transform_source'],
                                  target_transform=transforms_dict['transform_gt_train'],
                                  opts=Namespace(label_nc=1),
    )
    cond, real = test_dataset[i]
    input_cuda = cond.unsqueeze(0).cuda()
    result_batch = run_on_batch(input_cuda, net, opts)
    return result_batch

In [None]:
ID, canvas = spylus.ID_multicanvas()
marker = HTML(f'<div id="{ID}"></div>')
selector = wi.Dropdown(options=range(2), value=1)

save_input0 = wi.Text(layout=wi.Layout(width="30px"))
copy_button0 = wi.Button(description="copy", layout=wi.Layout(width="130px"))
paste_button0 = wi.Button(description="paste", layout=wi.Layout(width="130px"))
reload_button0 = wi.Button(description="reload", layout=wi.Layout(width="130px"))
color_input0 = wi.ColorPicker(value="#000000", layout=wi.Layout(width="130px"))
width_select0 = wi.Dropdown(options=range(1,20), value=2, layout=wi.Layout(width="130px"))

save_input1 = wi.Text(layout=wi.Layout(width="30px"))
copy_button1 = wi.Button(description="copy", layout=wi.Layout(width="130px"))
paste_button1 = wi.Button(description="paste", layout=wi.Layout(width="130px"))
clear_button1 = wi.Button(description="clear", layout=wi.Layout(width="130px"))
color_input1 = wi.ColorPicker(value='#000000', layout=wi.Layout(width="130px"))
width_select1 = wi.Dropdown(options=range(1,20), value=2, layout=wi.Layout(width="130px"))

out = wi.Output()
fc0 = FileChooser("/data/natsuki/training116/00030-v4-mirror-auto4-gamma100-batch64-noaug-resumecustom/011289_sketch")
@out.capture(clear_output=True)
def load0(chooser):
    image = Image.open(fc0.selected)
    text = spylus.encode(image)
    display(Javascript(f'app1.load("{text}");'))
fc0.register_callback(load0)

out_ca = ca.Canvas(width=512, height=512)
out_ca.sync_image_data = True
out_ca.fill_style = '#a9cafc'
out_ca.fill_rect(0, 0, 512, 512)
out_ca.fill_style = 'black'
out_ca.font = '32px serif'
out_ca.fill_text('Initialized', 0, 512)

synth_button = wi.Button(description="synthesize", layout=wi.Layout(width="100px"))
@out.capture(clear_output=True)
def synth(b, filter=Image.BICUBIC):
    out_ca.fill_text('Synthesizing...', 0, 512)
    sketch_pillow = spylus.decode(save_input1.value).resize((256, 256), filter)
    sketch_numpy =  np.array(sketch_pillow, dtype='float32')
    sketch_torch = torch.from_numpy(sketch_numpy.transpose((2,0,1))[None,:1,:,:]/255).cuda() #Rのみ
    output_torch = run_on_batch(sketch_torch, net, opts)
    output_pillow = common.tensor2im(output_torch[0])
    output_numpy = np.array(output_pillow)
    out_ca.put_image_data(output_numpy)

synth_button.on_click(synth)
@out.capture(clear_output=True)
def on_mouse_down(x, y):
    color_input1.value = '#'+''.join(map(lambda c: hex(c)[2:], out_ca.get_image_data()[int(y),int(x)][0:3]))
out_ca.on_mouse_down(on_mouse_down)

reload_button0.on_click(load0)

app = wi.VBox((
    wi.HBox((out_ca, canvas,)),
    wi.HBox((save_input0, copy_button0, paste_button0, reload_button0, color_input0, width_select0,)),
    wi.HBox((save_input1, copy_button1, paste_button1, clear_button1, color_input1, width_select1,)),
    wi.HBox((selector, synth_button)),
    fc0,
    out,
))
display(marker, app)

jscode = Javascript(f"""{spylus.js}
window.app0 = new (mix(
    Save, Load, Copy, Paste, Color, Width
))({{
    canvas: document.getElementById("{ID}0"),
    save_input: elem('{ID}', 'input', 0),
    copy_button: elem('{ID}', 'button', 0),
    paste_button: elem('{ID}', 'button', 1),
    color_input: elem('{ID}', 'input', 1),
    width_select: elem('{ID}', 'select', 0),
}});

window.app1 = new (mix(
    White, Save, Load, Copy, Paste, Color, Width, Clear
))({{
    canvas: document.getElementById("{ID}1"),
    save_input: elem('{ID}', 'input', 3),
    copy_button: elem('{ID}', 'button', 3),
    paste_button: elem('{ID}', 'button', 4),
    clear_button: elem('{ID}', 'button', 5),
    color_input: elem('{ID}', 'input', 4),
    width_select: elem('{ID}', 'select', 1),
}});

new Selector({{
    apps: [app0, app1],
    selector_select: elem('{ID}', 'select', 2),
}});
""")
with out:
    display(jscode)