In [1]:
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image



In [2]:
model = models.vgg19(pretrained=True).features
class VGG (nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.select = ['0','5','10', '19', '28']
        self.vgg = model

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 356    

model = VGG().to(device).eval()    



In [6]:
loader = transforms.Compose(
    [
    transforms.Resize((image_size, image_size)),  # scale imported image
    transforms.ToTensor(),  # transform it into a torch tensor
    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])  # normalize it
    ]
)

original_img = load_image("images/fotoNatyyyo.jpg")
style_img = load_image("images/jardin-giverny-monet.jpg")

generated = original_img.clone().requires_grad_(True)


In [7]:
#Hyperparameters
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = torch.optim.Adam([generated], lr=learning_rate)


In [10]:
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_img)
    style_features = model(style_img)

    style_loss = original_loss = 0

    for gen_feature, orig_feature, style_feature in zip(generated_features, original_img_features, style_features):
        batch_size, channel, height, width = gen_feature.shape
        original_loss += torch.mean((gen_feature - orig_feature)**2)

        #compute gram matrix
        G = gen_feature.view(channel, height * width).mm(gen_feature.view(channel, height * width).t())
        A = style_feature.view(channel, height * width).mm(style_feature.view(channel, height * width).t())

        style_loss += torch.mean((G - A)**2)

    total_loss = alpha * original_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 5 == 0:
        print(total_loss)
        save_image(generated, "images/output"+str(step)+".jpg")
    print(step)    

tensor(14474.7402, grad_fn=<AddBackward0>)
0
1
2
3
4
tensor(14030.7773, grad_fn=<AddBackward0>)
5
6
7
8
9
tensor(13617.7148, grad_fn=<AddBackward0>)
10
11
12
13
14
tensor(13231.9473, grad_fn=<AddBackward0>)
15
16
17
18
19
tensor(12870.2568, grad_fn=<AddBackward0>)
20
21


KeyboardInterrupt: 