In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
from torchvision.utils import save_image

In [2]:
model = models.vgg19(pretrained=True).features

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [3]:
model

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [4]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        self.selected_layer = ['5', '10', '19', '28']
        
        self.model = models.vgg19(pretrained=True).features[:29]
    
    def forward(self, x):
        features = []
        
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.selected_layer:
                features.append(x)
            
        return features
                
        

In [5]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device=device)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
image_size = 224

loader = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

In [8]:
!wget "https://upload.wikimedia.org/wikipedia/commons/c/cd/Anne_Hathaway_at_MIFF_%28cropped%29.jpg" -O anna.jpg
!wget "https://pbs.twimg.com/media/DU5DVJ_WAAICIMg.jpg" -O style.jpg

--2021-06-14 18:12:37--  https://upload.wikimedia.org/wikipedia/commons/c/cd/Anne_Hathaway_at_MIFF_%28cropped%29.jpg
Resolving upload.wikimedia.org (upload.wikimedia.org)... 208.80.154.240, 2620:0:861:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|208.80.154.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115695 (1.1M) [image/jpeg]
Saving to: ‘anna.jpg’


2021-06-14 18:12:39 (923 KB/s) - ‘anna.jpg’ saved [1115695/1115695]

--2021-06-14 18:12:41--  https://pbs.twimg.com/media/DU5DVJ_WAAICIMg.jpg
Resolving pbs.twimg.com (pbs.twimg.com)... 184.31.10.237, 2600:1480:4000:e5::
Connecting to pbs.twimg.com (pbs.twimg.com)|184.31.10.237|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 98443 (96K) [image/jpeg]
Saving to: ‘style.jpg’


2021-06-14 18:12:42 (2.05 MB/s) - ‘style.jpg’ saved [98443/98443]



In [9]:
original_image = load_image("./anna.jpg")
style_image = load_image('./style.jpg')

#initial image
generated = original_image.clone().requires_grad_(True)

In [10]:
model = VGG().to(device=device).eval()

Hyperparameters

In [13]:
total_steps = 600
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr =learning_rate)


In [15]:
for step in range(total_steps):
    generated_features = model(generated)
    original_features = model(original_image)
    style_features = model(style_image)
    
    style_loss = original_loss = 0
    
    for gen_feat, orig_feat, style_feat in zip(generated_features, original_features, style_features):
        batch_size, channel, height, width = gen_feat.shape
        original_loss += torch.mean((gen_feat - orig_feat) ** 2)
        
        G = gen_feat.view(channel, height*width).mm(
            gen_feat.view(channel, height*width).t()
        )
        
        A = style_feat.view(channel, height*width).mm(
            style_feat.view(channel, height*width).t()
        )
        
        style_loss += torch.mean((G-A)**2)
    
    total_loss= alpha * original_loss + beta * style_loss
    optimizer.zero_grad()
    
    total_loss.backward()
    
    optimizer.step()
    
    if (step % 10 == 0):
        print(f"At step : {step} Total loss: {total_loss}") 
    
    if (step % 200 == 0):
        print(f"Total loss: {total_loss}")
        save_image(generated, f"{step}_generated.png")

At step : 0 Total loss: 2450.09912109375
Total loss: 2450.09912109375
At step : 10 Total loss: 2433.60009765625
At step : 20 Total loss: 2418.9912109375
At step : 30 Total loss: 2403.1142578125
At step : 40 Total loss: 2389.49951171875
At step : 50 Total loss: 2375.083740234375
At step : 60 Total loss: 2361.882568359375
At step : 70 Total loss: 2348.4208984375
At step : 80 Total loss: 2335.579833984375
At step : 90 Total loss: 2323.47119140625
At step : 100 Total loss: 2311.444580078125
At step : 110 Total loss: 2299.803955078125
At step : 120 Total loss: 2288.2373046875
At step : 130 Total loss: 2276.32568359375
At step : 140 Total loss: 2265.840087890625
At step : 150 Total loss: 2254.78076171875
At step : 160 Total loss: 2243.88134765625
At step : 170 Total loss: 2233.658447265625
At step : 180 Total loss: 2223.55419921875
At step : 190 Total loss: 2212.70263671875
At step : 200 Total loss: 2201.451416015625
Total loss: 2201.451416015625
At step : 210 Total loss: 2192.608154296875
A

![](././400_generated.png)