# Fast Neural Transfer Style

## implementation of [Johnson et al.](https://cs.stanford.edu/people/jcjohns/eccv16/)

## define style

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
def gram(x):
    batch, channel, width, height = x.size()
    # flatten features
    x = x.view(batch * channel, width * height)
    # gram matrix 
    G = torch.mm(x, x.t())
    # normalize
    return G.div(batch * channel * width * height)

In [6]:
# override loss as a module
class Gram(nn.Module):
    
    def __init__(self, target):
        super(Gram, self).__init__()
        self.target = gram(target)
        
    def forward(self, x):
        # Gram matrix is defined as the gram matrix of all vectors
        self.loss = F.mse_loss(self.target, gram(x))
        return x    