In [12]:
import cv2
import numpy as np

import os
import json

import torch
from torchvision.transforms import transforms

from generator_model import Generator
from image_utils import revert_normalisation 

from tqdm.notebook import tqdm

In [13]:
device = "cuda"

In [14]:
model_parent_directory = "./runs/RecycleGAN/1680898964.0020654"
epoch_directory = "latest"
model_name = "F.pth"

In [15]:
with open(f"{model_parent_directory}/info_None.json", "r") as fp:
    model_info = json.load(fp)

upsample_strategy = model_info["upsample_strategy"]
block_count = model_info["block_count"]

In [16]:
model = Generator(block_count, upsample_strategy).to(device)
model.load_state_dict(torch.load(f"{model_parent_directory}/{epoch_directory}/{model_name}"))

<All keys matched successfully>

In [17]:
input_video_path = "./videos/input_test_movie.mp4"
output_video_path = "./videos/recyclegan_transfer_other.mp4"
batch_size = 4

In [18]:
video_in = cv2.VideoCapture(input_video_path)

In [19]:
preprocess_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [20]:
def transfer_style_to_batch(cv2_images):
    imgs = [preprocess_transform(img) for img in cv2_images]
    imgs = torch.stack(imgs).to(device)
    
    with torch.no_grad():
        imgs_transferred = model(imgs)
    
    imgs_transferred = [revert_normalisation(img_t.cpu()) for img_t in imgs_transferred]
    return imgs_transferred

In [21]:
def process_video(video, save_loc):
    # Try to delete the existing saved file
    try:
        os.remove(save_loc)
    except OSError:
        pass 
    
    fps = int(video.get(cv2.CAP_PROP_FPS))
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    
    success, frame = video.read()
    height, width, _ = frame.shape
    
    frame_buffer = []
    video_out = cv2.VideoWriter(save_loc, -1, fps, (width, height))
    
    with tqdm(total=total_frames) as progress:
        while success:
            frame_buffer.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

            if len(frame_buffer) == batch_size:
                transferred_frames = transfer_style_to_batch(frame_buffer)

                for out_frame in transferred_frames:
                    coloured_out_frame = cv2.cvtColor((out_frame.numpy() * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
                    video_out.write(coloured_out_frame)

                frame_buffer = []

            success, frame = video.read()
            progress.update(1)
    
    if len(frame_buffer) != 0:
        # Process any additional final frames that don't make a full batch
        transferred_frames = transfer_style_to_batch(frame_buffer)

        for out_frame in transferred_frames:
            coloured_out_frame = cv2.cvtColor((out_frame.numpy() * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
            video_out.write(coloured_out_frame)
    
    video_out.release()

In [22]:
process_video(video_in, output_video_path)

  0%|          | 0/1645 [00:00<?, ?it/s]