In [1]:
%matplotlib notebook
%cd ..

import matplotlib.pyplot as plt
import numpy as np
from skimage import io, color
import torch
from dataset.user_guided_dataset import UserGuidedVideoDataset
from model.zhang_model import SIGGRAPHGenerator
from model.user_guided_unet import UserGuidedUNet
from dataset.util import unnormalize_lab
from skimage.color import lab2rgb, rgb2lab
import matplotlib.gridspec as gridspec
import ipywidgets as widgets

device = torch.device('cuda')

/mnt/Research/ShanghaiProject/code


In [2]:
image_path = 'datasets/bw-frames/test/00078.png'
model_path = 'checkpoint/siggraph_caffemodel/latest_net_G.pth'

In [3]:
# Both 3D tensors
L_channel, ab_channels, _, _ = UserGuidedVideoDataset('', False, [image_path])[0]
L_channel = L_channel.to(device)
ab_channels = ab_channels.to(device)

### Create Zhang model

In [4]:
model = SIGGRAPHGenerator(4, 2)

### Create UserGuidedUNet model

In [10]:
model = UserGuidedUNet()

### Load weights

In [5]:
model = model.to(device)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)

<All keys matched successfully>

### Create default inputs for the model

In [6]:
# All 4D tensors, with first dim being batch size (1)
input_L = L_channel.unsqueeze(0)
input_ab = torch.zeros_like(ab_channels, device=device).unsqueeze(0)
input_mask = torch.zeros_like(input_L, device=device)

### Get output from Zhang model

In [7]:
_, predicted_ab = model(input_L, input_ab, input_mask)

### Get output from UserGuidedUNet model

In [None]:
predicted_ab = model(torch.cat((input_L, input_ab, input_mask), dim=1))

### Visualize predicted values

In [8]:
predicted_ab = predicted_ab.squeeze()
Lab = unnormalize_lab(L_channel, predicted_ab)
Lab = Lab.permute((1, 2, 0))
rgb = lab2rgb(Lab.detach().cpu().numpy())

grayscale_and_hints = unnormalize_lab(L_channel, torch.zeros_like(predicted_ab))
grayscale_and_hints = grayscale_and_hints.permute((1, 2, 0))
grayscale_and_hints = lab2rgb(grayscale_and_hints.detach().cpu().numpy())

plt.figure(figsize = (9, 4))
gs1 = gridspec.GridSpec(1, 2)
gs1.update(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)

ax1 = plt.subplot(gs1[0])
ax2 = plt.subplot(gs1[1])

ax1.imshow(grayscale_and_hints)
ax2.imshow(rgb)
ax1.set_axis_off()
ax2.set_axis_off()

hint_xy = None
hint_size = None
hint_rgb = None
hint_lab = None
hint_text = ax1.text(0, 0, "", va="bottom", ha="left")

candidate_ab_hint = None
candidate_ab_mask = None

def update_text_and_pred():
    global candidate_ab_hint
    global candidate_ab_mask
    hint_text.set_text(f'Hint pos: {hint_xy}, size: {hint_size}, rgb: {hint_rgb}, lab: {hint_lab}')
    
    if hint_xy is not None and hint_size is not None and hint_rgb is not None:
        candidate_ab_hint = input_ab.clone()
        candidate_ab_mask = input_mask.clone()

        x, y = hint_xy
        candidate_ab_mask[0, 0, y : y+hint_size, x : x+hint_size] = 1
        # Update hint and normalize it
        candidate_ab_hint[0, :, y : y+hint_size, x : x+hint_size] = \
            torch.tensor(hint_lab[0, 0, 1:]).unsqueeze(1).unsqueeze(1) / 110
        
        grayscale_and_hints = unnormalize_lab(L_channel, candidate_ab_hint.squeeze())
        grayscale_and_hints = grayscale_and_hints.permute((1, 2, 0))
        grayscale_and_hints = lab2rgb(grayscale_and_hints.detach().cpu().numpy())
        ax1.imshow(grayscale_and_hints)
        
        _, predicted_ab = model(input_L, candidate_ab_hint, candidate_ab_mask)
        predicted_ab = predicted_ab.squeeze()
        Lab = unnormalize_lab(L_channel, predicted_ab)
        Lab = Lab.permute((1, 2, 0))
        rgb = lab2rgb(Lab.detach().cpu().numpy())
        ax2.imshow(rgb)
        

def update_hint_pos(event):
    global hint_xy
    x = int(event.xdata)
    y = int(event.ydata)
    hint_xy = [x, y]
    update_text_and_pred()
    
def update_hint_size(change):
    global hint_size
    hint_size = int(change['new'])
    update_text_and_pred()
    
def update_hint_color(change):
    global hint_rgb
    global hint_lab
    hex_value = change['new'].lstrip('#')
    hint_rgb = np.array([[[int(hex_value[i : i+2], 16) for i in (0, 2, 4)]]]).astype('float')
    hint_lab = rgb2lab(hint_rgb / 255)
    print(hint_lab)
    update_text_and_pred()

ka = ax1.figure.canvas.mpl_connect('button_press_event', update_hint_pos)

color_picker = widgets.ColorPicker(
    concise=False,
    description='Pick a color',
    value='blue',
    disabled=False
)

slider = widgets.IntSlider(
    value=3,
    min=1,
    max=20,
    step=1,
    description='Test:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

slider.observe(update_hint_size, names='value')
color_picker.observe(update_hint_color, names='value')

display(color_picker)
display(slider)

<IPython.core.display.Javascript object>

ColorPicker(value='blue', description='Pick a color')

IntSlider(value=3, continuous_update=False, description='Test:', max=20, min=1)

[[[54.45579     9.90637153 56.39578632]]]
[[[42.97823848 39.9911633  53.55862888]]]
[[[27.63465735 27.52842793 38.69137808]]]
[[[28.58182706 22.22245459 33.41424619]]]
[[[ 24.67494025  59.12600973 -83.75000242]]]
[[[ 3.26076794  1.67133471 -4.54888578]]]
[[[23.64892416 24.30427374 34.00932934]]]
[[[29.00923967 52.06330569 42.41993678]]]
[[[14.36925146 35.12886763 22.44121105]]]


### Commit selected value to hint and mask

In [12]:
input_ab = candidate_ab_hint.clone()
input_mask = candidate_ab_mask.clone()