In [None]:
#NST (Neural Style Transfer) - Its all about applying style of style Image to content Image
#
# Implementation:
#=============================================================================================================================
# 1. Load and preprocess image
# 2. Extract features (FeatureMaps) from both content image and style image (VGG19 etc)
# 3. Computing Gram Matrics for Style (Gram Matrics is responsible to capture stylistic Patterns)
#
#         A Gram Matrix capture correlation between feature maps of layers. (stylistic features)
#
#             G = F X F(transpose)
#
#         where,
#          F is the reshaped feature map of shape (channels,Height*Width)
#
# 4. Initialize Target Image (Initialize with Content image or Noise) --- In my example I initialized with Content image
#
# 5. Define loss function
#
#      Content Loss - MSE
#      Style Loss - MSE
#
#       Total Loss = alpha * (Content Loss) + beta * (Style Loss)
#
# Where,
#          alpha and beta are Hyperparameters (value range 0 to 1) ---- (Analogy: Similar to Learning Rate)
#
# 6. optimize the generated target image
#
# 7. Post Process image ( Convert image from tensor to np array to visualize or save image)

: 

In [None]:
#Applications of NST
# 1. Art and Design ----> Generating art work
# 2. Gaming Industry ---> Action figure camoflauging in Background
# 3. NFTs
# 4. Social Media Filters

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Helper function to load and preprocess images
def load_image(image_path, max_size=400, shape=None):
    image = Image.open(image_path).convert('RGB')

    # Define image loader transformation
    if shape is not None:
        # Ensure shape is passed as (height, width) for transforms.Resize
        loader = transforms.Compose([
            transforms.Resize((shape[1], shape[0])),  # (height, width)
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    else:
        # If no shape is specified, resize maintaining aspect ratio
        loader = transforms.Compose([
            transforms.Resize(max_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    image = loader(image).unsqueeze(0)  # Add batch dimension
    return image.to(device)


In [None]:
# Load content image
content = load_image("old_apartment.jpg")

# Load style image with matching dimensions (width, height)
style = load_image("lux_apartment.jpg", shape=(content.shape[-1], content.shape[-2]))

In [None]:
# Helper function to convert tensor to image
def im_convert(tensor):
    image = tensor.clone().detach().cpu().squeeze(0)
    image = image.numpy().transpose(1, 2, 0)
    image = image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]  # denormalize
    image = image.clip(0, 1)
    return image

In [None]:
# Display images
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(im_convert(content))
ax[0].set_title("Content Image")
ax[1].imshow(im_convert(style))
ax[1].set_title("Style Image")
plt.show()

In [None]:
# Define VGG network
class VGGFeatures(nn.Module):
    def __init__(self):
        super(VGGFeatures, self).__init__()
        self.selected_layers = ['0', '5', '10', '19', '28']  # Conv layers from VGG19
        self.vgg = models.vgg19(pretrained=True).features[:29]

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.selected_layers:
                features.append(x)
        return features

In [None]:

# Initialize VGG model
vgg = VGGFeatures().to(device).eval()

In [None]:
# Function to compute Gram matrix for style

# 3. Computing Gram Matrics for Style (Gram Matrics is responsible to capture stylistic Patterns)
#
#         A Gram Matrix capture correlation between feature maps of layers. (stylistic features)
#
#             G = F X F(transpose)
#
#         where,
#          F is the reshaped feature map of shape (channels,Height*Width)

def gram_matrix(tensor):
    _, n_filters, h, w = tensor.size()
    tensor = tensor.view(n_filters, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram


In [None]:
# Get style features (detached from graph)
style_features = vgg(style)
style_grams = [gram_matrix(feat).detach() for feat in style_features]

In [None]:
# Get content features (detached from graph)
content_features = [feat.detach() for feat in vgg(content)]

In [None]:
# Initialize target image to optimize (clone content)
target = content.clone().requires_grad_(True).to(device)

In [None]:
# Define weights for style layers
style_weights = [1e3 / n**2 for n in [64, 128, 256, 512, 512]]

In [None]:
# Optimizer
optimizer = optim.LBFGS([target])

In [None]:
# Style and Content weights
alpha = 1e5  # content weight
beta = 1e10  # style weight

In [None]:
# Optimization Loop
epochs = 1000
run = [0]

while run[0] <= epochs:
    def closure():
        optimizer.zero_grad()

        target_features = vgg(target)

        # Compute content loss
        content_loss = torch.mean((target_features[2] - content_features[2])**2)

        # Compute style loss
        style_loss = 0
        for t_feat, s_gram, weight in zip(target_features, style_grams, style_weights):
            t_gram = gram_matrix(t_feat)
            style_loss += weight * torch.mean((t_gram - s_gram)**2)

        # Total loss
        total_loss = alpha * content_loss + beta * style_loss

        # Backward pass (no retain_graph needed)
        total_loss.backward()

        run[0] += 1
        if run[0] % 50 == 0:
            print(f"Epoch {run[0]}, Total Loss: {total_loss.item():.2f}")

        return total_loss

    optimizer.step(closure)

In [None]:
# Display final stylized image
plt.figure(figsize=(8, 8))
plt.imshow(im_convert(target))
plt.title("Stylized Image")
plt.show()