In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models

import copy
from tqdm import tqdm

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5', 'conv_6']

In [27]:
def gram_matrix(input):
    a, b, c, d = input.size()  # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)

    features = input.view(a * b, c * d)  # resise F_XL into \hat F_XL

    G = torch.mm(features, features.t())  # compute the gram product

    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return G.div(a * b * c * d).view(a,-1)

In [28]:
class StylePasser(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.main = layer
        
    def forward(self, styles, input):
        return styles, self.main(input)

In [29]:
class StyleRep(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.main = nn.Linear(input_size, 64)

    def forward(self, styles, input):
        style = self.main(gram_matrix(input))
        styles.append(style)
        return styles, input

In [30]:
# create a module to normalize input image so we can easily put it in a
# nn.Sequential
class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        # .view the mean and std to make them [C x 1 x 1] so that they can
        # directly work with image Tensor of shape [B x C x H x W].
        # B is batch size. C is number of channels. H is height and W is width.
        self.mean = mean.clone().detach().view(-1, 1, 1)
        self.std = std.clone().detach().view(-1, 1, 1)

    def forward(self, img):
        # normalize img
        return (img - self.mean) / self.std

In [37]:
def get_style_rep(style_layers=style_layers_default):
    cnn = models.vgg11(pretrained=True).features.eval()
    cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
    cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])

    # normalization module
    normalization = Normalization(cnn_normalization_mean, cnn_normalization_std).to(device)

    # just in order to have an iterable access to or list of content/syle
    # losses
    style_reps = []

    # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
    # to put in modules that are supposed to be activated sequentially
    model = nn.Sequential(normalization)

    i = 0  # increment every time we see a conv
    for layer in cnn.children():
        out_channels = None
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
            out_channels = layer.out_channels
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            # The in-place version doesn't play very nicely with the ContentLoss
            # and StyleLoss we insert below. So we replace with out-of-place
            # ones here.
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
            # Replace with Average Pool
            layer = nn.AvgPool2d(2)
            
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, StylePasser(layer))

        if name in style_layers:
            # add style loss:
            style_rep = StyleRep(out_channels * out_channels)
            model.add_module("style_rep_{}".format(i), style_rep)
            style_reps.append(style_rep)

    # now we trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], StyleRep):
            break

    model = model[:(i + 1)]

    return model

In [38]:
class StyleModel(nn.Module):
    def __init__(self, style_layers_default):
        super().__init__()
        self.main = get_style_rep(style_layers_default)
        
        self.classifier = nn.Sequential(
            nn.Linear(64 * len(style_layers_default), 10),
            nn.Softmax()
        )
        
    def forward(self, input):
        styles, out = self.main([], input)
        
        styles = torch.cat(styles, dim=1)
        
        return self.classifier(styles)
        
    
    

In [39]:
model = StyleModel(style_layers_default)

In [40]:
model

StyleModel(
  (main): Sequential(
    (0): Normalization()
    (conv_1): StylePasser(
      (main): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (style_rep_1): StyleRep(
      (main): Linear(in_features=4096, out_features=64, bias=True)
    )
    (relu_1): StylePasser(
      (main): ReLU()
    )
    (pool_1): StylePasser(
      (main): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (conv_2): StylePasser(
      (main): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (style_rep_2): StyleRep(
      (main): Linear(in_features=16384, out_features=64, bias=True)
    )
    (relu_2): StylePasser(
      (main): ReLU()
    )
    (pool_2): StylePasser(
      (main): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (conv_3): StylePasser(
      (main): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (style_rep_3): StyleRep(
      (main): Linear(in_features=65536, out_features=64, bias=True)
    