In [48]:
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
import cv2

# Demo of the trained Colorization Model

For now just colorizes the images of a video, and stores them as a new video.

In [49]:
def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

## 1. Data Loading

using the pytorch dataloader and dataset implementation from DataLoad.py

In [50]:
from DataLoad import CustomImageDataset, data_to_images
from format_transforms import *

In [51]:
use_flow=True

dataset = CustomImageDataset(resolution='176p', use_flow=use_flow)

dataset_size = len(dataset)
print('Dataset size:', dataset_size)

vid labels: ['basketball-game', 'bmx-rider', 'butterfly', 'car-competition', 'cat', 'chairlift', 'circus', 'dog-competition', 'dolphins-show', 'drone-flying', 'ducks', 'giraffes', 'gym-ball', 'helicopter-landing', 'horse-race', 'hurdles-race', 'ice-hockey', 'jet-ski', 'juggling-selfie', 'kids-robot', 'mantaray', 'mascot', 'motorbike-race', 'obstacles', 'plane-exhibition', 'robot-battle', 'snowboard-race', 'swimmer', 'tram', 'trucks-race']
Dataset size: 2264


## 2. Model Loading

In [52]:
from Model_UNet import UNet
from DataLoad import CustomImageDataset, data_to_images, images_to_data

print("Cuda available: ", torch.cuda.is_available())

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

print("Using device: ", device)

Cuda available:  True
Using device:  cpu


In [53]:
#initialize model
if use_flow:
    n_input_channels = 6
else:
    n_input_channels = 4

model = UNet(n_input_channels=n_input_channels)
model.to(device)

# Hyperparameters
learning_rate = 1e-3
batch_size = 16
epochs = 300

# Initialize the loss function
loss_fn = nn.MSELoss()


# Data loader
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [54]:
model_path = '~/Documents/Colorization/Models/'
UNet_folder = 'UNet'
model_path = os.path.expanduser(model_path)
UNet_folder = os.path.join(model_path, UNet_folder)

def checkpoint(model, filename):
    filename = os.path.join(UNet_folder, filename)
    torch.save(model.state_dict(), filename)
    
def resume(model, filename):
    filename = os.path.join(UNet_folder, filename)
    model.load_state_dict(torch.load(filename))

In [55]:
model_folder = 'UNet_Flow_'+str(epochs)+'_epochs'

filename = 'UNet_Flow_epoch_295.pth'

resume(model, os.path.join(model_folder, filename))

## 3. Video Rendering

In [57]:

dataset_path = '~/Documents/Colorization/Datasets/'
dataset_name = 'DAVIS'
image_folder = 'JPEGImages'

dataset_path = os.path.expanduser(dataset_path)
dataset_path = os.path.join(dataset_path, dataset_name)
img_folder_path = os.path.join(dataset_path, image_folder)
nd_array_path = os.path.join(dataset_path, 'nd_arrays')
nd_array_path = os.path.join(nd_array_path, '176p_deepflow')

resolution = '176p'
label = 'gray'
img_folder_path = os.path.join(img_folder_path, resolution+'_'+label)

In [58]:
save_label = 'UNet_Flow'

In [59]:
def load_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def save_image(image, image_path):
    cv2.imwrite(image_path, image)

In [68]:
for video_name in os.listdir(img_folder_path):
    print('Video name:', video_name)
    video_path = os.path.join(img_folder_path, video_name)
    save_path = '~/Documents/Colorization/Results/'
    save_path = os.path.expanduser(save_path)
    save_path = os.path.join(save_path, save_label, video_name)
    make_dir(save_path)
    still_path = os.path.join(save_path, 'stills')
    make_dir(still_path)

    n_frames = len(os.listdir(video_path))
    print('Number of frames:', n_frames)

    col_img = load_image(os.path.join(dataset_path, image_folder,resolution,video_name, '00000.jpg'))
    print(os.path.join(dataset_path, image_folder,resolution,video_name, '00000.jpg'))


    gray_img = None
    flow_img = None
    i = 1
    while i < n_frames:
        img_path = os.path.join(video_path, str(i).zfill(5)+'.jpg')
        gray_img = load_image(img_path)
        if use_flow:
            flow_frame = flow_frame = os.path.join(nd_array_path,video_name, str(i).zfill(5) + '.jpg.npy')
            flow_img = np.load(flow_frame)
        
        input, _ = images_to_data(col_img, flow_img, gray_img)
        
        prediction = model(input.unsqueeze(0).to(device))
        prediction = prediction.cpu()
        prediction = prediction.detach()
        prediction = prediction.squeeze(0)

        _, flow_img, grey_img, pred_img = data_to_images(input, prediction, use_flow=use_flow, input_only=False)

        col_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)

        if i%2==0:
            save_image(pred_img, os.path.join(still_path, str(i).zfill(5)+'.jpg'))
        else:
            save_image(cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR), os.path.join(still_path, str(i).zfill(5)+'.jpg'))

        col_img = pred_img
        i+=1

    images_2_video(still_path, os.path.join(save_path, video_name+'.avi'), fps=10)

Video name: circus
Number of frames: 73
/home/jansp/Documents/Colorization/Datasets/DAVIS/JPEGImages/176p/circus/00000.jpg
Video name: mascot
Number of frames: 78
/home/jansp/Documents/Colorization/Datasets/DAVIS/JPEGImages/176p/mascot/00000.jpg
Video name: helicopter-landing
Number of frames: 77
/home/jansp/Documents/Colorization/Datasets/DAVIS/JPEGImages/176p/helicopter-landing/00000.jpg
Video name: ducks
Number of frames: 75
/home/jansp/Documents/Colorization/Datasets/DAVIS/JPEGImages/176p/ducks/00000.jpg
Video name: kids-robot
Number of frames: 75
/home/jansp/Documents/Colorization/Datasets/DAVIS/JPEGImages/176p/kids-robot/00000.jpg
Video name: car-competition
Number of frames: 66
/home/jansp/Documents/Colorization/Datasets/DAVIS/JPEGImages/176p/car-competition/00000.jpg
Video name: jet-ski
Number of frames: 83
/home/jansp/Documents/Colorization/Datasets/DAVIS/JPEGImages/176p/jet-ski/00000.jpg
Video name: cat
Number of frames: 52
/home/jansp/Documents/Colorization/Datasets/DAVIS/JP