In [3]:
from torchvision import transforms
from torchvision import models
from PIL import Image

import matplotlib.pyplot as plt

import torch.nn.functional as F 
import torch.optim as optim
import torch.nn as nn
import torch

device = torch.device("mps")
cnn = models.vgg19(pretrained=True).features.to(device).eval()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /Users/bahk_insung/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [4]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target

    def forward(self, x):
        self.loss = F.mse_loss(x, self.target)
        return input

In [5]:
def gram_matrix(x):
    a, b, c, d = x.size()
    features = x.view(a * b, c * d)
    G = torch.mm(features, features.t())
    return G.div(a * b, c * d)

In [7]:
class StyleLoss(nn.Module):
    def __init__(self, target_features):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_features).detach()

    def forward(self, x):
        G = gram_matrix(x)
        self.loss = F.mse_loss(G, self.target)
        return x

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = mean.view(-1, 1, 1)
        self.std  = std.view(-1, 1, 1)
        
    def forward(self, img):
        return (img - self.mean) / self.std

In [8]:
def getStyleModelAndLosses(cnn, styleImg, contentImg):
    content_layers = ['conv_4']
    style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
    normalization_mean = torch.tensor([0.485, 0.224, 0.225]).to(device)
    normalization_std  = torch.tensor([0.229, 0.224, 0.225]).to(device)
    normalization      = Normalization(normalization_mean, normalization_std).to(device)
    contentLoss, styleLoss = [], []
    
    model = nn.Sequential(normalization)
    layerIndex = 0

    for layer in cnn.children():
        # Check the layer is available 
        if isinstance(layer, nn.Conv2d):
            layerIndex += 1
            name = 'conv_{}'.format(layerIndex)

        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(layerIndex)
            layer = nn.ReLU(inplace=False)

        elif isinstance(layer, nn.MaxPool2d):
            name = 'maxpool_{}'.format(layerIndex)
        
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(layerIndex)
        
        # Excpetion case
        else:
            raise RuntimeError('Unreconginzed layer : {}'.format(layer.__class__.__name__))
    
        model.add_module(name, layer)
        if name in content_layers:
            target = model(contentImg)
            contentLossItem = ContentLoss(target)
            model.add_module("content_loss_{}".format(layerIndex), contentLossItem)
            contentLoss.append(contentLossItem)

        if name in style_layers:
            targetFeature = model(styleImg)
            styleLossItem = StyleLoss(targetFeature)
            model.add_module("style_loss_{}".format(layerIndex), styleLossItem)
            styleLoss.append(styleLossItem)

        for i in range(len(model) -1, -1, -1):
            if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
                break
        
        model = model[:(i + 1)]
        return model, styleLoss, contentLoss


In [None]:
def runStyleTransfer(cnn, contentImg, styleImg, num_steps=300, style_weight=100000, content_weight=1):
    inputImg = contentImg.clone().detach().require_grad_(True)
    model, styleLosses, contentLosses = getStyleModelAndLosses(cnn, styleImg, contentImg)
    optimizer = optim.LBFGS()