# Neural Style Transfer

## 1. Imports

In [None]:
import sys
import os

# Add the parent directory to path so that we can import the src module
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import vgg19, VGG19_Weights

from src.config import CFG
from src.style_transfer import StyleTransfer
from src.model import VGGFeatures
from src.utils import view_image, load_image, save_image, get_white_noise_image

## 2. Configuration

In [None]:
print(f"Using device: {CFG.device}")

CFG.content_image_path = '../data/content.jpg'
CFG.style_image_path = '../data/style.jpg'

CFG.num_iterations = 1000
CFG.print_interval = 100

## 3. Load and View Images

In [None]:
content_img = load_image(CFG.content_image_path, CFG)
style_img = load_image(CFG.style_image_path, CFG)

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
view_image(content_img, title="Content Image", cfg=CFG)

plt.subplot(1, 2, 2)
view_image(style_img, title="Style Image", cfg=CFG)

plt.tight_layout()
plt.show()

## 4. Initialize Style Transfer Model

In [None]:
style_transfer = StyleTransfer(
    content_path=CFG.content_image_path,
    style_path=CFG.style_image_path,
    cfg=CFG
)

## 5. Run Style Transfer and Display Results

In [None]:
output_img = style_transfer.run(iterations=1000)

In [None]:
style_transfer.display_results(output_img)

## 6. Save Results

In [None]:
output_path = '../outputs/stylized_output.jpg'
style_transfer.save_image(output_img, output_path)

## 7. Experiment with Different Parameters

In [None]:
CFG.style_weight = 1e7  

CFG.optimizer = 'lbfgs'

style_transfer_experiment = StyleTransfer(
    content_path=CFG.content_image_path,
    style_path=CFG.style_image_path,
    cfg=CFG
)

output_img_experiment = style_transfer_experiment.run(iterations=500)

style_transfer_experiment.display_results(output_img_experiment)

style_transfer_experiment.save_image(output_img_experiment, '../outputs/experimental_output.jpg')

## 8. Direct Implementation

In [None]:
# Implement style transfer directly in this notebook without using the StyleTransfer class
a = load_image(CFG.style_image_path, CFG)  
p = load_image(CFG.content_image_path, CFG) 
x = get_white_noise_image(CFG) 
x.requires_grad_(True) 

vgg = VGGFeatures(CFG.device)

content_layer = 'conv4_2'
style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
weights = [0.2, 0.2, 0.2, 0.2, 0.2]

alpha = 1.0  
beta = 1e6   

optimizer = optim.Adam([x], lr=1e-2)

num_iterations = 1000
for iter in range(num_iterations):
    style_loss = 0
    for i, layer in enumerate(style_layers):
        input_feat = vgg.compute_feat_map(x, layer)
        style_feat = vgg.compute_feat_map(a, layer).detach().clone()
        
        input_gram = vgg.compute_gram_matrix(input_feat)
        style_gram = vgg.compute_gram_matrix(style_feat)
        
        size = input_feat.size(0) * input_feat.size(1)
        style_loss += weights[i] * (F.mse_loss(input_gram, style_gram, reduction='sum') / (4 * size))
    
    input_content_feat = vgg.compute_feat_map(x, content_layer)
    content_feat = vgg.compute_feat_map(p, content_layer).detach().clone()
    content_loss = 0.5 * F.mse_loss(input_content_feat, content_feat, reduction='sum')
    
    loss = alpha * content_loss + beta * style_loss
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    with torch.inference_mode():
        x.clamp_(0, 1)
    
    if (iter + 1) % 100 == 0:
        print(f'Iteration: {iter + 1} | Loss: {loss.item():.4f}')

plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
view_image(x, title="Style-Transferred Image", cfg=CFG)
plt.subplot(1, 3, 2)
view_image(p, title="Content Image", cfg=CFG)
plt.subplot(1, 3, 3)
view_image(a, title="Style Image", cfg=CFG)
plt.tight_layout()
plt.show()