In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image

In [None]:
vgg = models.vgg19(pretrained=True).features

# freeze all VGG parameters since we're only optimizing the target image
for param in vgg.parameters():
    param.requires_grad_(False)

In [None]:
def load_image(img_path, max_size=400, shape=None):
    image = Image.open(img_path).convert('RGB')
    
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    
    if shape is not None:
        size = shape
    
    transform = transforms.Compose([transforms.Resize(size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])
    image = transform(image).unsqueeze(0)
    #print(image.shape)
    #Comment below if it fails
    #image = torch.tensor(image.numpy().transpose(0, 2, 3, 1))
    return image

In [None]:
content = load_image('images/white_dog.jpg')
style = load_image('images/painting.jpg', shape=content.shape[-2:])
print(content.shape)
print(style.shape)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
# content and style ims side-by-side
ax1.imshow(content.numpy().transpose(0, 2, 3, 1).squeeze())
ax2.imshow(style.numpy().transpose(0, 2, 3, 1).squeeze())

In [None]:
#Print the model structure
print(vgg)

In [None]:
for name, layer in vgg._modules.items():
    print("Name = {}, Layer = {}".format(name, layer))

In [None]:
def get_features(image, model, layers=None):
    
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1', 
                  '10': 'conv3_1', 
                  '19': 'conv4_1',
                  '21': 'conv4_2',  ## content representation
                  '28': 'conv5_1'}
    features = {}
    x = image
    
    for name, layer in vgg._modules.items():
        x = layer(x)
        
        if name in layers:
            features[layers[name]] = x
    
    return features

In [None]:
t = torch.tensor(np.array([[1, 2, 3], [4, 5, 6]]))
print("Shape = ",t.shape)
print("Size = ",t.size())

In [None]:
def gram_matrix(tensor):
    _, d, h, w = tensor.shape
    tensor = tensor.view(d, h*w)
    gram = torch.mm(tensor, tensor.t())
    
    return gram

In [None]:
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

style_grams = {}

for layer in style_features:
    style_grams[layer] = gram_matrix(style_features[layer])

In [None]:
target = content.clone().requires_grad_(True)

In [None]:
style_weights = {'conv1_1': 1.,
                 'conv2_1': 0.75,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2}

content_weight = 1  # alpha
style_weight = 1e6  # beta

In [None]:
print(content_features['conv4_2'].shape)

In [None]:
steps = 2000
optimizer = optim.Adam([target], lr=0.003)
show_every = 200

for iter in range(steps):
    
    target_features = get_features(target, vgg)
    
    #Content Loss
    content_loss = torch.mean((content_features['conv4_2'] - target_features['conv4_2'])**2)
    
    style_loss = 0
    for layer in style_weights:
        
        style_gram = style_grams[layer]
        target_gram = gram_matrix(target_features[layer])
        _, d, h, w = target_features[layer].shape
        
        #print("Style gram shape = ", style_gram.shape)
        #print("Target gram shape = ", target_gram.shape)
        style_loss += (style_weights[layer]*torch.mean((style_gram - target_gram)**2))/(d*h*w)
    
    total_loss = content_weight*content_loss + style_weight*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    print("Total Loss at Iteration {} = {}".format(iter, total_loss.item()))
    cloned_target = target.clone().detach()
    plt.imshow(cloned_target.numpy().transpose(0, 2, 3, 1).squeeze())
    plt.show()