In [2]:
import torch

import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import VGG19_Weights
from torchvision.utils import save_image
import time


class VGGFeatures(nn.Module):
    def __init__(self):
        super(VGGFeatures, self).__init__()
        self.model = models.vgg19(weights=VGG19_Weights.DEFAULT).features[:37]

    def forward(self, x):
        return self.model.forward(x)


class ContentToViews(nn.Module):
    def __init__(self):
        super(ContentToViews, self).__init__()
        self.vgg = VGGFeatures() # the output shape is (1, 512, 7, 7)
        self.other_layers = nn.Sequential(
            nn.Conv2d(512, 512, 3),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 256, 3),
            nn.MaxPool2d(2),
            nn.Linear(256 * 7 * 7, 4096),
            nn.Dropout(0.4),
            nn.ReLU(),
            nn.Linear(4096, 1024),
            nn.Dropout(0.4),
            nn.ReLU(),
            nn.Linear(1024, 512),
        )


    def forward(self, input):
        """
        The output should be a scalar value,
        The input should be the same as VGG input
        """
        out = self.vgg(input)
        out = out.view(out.size(0), -1)
        out = self.other_layers(out)
        return out