# Fast Neural Transfer Style

## implementation of [Johnson et al.](https://cs.stanford.edu/people/jcjohns/eccv16/)

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable

from PIL import Image
import scipy.misc

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## define style and content loss

In [3]:
def gram(x):
    batch, channel, width, height = x.size()
    # flatten features
    x = x.view(batch * channel, width * height)
    # gram matrix 
    G = torch.mm(x, x.t())
    # normalize
    return G.div(batch * channel * width * height)

In [4]:
# override loss as a module
class StyleLoss(nn.Module):
    
    def __init__(self, target):
        super(StyleLoss, self).__init__()
        self.target = gram(target)
        
    def forward(self, x):
        # Gram matrix is defined as the gram matrix of all vectors
        self.loss = F.mse_loss(self.target, gram(x))
        return x    

In [5]:
class ContentLoss(nn.Module):
    
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target
        
    def forward(self, x):
        self.loss = F.mse_loss(self.target, x)
        return x

## define loss network

In [14]:
imsize = 256

loader = transforms.Compose([
             transforms.Resize((imsize, imsize)),
             transforms.ToTensor()
         ])

unloader = transforms.ToPILImage()

def image_loader(image_name):
    image = Image.open(image_name)
    image = Variable(loader(image))
    image = image.unsqueeze(0)
    return image
  
def save_image(input, path):
    image = input.data.clone().cpu()
    image = image.view(3, imsize, imsize)
    image = unloader(image)
    scipy.misc.imsave(path, image)

In [19]:
class LossNet():
    
    def __init__(self, x, content, style, style_weight=100000, content_weight=1):
        # content image
        self.content = content
        # style image
        self.style = style
        # final output to be saved
        self.x = self.content.clone()
        # where to insert loss
        self.content_layer = ['conv4']
        self.style_layer = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        # weight to update style and content, should be huge on style, little on content
        self.style_weight = style_weight
        self.content_weight = content_weight
        # pre-trained net
        self.vgg = models.vgg19(pretrained=True)
        self.loss = nn.MSELoss().to(device)
        self.optimizer = optim.LBFGS([self.x])
        
    def train():
        