In [2]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def get_iamge(path, image_transform, size = (300, 300)):
    image = Image.open(path)
    image = image.resize(size, Image.LANCZOS)
    image = image_transform(image).unsqueeze(0) # add batch dimension
    return image.to(device)

def get_gram(m): #m is of shape (C, H, W)
    """
    m is of shape (i, C, H, W) where a is the batch size
    """
    bs, c, h, w = m.size()
    m = m.view(c, h * w)
    m = torch.mm(m, m.t()) #metric multiplication
    return m

def denormalize_img(img):
    img = img.transpose(1, 2, 0) # (C, H, W) -> (H, W, C)
    mean = np.array((0.485, 0.456, 0.406))
    std = np.array((0.229, 0.224, 0.225))
    img = std * img + mean #denormalize image
    img = np.clip(img, 0, 1) # values should be between 0 and 1
    return img

In [10]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.selected_layers = [3,8,15,22] # layers to extract features from, choose layers after relu activation. Relu adds non-linearity to the feature maps, it is import for detecting patterns in the image
        self.vgg = models.vgg16(pretrained=True).features #pretrained VGG16 model and extract features from this model
        

    def forward(self, x):
        layer_features = []
        for layer_num, layer in self.vgg._modules.items():
            x = layer(x)
            if int(layer_num) in self.selected_layers:
                layer_features.append(x)
        return layer_features


In [11]:
img_transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

content_img = get_iamge('dancing.jpg', img_transform, size=(300, 300))
style_img = get_iamge('picasso.jpg', img_transform, size=(300, 300))

generated_img = content_img.clone().requires_grad_(True) # clone content image and set requires_grad to True, it should be learnable parameter

optimizer = torch.optim.Adam([generated_img], lr=0.003, betas=[0.5, 0.999]) 
encoder = FeatureExtractor().to(device)

for p in encoder.parameters(): # freeze the encoder parameters
    p.requires_grad = False


In [12]:
content_weight =1
style_weight = 100

for epoch in range(500):
    
    content_features = encoder(content_img)
    style_features = encoder(style_img)
    generated_features = encoder(generated_img)

    # content loss -> we are going to do this at the last layer only
    content_loss = torch.mean((content_features[-1] - generated_features[-1])**2)

    # for each of layers we are going to compute the gram matrix and add the style loss -> Computed at all of the selected layers
    style_loss = 0
    for gf, sf in zip(generated_features, style_features):
        bs, c, h, w = gf.size()
        gram_gf = get_gram(gf)
        gram_sf = get_gram(sf)
        style_loss += torch.mean((gram_gf - gram_sf)**2) / (c * h * w) # normalize the loss by dividing by the total number of elements in the feature map
    
    loss = content_weight * content_loss + style_loss * style_weight
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print('Epoch: {}, Content Loss: {:.4f}, Style Loss: {:.4f}'.format(epoch, content_loss.item(), style_loss.item()))

Epoch: 0, Content Loss: 0.0000, Style Loss: 160.9563
Epoch: 10, Content Loss: 0.6700, Style Loss: 113.7369
Epoch: 20, Content Loss: 0.8613, Style Loss: 89.9734
Epoch: 30, Content Loss: 0.9257, Style Loss: 73.6981
Epoch: 40, Content Loss: 0.9661, Style Loss: 60.4524
Epoch: 50, Content Loss: 1.0006, Style Loss: 49.3137
Epoch: 60, Content Loss: 1.0256, Style Loss: 40.2370
Epoch: 70, Content Loss: 1.0449, Style Loss: 33.0206
Epoch: 80, Content Loss: 1.0592, Style Loss: 27.3152
Epoch: 90, Content Loss: 1.0716, Style Loss: 22.7654
Epoch: 100, Content Loss: 1.0831, Style Loss: 19.0817
Epoch: 110, Content Loss: 1.0942, Style Loss: 16.0543
Epoch: 120, Content Loss: 1.1046, Style Loss: 13.5466
Epoch: 130, Content Loss: 1.1147, Style Loss: 11.4585
Epoch: 140, Content Loss: 1.1239, Style Loss: 9.7144
Epoch: 150, Content Loss: 1.1331, Style Loss: 8.2568
Epoch: 160, Content Loss: 1.1421, Style Loss: 7.0427
Epoch: 170, Content Loss: 1.1518, Style Loss: 6.0320
Epoch: 180, Content Loss: 1.1613, Style L

KeyboardInterrupt: 

In [None]:
inp = generated_img.detach().cpu().squeeze()
denorm_img = denormalize_img(inp)
plt.imshow(denorm_img)
