In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as tt
import torchvision.models as models
from torchvision.utils import save_image

In [10]:
model = models.vgg19(pretrained=True).features
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [3]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.chosen_features = set(['0', '5', '10', '19', '28'])
        self.model = models.vgg19(pretrained=True).features[:29]
    
    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            
            if str(layer_num) in self.chosen_features:
                features.append(x)
                
        return features

In [4]:
def load_image(path):
    image = Image.open(path)
    image = transform(image).unsqueeze(0)
    return image.to(DEVICE)

In [5]:
    DEVICE = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    IMAGE_SIZE = 512
    TOTAL_STEPS = 300
    LEARNING_RATE = 1e-3
    ALPHA = 1
    BETA = 1e-2

In [6]:
transform = tt.Compose([
    tt.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    tt.ToTensor(),
    #tt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [7]:
original_img = load_image('corgi.jpg')
style_img = load_image('style_1.jpg')
generated_img = original_img.clone().requires_grad_(True)

model = VGG().to(DEVICE).eval()

In [8]:
optimizer = optim.Adam([generated_img], lr=LEARNING_RATE)

In [10]:
for step in range(TOTAL_STEPS):
    generated_features = model(generated_img)
    original_features = model(original_img)
    style_features = model(style_img)
    
    style_loss = original_loss = 0
    
    for gen_feature, orig_feature, style_feature in zip(
        generated_features, original_features, style_features
    ):
    
        batch_size, channel, height, width = gen_feature.shape
        original_loss += torch.mean((gen_feature - orig_feature) ** 2)

        G = gen_feature.view(channel, height * width).mm(
            gen_feature.view(channel, height * width).t()
        )
        
        A = style_feature.view(channel, height * width).mm(
            style_feature.view(channel, height * width).t()
        )
        
        style_loss += torch.mean((G - A) ** 2)
    
    total_loss = ALPHA * original_loss + BETA * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step % 50 == 0:
        print("Total loss: {loss:9.3f}".format(loss=total_loss.item()))
        save_image(generated_img, 'generated.png')

Total loss: 200268.531
Total loss: 171947.953
Total loss: 150253.484
Total loss: 133325.844
Total loss: 120067.844
Total loss: 109597.211
