# UNet Polygon Coloring - Inference Demo

This notebook demonstrates how to use the trained UNet model to color polygons based on text input.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import os
import sys

# Add scripts directory to path
sys.path.append('scripts')

from model import ConditionalUNet
from utils import load_model, preprocess_image, postprocess_output, visualize_prediction, generate_synthetic_polygon
from dataset import PolygonDataset

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Load Trained Model

In [None]:
# Load the best trained model
checkpoint_path = 'checkpoints/best_model.pth'

if os.path.exists(checkpoint_path):
    model, color_to_idx = load_model(checkpoint_path, device)
    idx_to_color = {idx: color for color, idx in color_to_idx.items()}
    
    print("Model loaded successfully!")
    print(f"Available colors: {list(color_to_idx.keys())}")
else:
    print(f"Checkpoint not found at {checkpoint_path}")
    print("Please train the model first using the training script.")

## Test with Validation Dataset

In [None]:
# Load validation dataset for testing
data_dir = 'dataset'  # Update this path to your dataset location

if os.path.exists(data_dir):
    val_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    
    val_dataset = PolygonDataset(data_dir, 'validation', val_transform)
    val_dataset.color_to_idx = color_to_idx  # Use same color mapping as training
    
    print(f"Validation dataset loaded with {len(val_dataset)} samples")
else:
    print(f"Dataset not found at {data_dir}")
    print("We'll use synthetic polygons for demonstration.")
    val_dataset = None

## Inference Function

In [None]:
def predict_colored_polygon(model, input_image, color_name, color_to_idx, device):
    """
    Predict colored polygon given input image and color name
    
    Args:
        model: Trained UNet model
        input_image: PIL Image or tensor
        color_name: String color name
        color_to_idx: Color to index mapping
        device: torch device
    
    Returns:
        PIL Image of predicted colored polygon
    """
    model.eval()
    
    with torch.no_grad():
        # Preprocess input
        if isinstance(input_image, Image.Image):
            transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
            ])
            input_tensor = transform(input_image).unsqueeze(0).to(device)
        else:
            input_tensor = input_image.to(device)
        
        # Get color index
        if color_name not in color_to_idx:
            print(f"Color '{color_name}' not found. Available colors: {list(color_to_idx.keys())}")
            return None
        
        color_idx = torch.tensor([color_to_idx[color_name]], dtype=torch.long).to(device)
        
        # Predict
        output = model(input_tensor, color_idx)
        
        # Convert to PIL image
        prediction = postprocess_output(output)
        
    return prediction

## Test with Validation Samples

In [None]:
if val_dataset is not None and 'model' in locals():
    # Test with a few validation samples
    num_samples = min(5, len(val_dataset))
    
    for i in range(num_samples):
        sample = val_dataset[i]
        
        # Get input image as PIL
        input_pil = transforms.ToPILImage()(sample['input'])
        target_pil = transforms.ToPILImage()(sample['output'])
        color_name = sample['color_name']
        
        # Predict
        prediction_pil = predict_colored_polygon(
            model, input_pil, color_name, color_to_idx, device
        )
        
        # Visualize
        print(f"\nSample {i+1}: Color = {color_name}")
        visualize_prediction(input_pil, target_pil, prediction_pil, color_name)
else:
    print("Validation dataset not available or model not loaded.")

## Test with Synthetic Polygons

In [None]:
if 'model' in locals():
    # Generate synthetic polygons and test coloring
    shapes = ['triangle', 'square', 'pentagon', 'hexagon']
    test_colors = list(color_to_idx.keys())[:3]  # Test with first 3 available colors
    
    print(f"Testing with colors: {test_colors}")
    
    for shape in shapes:
        print(f"\n=== Testing {shape.upper()} ===")
        
        # Generate white polygon as input
        input_polygon = generate_synthetic_polygon(shape, size=256, color='white')
        
        # Test with different colors
        fig, axes = plt.subplots(1, len(test_colors) + 1, figsize=(4 * (len(test_colors) + 1), 4))
        
        # Show input
        axes[0].imshow(input_polygon)
        axes[0].set_title('Input')
        axes[0].axis('off')
        
        # Test each color
        for i, color in enumerate(test_colors):
            prediction = predict_colored_polygon(
                model, input_polygon, color, color_to_idx, device
            )
            
            if prediction is not None:
                axes[i + 1].imshow(prediction)
                axes[i + 1].set_title(f'{color.capitalize()}')
            else:
                axes[i + 1].text(0.5, 0.5, 'Error', ha='center', va='center')
                axes[i + 1].set_title(f'{color.capitalize()} (Error)')
            
            axes[i + 1].axis('off')
        
        plt.tight_layout()
        plt.show()
else:
    print("Model not loaded. Please load the model first.")

## Interactive Testing

In [None]:
if 'model' in locals():
    # Interactive function to test custom inputs
    def test_custom_input(image_path, color_name):
        """
        Test model with custom input image and color
        """
        if not os.path.exists(image_path):
            print(f"Image not found: {image_path}")
            return
        
        # Load and preprocess image
        input_image = Image.open(image_path).convert('RGB')
        
        # Predict
        prediction = predict_colored_polygon(
            model, input_image, color_name, color_to_idx, device
        )
        
        if prediction is not None:
            # Visualize
            visualize_prediction(input_image, None, prediction, color_name)
        else:
            print(f"Failed to generate prediction for color: {color_name}")
    
    # Example usage (uncomment and modify paths as needed)
    # test_custom_input('path/to/your/polygon.png', 'blue')
    
    print("Use the test_custom_input function to test with your own images:")
    print("test_custom_input('path/to/image.png', 'color_name')")
    print(f"Available colors: {list(color_to_idx.keys())}")
else:
    print("Model not loaded.")

## Model Performance Analysis

In [None]:
if val_dataset is not None and 'model' in locals():
    from utils import calculate_metrics
    from torch.utils.data import DataLoader
    
    # Evaluate model on validation set
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    all_predictions = []
    all_targets = []
    
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['input'].to(device)
            targets = batch['output'].to(device)
            color_indices = batch['color_idx'].to(device)
            
            outputs = model(inputs, color_indices)
            
            all_predictions.append(outputs)
            all_targets.append(targets)
    
    # Concatenate all predictions and targets
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    
    # Calculate metrics
    metrics = calculate_metrics(all_predictions, all_targets)
    
    print("\n=== Model Performance on Validation Set ===")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.6f}")
else:
    print("Cannot evaluate performance without validation dataset and model.")

## Conclusion

This notebook demonstrates the inference capabilities of the trained UNet model for polygon coloring. The model can:

1. Take a polygon image and color name as input
2. Generate a colored version of the polygon
3. Handle various polygon shapes and colors

Key observations:
- The model learns to associate color names with RGB values
- It preserves the shape and structure of input polygons
- Performance can be measured using MSE, MAE, and PSNR metrics

For production use, consider:
- Adding more diverse training data
- Implementing data augmentation strategies
- Fine-tuning hyperparameters
- Adding more sophisticated loss functions (e.g., perceptual loss)