# Neural Style Transfer with Ray and PyTorch

This notebook demonstrates distributed neural style transfer using PyTorch and Ray for batch processing of images.

## Overview
- Uses VGG19 features for style and content representation
- Implements Gram matrices for style loss computation
- Distributed batch processing with Ray
- GPU acceleration support

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import ray
from ray import data
import numpy as np
import io
import base64
import matplotlib.pyplot as plt
import os

print(f"PyTorch version: {torch.__version__}")
print(f"Ray version: {ray.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else False}")

## Initialize Ray

In [None]:
# Initialize Ray
if not ray.is_initialized():
    ray.init()
    
print(f"Ray cluster resources: {ray.cluster_resources()}")

## Model Components

### VGG Feature Extractor

In [None]:
class VGGFeatures(nn.Module):
    """Extract features from VGG19 for style transfer"""
    def __init__(self):
        super(VGGFeatures, self).__init__()
        # Load pre-trained VGG19
        vgg = models.vgg19(pretrained=True).features
        self.layers = nn.ModuleList(vgg[:29])  # Up to conv4_4
        
        # Freeze parameters
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        features = []
        for layer in self.layers:
            x = layer(x)
            features.append(x)
        return features

### Gram Matrix for Style Representation

In [None]:
class GramMatrix(nn.Module):
    """Compute Gram matrix for style representation"""
    def forward(self, x):
        batch, channels, height, width = x.size()
        features = x.view(batch * channels, height * width)
        gram = torch.mm(features, features.t())
        return gram.div(batch * channels * height * width)

### Complete Style Transfer Model

In [None]:
class StyleTransferModel(nn.Module):
    """Neural Style Transfer model"""
    def __init__(self):
        super(StyleTransferModel, self).__init__()
        self.vgg = VGGFeatures()
        self.gram = GramMatrix()
        
        # Style layers (conv layers where we compute style loss)
        self.style_layers = [0, 5, 10, 19, 28]  # conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
        # Content layer
        self.content_layer = 21  # conv4_2
        
    def get_style_features(self, style_image):
        """Extract style features from style image"""
        features = self.vgg(style_image)
        style_features = []
        for i in self.style_layers:
            style_features.append(self.gram(features[i]))
        return style_features
    
    def get_content_features(self, content_image):
        """Extract content features from content image"""
        features = self.vgg(content_image)
        return features[self.content_layer]

## Image Processing Functions

In [None]:
def preprocess_image(image_path, size=512):
    """Preprocess image for neural style transfer"""
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    if isinstance(image_path, str):
        image = Image.open(image_path).convert('RGB')
    else:
        # Assume it's already a PIL Image or bytes
        if isinstance(image_path, bytes):
            image = Image.open(io.BytesIO(image_path)).convert('RGB')
        else:
            image = image_path.convert('RGB')
    
    return transform(image).unsqueeze(0)

def deprocess_image(tensor):
    """Convert tensor back to PIL Image"""
    # Denormalize
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    tensor = tensor * std + mean
    tensor = torch.clamp(tensor, 0, 1)
    
    # Convert to PIL
    tensor = tensor.squeeze(0)
    transform = transforms.ToPILImage()
    return transform(tensor)

## Style Transfer Optimization

In [None]:
def style_transfer_step(target, content_features, style_features, model, 
                       content_weight=1, style_weight=1000, iterations=300, verbose=True):
    """Perform style transfer optimization"""
    target = target.clone().requires_grad_(True)
    optimizer = torch.optim.LBFGS([target])
    
    losses = []
    
    def closure():
        optimizer.zero_grad()
        
        # Get features from target
        target_features = model.vgg(target)
        
        # Content loss
        content_loss = F.mse_loss(target_features[model.content_layer], content_features)
        
        # Style loss
        style_loss = 0
        for i, style_layer_idx in enumerate(model.style_layers):
            target_gram = model.gram(target_features[style_layer_idx])
            style_loss += F.mse_loss(target_gram, style_features[i])
        
        # Total loss
        total_loss = content_weight * content_loss + style_weight * style_loss
        total_loss.backward()
        
        losses.append(total_loss.item())
        return total_loss
    
    # Optimization loop
    for i in range(iterations):
        optimizer.step(closure)
        if verbose and i % 50 == 0:
            print(f"Iteration {i}, Loss: {losses[-1]:.4f}")
    
    return target, losses

## Ray Predictor for Distributed Processing

In [None]:
class StyleTransferPredictor:
    """Ray predictor class for style transfer"""
    def __init__(self, style_image_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = StyleTransferModel().to(self.device)
        self.model.eval()
        
        # Load and preprocess style image
        self.style_image = preprocess_image(style_image_path).to(self.device)
        self.style_features = self.model.get_style_features(self.style_image)
        
    def __call__(self, batch):
        """Process a batch of content images"""
        results = []
        
        for item in batch["image"]:
            try:
                # Preprocess content image
                if isinstance(item, str):
                    content_image = preprocess_image(item).to(self.device)
                else:
                    # Handle base64 encoded images or raw bytes
                    if isinstance(item, str) and item.startswith('data:image'):
                        # Base64 encoded image
                        image_data = base64.b64decode(item.split(',')[1])
                        content_image = preprocess_image(image_data).to(self.device)
                    else:
                        content_image = preprocess_image(item).to(self.device)
                
                # Extract content features
                with torch.no_grad():
                    content_features = self.model.get_content_features(content_image)
                
                # Initialize target as content image
                target = content_image.clone()
                
                # Perform style transfer
                with torch.enable_grad():
                    stylized, losses = style_transfer_step(
                        target, content_features, self.style_features, self.model,
                        content_weight=1, style_weight=1000, iterations=100, verbose=False
                    )
                
                # Convert back to PIL and then to base64
                stylized_pil = deprocess_image(stylized.cpu())
                
                # Convert to base64 for storage/transmission
                buffer = io.BytesIO()
                stylized_pil.save(buffer, format='PNG')
                img_str = base64.b64encode(buffer.getvalue()).decode()
                
                results.append({
                    "stylized_image": f"data:image/png;base64,{img_str}",
                    "status": "success"
                })
                
            except Exception as e:
                results.append({
                    "stylized_image": None,
                    "status": f"error: {str(e)}"
                })
        
        return {"stylized_image": [r["stylized_image"] for r in results],
                "status": [r["status"] for r in results]}

## Single Image Style Transfer (Interactive)

In [None]:
def single_style_transfer(content_path, style_path, iterations=300, content_weight=1, style_weight=1000):
    """Perform style transfer on a single image with visualization"""
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load model
    model = StyleTransferModel().to(device)
    model.eval()
    
    # Load and preprocess images
    content_image = preprocess_image(content_path).to(device)
    style_image = preprocess_image(style_path).to(device)
    
    # Extract features
    with torch.no_grad():
        content_features = model.get_content_features(content_image)
        style_features = model.get_style_features(style_image)
    
    # Initialize target image
    target = content_image.clone()
    
    # Perform style transfer
    print("Starting style transfer...")
    stylized, losses = style_transfer_step(
        target, content_features, style_features, model,
        content_weight=content_weight, style_weight=style_weight, 
        iterations=iterations, verbose=True
    )
    
    # Convert images for display
    content_pil = deprocess_image(content_image.cpu())
    style_pil = deprocess_image(style_image.cpu())
    stylized_pil = deprocess_image(stylized.cpu())
    
    # Display results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(content_pil)
    axes[0].set_title('Content Image')
    axes[0].axis('off')
    
    axes[1].imshow(style_pil)
    axes[1].set_title('Style Image')
    axes[1].axis('off')
    
    axes[2].imshow(stylized_pil)
    axes[2].set_title('Stylized Result')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Plot loss
    plt.figure(figsize=(10, 4))
    plt.plot(losses)
    plt.title('Loss During Optimization')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.show()
    
    return stylized_pil

## Batch Processing with Ray

In [None]:
def run_batch_style_transfer(content_image_paths, style_image_path, output_dir="./outputs"):
    """Run style transfer on multiple images using Ray"""
    
    print("Creating dataset...")
    # Create Ray dataset
    dataset = ray.data.from_items([{"image": path} for path in content_image_paths])
    
    print("Setting up style transfer predictor...")
    # Create predictor with the style image
    predictor = StyleTransferPredictor(style_image_path)
    
    print("Running batch inference...")
    # Apply style transfer using Ray Data
    results = dataset.map_batches(
        predictor,
        batch_size=2,  # Process 2 images at a time
        num_gpus=1 if torch.cuda.is_available() else 0,
        num_cpus=2
    )
    
    print("Collecting results...")
    # Collect and save results
    output_data = results.take_all()
    
    os.makedirs(output_dir, exist_ok=True)
    
    saved_images = []
    for i, result in enumerate(output_data):
        if result["status"] == "success" and result["stylized_image"]:
            # Decode base64 and save
            img_data = base64.b64decode(result["stylized_image"].split(',')[1])
            output_path = os.path.join(output_dir, f"stylized_image_{i}.png")
            with open(output_path, 'wb') as f:
                f.write(img_data)
            print(f"Saved stylized image: {output_path}")
            saved_images.append(output_path)
        else:
            print(f"Failed to process image {i}: {result['status']}")
    
    return saved_images

## Example Usage

### Single Image Style Transfer

Replace the paths below with your actual image paths:

In [None]:
# Example paths - replace with your actual image paths
content_path = "assets/content1.jpg"  # Your content image
style_path = "assets/starry_night.jpg"  # Your style image (e.g., Van Gogh's Starry Night)

# Check if files exist
if os.path.exists(content_path) and os.path.exists(style_path):
    print("Running single image style transfer...")
    result = single_style_transfer(content_path, style_path, iterations=200)
    
    # Save result
    result.save("single_style_result.png")
    print("Result saved as 'single_style_result.png'")
else:
    print(f"Images not found. Please add:")
    print(f"- Content image: {content_path}")
    print(f"- Style image: {style_path}")
    print("To the assets/ directory")

### Batch Processing Example

In [None]:
# Example batch processing
content_images = [
    "assets/content1.jpg",
    "assets/content2.jpg",
    "assets/content3.jpg"
]

style_image = "assets/starry_night.jpg"

# Check if files exist
existing_images = [img for img in content_images if os.path.exists(img)]

if existing_images and os.path.exists(style_image):
    print(f"Running batch style transfer on {len(existing_images)} images...")
    
    saved_images = run_batch_style_transfer(
        content_image_paths=existing_images,
        style_image_path=style_image,
        output_dir="./batch_outputs"
    )
    
    print(f"\nBatch processing complete! Saved {len(saved_images)} images.")
    
    # Display results if any
    if saved_images:
        fig, axes = plt.subplots(1, min(len(saved_images), 4), figsize=(16, 4))
        if len(saved_images) == 1:
            axes = [axes]
        
        for i, img_path in enumerate(saved_images[:4]):
            img = Image.open(img_path)
            axes[i].imshow(img)
            axes[i].set_title(f'Result {i+1}')
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
else:
    print("Please add sample images to the assets/ directory:")
    print("- Content images (content1.jpg, content2.jpg, etc.)")
    print("- Style image (starry_night.jpg or any style image)")

## Cleanup

In [None]:
# Shutdown Ray when done
if ray.is_initialized():
    ray.shutdown()
    print("Ray shutdown complete.")

## Notes

1. **Image Requirements**: Place your images in the `assets/` directory
2. **GPU Support**: The notebook will automatically use CUDA if available
3. **Memory**: Style transfer can be memory-intensive. Reduce image size if needed
4. **Parameters**: Experiment with `content_weight` and `style_weight` for different effects
5. **Iterations**: More iterations generally give better results but take longer

### Parameter Guidelines:
- **content_weight**: Controls content preservation (default: 1)
- **style_weight**: Controls style transfer strength (default: 1000)
- **iterations**: Number of optimization steps (default: 300)

### Troubleshooting:
- If you get memory errors, try reducing the image size in `preprocess_image()`
- For faster results, reduce the number of iterations
- Make sure your images are in RGB format