In [None]:
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib
from matplotlib import pyplot as plt
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import models, transforms
import os

In [None]:
def downsamplingBlock(in_features, out_features):
    block = nn.Sequential(nn.Conv2d(in_features, out_features, 3, padding = 1),
                         nn.ReLU(),
                         nn.BatchNorm2d(out_features),
                         nn.Conv2d(out_features, out_features, 3, padding = 1),
                         nn.ReLU(),
                         nn.BatchNorm2d(out_features))
    return block
def upsamplingBlock(in_features, mid_features, out_features):
    block = nn.Sequential(nn.Conv2d(in_features, mid_features, 3, padding = 1), 
                         nn.ReLU(), 
                         nn.BatchNorm2d(mid_features),
                         nn.Conv2d(mid_features, mid_features, 3, padding = 1), 
                         nn.ReLU(),
                         nn.BatchNorm2d(mid_features),
                         nn.ConvTranspose2d(mid_features, out_features, 3, padding = 1, stride = 2, output_padding = 1))
    return block
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.pool = nn.MaxPool2d(2)
        self.block1 = downsamplingBlock(3, 32)
        self.block2 = downsamplingBlock(32, 64)
        self.block3 = downsamplingBlock(64, 128)
        self.bottom = upsamplingBlock(128, 256, 128)
        self.block4 = upsamplingBlock(256, 128, 64)
        self.block5 = upsamplingBlock(128, 64, 32)
        self.final = nn.Sequential(nn.Conv2d(64, 32, 3, padding = 1),
                                   nn.ReLU(),
                                   nn.BatchNorm2d(32),
                                   nn.Conv2d(32, 32, 3, padding = 1),
                                   nn.ReLU(),
                                   nn.BatchNorm2d(32),
                                   nn.Conv2d(32, 3, 3, padding = 1))
    def forward(self, x):
        out1 = self.block1(x)
        out2 = self.pool(out1)
        out2 = self.block2(out2)
        out3 = self.pool(out2)
        out3 = self.block3(out3)
        bottom1 = self.pool(out3)
        bottom1 = self.bottom(bottom1)
        concat1 = torch.cat((out3, bottom1), 1)
        concat1 = self.block4(concat1)
        concat2 = torch.cat((out2, concat1), 1)
        concat2 = self.block5(concat2)
        concat3 = torch.cat((out1, concat2), 1)
        output = self.final(concat3)
        return output

Algorithms are taken from "A Neural Algorithm of Artistic Style", Gatys et al., https://arxiv.org/abs/1508.06576

In [None]:
class PretrainedContentNetwork(nn.Module):
    def __init__(self): 
        super(PretrainedContentNetwork, self).__init__()
        self.sota = models.vgg19(pretrained = True)
        for el in self.sota.parameters():
            el.requires_grad = False
        for i in ([4, 9, 18, 27, 36]):
            self.sota.features[i] = nn.AvgPool2d(2, 2)
    def forward(self, x):
        for i in range(12):
            x = self.sota.features[i](x)
        x = x.view(-1)
        return x
    
class PretrainedStyleNetwork(nn.Module):
    def __init__(self): 
        super(PretrainedStyleNetwork, self).__init__()
        self.sota = models.vgg19(pretrained = True)
        for el in self.sota.parameters():
            el.requires_grad = False
        for i in ([4, 9, 18, 27, 36]):
            self.sota.features[i] = nn.AvgPool2d(2, 2)
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for i in range(2):
            self.slice1.add_module(str(i), self.sota.features[i])
        for i in range(2, 7):
            self.slice2.add_module(str(i), self.sota.features[i])
        for i in range(7, 12):
            self.slice3.add_module(str(i), self.sota.features[i])
        for i in range(12, 21):
            self.slice4.add_module(str(i), self.sota.features[i])
        for i in range(21, 30):
            self.slice5.add_module(str(i), self.sota.features[i])
        del(self.sota)
    def forward(self, x):
        self.batch_size = x.size()[0]
        x = self.slice1(x)
        features1 = x.view(self.batch_size, x.size()[1], -1)
        gram1 = torch.matmul(features1, torch.transpose(features1, 1, 2)) / features1.size()[1]
        x = self.slice2(x)
        features2 = x.view(self.batch_size, x.size()[1], -1)
        gram2 = torch.matmul(features2, torch.transpose(features2, 1, 2)) / features2.size()[1]
        x = self.slice3(x)
        features3 = x.view(self.batch_size, x.size()[1], -1)
        gram3 = torch.matmul(features3, torch.transpose(features3, 1, 2)) / features3.size()[1]
        x = self.slice4(x)
        features4 = x.view(self.batch_size, x.size()[1], -1)
        gram4 = torch.matmul(features4, torch.transpose(features4, 1, 2)) / features4.size()[1]
        x = self.slice5(x)
        features5 = x.view(self.batch_size, x.size()[1], -1)
        gram5 = torch.matmul(features5, torch.transpose(features5, 1, 2)) / features5.size()[1]
        return torch.cat((gram1.view(self.batch_size, -1),
                         gram2.view(self.batch_size, -1),
                         gram3.view(self.batch_size, -1),
                         gram4.view(self.batch_size, -1),
                         gram5.view(self.batch_size, -1)), 1)

In [None]:
image_transform = transforms.Compose([transforms.ToTensor()])
cocodir = "coco2017/train2017"
filelist = os.listdir(cocodir)
class DataLoader:
    def __init__(self):
        return
    def __getitem__(self, idx):
        im = Image.open(cocodir + filelist[idx]).convert("RGB")
        im_shape = np.array(im).shape[:2]
        proportion = im_shape[1] / im_shape[0]
        if (min(im_shape) == im_shape[0]):
            im = im.resize((int(proportion * 600), 600), Image.LANCZOS)
            left_border = im.size[1] / 2 - 300  
            im = im.crop((left_border, 0, left_border + 600, 600))
        else:
            im = im.resize((600, int(600 / proportion)), Image.LANCZOS)
            top_border = im.size[0] / 2 - 300
            im = im.crop((0, top_border, 600, top_border + 600))
        im = image_transform(im)
        return im
    def __len__(self):
        return len(filelist)

In [None]:
if (torch.cuda.is_available()):
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

In [None]:
style_image = Image.open(f"StylizationImage.jpg")
plt.imshow(style_image)
style_image = transforms.ToTensor()(style_image).to(device)

In [None]:
pretrainedContentNet = PretrainedContentNetwork()
pretrainedStyleNet = PretrainedStyleNetwork()
unet = Unet()
pretrainedStyleNet.to(device)
pretrainedContentNet.to(device)
unet.to(device)
pretrainedContentNet.eval()
pretrainedStyleNet.eval()
batch_size = 2
loader = torch.utils.data.DataLoader(DataLoader(), batch_size = batch_size)
style_criterion = nn.MSELoss(reduction = 'mean')
content_criterion = nn.MSELoss(reduction = 'mean')
optimizer = optim.Adam(unet.parameters(), lr = 0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor = 0.25, patience = 2000, verbose=True)

In [None]:
style_gram = pretrainedStyleNet(style_image.view(1, 3, 600, 600))
for i, data in enumerate(loader):
    optimizer.zero_grad()
    images = data.to(device)
    unetOutput = unet(data.to(device))
    content_source_preds = pretrainedContentNet(images)
    content_generated_preds = pretrainedContentNet(unetOutput)
    content_loss = content_criterion(content_source_preds.view(batch_size, -1),
                                     content_generated_preds.view(batch_size, -1))
    generated_grams = pretrainedStyleNet(unetOutput)
    style_loss = style_criterion(generated_grams, style_gram)
    
    borderloss = ((unetOutput > 1) * (unetOutput - 1) ** 2) + ((unetOutput < 0) * (-unetOutput) ** 2)
    loss = 0*style_loss + content_loss + torch.sum(borderloss)
    loss.backward()
    optimizer.step()
    scheduler.step(loss)
    if (i % 1000 == 0):
        minimal = torch.min(unetOutput)
        maximal = torch.max(unetOutput)
        saveim = (unetOutput[0].view(3, 600, 600) - minimal) / (maximal - minimal)
        print('\n', i // 1000 + 1, style_loss, content_loss)
        print(minimal, maximal)
        matplotlib.image.imsave(f'source_image{i // 1000}.png', data[0].cpu().detach().numpy().transpose(1, 2, 0))
        matplotlib.image.imsave(f'generated_image{i // 1000}.png', saveim.cpu().detach().numpy().transpose(1, 2, 0))
        np.save(f'numpy_image{i // 1000}.npy', saveim.cpu().detach().numpy().transpose(1, 2, 0))