In [1]:
# INFERENCE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor,ToPILImage,Normalize
from utils.dataset import LoadCoCoDataset
from utils.model import construct_style_loss_model,construct_decoder_from_encoder,AdaIN,SANet
from utils.losses import content_gatyes,style_gatyes,style_mmd_polynomial,adaIN
from utils.utility import video_to_frame_generator,video_to_frames,normalize,normalize_cw
import cv2
import os
from copy import deepcopy
import torch.utils.data as data
from torchvision.models import vgg19,VGG19_Weights
from PIL import Image,ImageOps
import torch.optim as optim
from tqdm import tqdm
%load_ext autoreload
%autoreload 2

In [2]:
# INFERENCE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("On device: ", device)

On device:  cuda


In [3]:
# INFERENCE

IMAGE_SIZE = (256,256)

In [4]:
# INFERENCE

STYLE_LAYERS = [9,12,14]
STYLE_LAYERS_WEIGHTS = [1.0,1.0,1.0]

In [5]:
# INFERENCE

# load standard vgg19 model
vgg = vgg19(VGG19_Weights.DEFAULT)

# remove classification head
vgg = vgg.features

# prepend a normalization layer
vgg = nn.Sequential(Normalize(mean = (0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)), *vgg)

style_loss_model = construct_style_loss_model(vgg,[],STYLE_LAYERS)
style_loss_model = style_loss_model.eval()
style_loss_model.requires_grad_(False)

# as encoders for our feedforward model we will just use the already existing vgg layers up to some point
vgg_encoder_r_4_1 = vgg[:10]
vgg_encoder_r_5_1 = vgg[10:13]

example_input = torch.rand(16,3,*IMAGE_SIZE)
# get the size of the respective encoder outputs
size_r_4_1 = vgg_encoder_r_4_1(example_input).size()[1:]
size_r_5_1 = vgg_encoder_r_5_1(vgg_encoder_r_4_1(example_input)).size()[1:]

# Our Style Model will use the feature output of vgg_encoder_r_5_1 and vgg_encoder_r_4_1. But the vgg_encoder_r_5_1 will be upsampled to the size of vgg_encoder_r_4_1.
# After that a 3x3 convolution is applied and that is fed into our decoder. To make our decoder match that last 3x3 convolution size, 
# we simply add it to the vgg_encoder_r_4_1 when constructing the decoder (NOT IN THE ACTUAL MODEL DIRECTLY AFTER THE ENCODER).
# This simulated encoder is then used to construct the decoder. 
simulated_encoder = nn.Sequential(*vgg[0:10],nn.Conv2d(size_r_4_1[0],size_r_4_1[0],kernel_size=3,stride=3))
# based on that we build a decoder that reverses our encoder and matches the shapes through interpolation
decoder = construct_decoder_from_encoder(simulated_encoder,3,*IMAGE_SIZE)



range(0, 15)


In [6]:
# INFERENCE

class Encoder(nn.Module):
    """
    This encoder returns two feature maps from an encoder. The encoders given as arguments have to be sliced in such a way that the output from encoder 1 can be fed into encoder 2.
    """
    
    def __init__(self, vgg_encoder_r_4_1, vgg_encoder_r_5_1):
        super().__init__()
        self.vgg_encoder_r_4_1 = vgg_encoder_r_4_1
        self.vgg_encoder_r_5_1 = vgg_encoder_r_5_1

    def forward(self, img):
        f_r_4_1 = self.vgg_encoder_r_4_1(img)
        f_r_5_1 = self.vgg_encoder_r_5_1(f_r_4_1)

        return f_r_4_1,f_r_5_1

In [7]:
# INFERENCE

class StyleModel(nn.Module):
    """
    Implementing the full Style Network from the SANet paper. Note that I named the variables after the variables in the model sketch in the paper
    """

    def __init__(self, encoder, size_r_4_1, size_r_5_1,decoder) -> None:
        super().__init__()
        self.encoder = encoder
        self.size_r_4_1 = size_r_4_1

        # initialize SANet with the channel size
        self.SANet_r_4_1 = SANet(size_r_4_1[0])
        self.SANet_r_5_1 = SANet(size_r_5_1[0])
        self.decoder = decoder

        self.conv1 = nn.Conv2d(size_r_4_1[0],size_r_4_1[0],kernel_size=1,stride=1)
        self.conv2 = nn.Conv2d(size_r_5_1[0],size_r_5_1[0],kernel_size=1,stride=1)
        self.conv3 = nn.Conv2d(size_r_4_1[0],size_r_4_1[0],kernel_size=3,stride=3)

    def forward(self, img, img_style):

        f_c_r_4_1,f_c_r_5_1 = self.encoder(img)
        f_s_r_4_1,f_s_r_5_1 = self.encoder(img_style)

        f_cs_r_4_1 = self.conv1(self.SANet_r_4_1(f_c_r_4_1,f_s_r_4_1))
        f_cs_r_5_1 = self.conv2(self.SANet_r_5_1(f_c_r_5_1,f_s_r_5_1))

        f_csc_r_4_1 = f_cs_r_4_1 + f_c_r_4_1
        f_csc_r_5_1 = f_cs_r_5_1 + f_c_r_5_1

        f_csc_r_5_1 = F.interpolate(f_csc_r_5_1.unsqueeze(1),self.size_r_4_1).squeeze(1)

        f_csc_m = f_csc_r_4_1 + f_csc_r_5_1

        f_csc_m = self.conv3(f_csc_m)

        return self.decoder(f_csc_m),f_c_r_4_1,f_c_r_5_1


In [8]:
# INFERENCE

encoder = Encoder(vgg_encoder_r_4_1, vgg_encoder_r_5_1).to(device)
encoder.requires_grad_(False)

style_model = StyleModel(encoder, size_r_4_1, size_r_5_1,decoder).to(device)
style_model.encoder.requires_grad_(False)

style_loss_model = style_loss_model.to(device)

optimimizer = optim.Adam(style_model.parameters(),lr=0.0001)

In [9]:
COCO_PATH = "./test2017/"
STYLE_IMAGE_PATH = "./wave.jpg"
BATCH_SIZE = 1

In [10]:
coco_dataloader = torch.utils.data.DataLoader(LoadCoCoDataset(COCO_PATH, BATCH_SIZE ,IMAGE_SIZE),batch_size=BATCH_SIZE)
style_img = Image.open(STYLE_IMAGE_PATH).convert('RGB')
# center crop image
style_img = ImageOps.fit(style_img,(min(style_img.size),min(style_img.size))).resize(IMAGE_SIZE)
style_img = ToTensor()(style_img).permute(0,2,1).unsqueeze(0).repeat(BATCH_SIZE,1,1,1).to(device)

In [11]:
# INFERENCE

EPOCHS = 10000
ACCUM_GRAD = 2
PRINT_STATS_EVERY = 100
CONTENT_WEIGHT = 25.0
STYLE_WEIGHT = 3.0
IDENTITY_1_WEIGHT = 50.0
IDENTITY_2_WEIGHT = 4000.0
LOSS_STYLE = adaIN
SAVE_PATH = "./SANet_save/"

In [12]:
os.makedirs(SAVE_PATH,exist_ok = True)


content_loss_aggregator = []
style_loss_aggregator = []

for epoch in tqdm(range(EPOCHS)):
    optimimizer.zero_grad()

    for _ in range(ACCUM_GRAD):
        content_img = next(iter(coco_dataloader)).to(device)

        with torch.no_grad():
            _,_,style_features_target = style_loss_model((style_img,[],[]))
        
        prediction, f_c_r_4_1, f_c_r_5_1 = style_model(content_img,style_img)

        f_c_r_4_1, f_c_r_5_1 = normalize_cw(f_c_r_4_1, f_c_r_5_1)
        
        pred_cs_r_4_1,pred_cs_r_5_1 = style_model.encoder(prediction)
        pred_cs_r_4_1, pred_cs_r_5_1 = normalize_cw(pred_cs_r_4_1, pred_cs_r_5_1)

        content_loss = CONTENT_WEIGHT*(torch.norm((pred_cs_r_4_1-f_c_r_4_1).view(BATCH_SIZE,-1), dim=-1).mean() + torch.norm((pred_cs_r_5_1-f_c_r_5_1).view(BATCH_SIZE,-1), dim=-1).mean())

        _,_,style_features= style_loss_model((prediction,[],[]))

        style_loss = 0.0
        for f,f_target,weight in zip(style_features,style_features_target, STYLE_LAYERS_WEIGHTS):
            style_loss += weight*LOSS_STYLE(*normalize(f,f_target)).mean()

        style_loss *= STYLE_WEIGHT

        i_cc,_,_ = style_model(content_img,content_img)
        i_ss,_,_ = style_model(style_img,style_img)

        identity_1_loss = IDENTITY_1_WEIGHT*(torch.norm((i_cc-content_img).view(BATCH_SIZE,-1), dim=-1).mean() + torch.norm((i_ss-style_img).view(BATCH_SIZE,-1), dim=-1).mean())
        _,_,i_cc_style_features= style_loss_model((i_cc,[],[]))
        _,_,content_img_style_features= style_loss_model((content_img,[],[]))
        _,_,i_ss_style_features= style_loss_model((i_ss,[],[]))
        _,_,style_img_style_features= style_loss_model((style_img,[],[]))

        identity_2_loss = 0.0
        for f_icc,f_ic,f_iss,f_is in zip(i_cc_style_features,content_img_style_features,i_ss_style_features,style_img_style_features):
            identity_2_loss += (torch.norm((f_icc-f_ic).view(BATCH_SIZE,-1), dim=-1).mean() + torch.norm((f_iss-f_is).view(BATCH_SIZE,-1), dim=-1).mean())
        identity_2_loss *= IDENTITY_2_WEIGHT

        content_loss_aggregator.append(content_loss.detach().cpu())
        style_loss_aggregator.append(style_loss.detach().cpu())

        loss = content_loss + style_loss + identity_1_loss + identity_2_loss

        loss.backward()

    optimimizer.step()  

    if epoch%PRINT_STATS_EVERY == 0 and not epoch == 0:
        print("Ending epoch: ", str(epoch), " with content loss: ", torch.stack(content_loss_aggregator).mean().numpy().item(),  " and style loss: ", torch.stack(style_loss_aggregator).mean().numpy().item())
        ToPILImage()(content_img[0].permute(0,2,1)).save(SAVE_PATH + "example_input_model.jpg")
        with torch.no_grad():   
            img,_,_ = style_model(content_img[0].unsqueeze(0), style_img[0].unsqueeze(0))
        img = ToPILImage()(img.cpu().squeeze(0).permute(0,2,1))
        img.save(SAVE_PATH + "example_output_model.jpg")
        torch.save(style_model.state_dict(), SAVE_PATH + "model_weights.pth")
        torch.save(optimimizer.state_dict(), SAVE_PATH + "optim_weights.pth")


    
    

  1%|          | 100/10000 [01:02<1:37:40,  1.69it/s]

Ending epoch:  100  with content loss:  79777.1015625  and style loss:  3.8656041622161865


  2%|▏         | 200/10000 [02:00<1:36:03,  1.70it/s]

Ending epoch:  200  with content loss:  78774.6953125  and style loss:  3.3525102138519287


  3%|▎         | 301/10000 [02:59<1:44:57,  1.54it/s]

Ending epoch:  300  with content loss:  78384.921875  and style loss:  3.093285083770752


  4%|▍         | 400/10000 [03:57<1:29:38,  1.78it/s]

Ending epoch:  400  with content loss:  78125.390625  and style loss:  2.9220762252807617


  5%|▌         | 500/10000 [04:57<1:33:53,  1.69it/s]

Ending epoch:  500  with content loss:  77896.46875  and style loss:  2.798980236053467


  6%|▌         | 600/10000 [05:54<1:29:54,  1.74it/s]

Ending epoch:  600  with content loss:  77714.7578125  and style loss:  2.7042291164398193


  7%|▋         | 700/10000 [06:53<1:31:42,  1.69it/s]

Ending epoch:  700  with content loss:  77585.25  and style loss:  2.629682779312134


  8%|▊         | 800/10000 [07:52<1:30:52,  1.69it/s]

Ending epoch:  800  with content loss:  77473.140625  and style loss:  2.569742202758789


  9%|▉         | 900/10000 [08:52<1:30:05,  1.68it/s]

Ending epoch:  900  with content loss:  77349.796875  and style loss:  2.5172319412231445


 10%|█         | 1000/10000 [09:51<1:28:54,  1.69it/s]

Ending epoch:  1000  with content loss:  77254.8515625  and style loss:  2.4737021923065186


 11%|█         | 1100/10000 [10:50<1:27:49,  1.69it/s]

Ending epoch:  1100  with content loss:  77168.1328125  and style loss:  2.4358396530151367


 12%|█▏        | 1200/10000 [11:50<1:26:44,  1.69it/s]

Ending epoch:  1200  with content loss:  77069.625  and style loss:  2.4037458896636963


 13%|█▎        | 1300/10000 [12:49<1:25:46,  1.69it/s]

Ending epoch:  1300  with content loss:  76993.6875  and style loss:  2.3749423027038574


 14%|█▍        | 1400/10000 [13:48<1:24:26,  1.70it/s]

Ending epoch:  1400  with content loss:  76934.75  and style loss:  2.352935552597046


 15%|█▌        | 1500/10000 [14:47<1:23:47,  1.69it/s]

Ending epoch:  1500  with content loss:  76881.59375  and style loss:  2.330317258834839


 16%|█▌        | 1600/10000 [15:46<1:22:31,  1.70it/s]

Ending epoch:  1600  with content loss:  76822.828125  and style loss:  2.3106155395507812


 17%|█▋        | 1700/10000 [16:45<1:21:36,  1.70it/s]

Ending epoch:  1700  with content loss:  76761.7421875  and style loss:  2.2888224124908447


 18%|█▊        | 1800/10000 [17:43<1:19:56,  1.71it/s]

Ending epoch:  1800  with content loss:  76718.1796875  and style loss:  2.273223876953125


 19%|█▉        | 1900/10000 [18:43<1:19:43,  1.69it/s]

Ending epoch:  1900  with content loss:  76666.359375  and style loss:  2.2569961547851562


 20%|██        | 2000/10000 [19:42<1:17:57,  1.71it/s]

Ending epoch:  2000  with content loss:  76620.8046875  and style loss:  2.242840528488159


 21%|██        | 2100/10000 [20:41<1:17:25,  1.70it/s]

Ending epoch:  2100  with content loss:  76582.9140625  and style loss:  2.228301525115967


 22%|██▏       | 2200/10000 [21:40<1:16:30,  1.70it/s]

Ending epoch:  2200  with content loss:  76547.3828125  and style loss:  2.2149722576141357


 23%|██▎       | 2300/10000 [22:39<1:15:06,  1.71it/s]

Ending epoch:  2300  with content loss:  76499.828125  and style loss:  2.202899694442749


 24%|██▍       | 2400/10000 [23:38<1:14:28,  1.70it/s]

Ending epoch:  2400  with content loss:  76461.1875  and style loss:  2.1929774284362793


 25%|██▌       | 2500/10000 [24:37<1:13:55,  1.69it/s]

Ending epoch:  2500  with content loss:  76429.5078125  and style loss:  2.1819088459014893


 26%|██▌       | 2600/10000 [25:36<1:12:03,  1.71it/s]

Ending epoch:  2600  with content loss:  76404.2421875  and style loss:  2.171450138092041


 27%|██▋       | 2701/10000 [26:36<1:19:44,  1.53it/s]

Ending epoch:  2700  with content loss:  76373.2578125  and style loss:  2.1607820987701416


 28%|██▊       | 2800/10000 [27:34<1:10:52,  1.69it/s]

Ending epoch:  2800  with content loss:  76337.65625  and style loss:  2.15177059173584


 29%|██▉       | 2900/10000 [28:33<1:09:13,  1.71it/s]

Ending epoch:  2900  with content loss:  76310.5703125  and style loss:  2.14178729057312


 30%|███       | 3001/10000 [29:33<1:15:55,  1.54it/s]

Ending epoch:  3000  with content loss:  76278.1875  and style loss:  2.132821559906006


 31%|███       | 3101/10000 [30:32<1:14:46,  1.54it/s]

Ending epoch:  3100  with content loss:  76243.3046875  and style loss:  2.123976230621338


 32%|███▏      | 3201/10000 [31:31<1:13:42,  1.54it/s]

Ending epoch:  3200  with content loss:  76212.4375  and style loss:  2.115290641784668


 33%|███▎      | 3300/10000 [32:29<1:06:13,  1.69it/s]

Ending epoch:  3300  with content loss:  76186.078125  and style loss:  2.107769727706909


 34%|███▍      | 3400/10000 [33:28<1:04:53,  1.70it/s]

Ending epoch:  3400  with content loss:  76161.5078125  and style loss:  2.100367307662964


 35%|███▌      | 3500/10000 [34:27<1:03:56,  1.69it/s]

Ending epoch:  3500  with content loss:  76136.8828125  and style loss:  2.0928969383239746


 36%|███▌      | 3600/10000 [35:27<1:02:45,  1.70it/s]

Ending epoch:  3600  with content loss:  76107.8515625  and style loss:  2.085719585418701


 37%|███▋      | 3700/10000 [36:26<1:01:50,  1.70it/s]

Ending epoch:  3700  with content loss:  76085.1640625  and style loss:  2.0796351432800293


 38%|███▊      | 3800/10000 [37:25<1:01:05,  1.69it/s]

Ending epoch:  3800  with content loss:  76058.4921875  and style loss:  2.072112560272217


 39%|███▉      | 3900/10000 [38:24<58:14,  1.75it/s]  

Ending epoch:  3900  with content loss:  76037.234375  and style loss:  2.0658862590789795


 40%|████      | 4000/10000 [39:19<53:45,  1.86it/s]  

Ending epoch:  4000  with content loss:  76013.6796875  and style loss:  2.0596718788146973


 41%|████      | 4101/10000 [40:19<1:04:49,  1.52it/s]

Ending epoch:  4100  with content loss:  75994.7890625  and style loss:  2.0534729957580566


 42%|████▏     | 4201/10000 [41:17<1:01:43,  1.57it/s]

Ending epoch:  4200  with content loss:  75979.0546875  and style loss:  2.0478646755218506


 43%|████▎     | 4301/10000 [42:14<58:09,  1.63it/s]  

Ending epoch:  4300  with content loss:  75953.5390625  and style loss:  2.0424230098724365


 44%|████▍     | 4401/10000 [43:12<59:38,  1.56it/s]

Ending epoch:  4400  with content loss:  75931.484375  and style loss:  2.035874843597412


 45%|████▌     | 4501/10000 [44:10<56:07,  1.63it/s]

Ending epoch:  4500  with content loss:  75907.8515625  and style loss:  2.029507875442505


 46%|████▌     | 4600/10000 [45:08<54:23,  1.65it/s]

Ending epoch:  4600  with content loss:  75883.859375  and style loss:  2.024287700653076


 47%|████▋     | 4700/10000 [46:07<53:03,  1.66it/s]  

Ending epoch:  4700  with content loss:  75865.0859375  and style loss:  2.0188217163085938


 47%|████▋     | 4725/10000 [46:22<53:02,  1.66it/s]

In [None]:
# INFERENCE

style_model.load_state_dict(torch.load(SAVE_PATH + "model_weights.pth",map_location=device))

In [None]:
# INFERENCE

CONTENT_IMAGE_PATH = "./dragon.jpg"
STYLE_IMAGE_PATH = "./wave.jpg"

In [None]:
# INFERENCE

content_img = Image.open(CONTENT_IMAGE_PATH).convert('RGB')
# center crop image
content_img = ImageOps.fit(content_img,(min(content_img.size),min(content_img.size))).resize(IMAGE_SIZE)
content_img = ToTensor()(content_img).permute(0,2,1).to(device)

style_img = Image.open(STYLE_IMAGE_PATH).convert('RGB')
# center crop image
style_img = ImageOps.fit(style_img,(min(style_img.size),min(style_img.size))).resize(IMAGE_SIZE)
style_img = ToTensor()(style_img).permute(0,2,1).to(device)

In [None]:
# INFERENCE

content_style_image,_,_ = style_model(content_img.unsqueeze(0),style_img.unsqueeze(0)).squeeze(0).cpu()
content_style_image = ToPILImage()(content_style_image.permute(0,2,1))
content_style_image.show()