In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
def gram_matrix(features):
    (b, c, h, w) = features.size()
    features = features.view(b, c, h * w)
    gram = torch.bmm(features, features.transpose(1, 2))
    gram = gram / (c * h * w)
    return gram
# Load pre-trained VGG19
vgg = models.vgg19(weights='DEFAULT').features.to(device).eval()

Using device: cpu


In [2]:
class VGGFeatures(nn.Module):
    def __init__(self, layers):
        super(VGGFeatures, self).__init__()
        self.vgg = models.vgg19(weights='DEFAULT').features.to(device).eval()
        self.layers = layers

    def forward(self, x):
        features = []
        for idx, layer in enumerate(self.vgg):
            x = layer(x)
            if str(idx) in self.layers:
                features.append(x)
        return features

# Example layers from VGG19 (e.g., conv1_1=0, relu1_1=1, conv1_2=2, relu1_2=3, etc.)
layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1', '28': 'conv5_1'}
layer_ids = ['0', '5', '10', '19', '28']

feature_extractor = VGGFeatures(layer_ids)

In [3]:
# Load images
# Define transformation for input images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
style_image = Image.open('pictures/style/tar58.png')
stylized_image = Image.open('pictures/results/result58.png')

style = transform(style_image).unsqueeze(0).to(device)
stylized = transform(stylized_image).unsqueeze(0).to(device)

In [4]:
style_features = feature_extractor(style)
stylized_features = feature_extractor(stylized)

original_grams = [gram_matrix(feat) for feat in style_features]
stylized_grams = [gram_matrix(feat) for feat in stylized_features]


In [5]:
mse_loss = nn.MSELoss()

gram_loss = 0
for g_orig, g_stylized in zip(original_grams, stylized_grams):
    gram_loss += mse_loss(g_stylized, g_orig)

print(f"Gram Loss: {gram_loss.item():.6f}")


Gram Loss: 0.000184
