In [2]:
import time
import os

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import torchvision
from torchvision import transforms

from PIL import Image
from collections import OrderedDict

# VGG architecture

In [3]:
class VGG(nn.Module):

  def __init__(self): 
    #construct VGG layers
    #1st convolutional layer
    self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
    self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

    #2nd convolutional layer
    self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
    self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

    #3rd convolutional layer
    self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 
    self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

    #4th convolutional layer
    self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
    self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    
    #5th convolutional layer
    self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
    self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

    #use average pooling (gives better results according to the paper)
    self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
    self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
    self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
    self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
    self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)


    def forward(self, x, output_params):
      #Forward propagation
      #store weights in a ditcionary
      ouput_weights = {}

      #Block 1
      output_weights['conv_layer1_1']  = F.relu(self.conv1_1(x))
      output_weights['conv_layer1_2']  = F.relu(self.conv1_2(output_weights['conv_layer1_1']))
      output_weights['pooling_layer1'] = self.pool1(output_weights['conv_layer1_2'])

      #Block 2
      output_weights['conv_layer2_1']  = F.relu(self.conv2_1(output_weights['pooling_layer1']))
      output_weights['conv_layer2_2']  = F.relu(self.conv2_2(output_weights['conv_layer2_1']))
      output_weights['pooling_layer2'] = self.pool2(output_weights['conv_layer2_2'])

      #Block 3
      output_weights['conv_layer3_1']  = F.relu(self.conv3_1(output_weights['pooling_layer2']))
      output_weights['conv_layer3_2']  = F.relu(self.conv3_2(output_weights['conv_layer3_1']))
      output_weights['conv_layer3_3']  = F.relu(self.conv3_3(output_weights['conv_layer3_2']))
      output_weights['conv_layer3_4']  = F.relu(self.conv3_4(output_weights['conv_layer3_3']))
      output_weights['pooling_layer3'] = self.pool3(output_weights['conv_layer3_4'])

      #Block 4
      output_weights['conv_layer4_1']  = F.relu(self.conv4_1(output_weights['pooling_layer3']))
      output_weights['conv_layer4_2']  = F.relu(self.conv4_2(output_weights['conv_layer4_1']))
      output_weights['conv_layer4_3']  = F.relu(self.conv4_3(output_weights['conv_layer4_2']))
      output_weights['conv_layer4_4']  = F.relu(self.conv4_4(output_weights['conv_layer4_3']))
      output_weights['pooling_layer4'] = self.pool4(output_weights['conv_layer4_4'])

      #Block 5
      output_weights['conv_layer5_1']  = F.relu(self.conv4_1(output_weights['pooling_layer4']))
      output_weights['conv_layer5_2']  = F.relu(self.conv4_2(output_weights['conv_layer5_1']))
      output_weights['conv_layer5_3']  = F.relu(self.conv4_3(output_weights['conv_layer5_2']))
      output_weights['conv_layer5_4']  = F.relu(self.conv4_4(output_weights['conv_layer5_3']))
      output_weights['pooling_layer5'] = self.pool5(output_weights['conv_layer5_4'])
      
      return [out[param] for param in output_params] #Return a list of the specified parameter weights

# Define Gram matrix and loss

In [4]:
class GramMatrix(nn.Module):
  def forward(self, input):
    batch, channel, height, width = input.size()
    F = input.view(b, c, height*width) #flatten the matrix
    '''
    F: bxcx(h*w); F.transpose(1,2): bx(h*w)xc 
    Perform batch matrix multiplication using only the last 2
    dimensions
    '''
    gram_matrix = torch.bmm(F, F.transpose(1,2))  
    return gram_matrix.div_(h*w)


#compute the style loss
class GramMatrixStyleLoss(nn.Module):
  def forward(self, input, target):
    #I have used the MSE loss here (as detailed in the 
    #original paper)
    return (nn.MSELoss()(GramMatrix()(input), target))