In [1]:
import torch
import torch.nn as nn
from torch.nn.functional import mse_loss
import torch.optim as optim
from torchvision import transforms
import torchvision.models as models
from torchvision.models import VGG19_Weights

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
torch.set_default_device(device)

### Content Loss

In [2]:
class ContentLoss(nn.Module):

    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()
        self.loss = None

    def forward(self, input):
        self.loss = mse_loss(input, self.target)
        return input

### Style Loss

In [None]:
class StyleLoss(nn.Module):

    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = self.gram_matrix(target_feature).detach()
        self.loss = None

    def gram_matrix(self, input):
        a, b, c, d = input.size()
        features = input.view(a * b, c * d) 
        G = torch.mm(features, features.t())
        return G.div(a * b * c * d)

    def forward(self, input):
        self.loss = mse_loss(self.gram_matrix(input), self.target)
        return input

### Normalization

In [None]:
class Normalization(nn.Module):

    def __init__(self):
        super(Normalization, self).__init__()

        cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
        cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])

        self.mean = torch.tensor(cnn_normalization_mean).view(-1, 1, 1)
        self.std = torch.tensor(cnn_normalization_std).view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

In [None]:
vgg19 = models.vgg19(weights=VGG19_Weights.DEFAULT).features.eval()

content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']