# Neural Style Transfer

Reference:
1. L. A. Gatys, A. S. Ecker and M. Bethge, "Image Style Transfer Using Convolutional Neural Networks," 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Las Vegas, NV, USA, 2016, pp. 2414-2423, doi: 10.1109/CVPR.2016.265. 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import numpy as np
from PIL import Image

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data

In [None]:
from matplotlib import pyplot as plt

# Load and preprocess images
def load_image(image_path, size=512):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)  # Add batch dimension

# De-normalize and convert tensor to image
def tensor_to_image(tensor):
    unnormalize = transforms.Normalize(
        mean=[-2.12, -2.04, -1.8],
        std=[4.37, 4.46, 4.44]
    )
    tensor = unnormalize(tensor.squeeze(0)).clamp(0, 1)
    return transforms.ToPILImage()(tensor)


def imshow(tensor, title=None):
    image = tensor.cpu().clone().detach()
    image = image.numpy().squeeze(0)
    image = image.transpose(1, 2, 0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    plt.imshow(image)
    plt.axis('off')
    if title is not None:
        plt.title(title)
    plt.show()

In [None]:
# Load images
style_img = load_image("datasets/style/Water Lily.jpeg").to(device)
content_img = load_image("datasets/photo/Mountain_1.jpg").to(device)
input_img = content_img.clone().to(device)
print("Content Image Shape:", content_img.shape)
print("Style Image Shape:", style_img.shape)


imshow(style_img)
imshow(content_img)


## Model

In [None]:
# Load pre-trained VGG model
cnn = models.vgg19(pretrained=True).to(device)

cnn

## Criterion

In [None]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, x):
        loss = nn.functional.mse_loss(x, self.target)
        return loss
    

class StyleLoss(nn.Module):
    def __init__(self, target, weights=None, loss=nn.MSELoss()):
        super(StyleLoss, self).__init__()
        self.target = []
        if weights is None:
            self.weights = torch.ones(len(target))
        else:
            self.weights = weights

        self.loss = loss

        for layer_name, feature in target.items():
            self.target.append(self.gram_matrix(feature).detach())

    def forward(self, x):
        loss = 0
        for feature, weight, target in zip(x, self.weights, self.target):
            G = self.gram_matrix(feature)
            loss += weight * self.loss(G, target)
        return loss

    def gram_matrix(self, x):
        b, c, h, w = x.size()
        features = x.view(b, c, h * w)
        G = torch.bmm(features, features.transpose(1, 2))
        return G.div(c * h * w)
    

def gram_matrix(x):
    b, c, h, w = x.size()
    features = x.view(b, c, h * w)
    G = torch.bmm(features, features.transpose(1, 2))
    return G.div(c * h * w)

In [None]:
# Load VGG model and extract required layers
def get_features(cnn, image, layers:dict=None):
    """ Get the feature representations

    Args:
        cnn (nn.Moudle): a Convolution Neural Network
        image (torch.tensor): input image
        layers (dict): dict of representation layres, layer_id -> layer_name

    """
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1',
                  '10': 'conv3_1',
                  '19': 'conv4_1',
                  '21': 'conv4_2',
                  '28': 'conv5_1'}
        
    features = {}

    x = image

    # name: int - 0, 1, ...
    # layer: - e.g.: Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    for id, layer in cnn.features._modules.items():
      x = layer(x)
      if id in layers:
          features[layers[id]] = x

    return features

In [None]:
style_layers = {'0': 'conv1_1',
                '5': 'conv2_1',
                '10': 'conv3_1',
                '19': 'conv4_1',
                '28': 'conv5_1'}

content_layers = {'21': 'conv4_2',}


style_feature = get_features(cnn, style_img, style_layers)
content_feature = get_features(cnn, content_img, content_layers)

s_criterion = StyleLoss(style_feature)
c_criterion = ContentLoss(content_feature['conv4_2'])

## Style Transfer

In [None]:
def run(cnn, c_criterion, s_criterion, input_img, num_steps=500, loss_weights=[1, 1e4], verbose=10):
    optimizer = optim.Adam([input_img.requires_grad_()], lr=0.1)

    for step in range(num_steps):
        optimizer.zero_grad()

        features = get_features(cnn, input_img, style_layers | content_layers)

        c_loss = c_criterion([features[layer] for _, layer in content_layers.items()][0])
        s_loss = s_criterion([features[layer] for _, layer in style_layers.items()])
        ### check the orders of layers

        loss = loss_weights[0] * c_loss + loss_weights[1] * s_loss

        loss.backward()

        optimizer.step()

        # input_img.data.clamp_(0, 1)

        if step % verbose == 0:
            print(f"Epoch {step}, style: {s_loss * loss_weights[1]}, content: {c_loss * loss_weights[0]}, loss: {loss:.4f}")
            imshow(input_img)

    return input_img

## Test

In [None]:
# Run style transfer
# input_img = torch.randn_like(content_img).to(device)
input_img = content_img.clone().to(device)

output = run(cnn, c_criterion, s_criterion, input_img, 10000)

# Save and show result
imshow(output)
output_image = tensor_to_image(output.cpu())
output_image.save("output.jpg")
output_image.show()

## Reconstruction of Style and Content

### Content

In [None]:
def reconstruct_content(cnn, content_img, layer_id, num_steps=1000):
    model = nn.Sequential()
    for i, layer in cnn.features._modules.items():
        model.add_module(i, layer)
        if id == layer_id:
            break
    target = model(content_img).detach()

    # Initialize random image for reconstruction
    # input_img = torch.randn_like(content_img).requires_grad_(True)
    input_img = torch.rand_like(content_img)
    input_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_img).requires_grad_(True)
    optimizer = optim.Adam([input_img], lr=0.2)

    # Optimize
    print("Reconstructing content...")
    for step in range(num_steps):
        input_img.data.clamp_(-3, 3)
        optimizer.zero_grad()
        content_features = model(input_img)
        loss = nn.functional.mse_loss(content_features, target)
        loss.backward()
        if step % 50 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")
            imshow(input_img)
        optimizer.step()
    
    return input_img
            

# Content Reconstruction
content_reconstruction = reconstruct_content(cnn, content_img, '21')
tensor_to_image(content_reconstruction.cpu()).save("content_reconstruction.jpg")

imshow(content_reconstruction)


### Style

In [None]:
def reconstruct_style(cnn, style_img, layer_id, num_steps=1000):
    model = nn.Sequential()
    for i, layer in cnn.features._modules.items():
        model.add_module(i, layer)
        if id == layer_id:
            break
    target = gram_matrix(model(style_img)).detach()

    # Initialize random image for reconstruction
    # input_img = torch.randn_like(style_img).requires_grad_(True)
    input_img = torch.rand_like(style_img)
    input_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_img).requires_grad_(True)
    optimizer = optim.Adam([input_img], lr=0.2)

    # Optimize
    print("Reconstructing style...")
    for step in range(num_steps):
        input_img.data.clamp_(-3, 3)
        optimizer.zero_grad()
        style_features = gram_matrix(model(input_img))
        loss = nn.functional.mse_loss(style_features, target) * 1e6
        loss.backward()
        if step % 50 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")
            imshow(input_img)
        optimizer.step()
    
    return input_img
            

# Style Reconstruction
for layer_id in style_layers:
    style_reconstruction = reconstruct_style(cnn, style_img, layer_id)
    tensor_to_image(style_reconstruction.cpu()).save(f"style_reconstruction_{layer_id}.jpg")

    imshow(style_reconstruction)