In [4]:
# ============================================================
# üé® Neural Style Transfer (Gatys et al.) - Google Colab Ready
# ============================================================
# Applies artistic style from one image onto another using VGG19.
# ============================================================

# --- Step 1: Setup Environment ---
!pip install torch torchvision pillow matplotlib tqdm --quiet

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import copy
from tqdm import tqdm
import requests
import os

# --- Step 2: Device setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("‚úÖ Using device:", device)

# --- Step 3: Folder & file paths ---
base_dir = "/content/style_transfer_project"
content_dir = os.path.join(base_dir, "content")
style_dir = os.path.join(base_dir, "style")
output_dir = os.path.join(base_dir, "output")

os.makedirs(content_dir, exist_ok=True)
os.makedirs(style_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

content_path = os.path.join(content_dir, "mountains.jpg")
style_path   = os.path.join(style_dir, "van_gogh_starry_night.jpg")
output_path  = os.path.join(output_dir, "stylized_image.jpg")

# --- Step 4: Download sample images ---
content_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0a/The_Great_Wave_off_Kanagawa.jpg/1280px-The_Great_Wave_off_Kanagawa.jpg" # Using a thumbnail URL
style_url   = "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Vincent_van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Vincent_van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg" # Using a thumbnail URL


def download_image(url, path):
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an HTTPError for bad responses (4xx or 5xx)
        img = Image.open(response.raw)
        img.save(path)
        print(f"Downloaded {url} to {path}")
        return True
    except requests.exceptions.RequestException as e:
        print(f"Error downloading {url}: {e}")
        return False
    except Exception as e:
        print(f"An error occurred while processing {url}: {e}")
        return False

content_downloaded = download_image(content_url, content_path)
style_downloaded = download_image(style_url, style_path)


print(f"üìÇ Files organized under: {base_dir}")
print(f"üñºÔ∏è Content Image Path: {content_path}")
print(f"üé® Style Image Path: {style_path}")
print(f"üíæ Output Path: {output_path}")

# --- Step 5: Helper Functions ---
imsize = 512 if torch.cuda.is_available() else 256

loader = transforms.Compose([
    transforms.Resize((imsize, imsize)),
    transforms.ToTensor()
])
unloader = transforms.ToPILImage()

def image_loader(path):
    image = Image.open(path).convert('RGB')
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

def imshow(tensor, title=None):
    image = tensor.cpu().clone().squeeze(0)
    image = unloader(image)
    plt.figure(figsize=(6,6))
    if title:
        plt.title(title)
    plt.imshow(image)
    plt.axis('off')
    plt.show()

# --- Step 6: Load images ---
if content_downloaded and style_downloaded:
    content_img = image_loader(content_path)
    style_img   = image_loader(style_path)

    print("üñºÔ∏è Content Image:")
    imshow(content_img, title="Content Image")
    print("üé® Style Image:")
    imshow(style_img, title="Style Image")

    # --- Step 7: Define model & loss classes ---
    cnn = models.vgg19(pretrained=True).features.to(device).eval()
    cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    cnn_normalization_std  = torch.tensor([0.229, 0.224, 0.225]).to(device)

    class Normalization(nn.Module):
        def __init__(self, mean, std):
            super().__init__()
            self.mean = mean.clone().detach().view(-1,1,1)
            self.std = std.clone().detach().view(-1,1,1)
        def forward(self, img):
            return (img - self.mean) / self.std

    def gram_matrix(input_tensor):
        b, c, h, w = input_tensor.size()
        features = input_tensor.view(b * c, h * w)
        G = torch.mm(features, features.t())
        return G.div(b * c * h * w)

    class ContentLoss(nn.Module):
        def __init__(self, target):
            super().__init__()
            self.target = target.detach()
            self.loss = 0
        def forward(self, input):
            self.loss = nn.functional.mse_loss(input, self.target)
            return input

    class StyleLoss(nn.Module):
        def __init__(self, target_feature):
            super().__init__()
            self.target = gram_matrix(target_feature).detach()
            self.loss = 0
        def forward(self, input):
            G = gram_matrix(input)
            self.loss = nn.functional.mse_loss(G, self.target)
            return input

    def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                                   style_img, content_img,
                                   content_layers=['conv_4'],
                                   style_layers=['conv_1','conv_2','conv_3','conv_4','conv_5']):
        cnn = copy.deepcopy(cnn)
        normalization = Normalization(normalization_mean, normalization_std).to(device)
        content_losses = []
        style_losses = []
        model = nn.Sequential(normalization)

        i = 0
        for layer in cnn.children():
            if isinstance(layer, nn.Conv2d):
                i += 1
                name = f'conv_{i}'
            elif isinstance(layer, nn.ReLU):
                name = f'relu_{i}'
                layer = nn.ReLU(inplace=False)
            elif isinstance(layer, nn.MaxPool2d):
                name = f'pool_{i}'
            else:
                name = f'layer_{i}'
            model.add_module(name, layer)

            if name in content_layers:
                target = model(content_img).detach()
                content_loss = ContentLoss(target)
                model.add_module(f"content_loss_{i}", content_loss)
                content_losses.append(content_loss)

            if name in style_layers:
                target_feature = model(style_img).detach()
                style_loss = StyleLoss(target_feature)
                model.add_module(f"style_loss_{i}", style_loss)
                style_losses.append(style_loss)

        for i in range(len(model) - 1, -1, -1):
            if isinstance(model[i], (ContentLoss, StyleLoss)):
                break
        model = model[:i+1]
        return model, style_losses, content_losses

    # --- Step 8: Run Style Transfer ---
    input_img = content_img.clone()
    style_weight = 1e6
    content_weight = 1e0

    def run_style_transfer(cnn, normalization_mean, normalization_std,
                           content_img, style_img, input_img, num_steps=200):
        print("üöÄ Starting style transfer...\n")
        model, style_losses, content_losses = get_style_model_and_losses(
            cnn, normalization_mean, normalization_std, style_img, content_img)
        optimizer = optim.LBFGS([input_img.requires_grad_()])

        run = [0]
        with tqdm(total=num_steps) as pbar:
            while run[0] <= num_steps:
                def closure():
                    input_img.data.clamp_(0, 1)
                    optimizer.zero_grad()
                    model(input_img)
                    style_score = 0
                    content_score = 0
                    for sl in style_losses:
                        style_score += sl.loss
                    for cl in content_losses:
                        content_score += cl.loss
                    loss = style_weight * style_score + content_weight * content_score
                    loss.backward()
                    pbar.set_description(f"Loss: {loss.item():.2e}")
                    return loss
                optimizer.step(closure)
                run[0] += 1
                pbar.update(1)

        input_img.data.clamp_(0, 1)
        return input_img

    output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
                                content_img, style_img, input_img, num_steps=200)

    # --- Step 9: Display & Save Result ---
    imshow(output, title="‚ú® Stylized Image")
    output_pil = unloader(output.cpu().squeeze(0))
    output_pil.save(output_path)

    print(f"‚úÖ Saved stylized image to: {output_path}")
else:
    print("‚ùå Image download failed. Please check the URLs and try again.")

‚úÖ Using device: cpu
Error downloading https://upload.wikimedia.org/wikipedia/commons/thumb/0/0a/The_Great_Wave_off_Kanagawa.jpg/1280px-The_Great_Wave_off_Kanagawa.jpg: 403 Client Error: Forbidden for url: https://upload.wikimedia.org/wikipedia/commons/thumb/0/0a/The_Great_Wave_off_Kanagawa.jpg/1280px-The_Great_Wave_off_Kanagawa.jpg
Error downloading https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Vincent_van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Vincent_van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg: 403 Client Error: Forbidden for url: https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Vincent_van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Vincent_van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg
üìÇ Files organized under: /content/style_transfer_project
üñºÔ∏è Content Image Path: /content/style_transfer_project/content/mountains.jpg
üé® Style Image Path: /content/style_transfer_project/style/van_gogh_starry_night.jpg
üíæ Output Path: /co