# AdaIN Style Transfer

In [23]:
import torch
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
import requests
from torchvision.models import VGG19_Weights
import matplotlib.pyplot as plt



In [4]:
# Load VGG19 Pretrained Model
vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features.eval()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /Users/maxfrischknecht/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [01:45<00:00, 5.45MB/s] 


In [38]:
# ✅ Define which VGG19 layers to extract features from
VGG_LAYERS = {
    "content": 21,  # relu4_2 in VGG19 (captures content structure)
    "style": 21  # Extract style from the SAME layer
    # "content": [1, 6, 11, 20],  # relu4_2
    # "style": [1, 6, 11, 20]  # relu1_1, relu2_1, relu3_1, relu4_1
}

# ✅ Class to extract content & style features from VGG19
class VGGFeatures(torch.nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features.eval()
        self.layers = layers

    def forward(self, x):
        features = {}
        for name, layer in enumerate(self.vgg):
            x = layer(x)  # Pass image through VGG
            if name == self.layers["style"]:
                features["style"] = x  # Save style feature
            if name == self.layers["content"]:
                features["content"] = x  # Save content feature
                break  # Stop after content layer
        return features

In [34]:
# load from url
# def load_image(url):
#     response = requests.get(url, stream=True)
#     image = Image.open(response.raw)
#     transform = transforms.Compose([
#         transforms.Resize((256, 256)),
#         transforms.ToTensor(),
#     ])
#     return transform(image).unsqueeze(0)

# load images locally
def load_image(image_path):
    """Loads an image from a local path and converts it to a PyTorch tensor."""
    try:
        image = Image.open(image_path).convert("RGB")  # Ensure RGB mode
    except Exception as e:
        raise ValueError(f"Error opening image: {e}")

    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize to 256x256
        transforms.ToTensor(),  # Convert to tensor (C, H, W)
    ])
    
    return transform(image).unsqueeze(0)  # Add batch dimension (1, C, H, W)

content_image = load_image("./data/moss-forest.jpg")
style_image = load_image("./data/hokusai-fuji.jpg")

In [39]:
# Perform AdaIN transformation
# Define Adaptive Instance Normalization (AdaIN) Function
def adain(content, style):
    """Performs Adaptive Instance Normalization (AdaIN) to blend style with content."""
    
    # Compute mean & standard deviation for content and style
    mean_content, std_content = content.mean([2, 3], keepdim=True), content.std([2, 3], keepdim=True)
    mean_style, std_style = style.mean([2, 3], keepdim=True), style.std([2, 3], keepdim=True)

    # Normalize content and apply style statistics
    normalized_content = (content - mean_content) / (std_content + 1e-7)  # Prevent division by zero
    stylized = normalized_content * std_style + mean_style

    return stylized

vgg = VGGFeatures(VGG_LAYERS)

# Extract features from content & style images using VGG
with torch.no_grad():
    features = vgg(content_image)  # Extract features
    content_features = features["content"]  # Extract content feature
    style_features = features["style"]  # Extract style from SAME layer
    stylized_image = adain(content_features, style_features)  # Apply AdaIN

# ✅ Debug: Check shape of output
print("Stylized Image Shape:", stylized_image.shape)  # Expected: (1, 512, 32, 32) → Must fix

# ✅ **Fix Output Shape: Reduce 512 Channels → 3 Channels & Resize**
class Decoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(512, 3, kernel_size=3, padding=1)  # Reduce channels to 3 (RGB)
        self.upsample = torch.nn.Upsample(size=(256, 256), mode="bilinear", align_corners=False)  # Resize

    def forward(self, x):
        x = self.conv(x)  # Reduce channels
        x = self.upsample(x)  # Resize to 256x256
        return x

decoder = Decoder()

with torch.no_grad():
    stylized_image = decoder(stylized_image)  # Convert to RGB image

# ✅ Ensure correct dimensions
stylized_image = stylized_image.squeeze(0)  # Remove batch dimension -> (C, H, W)

# ✅ Fix channel ordering (PyTorch format → Image format)
if stylized_image.shape[0] == 3:
    stylized_image = stylized_image.permute(1, 2, 0)  # Convert (C, H, W) → (H, W, C)

# ✅ Convert tensor to NumPy image and scale values
output_image = stylized_image.clamp(0, 1).numpy()
output_image = (output_image * 255).astype("uint8")

# ✅ Convert to PIL image and show
output_image = Image.fromarray(output_image)
output_image.show()

Stylized Image Shape: torch.Size([1, 512, 32, 32])
