- Downloads the COCO `test2017.zip` dataset using `wget`.
- Creates a new directory named `dataset` to store extracted files.
- Unzips the downloaded file into the `./dataset` folder quietly.

In [None]:
!wget http://images.cocodataset.org/zips/test2017.zip
!mkdir './dataset'
!unzip -q ./test2017.zip -d './dataset'

- Creates a directory named `checkpoints` to store model files.
- Downloads `best_model.pth` from Dropbox quietly and saves it with the same name.
- Moves the downloaded model file into the `./checkpoints` directory.

In [None]:
!mkdir ./checkpoints
!wget -q -O 'best_model.pth' https://www.dropbox.com/s/7xvmmbn1bx94exz/best_model.pth?dl=1
!mv best_model.pth ./checkpoints

- Creates a directory named `content` for storing content images.
- Creates a directory named `style` for storing style images.

In [None]:
!mkdir ./content
!mkdir ./style

- Imports essential libraries for deep learning, image processing, and visualization.
- Defines a `seed_everything()` function to set random seeds for reproducibility.
- Sets device to GPU if available and defines ImageNet normalization `mean` and `std`.


In [None]:
import torch
from torch.autograd import Variable
from collections import namedtuple
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
import os
import sys
import random
from PIL import Image
import glob
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import cv2
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42) #for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mean and standard deviation used for training
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

In [None]:
class VGG16(torch.nn.Module):
  def __init__(self,requires_grad=False):
    super(VGG16,self).__init__()

    #load the pretrained VGG16 model(trained to understand natural images) from torchvision, use only the feature layers(not classifier)
    vgg_pretrained_features = models.vgg16(pretrained=True).features

    #define slices to extract features at specific layers
    self.slice1 = torch.nn.Sequential()
    self.slice2 = torch.nn.Sequential()
    self.slice3 = torch.nn.Sequential()
    self.slice4 = torch.nn.Sequential()

    #add layers 0-3
    for x in range(4):
      self.slice1.add_module(str(x),vgg_pretrained_features[x])

    #add layers 4-8
    for x in range(4,9):
      self.slice2.add_module(str(x),vgg_pretrained_features[x])

    #add layers 9-15 to slice3
    for x in range(9,16):
      self.slice3.add_module(str(x),vgg_pretrained_features[x])

    #add layers 16-22
    for x in range(16,23):
      self.slice4.add_module(str(x),vgg_pretrained_features[x])

    #if we don't want to train vgg16 (we're just using it for feature extraction) freeze the parameters
    if not requires_grad:
      for param in self.parameters():
        param.requires_grad = False

    def forward(self,X):
      #pass input through each slice and capture intermediate outputs
      h = self.slice1(X)
      h_relu1_2 = h
      h = self.slice2(h)
      h_relu2_2 = h
      h = self.slice3(h)
      h_relu3_3 = h
      h = self.slice4(h)
      h_relu4_3 = h

      #return the named tuple of feature maps at different layers
      vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3"])
      out = vgg_outputs(h_relu1_2,h_relu2_2,h_relu3_3,h_relu4_3)

      return out


In [None]:
class TransformerNet(torch.nn.Module):
  def __init__(self):
    super(TransformerNet,self).__init__()

    #define a sequential model for image transformation
    self.model = nn.Sequential(
        #initial convolution laters with large kernel to capture texture
        ConvBlock(3,32,kernel_size=9,stride=1),

        #downsampling: reduce resolution while increasng depth
        ConvBlock(32,64,kernel_size=3,stride=2),
        ConvBlock(64,128,kernel_size=3,stride=2),

        #five residual blocks for deeper representation while keeping input size
        ResidualBlock(128),
        ResidualBlock(128),
        ResidualBlock(128),
        ResidualBlock(128),
        ResidualBlock(128),

        #upsampling: increase resolution while reducing depth
        ConvBlock(128,64,kernel_size=3,upsample=True),
        ConvBlock(64,32,kernel_size=3,upsample=True),

        #final convolution to return to 3 channels
        ConvBlock(32,3,kernel_size=9,stride=1,normalize=False,relu=False),


    )

    def forward(self,x):
      return self.model(x)



In [None]:
class ResidualBlock(torch.nn.Module):
  def __init__(self,channels):
    super(ResidualBlock,self).__init__()

    #residual connection: output = input + block(input)
    self.block = nn.Sequential(
        ConvBlock(channels,channels,kernel_size=3,stride=1,normalize=True,relu=True),
        ConvBlock(channels,channels,kernel_size=3,stride=1,normalize=True,relu=False)
    )

  def forward(self,x):
    return self.block(x) + x

In [None]:
class ConvBlock(torch.nn.Module):
  def __init__(self,in_channels,out_channels,kernel_size,stride=1,upsample=False,normalize=True,relu=True):
    super(ConvBlock,self).__init__()

    self.upsample = upsample

    #block consists of reflection padding and convolution
    self.block = nn.Sequential(
        nn.ReflectionPad3d(kernel_size//2),
        nn.Conv2d(in_channels,out_channels,kernel_size,stride)
    )

    self.norm = nn.InstanceNorm2d(out_channels,affine = True) if normalize else None
    self.relu = relu

  def forward(self,x):
    if self.upsample:
      x = F.interpolate(x,scale_factor = 2)

      #apply padding + convolution
      s = self.block(x)

      #apply normalization is enabled
      if self.norm is not None:
        x = self.norm(x)

      #apply Relu is enabled
      if self.relu:
        x = F.relu(x)

      return x

gram_matrix → Calculates style features

train_transform, style_transform, test_transform → Preprocess input images

denormalize, deprocess → Convert model output back to viewable images



In [None]:
def gram_matrix(y):
  '''this function calculates the gram matrix of a feature map y
  in style transfer we compare style images using the correlation between feature maps not pixel values
  the gram matrix does that
  '''
  (b,c,h,w)= y.size()
  features = y.view(b,c,w*h) #reshapes image so each feature map is flattered
  features_t = features.transpose(1,2) #swaps dimension to prepare for matrix multiplication
  gram = features.bmm(features_t) / (c * h*w) #batch matrix multiplication, computing correlations
  return gram

def train_transform(image_size):
  transform = transforms.Compose([
    transforms.Resize((int(image_size * 1.15), int(image_size * 1.15))),
    transforms.RandomCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

def style_transform(image_size=None):
  resize = [transforms.Resize((image_size, image_size))] if image_size else []
  transform = transforms.Compose(resize + [transforms.ToTensor(), transforms.Normalize(mean, std)])

def test_transform(image_size=None):
  resize = [transforms.Resize(image_size)] if image_size else []
  transform = transforms.Compose(resize + [transforms.ToTensor(), transforms.Normalize(mean, std)])

def denormalize(tensors):
  for c in range(3):
    tensors[:,c].mul_(std[c]).add_(mean[c])
  return tensors

def deprocess(image_tensor):
  image_tensor = denormalize(image_tensor)[0]
  image_tensor *= 255
  image_np = np.clip(image_tensor.numpy(), 0, 255).astype('uint8')
  image_np = image_np.transpose(1,2,0)
  return image_np


During training:

Apply train_transform to content image

Apply style_transform to style image

Feed both through networks

Use gram_matrix to compute style loss

Backpropagate & update TransformerNet

During testing/inference:

Apply test_transform to input image

Stylize using trained TransformerNet

Use deprocess to convert back to image

Show or save result

In [None]:
def fast_trainer(
    style_image,           # Path to the style image
    style_name,            # Name for saving outputs/checkpoints
    dataset_path,          # Path to training dataset
    image_size=256,        # Size of content images
    style_size=448,        # Size of style image
    batch_size=8,          # Number of images per training batch
    lr=1e-5,               # Learning rate
    epochs=1,              # Number of training epochs
    checkpoint_model=None, # Path to resume training from a saved model
    checkpoint_interval=200, # Save model every N batches
    sample_interval=200,     # Save sample output every N batches
    lambda_style=10e10,      # Weight for style loss
    labda_content=10e5       # Weight for content loss (typo here: should be `lambda_content`)
):
    # Create output directories for saving training progress and models
    os.makedirs(f"./images/outputs/{style_name}-training", exist_ok=True)
    os.makedirs(f"./checkpoints", exist_ok=True)

    # Load dataset with transforms (resizing, normalizing etc.)
    train_dataset = datasets.ImageFolder(dataset_path, train_transform(image_size))
    dataloader = DataLoader(train_dataset, batch_size=batch_size)

    # Initialize the style transfer model (TransformerNet) and feature extractor (VGG16)
    transformer = TransformerNet().to(device)
    vgg = VGG16(requires_grad=False).to(device)  # Frozen VGG16 for perceptual loss

    # Optionally load model weights if resuming training
    if checkpoint_model:
        transformer.load_state_dict(torch.load(checkpoint_model))

    # Set up optimizer and loss function
    optimizer = Adam(transformer.parameters(), lr)
    l2_loss = torch.nn.MSELoss().to(device)

    # Preprocess the style image and replicate across batch size
    style = style_transform(style_size)(Image.open(style_image))
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    # Extract features from the style image using VGG and compute Gram matrices
    features_style = vgg(style)
    gram_style = [gram_matrix(y) for y in features_style]

    # Sample 8 content images to periodically visualize progress
    image_samples = []
    for path in random.sample(glob.glob(f"{dataset_path}/*/*.jpg"), 8):
        image_samples += [style_transform(image_size)(Image.open(path))]
    image_samples = torch.stack(image_samples)

    # Function to save stylized sample outputs during training
    def save_sample(batches_done):
        transformer.eval()
        with torch.no_grad():
            output = transformer(image_samples.to(device))
        image_grid = denormalize(torch.cat((image_samples.cpu(), output.cpu()), 2))
        save_image(image_grid, f"./images/outputs/{style_name}-training/{batches_done}.jpg", nrow=4)
        transformer.train()

    # Begin training loop
    for epoch in range(epochs):
        # Track losses per epoch and across all iterations
        train_metrics = {"content": [], "style": [], "total": []}
        epoch_metrics = {"content": [], "style": [], "total": []}

        for batch_i, (images, _) in enumerate(dataloader):
            optimizer.zero_grad()

            # Move input images to GPU
            images_original = images.to(device)

            # Generate stylized images
            images_transformed = transformer(images_original)

            # Extract VGG features for both original and stylized images
            features_original = vgg(images_original)
            features_transformed = vgg(images_transformed)

            # Compute content loss (MSE between relu2_2 features)
            content_loss = labda_content * l2_loss(
                features_transformed.relu2_2, features_original.relu2_2
            )

            # Compute style loss (sum of MSE between Gram matrices)
            style_loss = 0
            for ft_y, gm_s in zip(features_transformed, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += l2_loss(gm_y, gm_s[:images.size(0), :, :])
            style_loss *= lambda_style

            # Combine losses and backpropagate
            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # Record losses
            epoch_metrics["content"].append(content_loss.item())
            epoch_metrics["style"].append(style_loss.item())
            epoch_metrics["total"].append(total_loss.item())

            train_metrics["content"].append(content_loss.item())
            train_metrics["style"].append(style_loss.item())
            train_metrics["total"].append(total_loss.item())

            # Print training progress to terminal
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [Content: %.2f (%.2f) Style: %.2f (%.2f) Total: %.2f (%.2f)]"
                % (
                    epoch + 1,
                    epochs,
                    batch_i,
                    len(train_dataset),
                    content_loss.item(),
                    np.mean(epoch_metrics["content"]),
                    style_loss.item(),
                    np.mean(epoch_metrics["style"]),
                    total_loss.item(),
                    np.mean(epoch_metrics["total"]),
                )
            )

            # Save intermediate results and model checkpoints
            batches_done = epoch * len(dataloader) + batch_i + 1
            if batches_done % sample_interval == 0:
                save_sample(batches_done)

            if checkpoint_interval > 0 and batches_done % checkpoint_interval == 0:
                torch.save(transformer.state_dict(), f"./checkpoints/{style_name}_{batches_done}.pth")

            # Save latest model checkpoint
            torch.save(transformer.state_dict(), f"./checkpoints/last_checkpoint.pth")

    # Training complete — save final model
    print("Training Completed!")
    final_model_path = f"./checkpoints/{style_name}_final.pth"
    torch.save(transformer.state_dict(), final_model_path)
    print(f"Final model saved to {final_model_path}")

    # Plot training loss curves
    plt.plot(train_metrics["content"], label="Content Loss")
    plt.plot(train_metrics["style"], label="Style Loss")
    plt.plot(train_metrics["total"], label="Total Loss")
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    plt.show()


In [None]:
def test_image(image_path,checkpoint_model,save_path):
    os.makedirs(os.path.join(save_path,"results"), exist_ok=True)

    transform = test_transform()

    # Define model and load model checkpoint
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(checkpoint_model))
    transformer.eval()

    # Prepare input
    image_tensor = Variable(transform(Image.open(image_path))).to(device)
    image_tensor = image_tensor.unsqueeze(0)

    # Stylize image
    with torch.no_grad():
        stylized_image = denormalize(transformer(image_tensor)).cpu()
    # Save image
    fn = checkpoint_model.split('/')[-1].split('.')[0]
    save_image(stylized_image, os.path.join(save_path,f"results/{fn}-output.jpg"))
    print("Image Saved!")
    plt.imshow(cv2.cvtColor(cv2.imread(os.path.join(save_path,f"results/{fn}-output.jpg")), cv2.COLOR_BGR2RGB))

In [None]:
""" Run this to train the model """
#[NOTE]: For representation purpose i am using a smaller dataset. Pls use the dataset given at the start of this notebook
#for better results and change the dataset_path in this function.

fast_trainer(style_image='/content/style/style1.jpg',style_name = 'Picasso_Selfportrait',
             dataset_path='/content/dataset', epochs = 1)

In [None]:
test_image(image_path = '/content/content/Bagavati_IT_2ndYear.jpg',
           checkpoint_model = '/content/checkpoints/Picasso_Selfportrait_5000.pth',
           save_path = './')