In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import copy
import os

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load and transform image
def load_image(img_path, max_size=400, shape=None):
    if not os.path.exists(img_path):
        raise FileNotFoundError(f"Image not found at {img_path}")
        
    image = Image.open(img_path).convert('RGB')

    if shape is not None:
        size = shape  # (height, width)
    else:
        size = max(image.size)
        if size > max_size:
            size = max_size
        size = (size, size)  # make it square

    in_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    image = in_transform(image).unsqueeze(0)
    return image.to(device)

# Convert tensor to displayable image
def im_convert(tensor):
    image = tensor.to("cpu").clone().detach()
    image = image.squeeze()
    image = image.numpy().transpose(1, 2, 0)
    image = image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    image = image.clip(0, 1)
    return image

# Load images
content = load_image("content.jpg")
style = load_image("style.jpg", shape=(content.size(2), content.size(3)))  # match content size

# Show input images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(im_convert(content))
ax1.set_title("Content Image")
ax1.axis("off")

ax2.imshow(im_convert(style))
ax2.set_title("Style Image")
ax2.axis("off")
plt.show()

# Load VGG19 and freeze weights
vgg = models.vgg19(pretrained=True).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

# Feature extraction
def get_features(image, model, layers=None):
    if layers is None:
        layers = {
            '0': 'conv1_1',
            '5': 'conv2_1',
            '10': 'conv3_1',
            '19': 'conv4_1',
            '21': 'conv4_2',  # content representation
            '28': 'conv5_1'
        }
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
    return features

# Gram matrix (for style)
def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram

# Extract features
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

# Compute style grams
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

# Initialize target image
target = content.clone().requires_grad_(True).to(device)

# Style weights
style_weights = {
    'conv1_1': 1.0,
    'conv2_1': 0.75,
    'conv3_1': 0.5,
    'conv4_1': 0.25,
    'conv5_1': 0.1
}

content_weight = 1e4  # alpha
style_weight = 1e6    # beta

# Optimizer
optimizer = optim.Adam([target], lr=0.003)

# Training loop
steps = 1000
print_every = 100

for step in range(1, steps + 1):
    target_features = get_features(target, vgg)
    
    # Content loss
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
    
    # Style loss
    style_loss = 0
    for layer in style_weights:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        style_gram = style_grams[layer]
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
        style_loss += layer_style_loss / (target_feature.shape[1] ** 2)
        
    total_loss = content_weight * content_loss + style_weight * style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step % print_every == 0:
        print(f"Step {step}/{steps}, Total loss: {total_loss.item():.2f}")

# Display final result
final_img = im_convert(target)
plt.figure(figsize=(8, 8))
plt.imshow(final_img)
plt.title("Stylized Image")
plt.axis("off")
plt.show()

# Save result
output = Image.fromarray((final_img * 255).astype('uint8'))
output.save("output_styled.jpg")
print("✅ Output image saved as 'output_styled.jpg'")
