In [18]:
# INFERENCE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor,ToPILImage,Normalize
from utils.dataset import LoadCoCoDataset
from utils.model import construct_style_loss_model,construct_decoder_from_encoder,AdaIN
from utils.losses import content_gatyes,style_gatyes,style_mmd_polynomial,adaIN
from utils.utility import video_to_frame_generator,video_to_frames,normalize
import cv2
import os
from copy import deepcopy
import torch.utils.data as data
from torchvision.models import vgg19,VGG19_Weights
from PIL import Image,ImageOps
import torch.optim as optim
from tqdm import tqdm
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# INFERENCE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("On device: ", device)

On device:  cuda


In [3]:
# INFERENCE

IMAGE_SIZE = (256,256)

In [4]:
# INFERENCE

CONTENT_LAYERS = [6]
STYLE_LAYERS = [9,12,14]
CONTENT_LAYERS_WEIGHTS = [1.0]
STYLE_LAYERS_WEIGHTS = [1.0,1.0,1.0]

In [5]:
# INFERENCE

# load standard vgg19 model
vgg = vgg19(VGG19_Weights.DEFAULT)

# remove classification head
vgg = vgg.features

# prepend a normalization layer
vgg = nn.Sequential(Normalize(mean = (0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)), *vgg)

style_loss_model = construct_style_loss_model(vgg,CONTENT_LAYERS,STYLE_LAYERS)
style_loss_model = style_loss_model.eval()
style_loss_model.requires_grad_(False)
style_loss_model = style_loss_model.to(device)

# as an encoder for our feedforward model we will just use the already existing vgg layers up to some point
vgg_encoder = vgg[:8]

# based on that we build a decoder that reverses our encoder and matches the shapes through interpolation
decoder = construct_decoder_from_encoder(vgg_encoder.cpu(),3,*IMAGE_SIZE)



range(0, 15)


In [6]:
# INFERENCE

class StyleModel(nn.Module):
    def __init__(self, encoder, decoder) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.adain = AdaIN()

    def forward(self, image, image_style, alpha = 1.0):
        # alpha creates a weighted average between image and style_image features
        # note that this should not be done during training, but only during inference when controlling the style level
        f = self.encoder(image)
        f_style = self.encoder(image_style)

        
        f = (1-alpha)*f + alpha*self.adain(f,f_style)

        image = self.decoder(f)

        return image

In [7]:
# INFERENCE

style_model = StyleModel(vgg_encoder,decoder).to(device)
optimimizer = optim.Adam(decoder.parameters(),lr=0.001)

In [8]:
COCO_PATH = "./test2017/"
STYLE_IMAGE_PATH = "./wave.jpg"
BATCH_SIZE = 12

In [9]:
coco_dataloader = torch.utils.data.DataLoader(LoadCoCoDataset(COCO_PATH, BATCH_SIZE ,IMAGE_SIZE),batch_size=BATCH_SIZE)
style_img = Image.open(STYLE_IMAGE_PATH).convert('RGB')
# center crop image
style_img = ImageOps.fit(style_img,(min(style_img.size),min(style_img.size))).resize(IMAGE_SIZE)
style_img = ToTensor()(style_img).permute(0,2,1).unsqueeze(0).repeat(BATCH_SIZE,1,1,1).to(device)

In [10]:
# INFERENCE

EPOCHS = 10000
ACCUM_GRAD = 1
PRINT_STATS_EVERY = 500
STYLE_WEIGHT = 10.0
LOSS_CONTENT = content_gatyes
LOSS_STYLE = adaIN
SAVE_PATH = "./adain_save/"

In [11]:
os.makedirs(SAVE_PATH,exist_ok = True)

content_loss_aggregator = []
style_loss_aggregator = []

for epoch in tqdm(range(EPOCHS)):
    optimimizer.zero_grad()

    for _ in range(ACCUM_GRAD):
        content_img = next(iter(coco_dataloader)).to(device)

        with torch.no_grad():
            _,content_features_target,_ = style_loss_model((content_img,[],[]))
            _,_,style_features_target = style_loss_model((style_img,[],[]))
        

        prediction = style_model(content_img,style_img)

        _,content_features,style_features= style_loss_model((prediction,[],[]))

        content_loss = 0.0
        for f,f_target,weight in zip(content_features,content_features_target, CONTENT_LAYERS_WEIGHTS):
            content_loss += weight*LOSS_CONTENT(*normalize(f,f_target)).mean()

        style_loss = 0.0
        for f,f_target,weight in zip(style_features,style_features_target, STYLE_LAYERS_WEIGHTS):
            style_loss += weight*LOSS_STYLE(*normalize(f,f_target)).mean()

        style_loss *= STYLE_WEIGHT

        content_loss_aggregator.append(content_loss.detach().cpu())
        style_loss_aggregator.append(style_loss.detach().cpu())

        loss = content_loss+style_loss

        loss.backward()

    optimimizer.step()  

    if epoch%PRINT_STATS_EVERY == 0 and not epoch == 0:
        print("Ending epoch: ", str(epoch), " with content loss: ", torch.stack(content_loss_aggregator).mean().numpy().item(),  " and style loss: ", torch.stack(style_loss_aggregator).mean().numpy().item())
        ToPILImage()(content_img[0].permute(0,2,1)).save(SAVE_PATH + "example_input_model.jpg")
        with torch.no_grad():   
            img = style_model(content_img[0].unsqueeze(0), style_img[0].unsqueeze(0))
        img = ToPILImage()(img.cpu().squeeze(0).permute(0,2,1))
        img.save(SAVE_PATH + "example_output_model.jpg")
        torch.save(style_model.state_dict(), SAVE_PATH + "model_weights.pth")
        torch.save(optimimizer.state_dict(), SAVE_PATH + "optim_weights.pth")


    
    

  5%|▌         | 501/10000 [01:55<40:31,  3.91it/s]

Ending epoch:  500  with content loss:  1.6700294017791748  and style loss:  1.5538650751113892


 10%|█         | 1001/10000 [03:49<37:04,  4.04it/s]

Ending epoch:  1000  with content loss:  1.6408993005752563  and style loss:  1.0866953134536743


 15%|█▌        | 1501/10000 [05:43<36:48,  3.85it/s]

Ending epoch:  1500  with content loss:  1.619925618171692  and style loss:  0.8881847858428955


 20%|██        | 2001/10000 [07:37<34:04,  3.91it/s]

Ending epoch:  2000  with content loss:  1.6041851043701172  and style loss:  0.7736931443214417


 25%|██▌       | 2501/10000 [09:33<32:07,  3.89it/s]

Ending epoch:  2500  with content loss:  1.5910303592681885  and style loss:  0.6981415748596191


 30%|███       | 3001/10000 [11:27<28:39,  4.07it/s]

Ending epoch:  3000  with content loss:  1.579669713973999  and style loss:  0.6427444815635681


 35%|███▌      | 3501/10000 [13:19<26:41,  4.06it/s]

Ending epoch:  3500  with content loss:  1.5697897672653198  and style loss:  0.6003834009170532


 40%|████      | 4001/10000 [15:14<25:08,  3.98it/s]

Ending epoch:  4000  with content loss:  1.5606598854064941  and style loss:  0.5665675401687622


 45%|████▌     | 4501/10000 [17:06<23:16,  3.94it/s]

Ending epoch:  4500  with content loss:  1.5525904893875122  and style loss:  0.5390615463256836


 50%|█████     | 5001/10000 [18:57<20:53,  3.99it/s]

Ending epoch:  5000  with content loss:  1.545293927192688  and style loss:  0.5163284540176392


 55%|█████▌    | 5501/10000 [20:47<18:18,  4.10it/s]

Ending epoch:  5500  with content loss:  1.538694977760315  and style loss:  0.4970754086971283


 60%|██████    | 6001/10000 [22:35<16:06,  4.14it/s]

Ending epoch:  6000  with content loss:  1.5325473546981812  and style loss:  0.4804943799972534


 65%|██████▌   | 6501/10000 [24:22<13:40,  4.27it/s]

Ending epoch:  6500  with content loss:  1.5268853902816772  and style loss:  0.46600043773651123


 70%|███████   | 7001/10000 [26:09<12:17,  4.06it/s]

Ending epoch:  7000  with content loss:  1.5217007398605347  and style loss:  0.4534367024898529


 75%|███████▌  | 7501/10000 [27:55<09:45,  4.27it/s]

Ending epoch:  7500  with content loss:  1.5168037414550781  and style loss:  0.4419664740562439


 80%|████████  | 8001/10000 [29:40<07:51,  4.24it/s]

Ending epoch:  8000  with content loss:  1.5123587846755981  and style loss:  0.4318700432777405


 85%|████████▌ | 8501/10000 [31:25<05:59,  4.17it/s]

Ending epoch:  8500  with content loss:  1.5082148313522339  and style loss:  0.422717809677124


 90%|█████████ | 9001/10000 [33:12<03:53,  4.27it/s]

Ending epoch:  9000  with content loss:  1.5042678117752075  and style loss:  0.4143603444099426


 95%|█████████▌| 9501/10000 [35:08<02:05,  3.98it/s]

Ending epoch:  9500  with content loss:  1.5007281303405762  and style loss:  0.40696725249290466


100%|██████████| 10000/10000 [36:59<00:00,  4.51it/s]


In [None]:
# INFERENCE

style_model.load_state_dict(torch.load(SAVE_PATH + "model_weights.pth",map_location=device))

In [12]:
# INFERENCE

CONTENT_IMAGE_PATH = "./dragon.jpg"
STYLE_IMAGE_PATH = "./wave.jpg"

In [20]:
# INFERENCE

content_img = Image.open(CONTENT_IMAGE_PATH).convert('RGB')
# center crop image
content_img = ImageOps.fit(content_img,(min(content_img.size),min(content_img.size))).resize(IMAGE_SIZE)
content_img = ToTensor()(content_img).permute(0,2,1).to(device)

style_img = Image.open(STYLE_IMAGE_PATH).convert('RGB')
# center crop image
style_img = ImageOps.fit(style_img,(min(style_img.size),min(style_img.size))).resize(IMAGE_SIZE)
style_img = ToTensor()(style_img).permute(0,2,1).to(device)

In [None]:
# INFERENCE

# controlls the "amount of style transfer"
ALPHA = 0.7

In [21]:
# INFERENCE

content_style_image = style_model(content_img.unsqueeze(0),style_img.unsqueeze(0), ALPHA).squeeze(0).cpu()
content_style_image = ToPILImage()(content_style_image.permute(0,2,1))
content_style_image.show()