In [3]:
import os
os.chdir("../")


In [4]:
import torch
import torch.nn as nn
from tsalib import dim_vars
from torchvision.models import vgg19
from torchvision import transforms

from PIL import Image

from style_transfer import losses

B, C, H, W = dim_vars("Batch(B) Channel(C) Height(H) Width(W)", exists_ok=True)

In [5]:
feature_extractor: nn.Sequential = vgg19(pretrained=True).features.eval()

In [6]:
for layer in feature_extractor.named_children():
    print(layer)

('0', Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('1', ReLU(inplace=True))
('2', Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('3', ReLU(inplace=True))
('4', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
('5', Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('6', ReLU(inplace=True))
('7', Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('8', ReLU(inplace=True))
('9', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
('10', Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('11', ReLU(inplace=True))
('12', Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('13', ReLU(inplace=True))
('14', Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('15', ReLU(inplace=True))
('16', Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('17', ReLU(inplace=True))
('18', MaxPool2d(kernel_si

In [7]:
# preprocessing transform
preprocess = transforms.Compose([
    transforms.Resize(512),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.40760392, 0.45795686, 0.48501961],  # subtracting the mean of imagenet, which VGG net is trained on
        std=[1, 1, 1]
    )
])

# postprocessing transform
postprocess = transforms.Compose([
    transforms.Normalize(
        mean=[-0.40760392, -0.45795686, -0.48501961],  # adding back the mean subtracted before
        std=[1, 1, 1]
    ),
    transforms.Lambda(lambda x: torch.clamp(x, min=0, max=1)),
    transforms.ToPILImage()
])

In [31]:
model: nn.Sequential = nn.Sequential()
conv_count: int = 0

content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

last_layer_count = max(len(content_layers), len(style_layers))

content_image = preprocess(Image.open("images/content.jpg"))
style_image = preprocess(Image.open("images/style.jpg"))
content_image = content_image.unsqueeze(0)
style_image = style_image.unsqueeze(0)

iterator = iter(feature_extractor)

while conv_count < last_layer_count:
    layer = next(iterator)

    if isinstance(layer, nn.Conv2d):
        conv_count += 1
        name = f"conv_{conv_count}"
    elif isinstance(layer, nn.ReLU):
        name = f"relu_{conv_count}"
        layer = nn.ReLU(inplace=False)
    elif isinstance(layer, nn.MaxPool2d):
        name = f"pool_{conv_count}"
    elif isinstance(layer, nn.BatchNorm2d):
        name = f"batchnorm_{conv_count}"
    else:
        raise RuntimeError("Unknown layer: %s" % layer)
    
    model.add_module(name, layer)

    if name in content_layers:
        # getting feature map of content image generated upto this layer
        feature_map = model(content_image)

        # constructing content loss module
        model.add_module(f"content_loss_{conv_count}", losses.ContentLoss(feature_map))
    
    if name in style_layers:
        # getting feature map of content image generated upto this layer
        feature_map = model(content_image)

        # constructing content loss module
        model.add_module(f"style_loss_{conv_count}", losses.StyleLoss(feature_map))

    

In [32]:
for layer in model.named_children():
    print(layer)

('conv_1', Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('style_loss_1', StyleLoss())
('relu_1', ReLU())
('conv_2', Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('style_loss_2', StyleLoss())
('relu_2', ReLU())
('pool_2', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
('conv_3', Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('style_loss_3', StyleLoss())
('relu_3', ReLU())
('conv_4', Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('content_loss_4', ContentLoss())
('style_loss_4', StyleLoss())
('relu_4', ReLU())
('pool_4', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
('conv_5', Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
('style_loss_5', StyleLoss())
