# Neural Style Transfer Demonstration

This notebook demonstrates the neural style transfer technique implemented in our project. We'll showcase:
1. Basic style transfer using the VGG19 CNN model
2. The effect of different style thresholds
3. Comparison with Vision Transformer (ViT) implementation
4. Analysis of loss functions during optimization

Let's begin by importing the necessary modules and setting up our environment.

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Add parent directory to path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# Import project modules
from config.config import *
from utils.image_utils import image_loader, save_image, display_images, get_file_paths
from utils.optimizer import run_optimization
from models.model_factory import get_model
from losses.content_loss import ContentLoss
from losses.style_loss import StyleLoss
from losses.tv_loss import TotalVariationLoss

# Ensure we're using GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Loading the Images

Let's load a content image and a style image to use for our demonstration.

In [None]:
# Get available content and style images
content_paths, style_paths = get_file_paths(CONTENT_DIR, STYLE_DIR)

# Print available images
print("Available content images:")
for i, path in enumerate(content_paths):
    print(f"[{i}] {os.path.basename(path)}")

print("\nAvailable style images:")
for i, path in enumerate(style_paths):
    print(f"[{i}] {os.path.basename(path)}")

In [None]:
# Select indices of content and style images
content_idx = 0  # Change this to select a different content image
style_idx = 0    # Change this to select a different style image

# Load the selected images
content_path = content_paths[content_idx]
style_path = style_paths[style_idx]

print(f"Selected content image: {os.path.basename(content_path)}")
print(f"Selected style image: {os.path.basename(style_path)}")

# Load images
content_img = image_loader(content_path, IMAGE_SIZE, device)
style_img = image_loader(style_path, IMAGE_SIZE, device)

# Display the original images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Convert tensors to numpy arrays for display
def tensor_to_np_img(tensor):
    img = tensor.cpu().clone().detach().numpy().squeeze(0)
    img = img.transpose(1, 2, 0)
    img = np.clip(img, 0, 1)
    return img

ax1.imshow(tensor_to_np_img(content_img))
ax1.set_title('Content Image')
ax1.axis('off')

ax2.imshow(tensor_to_np_img(style_img))
ax2.set_title('Style Image')
ax2.axis('off')

plt.tight_layout()

## 2. Basic Style Transfer with VGG19

Now, let's perform style transfer using the VGG19-based model with default parameters.

In [None]:
def perform_style_transfer(content_img, style_img, model_type='vgg', style_threshold=0.7,
                          num_steps=300, style_weight=1e6, content_weight=1, tv_weight=1,
                          content_layers=None, style_layers=None):
    """Function to perform neural style transfer."""
    # Use default layers if not specified
    if content_layers is None:
        content_layers = CONTENT_LAYERS
    if style_layers is None:
        style_layers = STYLE_LAYERS
    
    # Initialize model
    model = get_model(model_type, content_layers, style_layers, device)
    
    # Create a white noise image as the starting point
    input_img = content_img.clone().requires_grad_(True)
    
    # Extract features
    content_features, _ = model(content_img)
    _, style_features = model(style_img)
    
    # Create loss modules
    content_losses = [ContentLoss(f.detach()).to(device) for f in content_features]
    
    style_losses = []
    for i, f in enumerate(style_features):
        # Layer weighting
        layer_weight = (i + 1)**2
        style_loss = StyleLoss(f.detach(), layer_weight, style_threshold).to(device)
        style_losses.append(style_loss)
    
    # Add total variation loss
    tv_loss = TotalVariationLoss(weight=tv_weight).to(device)
    
    # Run optimization
    output_img, loss_history = run_optimization(
        model, input_img, content_losses, style_losses,
        tv_loss, num_steps, style_weight, content_weight, tv_weight, style_threshold
    )
    
    return output_img, loss_history

In [None]:
# Perform style transfer with VGG19
print("Running style transfer with VGG19...")
output_img_vgg, loss_history_vgg = perform_style_transfer(
    content_img, style_img,
    model_type='vgg',
    style_threshold=0.7,
    num_steps=300,  # Using fewer steps for demonstration
    style_weight=1e6,
    content_weight=1,
    tv_weight=1
)

# Display the resulting image alongside originals
fig = display_images(content_img, style_img, output_img_vgg)
plt.title("VGG19-based Style Transfer")
plt.show()

## 3. Analyzing the Effect of Style Threshold

Let's explore how different style thresholds affect the output. The style threshold controls how much of the style is applied to the content image.

In [None]:
# Test different style thresholds
thresholds = [0.2, 0.4, 0.6, 0.8, 1.0]
outputs = []

for threshold in thresholds:
    print(f"Running style transfer with threshold = {threshold}...")
    output_img, _ = perform_style_transfer(
        content_img, style_img, 
        style_threshold=threshold,
        num_steps=200  # Fewer steps for faster demonstration
    )
    outputs.append(output_img)

# Display all results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Show content and style images
axes[0, 0].imshow(tensor_to_np_img(content_img))
axes[0, 0].set_title('Content Image')
axes[0, 0].axis('off')

axes[0, 1].imshow(tensor_to_np_img(style_img))
axes[0, 1].set_title('Style Image')
axes[0, 1].axis('off')

# Hide one subplot if there are only 5 thresholds
axes[0, 2].axis('off')

# Show results with different thresholds
for i, (threshold, output) in enumerate(zip(thresholds, outputs)):
    row = 1 if i >= 3 else 0
    col = i % 3 if i < 3 else i - 3
    axes[row, col].imshow(tensor_to_np_img(output))
    axes[row, col].set_title(f'Threshold = {threshold}')
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

## 4. Using Vision Transformer (ViT) [Bonus]

Let's try style transfer using a Vision Transformer model instead of the CNN.

In [None]:
try:
    # Check if transformers library is installed
    import transformers
    print("Running style transfer with Vision Transformer...")
    output_img_vit, loss_history_vit = perform_style_transfer(
        content_img, style_img,
        model_type='vit',
        style_threshold=0.7,
        num_steps=300,  # Using fewer steps for demonstration
        style_weight=1e5,  # Different weights may work better for ViT
        content_weight=10,
        tv_weight=1
    )
    
    # Compare VGG and ViT results
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    
    axes[0, 0].imshow(tensor_to_np_img(content_img))
    axes[0, 0].set_title('Content Image')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(tensor_to_np_img(style_img))
    axes[0, 1].set_title('Style Image')
    axes[0, 1].axis('off')
    
    axes[1, 0].imshow(tensor_to_np_img(output_img_vgg))
    axes[1, 0].set_title('VGG19 Result')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(tensor_to_np_img(output_img_vit))
    axes[1, 1].set_title('Vision Transformer Result')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
except ImportError:
    print("The 'transformers' library is not installed. Install it with:")
    print("pip install transformers")

## 5. Loss Analysis

Let's analyze the loss values during the optimization process to understand the convergence behavior.

In [None]:
# Plot the loss history from our VGG model run
plt.figure(figsize=(10, 6))
steps = range(1, len(loss_history_vgg['total'])+1)

plt.plot(steps, loss_history_vgg['content'], label='Content Loss')
plt.plot(steps, loss_history_vgg['style'], label='Style Loss')
plt.plot(steps, loss_history_vgg['tv'], label='Total Variation Loss')
plt.plot(steps, loss_history_vgg['total'], label='Total Loss')

plt.xlabel('Optimization Steps')
plt.ylabel('Loss Value')
plt.title('Loss History during Style Transfer Optimization')
plt.legend()
plt.yscale('log')  # Log scale to see small changes
plt.grid(True, which="both", ls="--", c='.75')
plt.show()

## 6. Saving the Results

Let's save our generated images to the results directory.

In [None]:
# Save the VGG result
content_name = os.path.basename(content_path).split('.')[0]
style_name = os.path.basename(style_path).split('.')[0]
output_path = os.path.join(RESULTS_DIR, f"{content_name}_{style_name}_vgg.jpg")
save_image(output_img_vgg, output_path)
print(f"VGG result saved to {output_path}")

# Save the comparison figure
fig = display_images(content_img, style_img, output_img_vgg)
comparison_path = os.path.join(RESULTS_DIR, f"{content_name}_{style_name}_comparison.jpg")
plt.savefig(comparison_path)
print(f"Comparison image saved to {comparison_path}")