In [1]:
import torch
import torchvision
from networks import ResnetGenerator
from IPython.display import Video
from tqdm.notebook import trange
from skimage.exposure import match_histograms

In [2]:
def to_tensor(x, device):
    x = x.type(torch.float) / 255.0
    x = (x - 0.5) / 0.5
    x = x.permute(0, 3, 1, 2)
    x = x.to(device)
    return x

def to_video(x):
    x = x.permute(0, 2, 3, 1)
    x = (x * 0.5) + 0.5
    x = x * 255.0
    x = x.type(torch.uint8)
    x = x.cpu()
    return x

In [3]:
def match_histogram(video):
    reference = video[0].numpy()
    output_video = torch.zeros_like(video)
    for i in trange(video.size(0)):
        old_frame = video[i].numpy()
        new_frame = match_histograms(old_frame, reference, multichannel=True)
        output_video[i] = torch.from_numpy(new_frame)
    return output_video

In [4]:
def process_video(model, input_video_path, output_video_path):
    raw_input_video, _, _ = torchvision.io.read_video(input_video_path, pts_unit="sec")
    input_video = match_histogram(raw_input_video)
    input_tensor = to_tensor(input_video, device)
    output_tensor = torch.zeros_like(input_tensor)
    
    with torch.no_grad():
        for i in trange(input_tensor.size(0)):
            input_frame = input_tensor[i].unsqueeze(0)
            output_frame, _ ,_ = model(input_frame)
            output_tensor[i] = output_frame
    
    output_video = to_video(output_tensor)
    output_video = match_histogram(output_video)
    
    merge = torch.cat((raw_input_video, output_video), dim=2)
    torchvision.io.write_video(output_video_path, merge, fps=24.0)

In [5]:
def get_model(device):
    model = ResnetGenerator()
    model.load_state_dict(torch.load("state_dict/genA2B_best.pt", map_location="cpu"))
    for p in model.parameters():
        p.required_grad = False
    model.eval()
    model = model.to(device)
    return model

In [6]:
device = torch.device("cuda:3")
input_video_path = "media/video2.mp4"
output_video_path = "media/output_video2.mp4"

In [7]:
model = get_model(device)

In [8]:
process_video(model, input_video_path, output_video_path)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1440.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1440.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1440.0), HTML(value='')))




In [9]:
Video(output_video_path)