In [None]:
%matplotlib notebook

import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import cv2
import numpy as np
import matplotlib.pyplot as plt
from dataset.user_guided_dataset import UserGuidedVideoDataset
import torch
from skimage import color
from math import round

In [None]:
frame1 = '../datasets/bw-frames/all/01130.png'
frame2 = '../datasets/bw-frames/all/01140.png'

In [None]:
color1 = cv2.imread(frame1, cv2.IMREAD_COLOR)
color2 = cv2.imread(frame2, cv2.IMREAD_COLOR)

gray1 = cv2.cvtColor(color1, cv2.COLOR_BGR2GRAY)
gray2 = cv2.cvtColor(color2, cv2.COLOR_BGR2GRAY)

mask = np.zeros_like(color1)
# Set image saturation to maximum value as we do not need it
mask[..., 1] = 255

In [None]:
flow = cv2.calcOpticalFlowFarneback(gray1, gray2, None, pyr_scale = 0.5, levels = 5,
                                    winsize = 11, iterations = 5, poly_n = 5, poly_sigma = 1.1, flags = 0)
magnitude, angle = cv2.cartToPolar(flow[..., 0], flow[..., 1])

mask[..., 0] = angle * 180 / np.pi / 2
# Set image value according to the optical flow magnitude (normalized)
mask[..., 2] = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
# Convert HSV to RGB (BGR) color representation
rgb = cv2.cvtColor(mask, cv2.COLOR_HSV2BGR)

In [None]:
dense_flow = cv2.addWeighted(color1, 1, rgb, 2, 0)

plt.figure()
plt.imshow(dense_flow)
plt.figure()
plt.imshow(color1)

interval = 20
for y in range(0, flow.shape[0], interval):
    for x in range(0, flow.shape[1], interval):
        plt.arrow(x, y, flow[y, x, 0], flow[y, x, 1], fc="k", ec="k", head_width=3, head_length=2)

### Sample color patches and apply optical flow

In [None]:
L_channel, ab_channels, ab_hint, ab_mask, bounding_boxes = \
    UserGuidedVideoDataset('', False, [frame1], crop_to_fit=False)[0]
replaced_ab = torch.where(torch.cat((ab_mask, ab_mask), dim=0) > 0, ab_hint, ab_channels)
replaced_l = torch.where(ab_mask > 0, torch.ones_like(L_channel) * -0.5, L_channel)
replaced_lab = torch.cat((replaced_l * 100 + 50, replaced_ab * 110), dim=0)
replaced_rgb = color.lab2rgb(replaced_lab.permute((1, 2, 0)).numpy())

plt.figure()
plt.imshow(replaced_rgb)

In [None]:
shifted_ab_hint = torch.zeros_like(ab_hint)
shifted_ab_mask = torch.zeros_like(ab_mask)

hint_indices = ab_mask.nonzero()
nonzero_y = hint_indices[:, 1]
nonzero_x = hint_indices[:, 2]

flow_in_CHW = torch.tensor(flow).permute((2, 0, 1))
shifts_of_hint = flow_in_CHW[:, nonzero_y, nonzero_x]  # Values are in (delta_x, delta_y)!
y_after_shift = torch.clamp(torch.round(shifts_of_hint[1] + nonzero_y).long(),
                            0, ab_hint.shape[1] - 1)
x_after_shift = torch.clamp(torch.round(shifts_of_hint[0] + nonzero_x).long(),
                            0, ab_hint.shape[2] - 1)

shifted_ab_mask[0, y_after_shift, x_after_shift] = 1
shifted_ab_hint[:, y_after_shift, x_after_shift] = ab_hint[:, nonzero_y, nonzero_x]
    
img2_L_channel, img2_ab_channels, _, _, _ = \
    UserGuidedVideoDataset('', False, [frame2], crop_to_fit=False)[0]
img2_replaced_ab = torch.where(torch.cat((shifted_ab_mask, shifted_ab_mask), dim=0) > 0,
                               shifted_ab_hint, img2_ab_channels)
img2_replaced_l = torch.where(shifted_ab_mask > 0, torch.ones_like(img2_L_channel) * -0.5, img2_L_channel)
img2_replaced_lab = torch.cat((img2_replaced_l * 100 + 50, img2_replaced_ab * 110), dim=0)
img2_replaced_rgb = color.lab2rgb(img2_replaced_lab.permute((1, 2, 0)).numpy())

plt.figure()
plt.imshow(img2_replaced_rgb)