In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor,ToPILImage,Normalize
from utils.dataset import LoadFilesDataset
from utils.model import construct_style_loss_model,construct_decoder_from_encoder,AdaIN
from utils.losses import content_gatyes,style_gatyes,style_mmd_polynomial,adaIN
from utils.utility import video_to_frame_generator,video_to_frames,normalize
import cv2
import os
from copy import deepcopy
import torch.utils.data as data
from torchvision.models import vgg19,VGG19_Weights
from PIL import Image
import torch.optim as optim
from tqdm import tqdm
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("On device: ", device)

On device:  cuda


In [14]:
# keep 16:9 ratio
IMAGE_SIZE = (512,288)
CONTENT_VIDEO_PATH = "./content_video/content_video.mp4"
STYLE_VIDEO_PATH = "./style_video/style_video.mp4"

In [None]:
# load a generator that returns each video frame
content_frame_generator = video_to_frame_generator(CONTENT_VIDEO_PATH,IMAGE_SIZE)
style_frame_generator = video_to_frame_generator(STYLE_VIDEO_PATH,IMAGE_SIZE)

In [15]:
CONTENT_FRAME_SAVE_PATH = "./content_frames/"
STYLE_FRAME_SAVE_PATH = "./style_frames/"

In [None]:
# save the single frames in a path to load later
video_to_frames(content_frame_generator, CONTENT_FRAME_SAVE_PATH)
video_to_frames(style_frame_generator, STYLE_FRAME_SAVE_PATH)

In [16]:
CONTENT_LAYERS = [6]
STYLE_LAYERS = [9,12,14]
CONTENT_LAYERS_WEIGHTS = [1.0]
STYLE_LAYERS_WEIGHTS = [1.0,1.0,1.0]

In [17]:
# load standard vgg19 model
vgg = vgg19(VGG19_Weights.DEFAULT)

# remove classification head
vgg = vgg.features

# prepend a normalization layer
vgg = nn.Sequential(Normalize(mean = (0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)), *vgg)

style_loss_model = construct_style_loss_model(vgg,CONTENT_LAYERS,STYLE_LAYERS)
style_loss_model = style_loss_model.eval()
style_loss_model.requires_grad_(False)
style_loss_model = style_loss_model.to(device)

# as an encoder for our feedforward model we will just use the already existing vgg layers up to some point
vgg_encoder = vgg[:8]

# based on that we build a decoder that reverses our encoder and matches the shapes through interpolation
decoder = construct_decoder_from_encoder(vgg_encoder.cpu(),3,*IMAGE_SIZE)

range(0, 15)


In [18]:
BATCH_SIZE = 12

In [19]:
content_dataloader =data.DataLoader(LoadFilesDataset(CONTENT_FRAME_SAVE_PATH,batch_size= BATCH_SIZE),batch_size=BATCH_SIZE)
style_dataloader = data.DataLoader(LoadFilesDataset(STYLE_FRAME_SAVE_PATH,batch_size= BATCH_SIZE),batch_size=BATCH_SIZE)

In [20]:
class StyleModel(nn.Module):
    def __init__(self, encoder, decoder) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.adain = AdaIN()

    def forward(self, image, image_style):
        f = self.encoder(image)
        f_style = self.encoder(image_style)

        f = self.adain(f,f_style)

        image = self.decoder(f)

        return image

In [21]:
style_model = StyleModel(vgg_encoder,decoder).to(device)
optimimizer = optim.Adam(decoder.parameters(),lr=0.001)

In [22]:
EPOCHS = 10000
ACCUM_GRAD = 1
PRINT_STATS_EVERY = 100
STYLE_WEIGHT = 1.0
LOSS_CONTENT = content_gatyes
LOSS_STYLE = adaIN
SAVE_PATH = "./adain_save/"

In [23]:
os.makedirs(SAVE_PATH,exist_ok = True)

content_loss_aggregator = []
style_loss_aggregator = []

for epoch in tqdm(range(EPOCHS)):
    optimimizer.zero_grad()

    for _ in range(ACCUM_GRAD):
        content_img = next(iter(content_dataloader)).to(device)
        style_img = next(iter(style_dataloader)).to(device)

        with torch.no_grad():
            _,content_features_target,_ = style_loss_model((content_img,[],[]))
            _,_,style_features_target = style_loss_model((style_img,[],[]))
        

        prediction = style_model(content_img,style_img)

        _,content_features,style_features= style_loss_model((prediction,[],[]))

        content_loss = 0.0
        for f,f_target,weight in zip(content_features,content_features_target, CONTENT_LAYERS_WEIGHTS):
            content_loss += weight*LOSS_CONTENT(normalize_cw(f,f_target)).mean()

        style_loss = 0.0
        for f,f_target,weight in zip(style_features,style_features_target, STYLE_LAYERS_WEIGHTS):
            style_loss += weight*LOSS_STYLE(normalize_cw(f,f_target)).mean()

        style_loss *= STYLE_WEIGHT

        content_loss_aggregator.append(content_loss.detach().cpu())
        style_loss_aggregator.append(style_loss.detach().cpu())

        loss = content_loss+style_loss

        loss.backward()

    optimimizer.step()  

    if epoch%PRINT_STATS_EVERY == 0 and not epoch == 0:
        print("Ending epoch: ", str(epoch), " with content loss: ", torch.stack(content_loss_aggregator).mean().numpy().item(),  " and style loss: ", torch.stack(style_loss_aggregator).mean().numpy().item())
        img = content_img[0]
        style_img = style_img[0]

        with torch.no_grad():   
            img = style_model(img.unsqueeze(0), style_img.unsqueeze(0))
        img = ToPILImage()(img.cpu().squeeze(0).permute(0,2,1))
        img.save(SAVE_PATH + "example_output_model.jpg")
        torch.save(style_model.state_dict(), SAVE_PATH + "model_weights.pth")
        torch.save(optimimizer.state_dict(), SAVE_PATH + "optim_weights.pth")


    
    

  1%|          | 100/10000 [00:56<1:25:28,  1.93it/s]

Ending epoch:  100  with content loss:  1.2336143255233765  and style loss:  0.5468908548355103


  2%|▏         | 201/10000 [01:46<1:29:39,  1.82it/s]

Ending epoch:  200  with content loss:  1.0675320625305176  and style loss:  0.4527081251144409


  3%|▎         | 301/10000 [02:38<1:31:29,  1.77it/s]

Ending epoch:  300  with content loss:  0.9881121516227722  and style loss:  0.4058834910392761


  4%|▍         | 401/10000 [03:29<1:38:03,  1.63it/s]

Ending epoch:  400  with content loss:  0.9392931461334229  and style loss:  0.3795006275177002


  5%|▌         | 500/10000 [04:22<1:25:19,  1.86it/s]

Ending epoch:  500  with content loss:  0.9066312313079834  and style loss:  0.36139097809791565


  6%|▌         | 601/10000 [05:17<1:32:09,  1.70it/s]

Ending epoch:  600  with content loss:  0.8845729827880859  and style loss:  0.3493477702140808


  7%|▋         | 701/10000 [06:11<1:32:54,  1.67it/s]

Ending epoch:  700  with content loss:  0.8648970127105713  and style loss:  0.33842453360557556


  8%|▊         | 801/10000 [07:05<1:32:52,  1.65it/s]

Ending epoch:  800  with content loss:  0.8515757322311401  and style loss:  0.3301888108253479


  9%|▉         | 901/10000 [07:57<1:16:39,  1.98it/s]

Ending epoch:  900  with content loss:  0.8387537598609924  and style loss:  0.3235618472099304


 10%|█         | 1001/10000 [08:43<1:16:54,  1.95it/s]

Ending epoch:  1000  with content loss:  0.8281869292259216  and style loss:  0.31743574142456055


 11%|█         | 1101/10000 [09:29<1:15:09,  1.97it/s]

Ending epoch:  1100  with content loss:  0.8191487193107605  and style loss:  0.31222712993621826


 12%|█▏        | 1201/10000 [10:15<1:14:10,  1.98it/s]

Ending epoch:  1200  with content loss:  0.8118916749954224  and style loss:  0.3073665201663971


 13%|█▎        | 1301/10000 [11:01<1:14:13,  1.95it/s]

Ending epoch:  1300  with content loss:  0.8046957850456238  and style loss:  0.30292466282844543


 14%|█▍        | 1401/10000 [11:47<1:12:55,  1.97it/s]

Ending epoch:  1400  with content loss:  0.7983420491218567  and style loss:  0.2994726598262787


 15%|█▌        | 1501/10000 [12:33<1:11:54,  1.97it/s]

Ending epoch:  1500  with content loss:  0.792836606502533  and style loss:  0.2962850332260132


 16%|█▌        | 1601/10000 [13:20<1:11:15,  1.96it/s]

Ending epoch:  1600  with content loss:  0.7878202795982361  and style loss:  0.2933942973613739


 17%|█▋        | 1701/10000 [14:07<1:11:04,  1.95it/s]

Ending epoch:  1700  with content loss:  0.7837140560150146  and style loss:  0.29135265946388245


 18%|█▊        | 1801/10000 [14:53<1:10:43,  1.93it/s]

Ending epoch:  1800  with content loss:  0.7795186638832092  and style loss:  0.2890585660934448


 19%|█▉        | 1901/10000 [15:40<1:09:34,  1.94it/s]

Ending epoch:  1900  with content loss:  0.7760916352272034  and style loss:  0.2869320511817932


 20%|██        | 2001/10000 [16:26<1:08:08,  1.96it/s]

Ending epoch:  2000  with content loss:  0.7723482847213745  and style loss:  0.285053551197052


 21%|██        | 2101/10000 [17:12<1:07:37,  1.95it/s]

Ending epoch:  2100  with content loss:  0.7685418725013733  and style loss:  0.2831972539424896


 22%|██▏       | 2201/10000 [17:59<1:07:00,  1.94it/s]

Ending epoch:  2200  with content loss:  0.765198290348053  and style loss:  0.28146564960479736


 23%|██▎       | 2301/10000 [18:45<1:05:45,  1.95it/s]

Ending epoch:  2300  with content loss:  0.7620775103569031  and style loss:  0.2795625627040863


 24%|██▍       | 2401/10000 [19:31<1:05:04,  1.95it/s]

Ending epoch:  2400  with content loss:  0.7587950229644775  and style loss:  0.2780008316040039


 25%|██▌       | 2501/10000 [20:18<1:10:54,  1.76it/s]

Ending epoch:  2500  with content loss:  0.756460964679718  and style loss:  0.276410847902298


 26%|██▌       | 2601/10000 [21:08<1:03:27,  1.94it/s]

Ending epoch:  2600  with content loss:  0.7541303038597107  and style loss:  0.2750207483768463


 27%|██▋       | 2701/10000 [21:55<1:01:43,  1.97it/s]

Ending epoch:  2700  with content loss:  0.7516283988952637  and style loss:  0.2736102342605591


 28%|██▊       | 2801/10000 [22:42<1:02:24,  1.92it/s]

Ending epoch:  2800  with content loss:  0.7492120265960693  and style loss:  0.27239158749580383


 29%|██▉       | 2901/10000 [23:30<1:02:16,  1.90it/s]

Ending epoch:  2900  with content loss:  0.7471269369125366  and style loss:  0.2711432874202728


 30%|███       | 3001/10000 [24:17<1:00:32,  1.93it/s]

Ending epoch:  3000  with content loss:  0.7448714971542358  and style loss:  0.27039122581481934


 31%|███       | 3101/10000 [25:06<1:03:03,  1.82it/s]

Ending epoch:  3100  with content loss:  0.7428904175758362  and style loss:  0.2693850100040436


 32%|███▏      | 3201/10000 [25:54<59:15,  1.91it/s]  

Ending epoch:  3200  with content loss:  0.7408581376075745  and style loss:  0.26829618215560913


 33%|███▎      | 3301/10000 [26:43<56:36,  1.97it/s]  

Ending epoch:  3300  with content loss:  0.7388423085212708  and style loss:  0.2673088312149048


 34%|███▍      | 3401/10000 [27:30<55:51,  1.97it/s]

Ending epoch:  3400  with content loss:  0.7370851635932922  and style loss:  0.26648199558258057


 35%|███▌      | 3501/10000 [28:17<55:10,  1.96it/s]

Ending epoch:  3500  with content loss:  0.7351782917976379  and style loss:  0.26535189151763916


 36%|███▌      | 3601/10000 [29:04<55:35,  1.92it/s]

Ending epoch:  3600  with content loss:  0.7332947254180908  and style loss:  0.26451677083969116


 37%|███▋      | 3701/10000 [29:53<56:59,  1.84it/s]

Ending epoch:  3700  with content loss:  0.7317313551902771  and style loss:  0.2637879550457001


 38%|███▊      | 3801/10000 [30:42<59:00,  1.75it/s]

Ending epoch:  3800  with content loss:  0.7303416132926941  and style loss:  0.26313719153404236


 39%|███▉      | 3901/10000 [31:30<54:12,  1.88it/s]

Ending epoch:  3900  with content loss:  0.7284301519393921  and style loss:  0.2624848186969757


 40%|████      | 4001/10000 [32:18<51:43,  1.93it/s]

Ending epoch:  4000  with content loss:  0.7267278432846069  and style loss:  0.26159051060676575


 41%|████      | 4101/10000 [33:05<50:19,  1.95it/s]

Ending epoch:  4100  with content loss:  0.7254154682159424  and style loss:  0.26081231236457825


 42%|████▏     | 4201/10000 [33:51<49:15,  1.96it/s]

Ending epoch:  4200  with content loss:  0.7240453362464905  and style loss:  0.2600870728492737


 43%|████▎     | 4301/10000 [34:37<48:09,  1.97it/s]

Ending epoch:  4300  with content loss:  0.7228165864944458  and style loss:  0.2595096230506897


 44%|████▍     | 4401/10000 [35:23<47:18,  1.97it/s]

Ending epoch:  4400  with content loss:  0.7217163443565369  and style loss:  0.25893253087997437


 45%|████▌     | 4501/10000 [36:10<48:04,  1.91it/s]

Ending epoch:  4500  with content loss:  0.7205666899681091  and style loss:  0.2582501471042633


 46%|████▌     | 4601/10000 [36:57<46:50,  1.92it/s]

Ending epoch:  4600  with content loss:  0.7192398905754089  and style loss:  0.25767838954925537


 47%|████▋     | 4701/10000 [37:44<45:28,  1.94it/s]

Ending epoch:  4700  with content loss:  0.7179045677185059  and style loss:  0.25709861516952515


 48%|████▊     | 4801/10000 [38:31<44:58,  1.93it/s]

Ending epoch:  4800  with content loss:  0.716627836227417  and style loss:  0.2564457654953003


 49%|████▉     | 4901/10000 [39:18<45:28,  1.87it/s]

Ending epoch:  4900  with content loss:  0.7153955101966858  and style loss:  0.25591152906417847


 50%|█████     | 5001/10000 [40:06<42:44,  1.95it/s]

Ending epoch:  5000  with content loss:  0.7143073678016663  and style loss:  0.2553049921989441


 51%|█████     | 5101/10000 [40:53<41:45,  1.96it/s]

Ending epoch:  5100  with content loss:  0.7132276296615601  and style loss:  0.2548961341381073


 52%|█████▏    | 5201/10000 [41:41<41:55,  1.91it/s]

Ending epoch:  5200  with content loss:  0.7122882008552551  and style loss:  0.25434932112693787


 53%|█████▎    | 5301/10000 [42:28<41:46,  1.87it/s]

Ending epoch:  5300  with content loss:  0.7113151550292969  and style loss:  0.2540483772754669


 54%|█████▍    | 5401/10000 [43:15<39:21,  1.95it/s]

Ending epoch:  5400  with content loss:  0.7103047370910645  and style loss:  0.25360944867134094


 55%|█████▌    | 5501/10000 [44:02<39:31,  1.90it/s]

Ending epoch:  5500  with content loss:  0.7093793749809265  and style loss:  0.2531166970729828


 56%|█████▌    | 5601/10000 [44:52<42:17,  1.73it/s]

Ending epoch:  5600  with content loss:  0.7084925770759583  and style loss:  0.25257569551467896


 57%|█████▋    | 5701/10000 [45:44<40:30,  1.77it/s]

Ending epoch:  5700  with content loss:  0.7077062726020813  and style loss:  0.25227221846580505


 58%|█████▊    | 5801/10000 [46:36<40:50,  1.71it/s]

Ending epoch:  5800  with content loss:  0.7067771553993225  and style loss:  0.2519291043281555


 59%|█████▉    | 5901/10000 [47:29<40:56,  1.67it/s]

Ending epoch:  5900  with content loss:  0.7058200240135193  and style loss:  0.25149938464164734


 60%|██████    | 6001/10000 [48:23<39:04,  1.71it/s]

Ending epoch:  6000  with content loss:  0.7051131129264832  and style loss:  0.25117090344429016


 61%|██████    | 6088/10000 [49:07<31:34,  2.07it/s]


KeyboardInterrupt: 