# 🏋️ Neural Style Transfer - Model Training Notebook

This notebook demonstrates how to **train a Transformer network** for fast neural style transfer.
Unlike the original optimization-based method (Gatys et al.), this approach trains a **feedforward model** for each style.

By the end of this notebook, you will:
- Understand the training pipeline
- Learn about content/style/TV losses
- Monitor training progress with logs
- Save your own style transfer model

## 📦 Setup & Imports

In [None]:
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

from models.definitions.transformer_net import TransformerNet
from models.definitions.perceptual_loss_net import PerceptualLossNet
import utils.utils as utils

## ⚙️ Configuration Setup

In [None]:
# Paths
base_dir = os.getcwd()
training_config = {
    'style_img_name': 'psychedelic.jpg',  # must exist in data/styles/
    'content_weight': 1e0,
    'style_weight': 2e5,
    'tv_weight': 0,
    'num_of_epochs': 2,
    'subset_size': None,  # Limit data for quick test (10K) / Use None to train on entire dataset
    'enable_tensorboard': True,
    'image_log_freq': 100,
    'console_log_freq': 50,
    'checkpoint_freq': 200,
    'image_size': 256,
    'batch_size': 4,
    'dataset_path': os.path.join(base_dir, 'data', 'dataset'),
    'style_images_path': os.path.join(base_dir, 'data', 'styles'),
    'model_binaries_path': os.path.join(base_dir, 'models', 'binaries'),
    'checkpoints_path': os.path.join(base_dir, 'models', 'checkpoints', 'cubism')
}

os.makedirs(training_config['model_binaries_path'], exist_ok=True)
os.makedirs(training_config['checkpoints_path'], exist_ok=True)

## 🧩 Interactive Config: Select Style and Parameters

In [None]:
import ipywidgets as widgets
from IPython.display import display

# Style Image Selector
dropdown_style = widgets.Dropdown(
    options=sorted(os.listdir(training_config['style_images_path'])),
    value=training_config['style_img_name'],
    description='Style Image:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

# Epochs input
slider_epochs = widgets.IntSlider(
    value=training_config['num_of_epochs'],
    min=1, max=10, step=1,
    description='Epochs:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

# Style weight input
style_weight_text = widgets.FloatText(
    value=training_config['style_weight'],
    description='Style Weight:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

# Subset size input (None means full dataset)
subset_text = widgets.Text(
    value=str(training_config['subset_size']) if training_config['subset_size'] is not None else '',
    description='Subset Size (leave blank for all):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

# Display all widgets together
display(dropdown_style, slider_epochs, style_weight_text, subset_text)

## 🧾 Final Training Configuration

In [None]:
# Update config values from user inputs
training_config['style_img_name'] = dropdown_style.value
training_config['num_of_epochs'] = slider_epochs.value
training_config['style_weight'] = style_weight_text.value

# Subset size logic
if subset_text.value.strip() == '':
    training_config['subset_size'] = None
else:
    try:
        training_config['subset_size'] = int(subset_text.value.strip())
    except ValueError:
        print("Invalid subset size. Using full dataset.")
        training_config['subset_size'] = None

# Display final training config before starting
print("🧾 Final Training Configuration:")
for k, v in training_config.items():
    print(f"{k:20s}: {v}")

## 🧠 Initialize Model & Style

In [None]:
# Device setup
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Networks
transformer_net = TransformerNet().train().to(device)
perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device)
optimizer = Adam(transformer_net.parameters())

# Load style image
style_path = os.path.join(training_config['style_images_path'], training_config['style_img_name'])
style_img = utils.prepare_img(
    style_path, target_shape=512, device=device, batch_size=training_config['batch_size']
)

# Get style features (Gram matrices)
style_features = perceptual_loss_net(style_img)
target_grams = [utils.gram_matrix(x) for x in style_features]

## 🖼️ Load Dataset

In [None]:
train_loader = utils.get_training_data_loader(training_config)
print("Data loaded:", len(train_loader), "batches")

## 🔁 Training Loop

In [None]:
writer = SummaryWriter()
utils.print_header(training_config)

acc_content, acc_style, acc_tv = 0, 0, 0
start_time = time.time()

for epoch in range(training_config['num_of_epochs']):
    for batch_id, (content_batch, _) in enumerate(train_loader):
        content_batch = content_batch.to(device)

        # Forward pass
        stylized_batch = transformer_net(content_batch)
        content_feats = perceptual_loss_net(content_batch)
        stylized_feats = perceptual_loss_net(stylized_batch)

        # Content loss (relu2_2)
        content_loss = training_config['content_weight'] * torch.nn.functional.mse_loss(
            stylized_feats.relu2_2, content_feats.relu2_2
        )

        # Style loss
        current_grams = [utils.gram_matrix(x) for x in stylized_feats]
        style_loss = 0
        for g1, g2 in zip(target_grams, current_grams):
            style_loss += torch.nn.functional.mse_loss(g1, g2)
        style_loss = style_loss * training_config['style_weight'] / len(target_grams)

        # Total variation (TV) loss
        tv_loss = training_config['tv_weight'] * utils.total_variation(stylized_batch)

        # Total loss
        total_loss = content_loss + style_loss + tv_loss

        # Backward
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Accumulate for logs
        acc_content += content_loss.item()
        acc_style += style_loss.item()
        acc_tv += tv_loss.item()

        global_step = epoch * len(train_loader) + batch_id

        # TensorBoard logs
        if training_config['enable_tensorboard']:
            writer.add_scalar("Loss/Content", content_loss.item(), global_step)
            writer.add_scalar("Loss/Style", style_loss.item(), global_step)
            writer.add_scalar("Loss/TV", tv_loss.item(), global_step)

            if batch_id % training_config['image_log_freq'] == 0:
                grid = make_grid(stylized_batch[:4].detach().cpu().clamp(0, 1))
                writer.add_image("Stylized", grid, global_step)

        # Console log
        if batch_id % training_config['console_log_freq'] == 0:
            elapsed = (time.time() - start_time) / 60
            print(f"Epoch {epoch+1}/{training_config['num_of_epochs']} | Batch {batch_id}/{len(train_loader)} | "
                  f"Elapsed: {elapsed:.2f} min\nContent: {acc_content:.4f} | Style: {acc_style:.4f} | TV: {acc_tv:.4f}")
            acc_content, acc_style, acc_tv = 0, 0, 0

        # Save checkpoints
        if training_config['checkpoint_freq'] and (batch_id+1) % training_config['checkpoint_freq'] == 0:
            ckpt = utils.get_training_metadata(training_config)
            ckpt['state_dict'] = transformer_net.state_dict()
            ckpt['optimizer_state'] = optimizer.state_dict()
            fname = f"ckpt_{epoch+1}_{batch_id+1}.pth"
            torch.save(ckpt, os.path.join(training_config['checkpoints_path'], fname))

## 💾 Save Final Model

In [None]:
final_model = utils.get_training_metadata(training_config)
final_model['state_dict'] = transformer_net.state_dict()
final_model['optimizer_state'] = optimizer.state_dict()
model_name = f"style_{training_config['style_img_name'].split('.')[0]}_final.pth"
torch.save(final_model, os.path.join(training_config['model_binaries_path'], model_name))
print(f"\n✅ Final model saved to: {model_name}")

## 🧪 What's Next?

- Use `stylization_script.py` or the demo notebook `Image_NST_Notebook` to apply your trained model
- Try with different style images