In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor,ToPILImage,Normalize
from utils.dataset import LoadFilesDataset
from utils.model import construct_style_loss_model,construct_decoder_from_encoder
from utils.losses import content_gatyes,style_gatyes,style_mmd_polynomial,adaIN
from utils.utility import video_to_frame_generator,video_to_frames,normalize
import cv2
import os
from copy import deepcopy
import torch.utils.data as data
from torchvision.models import vgg19,VGG19_Weights
from PIL import Image
import torch.optim as optim
from tqdm import tqdm
%load_ext autoreload
%autoreload 2

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("On device: ", device)

On device:  cuda


In [3]:
# keep 16:9 ratio
IMAGE_SIZE = (512,288)
CONTENT_VIDEO_PATH = "./content_video/content_video.mp4"
STYLE_VIDEO_PATH = "./style_video/style_video.mp4"

In [4]:
# load a generator that returns each video frame
content_frame_generator = video_to_frame_generator(CONTENT_VIDEO_PATH,IMAGE_SIZE)
style_frame_generator = video_to_frame_generator(STYLE_VIDEO_PATH,IMAGE_SIZE)

In [4]:
CONTENT_FRAME_SAVE_PATH = "./content_frames/"
STYLE_FRAME_SAVE_PATH = "./style_frames/"

In [6]:
# save the single frames in a path to load later
video_to_frames(content_frame_generator, CONTENT_FRAME_SAVE_PATH)
video_to_frames(style_frame_generator, STYLE_FRAME_SAVE_PATH)

KeyboardInterrupt: 

In [5]:
CONTENT_LAYERS = [6]
STYLE_LAYERS = [9,12,14]
CONTENT_LAYERS_WEIGHTS = [1.0]
STYLE_LAYERS_WEIGHTS = [1.0,1.0,1.0]

In [6]:
# load standard vgg19 model
vgg = vgg19(VGG19_Weights.DEFAULT)

# remove classification head
vgg = vgg.features

# prepend a normalization layer
vgg = nn.Sequential(Normalize(mean = (0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)), *vgg)

style_loss_model = construct_style_loss_model(vgg,CONTENT_LAYERS,STYLE_LAYERS)
style_loss_model = style_loss_model.eval()
style_loss_model.requires_grad_(False)
style_loss_model = style_loss_model.to(device)

# as an encoder for our feedforward model we will just use the already existing vgg layers up to some point
vgg_encoder = vgg[:8]

# based on that we build a decoder that reverses our encoder and matches the shapes through interpolation
decoder = construct_decoder_from_encoder(vgg_encoder.cpu(),3,*IMAGE_SIZE)



range(0, 15)


In [7]:
BATCH_SIZE = 12

In [8]:
content_dataloader =data.DataLoader(LoadFilesDataset(CONTENT_FRAME_SAVE_PATH,batch_size= BATCH_SIZE),batch_size=BATCH_SIZE)
style_dataloader = data.DataLoader(LoadFilesDataset(STYLE_FRAME_SAVE_PATH,batch_size= BATCH_SIZE),batch_size=BATCH_SIZE)

In [9]:
style_model = nn.Sequential(vgg_encoder,decoder).to(device)
optimimizer = optim.Adam(decoder.parameters(),lr=0.00001)

In [10]:
EPOCHS = 10000
ACCUM_GRAD = 1
PRINT_STATS_EVERY = 500
STYLE_WEIGHT = 1000000.0
LOSS_CONTENT = content_gatyes
LOSS_STYLE = style_gatyes
SAVE_PATH = "./feedforward_save/"

In [12]:
os.makedirs(SAVE_PATH,exist_ok = True)

content_loss_aggregator = []
style_loss_aggregator = []

for epoch in tqdm(range(EPOCHS)):

    optimimizer.zero_grad()

    for _ in range(ACCUM_GRAD):
        content_img = next(iter(content_dataloader)).to(device)
        style_img = next(iter(style_dataloader)).to(device)

        with torch.no_grad():
            _,content_features_target,_ = style_loss_model((content_img,[],[]))
            _,_,style_features_target = style_loss_model((style_img,[],[]))
        

        prediction = style_model(content_img)

        _,content_features,style_features= style_loss_model((prediction,[],[]))

        content_loss = 0.0
        for f,f_target,weight in zip(content_features,content_features_target, CONTENT_LAYERS_WEIGHTS):
            content_loss += weight*LOSS_CONTENT(normalize(f,f_target)).mean()

        style_loss = 0.0
        for f,f_target,weight in zip(style_features,style_features_target, STYLE_LAYERS_WEIGHTS):
            style_loss += weight*LOSS_STYLE(normalize(f,f_target)).mean()

        style_loss *= STYLE_WEIGHT

        content_loss_aggregator.append(content_loss.detach().cpu())
        style_loss_aggregator.append(style_loss.detach().cpu())

        loss = content_loss+style_loss

        loss.backward()

    optimimizer.step()

    if epoch%PRINT_STATS_EVERY == 0 and not epoch == 0:
        print("Ending epoch: ", str(epoch), " with content loss: ", torch.stack(content_loss_aggregator).mean().numpy().item(),  " and style loss: ", torch.stack(style_loss_aggregator).mean().numpy().item())
        img = next(iter(content_dataloader))[0]
        ToPILImage()(img.permute(0,2,1)).save(SAVE_PATH + "example_input_model.jpg")
        img = img.unsqueeze(0).to(device)

        with torch.no_grad():   
            img = style_model(img)
        img = ToPILImage()(img.cpu().squeeze(0).permute(0,2,1))
        img.save(SAVE_PATH + "example_output_model.jpg")
        torch.save(style_model.state_dict(), SAVE_PATH + "model_weights.pth")
        torch.save(optimimizer.state_dict(), SAVE_PATH + "optim_weights.pth")


    
    

  5%|▌         | 501/10000 [04:10<1:21:32,  1.94it/s]

Ending epoch:  500  with content loss:  1.7328954935073853  and style loss:  5.0672783851623535


 10%|█         | 1001/10000 [07:57<1:16:14,  1.97it/s]

Ending epoch:  1000  with content loss:  1.673820972442627  and style loss:  4.473562717437744


 15%|█▌        | 1501/10000 [11:43<1:11:41,  1.98it/s]

Ending epoch:  1500  with content loss:  1.6176385879516602  and style loss:  4.264816761016846


 20%|██        | 2001/10000 [15:30<1:07:42,  1.97it/s]

Ending epoch:  2000  with content loss:  1.5646336078643799  and style loss:  4.079663276672363


 25%|██▌       | 2501/10000 [19:15<1:02:57,  1.98it/s]

Ending epoch:  2500  with content loss:  1.5181852579116821  and style loss:  4.000204563140869


 30%|███       | 3001/10000 [23:00<58:28,  1.99it/s]  

Ending epoch:  3000  with content loss:  1.4766038656234741  and style loss:  3.906033515930176


 35%|███▌      | 3501/10000 [26:46<54:04,  2.00it/s]

Ending epoch:  3500  with content loss:  1.4398655891418457  and style loss:  3.8338711261749268


 40%|████      | 4001/10000 [30:36<50:15,  1.99it/s]

Ending epoch:  4000  with content loss:  1.4078761339187622  and style loss:  3.775028705596924


 45%|████▌     | 4501/10000 [34:25<47:53,  1.91it/s]

Ending epoch:  4500  with content loss:  1.379982590675354  and style loss:  3.7367610931396484


 50%|█████     | 5001/10000 [38:17<41:11,  2.02it/s]

Ending epoch:  5000  with content loss:  1.3554295301437378  and style loss:  3.7128803730010986


 55%|█████▌    | 5501/10000 [42:05<37:42,  1.99it/s]

Ending epoch:  5500  with content loss:  1.3334014415740967  and style loss:  3.6720876693725586


 56%|█████▌    | 5559/10000 [42:31<33:58,  2.18it/s]


KeyboardInterrupt: 

In [None]:
style_model.load_state_dict(torch.load(SAVE_PATH + "model_weights.pth",map_location=device))
optimimizer.load_state_dict(torch.load(SAVE_PATH + "optim_weights.pth",map_location=device))