# Neural Style Transfer

In this project we trained a custom Neural Style Transfer (NST) model that can take a realisic content image and apply a style (painting).

Dependencies:

In [1]:
import torch
import torch.nn as nn
import torch.optim as optimization
import numpy as np

import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.utils import save_image

from tqdm.notebook import tqdm
from PIL import Image
from pathlib import Path

We'll crop the original images (content and style) to be at the same dimenstions. We'll define a crop function that crops 512x512 pixels from the center of each image:

In [2]:
def crop_it(img):
    input_path = Path(img)
    image = Image.open(img)
    width, height = image.size
    left = (width-512)/2
    right = left+512
    top = (height-512)/2
    bottom = top+512
    cropped_img = image.crop((left, top, right, bottom))
    modified_path = input_path.with_name(f"{input_path.stem}_crop{input_path.suffix}")
    cropped_img.save(modified_path)
    return cropped_img.show

Let's apply it on each image:

In [3]:
crop_it('Path/to/style.jpg')
crop_it('Path/to/content.jpg')

<bound method Image.show of <PIL.Image.Image image mode=RGB size=512x512 at 0x1B7056462E0>>

## Loss functions
### Content loss function

We'll implement a function to calculate the content loss which is the squared error loss between the two feature vectors of the content image and the target image.

In [4]:
def get_content_loss(target_vec, content_vec):
  return torch.mean((target_vec-content_vec)**2)

### Style loss function

We'll implement a function to calculate the style loss by using Gram matrix. The total loss is the sum of every mean-squared distance (between two gram matrices of the style and the target images) for every layer times the weighted factor (the influence factor) of every layer.

In [5]:
def gram_matrix(input, c, h, w):
  #c-channels; h-height; w-width 
  input = input.view(c, h*w) 
  #matrix multiplication on its own transposed form
  G = torch.mm(input, input.t())
  return G
  
def get_style_loss(target, style):
  _, c, h, w = target.size()
  G = gram_matrix(target, c, h, w) #gram matrix for the target image
  S = gram_matrix(style, c, h, w) #gram matrix for the style image
  return torch.mean((G-S)**2)/(c*h*w)

### The model (loading a pretrained VGG19 and modifying it)
We will use only 5 layers from the model (conv layers) just for feature extraction. We remove other layers used for classification.

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()
    self.select_features = ['0', '5', '10', '19', '28'] #conv layers
    self.vgg = models.vgg19(pretrained=True).features
  
  def forward(self, output):
    features = []
    for name, layer in self.vgg._modules.items():
      output = layer(output)
      if name in self.select_features:
        features.append(output)
    return features

#load the model
vgg = VGG().to(device).eval()

Load image function:

In [15]:

#preprocessing of the images
loader = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize image
    transforms.ToTensor(),          # Convert image to tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def random_noise(device='cuda'):
 height, width = content_img.shape[2:]
 output_img = np.random.randn(3, height, width) # (3, H, W)
 noise_tensor = torch.tensor(output_img, dtype=torch.float32).unsqueeze(0) # (1, 3, H, W)
 return noise_tensor

# Let's define a function to load images (style, content)
def load_img(path):
  img = Image.open(path)
  img = loader(img).unsqueeze(0)
  return img.to(device)

content_img = load_img('Path/to/content_crop.jpg')
style_img = load_img('Path/to/style_crop.jpg')
#we can start by copying the content image as a starting point
target_img = random_noise()
#target_img = torch.from_numpy(target_img)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
target_img = target_img.to(device)
img_size = 512 if torch.cuda.is_available() else 128
target_img = torch.randn_like(content_img, device=device, requires_grad=True)

#initial model
model = models.vgg19(pretrained=True).features
#the optimizer: We use Adam since it's generally more adequate
optimizer = optimization.Adam([target_img], lr=0.001)

alpha = 50 #content wight
beta = 50 #style weight

#define the load_img function first
content_img = load_img('Path/to/content_crop.jpg')
style_img = load_img('Path/to/style_crop.jpg')
#we can start from a random noise generated image or 
#just copy the content image as a starting point


In [16]:
print(target_img.device)

cuda:0


In [17]:
print(target_img.shape)

torch.Size([1, 3, 512, 512])


Loss calculation:

In [18]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision import transforms
from tqdm import tqdm

# Define denormalization (assuming you used these values for normalization)
denormalization = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))

# Function to save the target image
def save(target, i):
    img = target.clone().detach().squeeze()  # Remove gradients
    img = denormalization(img).clamp(0, 1)   # Apply denormalization and clamp values
    save_image(img, f'result_{i}.png')

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the target image (learnable)
target_img = torch.randn_like(content_img, device=device, requires_grad=True)

# Optimizer (Adam is smoother, LBFGS is used in the original NST paper)
optimizer = optim.Adam([target_img], lr=0.01)

# Compute content and style features once (no need to compute them every step)
content_feature = [f.detach() for f in vgg(content_img)]
style_feature = [f.detach() for f in vgg(style_img)]
#content_feature = vgg(content_img).detach()  # No need to track gradients
#style_feature = vgg(style_img).detach()

# Training loop
steps = 10000
for step in tqdm(range(steps)):
    # Compute target features
    target_feature = vgg(target_img)

    # Compute losses
    #content_loss = get_content_loss(target_feature, content_feature)
    content_loss = sum(get_content_loss(t, c) for t, c in zip(target_feature, content_feature))
    style_loss = sum(get_style_loss(t, c) for t, c in zip(target_feature, style_feature))


    #style_loss = get_style_loss(target_feature, style_feature)

    # Calculate total loss
    total_loss = alpha * content_loss + beta * style_loss

    # Zero out previous gradients
    optimizer.zero_grad()

    # Compute gradients
    total_loss.backward()

    # Update the target image
    optimizer.step()

    # Save progress every 500 steps
    if step % 500 == 0:
        save(target_img, step)
        print(f"Step {step}: Loss = {total_loss.item():.4f}")



  0%|          | 9/10000 [00:00<12:48, 13.00it/s]  

Step 0: Loss = 1947876.0000


  5%|▌         | 501/10000 [02:58<3:36:00,  1.36s/it]

Step 500: Loss = 20966.8379


 10%|█         | 1008/10000 [05:57<1:04:11,  2.33it/s]

Step 1000: Loss = 13274.2188


 15%|█▌        | 1501/10000 [08:21<2:36:36,  1.11s/it]

Step 1500: Loss = 8583.6123


 20%|██        | 2001/10000 [10:46<2:28:39,  1.12s/it]

Step 2000: Loss = 4453.6289


 25%|██▌       | 2501/10000 [13:11<2:18:58,  1.11s/it]

Step 2500: Loss = 2856.5024


 30%|███       | 3010/10000 [15:36<27:40,  4.21it/s]  

Step 3000: Loss = 2346.8604


 35%|███▌      | 3501/10000 [17:59<1:58:34,  1.09s/it]

Step 3500: Loss = 2053.1338


 40%|████      | 4001/10000 [20:23<1:49:29,  1.10s/it]

Step 4000: Loss = 1832.6255


 45%|████▌     | 4501/10000 [22:47<1:40:12,  1.09s/it]

Step 4500: Loss = 1651.9519


 50%|█████     | 5001/10000 [25:10<1:31:06,  1.09s/it]

Step 5000: Loss = 1498.0325


 55%|█████▌    | 5501/10000 [27:35<1:23:03,  1.11s/it]

Step 5500: Loss = 1361.8513


 60%|██████    | 6001/10000 [30:01<1:14:06,  1.11s/it]

Step 6000: Loss = 1240.9624


 65%|██████▌   | 6510/10000 [2:18:33<16:14,  3.58it/s]       

Step 6500: Loss = 1135.9574


 70%|███████   | 7010/10000 [2:21:26<14:08,  3.52it/s]  

Step 7000: Loss = 1049.2941


 73%|███████▎  | 7309/10000 [2:23:10<52:42,  1.18s/it]


KeyboardInterrupt: 