#Imports

In [7]:
import math
from random import sample
from random import randint

!pip install comet-ml &> /dev/null
import comet_ml

!pip install pytorch-lightning &> /dev/null
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger
import tensorboard
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import cv2
from PIL import Image
from IPython import display

!pip install pytorch-msssim &> /dev/null
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

!pip install dlib &> /dev/null
import dlib

import matplotlib.pyplot as plt

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Preprocessing  (1.1)

In [8]:
def load_video(filepath, start_frame=0, end_frame=-1):
    # import video
    video = cv2.VideoCapture(filepath)
    # get frame size (to size array) and number of frames
    frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # https://www.tutorialkart.com/opencv/python/opencv-python-resize-image/ resizing image
    scale_percent = 50
    width = int(frame_width * scale_percent / 100)
    height = int(frame_height * scale_percent / 100)
    dim = (width, height)
    
    if end_frame == -1 or end_frame > frame_count:
        end_frame = frame_count

    frames = np.empty((end_frame-start_frame, height, width, 3), np.dtype('uint8'))

    for frame in range(start_frame,min(end_frame,frame_count)):
        success, img = video.read()
        if not success: break
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # reduce image size to help model memory
        frames[frame] = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)

    video.release()
    return frames

def load_video_random_frames(filepath, number_of_frames):
    # import video
    video = cv2.VideoCapture(filepath)
    # get frame size (to size array) and number of frames
    frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # https://www.tutorialkart.com/opencv/python/opencv-python-resize-image/ resizing image
    scale_percent = 10 # 120x72 image size
    width = int(frame_width * scale_percent / 100)
    height = int(frame_height * scale_percent / 100)
    dim = (width, height)                  
    
    frames = np.empty((number_of_frames, height, width, 3), np.dtype('uint8'))
    
    used_frames = []
    for frame in range(number_of_frames):
        num = randint(0,frame_count-1)
        while True:
            if num not in used_frames:
                used_frames.append(num)
                break
            num = randint(0,frame_count-1)
        video.set(cv2.CAP_PROP_POS_FRAMES,num);
        success, img = video.read()
        if not success: break
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # reduce image size to help model memory
        frames[frame] = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
                     
    video.release()
    return frames

In [None]:
from google.colab import drive
drive.mount('/content/drive')

MAFIA_FILEPATH = "/content/drive/My Drive/MafiaVideogame.mp4"
GODFATHER_FILEPATH = "/content/drive/My Drive/TheGodfather.mp4"
IRISHMAN_FILEPATH = "/content/drive/My Drive/TheIrishman.mp4"
SOPRANOS_FILEPATH = "/content/drive/My Drive/TheSopranos.mp4"
d = display.display(None, display_id=True)
MAFIA_FRAMES = load_video_random_frames(MAFIA_FILEPATH,10)


#for count, frame in enumerate(MAFIA_FRAMES):
#    face, success = extract_face(frame)
#    if success:
#        d.update(Image.fromarray(face))
#    else:
#        print(f"No face for frame {count}")

#GODFATHER_FRAMES = load_video_random_frames(GODFATHER_FILEPATH)
#IRISHMAN_FRAMES = load_video_random_frames(IRISHMAN_FILEPATH)
#SOPRANOS_FRAMES = load_video_random_frames(SOPRANOS_FILEPATH)


# view frames

#for count, frame in enumerate(MAFIA_FRAMES):
#    d.update(Image.fromarray(frame))



#Preprocessing (Faces 3.1)

In [4]:
def within_bounds(centre, width, height, frame_size):
    max_height, max_width = frame_size[0], frame_size[1]
    if centre.x - width//2 < 0 or centre.x + width//2 >= max_width:
        return False
    if centre.y - height//2 < 0 or centre.y + height//2 >= max_height:
        return False
    return True

def extract_face(frame):
    dim = (120, 72) 
    detector = dlib.get_frontal_face_detector()
    dets, scores, _ = detector.run(frame, 1, -1)
    for img, score in zip(dets, scores):
        if score < 0: continue
        # location and size of the face
        centre, width, height = img.dcenter(), img.width(), img.height()
        print()
        if within_bounds(centre, width, height, np.shape(frame)):
            crop = frame[centre.y-height//2:centre.y+height//2, centre.x-width//2:centre.x+width//2]
            return cv2.resize(crop, dim, interpolation=cv2.INTER_AREA), True
    return None, False

# Frame-to-Frame Model (2.1)

##Generator

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        channels = 3
        nf = 64
        
        def downsample_convolution(in_features, out_features):
            return nn.Sequential(nn.Conv2d(in_features, out_features, 4, 2, 1),
                        nn.InstanceNorm2d(out_features*2),
                        nn.LeakyReLU(0.2))
        
        def residual_convolution(in_features):
             return nn.Sequential(nn.Conv2d(in_features, in_features, 3, 1, 1),
                                  nn.Conv2d(in_features, in_features, 3, 1, 1))
        
        def upsample_convolution(in_features, out_features):
            return nn.Sequential(nn.ConvTranspose2d(in_features, out_features, 4, 2, 1),
                        nn.InstanceNorm2d(out_features, 0.8),
                        nn.ReLU())
            
            
        self.conv1 = nn.Sequential(
                    nn.Conv2d(channels, nf, 4, 2, 1),
                    nn.LeakyReLU())
        self.downsample1 = downsample(nf,nf*2)
        self.residual1 = residual_convolution(nf*2)
        
        self.downsample2 = downsample(nf*2,nf*4)
        self.residual2 = residual_convolution(nf*4)
        self.upsample1 = upsample_convolution(nf*4, nf*2)
        
        self.residual3 = residual_convolution(nf*2)
        self.upsample2 = upsample_convolution(nf*2, nf)
        self.conv2 = nn.Sequential(
                    nn.ConvTranspose2d(nf, channels, 4, 2, 1),
                    nn.Tanh())
    
    def forward(self, x):
        x = self.conv1(x)
        pre_residual1_x = self.downsample1(x)
        x = torch.cat([pre_residual1_x, self.residual1(pre_residual1_x)], dim=1)
        pre_residual2_x = self.downsample2(x)
        x = torch.cat([pre_residual2_x, self.residual2(pre_residual2_x)], dim=1)
        pre_residual3_x = self.upsample1(x)
        x = torch.cat([pre_residual3_x, self.residual3(pre_residual3_x)], dim=1)
        x = self.upsample2(x)
        x = self.conv2(x)
        return x

##Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        nf = 64
        channels = 3
        
        def dilated_convolution(in_features, out_features, dilation):
            return nn.Sequential(nn.Conv2d(in_features, out_features, 4, 1, dilation, dilation),
                            nn.InstanceNorm2d(out_features),
                            nn.LeakyReLU(0.2, inplace=True))

        def convolution(in_features, out_features, kernal_size=4, stride=2, padding=1):
            return nn.Sequential(nn.Conv2d(in_features, out_features, kernal_size, stride, padding),
                            nn.InstanceNorm2d(out_features),
                            nn.LeakyReLU(0.2, inplace=True))
        
        self.layer1 = nn.Sequential(
                    nn.Conv2d(channels, nf, 4, 2, 1),
                    nn.LeakyReLU(0.2, inplace=True))
        self.layer2 = convolution(nf, nf*2)
        self.layer3 = convolution(nf*2, nf*4, 3, 1 , 1)
        self.layer4 = dilated_convolution(nf*4, nf*4, 2)
        self.layer5 = dilated_convolution(nf*4, nf*4, 4)
        self.layer7 = convolution(nf*8, nf*4, 3, 1, 1)
        self.layer8 = nn.Conv2d(nf*4, 1, 3, 1, 1)
        
    def forward(self, x):
        layer1 = self.layer1(x)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)
        layer5 = self.layer5(layer4)
        layer6 = torch.cat([layer3,layer5], dim=1)
        layer7 = self.layer7(layer6)
        layer8 = self.layer8(layer7)
        
        return layer8, (layer2, layer3, layer4, layer5, layer7)

##Model Training

In [None]:
class GameMovieGAN(pl.LightningModule):
    def __init__(self, mov_2_game_G, game_2_mov_G, mov_D, game_D, learning_rate=1e-3):
        super().__init__()
        self.learning_rate = learning_rate
        self.automatic_optimization = False
        self.mov_2_game_G = mov_2_game_G
        self.game_2_mov_G = game_2_mov_G
        self.mov_D = mov_D
        self.game_D = game_D
        self.ms_ssim = MS_SSIM(data_range=255, size_average=True, channel=3)
        self.diss_loss = nn.BCELoss()
        self.lambda_adversary = 0.65
        self.lambda_cycle = 0.35

    def configure_optimizers(self):
        mov_2_game_G_opt = torch.optim.Adam(self.mov_2_game_G.parameters(), lr=self.learning_rate)
        game_2_mov_G_opt = torch.optim.Adam(self.game_2_mov_G.parameters(), lr=self.learning_rate)
        game_D_opt = torch.optim.Adam(self.mov_2_game_D.parameters(), lr=self.learning_rate)
        mov_D_opt = torch.optim.Adam(self.game_2_mov_D.parameters(), lr=self.learning_rate)
        return mov_2_game_G_opt, game_2_mov_G_opt, game_D_opt, mov_D_opt

    def training_step(self, batch, batch_idx):
        mov_2_game_G_opt, mov_2_game_D_opt, game_D_opt, mov_D_opt = self.optimizers()
        true_game_img, true_movie_img = batch # check format
        batch_size = batch.shape[0] # likely error

        # Generator Training
        ## mov_2_game train ##
        ## get images
        gen_game_img = self.mov_2_game_G.forward(true_movie_img)
        cycle_img = self.game_2_mov_G.forward(gen_game_img)

        ## adversary loss
        adversary_loss = self.game_D(gen_game_img) * self.lambda_adversary
        ## cycle loss
        cycle_loss = (1 - self.ms_ssim(true_movie_img, cycle_img)) * self.lambda_cycle
        ## total loss and backprop
        total_loss = adversary_loss + cycle_loss
        self.log('mov_2_game_generator_loss', total_loss)
        mov_2_game_G_opt.zero_grad()
        self.manual_backward(total_loss)
        mov_2_game_G_opt.step()

        ## game_2_mov train ##
        ## get images
        gen_mov_img = self.game_2_mov_G.forward(true_game_img)
        cycle_img = self.mov_2_game_G.forward(gen_mov_img)

        ## adversary loss
        adversary_loss = self.game_D(gen_mov_img) * self.lambda_adversary
        ## cycle loss
        cycle_loss = (1 - self.ms_ssim(true_game_img, cycle_img)) * self.lambda_cycle
        ## total loss and backprop
        total_loss = adversary_loss + cycle_loss
        self.log('game_2_mov_generator_loss', total_loss)
        game_2_mov_G_opt.zero_grad()
        self.manual_backward(total_loss)
        game_2_mov_G_opt.step()

        # Discriminator Training
        real_y = torch.ones((batch_size, 1), device=self.device)
        fake_y = torch.zeros((batch_size, 1), device=self.device)

        ## Game discriminator
        game_real_class = self.game_D(true_game_img)
        real_error = self.diss_loss(game_real_class, real_y)
        game_fake_class = self.game_D(self.mov_2_game_G(true_movie_img))
        fake_error = self.diss_loss(game_fake_class, fake_y)

        total_error = real_error + fake_error
        self.log('game_discriminator_loss', total_error)
        game_D_opt.zero_grad()
        self.manual_backward(total_error)
        game_D_opt.step()

        ## Movie discriminator
        mov_real_class = self.mov_D(true_mov_img)
        real_error = self.diss_loss(mov_real_class, real_y)
        mov_fake_class = self.mov_D(self.game_2_mov_G(true_game_img))
        fake_error = self.diss_loss(mov_fake_class, fake_y)

        total_error = real_error + fake_error
        self.log('movie_discriminator_loss', total_error)
        mov_D_opt.zero_grad()
        self.manual_backward(total_error)
        mov_D_opt.step()
    
    def validation_step(self, val_batch, batch_idx):
        true_game_img, true_movie_img = val_batch # check format
        batch_size = batch.shape[0] # likely error
        ## mov_2_game val ##
        ## get images
        gen_game_img = self.mov_2_game_G.forward(true_movie_img)
        cycle_img = self.game_2_mov_G.forward(gen_game_img)

        ## adversary loss
        adversary_loss = self.game_D(gen_game_img) * self.lambda_adversary
        ## cycle loss
        cycle_loss = (1 - self.ms_ssim(true_movie_img, cycle_img)) * self.lambda_cycle
        ## total loss and backprop
        total_loss = adversary_loss + cycle_loss
        self.log('mov_2_game_generator_val_loss', total_loss)

        ## game_2_mov train ##
        ## get images
        gen_mov_img = self.game_2_mov_G.forward(true_game_img)
        cycle_img = self.mov_2_game_G.forward(gen_mov_img)

        ## adversary loss
        adversary_loss = self.game_D(gen_mov_img) * self.lambda_adversary
        ## cycle loss
        cycle_loss = (1 - self.ms_ssim(true_game_img, cycle_img)) * self.lambda_cycle
        ## total loss and backprop
        total_loss = adversary_loss + cycle_loss
        self.log('game_2_mov_generator_val_loss', total_loss)