In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm


In [12]:
vgg_19=models.vgg19(pretrained=True).features.eval()
vgg_19



Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [13]:
class ModifiedVGG(nn.Module):
    def __init__(self):
        super(ModifiedVGG,self).__init__()
        
        self.chosen_features = ['0','5','10','19','28']
        self.model = vgg_19[:29]
    
    def forward(self,x):
        features = []
        for layer_number, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_number) in self.chosen_features:
                features.append(x)
        return features

In [14]:
model = ModifiedVGG()

In [15]:

loader=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

def load_image(image_path):
    image=Image.open(image_path).convert('RGB')
    image=loader(image).unsqueeze(0)
    return image


In [16]:
original_image=load_image('images/test/anna_hathaway.jpg')
style_image=load_image('images/styles/acrylic_style.jpg')

In [17]:
generated=original_image.clone().requires_grad_(True)

In [18]:
total_steps=2000
learning_rate=0.0003
alpha = 1
beta = 0.01

In [None]:
from torchvision.utils import save_image

optimizer=optim.Adam([generated],lr=learning_rate)

for step in tqdm(range(total_steps), desc="Training"):
    generated_features=model(generated)
    original_image_features=model(original_image)
    style_image_features=model(style_image)
    
    style_loss=0
    original_loss=0
    
    for gen_feature, orig_feature, style_feature in zip(generated_features, original_image_features, style_image_features):
        
        batch_size,channel,height,width=gen_feature.shape
        original_loss+=torch.mean((gen_feature - orig_feature)**2)
        
        #compute the gram matrix
        G=gen_feature.view(channel, height*width).mm(gen_feature.view(channel, height*width).t()) # mm = matrix multiplication
        A=style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())
        
        style_loss+=torch.mean((G - A)**2) 
    total_loss= alpha*original_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step % 200 ==0:
        tqdm.write("Total loss at step {}: {}".format(step, total_loss.item()))
        save_image(generated,"results/generated.png")
