# Neural Style Transfer

**Reference:** [Gatys, L. A., Ecker, A. S., & Bethge, M. (2016). Image Style Transfer Using Convolutional Neural Networks. CVPR.](https://doi.org/10.1109/CVPR.2016.265)

**Objective:** Implement the neuron style transfer algorithm described in the reference paper using Jittor deeplearning framework.

## Setup code
Some boilerplate code to set up our environment before getting started

In [2]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (10.0, 8.0)  # set default size of plots

## Load pre-trained model
We use VGG-16 to do the style transfer.

In [None]:
import jittor as jt
from jittor.models.vgg import *

# use the pre-trained version
model = vgg16(pretrained=True)
x = jt.rand(1, 3, 224, 224)
y = model(x)
print(y.shape)

Run below to get detailed model structure info.

In [None]:
# print vgg16 model structure
print(model.features)

## Neuron style transfer algorithm

**Target Feature Computation**

A custom `VGGFeatureExtractor` class is used to obtain feature maps from specific layers of the VGG network.  

By passing image through this extractor, we obtain the **content features** and **style features**.

In [None]:
from neural_style_transfer import VGGFeatureExtractor

# Define the default layers to extract content and style features:
# - Content features are extracted from layer relu4_2 (index 20) to capture high-level image structure
# - Style features are extracted from layers relu1_1, relu2_1, relu3_1, relu4_1, relu5_1
#   (indices 1, 6, 11, 18, 25) to capture multi-scale texture and style information
content_layer = 20 
style_layers = [1, 6, 11, 18, 25] 

# Initialize feature extractor with specific layers
extractor = VGGFeatureExtractor(model, content_layer, style_layers)

**Image Synthesis Initialization**

The synthesis image is initialized as **a clone of the content image**.  
It is set to require gradients so that it can be iteratively optimized using backpropagation.

**Optimization Loop**  

We aim to minimize the total loss which equals **content loss** computed by `compute_content_loss`, **style loss** computed by `compute_style_loss` and **tv loss** computed by `compute_tv_loss`.

In neural style transfer, the content loss and style loss measure how well the generated image preserves the original content and matches the target style, respectively. In addition to these two objectives, we also include a total variation (TV) loss, which acts as a regularizer that encourages smoothness in the generated image by reducing unnecessary noise and artifacts.

**Post-processing and Visualization**

Denormalize the final image to convert it back to the standard RGB range.  

Display the final stylized image and plot loss curves to monitor convergence.  


In [None]:
import jittor as jt
from jittor import optim
import matplotlib.pyplot as plt
from neural_style_transfer import (
    preprocess, deprocess,
    compute_content_loss, compute_style_loss, compute_tv_loss,
)

def style_transfer(
    content_image_path,
    style_image_path,
    saved_path=None,       #  do not save if None
    content_weight=1e0,
    style_weight=1e7,
    tv_weight=1e-6,
    learning_rate=0.3,
    num_steps=500,
    max_width=512,
):
    """
    Perform neural style transfer from style_image onto content_image.
    
    Args:
        content_image_path (str): Path to content image.
        style_image_path (str): Path to style image.
        saved_path (str or None): Path to save the final image. If None, do not save.
        content_weight (float): Weight for content loss.
        style_weight (float): Weight for style loss.
        tv_weight (float): Weight for total variation loss.
        learning_rate (float): Learning rate for optimizer.
        num_steps (int): Number of iterations.
        max_width (int): Resize content image to this width (maintaining aspect ratio).
    """

    # GPU auto detect
    jt.flags.use_cuda = 1 if jt.has_cuda else 0

    # ---------------------------
    # Load images
    # ---------------------------
    content_image = preprocess(content_image_path, max_width)
    style_image = preprocess(style_image_path, max_width)
    
    # Extract features
    target_content_rep, _ = extractor(content_image)
    _, target_style_rep = extractor(style_image)

    # Initialize synthesis image
    image_synthesis = content_image.clone().stop_grad()
    image_synthesis.requires_grad = True

    optimizer = optim.Adam([image_synthesis], lr=learning_rate)

    for step in range(num_steps):
        synth_content_rep, synth_style_rep = extractor(image_synthesis)

        # Compute losses
        content_loss = compute_content_loss(synth_content_rep, target_content_rep)
        style_loss = compute_style_loss(synth_style_rep, target_style_rep)
        tv_loss = compute_tv_loss(image_synthesis)

        total_loss = content_weight * content_loss + style_weight * style_loss + tv_weight * tv_loss

        # Backpropagation
        optimizer.step(total_loss)

        # 打印日志
        if step % 50 == 0 or step == num_steps - 1:
            print(f"Step {step}: Total Loss {total_loss.item():.4f}, "
                  f"Content {content_loss.item() * content_weight:.4f}, "
                  f"Style {style_loss.item() * style_weight:.4f}, "
                  f"TV {tv_loss.item() * tv_weight:.4f}")

    final_image = deprocess(image_synthesis)
    plt.imshow(final_image)
    plt.axis("off")
    plt.show()

    # ---------------------------
    # auto save final image（optional）
    # ---------------------------
    if saved_path is not None:
        final_image.save(saved_path)
        print(f"Final image saved to {saved_path}")



## Generate some pictures

In [None]:
parameters = {
    'content_image_path': 'images/tubingen.jpg',
    'style_image_path': 'styles/starry_night.jpg',
    'saved_path': None,
    'content_weight': 1e0, 
    'style_weight': 1e7,
    'tv_weight': 1e-5,
    'learning_rate': 0.1,
    'num_steps': 500,
    'max_width': 512
}

style_transfer(**parameters)