# Fast Image Colorization with Pix2Pix GAN
**Optimized for Google Colab T4 GPU - Complete training in 2-3 hours**

This notebook implements a lightweight Pix2Pix GAN for fast image colorization using:
- 5,000 image subset from Kaggle dataset
- Lightweight U-Net generator (5 blocks, 32 base filters)
- PatchGAN discriminator (4 layers)
- Mixed precision training
- Speed optimizations for Colab

## 🚀 Automatic Setup for Google Colab

In [5]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected. This will be very slow on CPU!")

CUDA available: False
⚠️ No GPU detected. This will be very slow on CPU!


In [6]:
# Install required packages
!pip install -q kaggle scikit-image tqdm matplotlib seaborn

# Import all necessary libraries
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import zipfile
import shutil
from pathlib import Path
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from PIL import Image
import cv2
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True  # Enable for speed

set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


## 📁 Automatic Kaggle Dataset Download

In [7]:
# Kaggle API setup
print("🔑 Setting up Kaggle authentication...")
print("Please upload your kaggle.json file when prompted, or create it manually:")
print("1. Go to https://www.kaggle.com/account")
print("2. Click 'Create New Token'")
print("3. Upload the downloaded kaggle.json file")

try:
    from google.colab import files
    uploaded = files.upload()
    
    # Setup Kaggle credentials
    os.makedirs('/root/.kaggle', exist_ok=True)
    shutil.move('kaggle.json', '/root/.kaggle/kaggle.json')
    os.chmod('/root/.kaggle/kaggle.json', 600)
    print("✅ Kaggle authentication successful!")
except ImportError:
    print("ℹ️ Not running in Colab. Please ensure kaggle.json is in ~/.kaggle/")
except:
    print("⚠️ Manual setup: Please ensure kaggle.json is properly configured")

🔑 Setting up Kaggle authentication...
Please upload your kaggle.json file when prompted, or create it manually:
1. Go to https://www.kaggle.com/account
2. Click 'Create New Token'
3. Upload the downloaded kaggle.json file
ℹ️ Not running in Colab. Please ensure kaggle.json is in ~/.kaggle/


In [9]:
# Download and extract dataset
dataset_name = "shravankumar9892/image-colorization"
data_dir = Path("/content/colorization_data")

if not data_dir.exists():
    print(f"📥 Downloading dataset: {dataset_name}")
    !kaggle datasets download -d {dataset_name} -p /content/
    
    print("📦 Extracting dataset...")
    with zipfile.ZipFile('/content/image-colorization.zip', 'r') as zip_ref:
        zip_ref.extractall('/content/')
    
    # Find the extracted folder
    extracted_folders = [f for f in os.listdir('/content/') if 'colorization' in f.lower() or 'image' in f.lower()]
    if extracted_folders:
        os.rename(f'/content/{extracted_folders[0]}', data_dir)
    
    print(f"✅ Dataset ready at: {data_dir}")
else:
    print(f"✅ Dataset already exists at: {data_dir}")

# List dataset contents
if data_dir.exists():
    subfolders = [f for f in data_dir.iterdir() if f.is_dir()]
    print(f"\n📂 Dataset structure:")
    for folder in subfolders:
        img_count = len(list(folder.glob('*.jpg')) + list(folder.glob('*.png')))
        print(f"  {folder.name}: {img_count} images")

📥 Downloading dataset: shravankumar9892/image-colorization
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.11/bin/kaggle", line 7, in <module>
    sys.exit(main())
             ^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/kaggle/cli.py", line 68, in main
    out = args.func(**command_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/kaggle/api/kaggle_api_extended.py", line 1741, in dataset_download_cli
    with self.build_kaggle_client() as kaggle:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/kaggle/api/kaggle_api_extended.py", line 688, in build_kaggle_client
    username=self.config_values['username'],
             ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'username'
📦 Extracting dataset...


FileNotFoundError: [Errno 2] No such file or directory: '/content/image-colorization.zip'

## 🏗️ Lightweight Model Architecture

In [None]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_dropout=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2 if down else 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2) if down else nn.ReLU()
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5) if use_dropout else None
        
    def forward(self, x):
        x = self.conv(x)
        if self.use_dropout and self.dropout:
            x = self.dropout(x)
        return x

class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, features=32):
        super().__init__()
        
        # Encoder (Downsampling) - 5 blocks for speed
        self.down1 = nn.Conv2d(in_channels, features, 4, 2, 1)  # 128 -> 64
        self.down2 = UNetBlock(features, features*2, down=True)  # 64 -> 32
        self.down3 = UNetBlock(features*2, features*4, down=True)  # 32 -> 16
        self.down4 = UNetBlock(features*4, features*8, down=True)  # 16 -> 8
        self.down5 = UNetBlock(features*8, features*8, down=True)  # 8 -> 4
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2, 1),  # 4 -> 2
            nn.ReLU()
        )
        
        # Decoder (Upsampling)
        self.up1 = nn.ConvTranspose2d(features*8, features*8, 4, 2, 1)  # 2 -> 4
        self.up2 = nn.ConvTranspose2d(features*8*2, features*8, 4, 2, 1)  # 4 -> 8
        self.up3 = nn.ConvTranspose2d(features*8*2, features*4, 4, 2, 1)  # 8 -> 16
        self.up4 = nn.ConvTranspose2d(features*4*2, features*2, 4, 2, 1)  # 16 -> 32
        self.up5 = nn.ConvTranspose2d(features*2*2, features, 4, 2, 1)    # 32 -> 64
        
        self.final = nn.ConvTranspose2d(features*2, out_channels, 4, 2, 1)  # 64 -> 128
        self.tanh = nn.Tanh()
        
        # Batch norm layers for decoder
        self.bn_up1 = nn.BatchNorm2d(features*8)
        self.bn_up2 = nn.BatchNorm2d(features*8)
        self.bn_up3 = nn.BatchNorm2d(features*4)
        self.bn_up4 = nn.BatchNorm2d(features*2)
        self.bn_up5 = nn.BatchNorm2d(features)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        
        # Bottleneck
        bottleneck = self.bottleneck(d5)
        
        # Decoder with skip connections
        u1 = self.relu(self.bn_up1(self.up1(bottleneck)))
        u1 = self.dropout(u1)
        u1 = torch.cat([u1, d5], dim=1)
        
        u2 = self.relu(self.bn_up2(self.up2(u1)))
        u2 = self.dropout(u2)
        u2 = torch.cat([u2, d4], dim=1)
        
        u3 = self.relu(self.bn_up3(self.up3(u2)))
        u3 = self.dropout(u3)
        u3 = torch.cat([u3, d3], dim=1)
        
        u4 = self.relu(self.bn_up4(self.up4(u3)))
        u4 = torch.cat([u4, d2], dim=1)
        
        u5 = self.relu(self.bn_up5(self.up5(u4)))
        u5 = torch.cat([u5, d1], dim=1)
        
        return self.tanh(self.final(u5))

# Test generator
gen = Generator().to(device)
test_input = torch.randn(1, 1, 128, 128).to(device)
test_output = gen(test_input)
print(f"Generator output shape: {test_output.shape}")
print(f"Generator parameters: {sum(p.numel() for p in gen.parameters()):,}")

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=4, features=32):  # 1 (grayscale) + 3 (RGB) = 4
        super().__init__()
        
        # PatchGAN discriminator with 4 layers
        self.model = nn.Sequential(
            # Layer 1: 128x128 -> 64x64
            nn.Conv2d(in_channels, features, 4, 2, 1),
            nn.LeakyReLU(0.2),
            
            # Layer 2: 64x64 -> 32x32
            nn.Conv2d(features, features*2, 4, 2, 1),
            nn.BatchNorm2d(features*2),
            nn.LeakyReLU(0.2),
            
            # Layer 3: 32x32 -> 16x16
            nn.Conv2d(features*2, features*4, 4, 2, 1),
            nn.BatchNorm2d(features*4),
            nn.LeakyReLU(0.2),
            
            # Layer 4: 16x16 -> 8x8
            nn.Conv2d(features*4, features*8, 4, 2, 1),
            nn.BatchNorm2d(features*8),
            nn.LeakyReLU(0.2),
            
            # Final layer: 8x8 -> 1 (patch output)
            nn.Conv2d(features*8, 1, 4, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, grayscale, rgb):
        x = torch.cat([grayscale, rgb], dim=1)
        return self.model(x)

# Test discriminator
disc = Discriminator().to(device)
test_gray = torch.randn(1, 1, 128, 128).to(device)
test_rgb = torch.randn(1, 3, 128, 128).to(device)
test_output = disc(test_gray, test_rgb)
print(f"Discriminator output shape: {test_output.shape}")
print(f"Discriminator parameters: {sum(p.numel() for p in disc.parameters()):,}")

total_params = sum(p.numel() for p in gen.parameters()) + sum(p.numel() for p in disc.parameters())
print(f"\n🎯 Total model parameters: {total_params:,} (Lightweight for fast training!)")

## 📊 Smart Data Preprocessing Pipeline

In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, image_paths, transform=None, max_images=5000):
        self.image_paths = image_paths[:max_images]  # Limit to 5K images
        self.transform = transform
        
        print(f"📸 Dataset initialized with {len(self.image_paths)} images")
        
    def __len__(self):
        return len(self.image_paths)
    
    def rgb_to_lab(self, rgb_img):
        """Convert RGB to LAB color space"""
        # Convert to numpy if tensor
        if torch.is_tensor(rgb_img):
            rgb_img = rgb_img.permute(1, 2, 0).numpy()
        
        # Ensure values are in [0, 1]
        if rgb_img.max() > 1.0:
            rgb_img = rgb_img / 255.0
            
        # Convert to LAB
        lab_img = cv2.cvtColor(rgb_img.astype(np.float32), cv2.COLOR_RGB2LAB)
        
        # Normalize L channel to [-1, 1] and AB channels to [-1, 1]
        L = lab_img[:, :, 0] / 50.0 - 1.0  # L: [0, 100] -> [-1, 1]
        AB = lab_img[:, :, 1:] / 128.0     # AB: [-128, 127] -> [-1, 1]
        
        return torch.FloatTensor(L).unsqueeze(0), torch.FloatTensor(AB).permute(2, 0, 1)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        
        try:
            # Load and convert image
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            
            # Convert to LAB
            L, AB = self.rgb_to_lab(image)
            
            return L, AB  # L (grayscale) as input, AB (color) as target
            
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a dummy tensor if image fails to load
            return torch.zeros(1, 128, 128), torch.zeros(2, 128, 128)

# Define transforms
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Fast 128x128 resolution
    transforms.ToTensor(),
])

print("✅ Dataset class ready!")

In [None]:
# Find all image files
def find_images(data_dir, extensions=['.jpg', '.jpeg', '.png', '.bmp']):
    image_paths = []
    for ext in extensions:
        image_paths.extend(list(data_dir.rglob(f'*{ext}')))
        image_paths.extend(list(data_dir.rglob(f'*{ext.upper()}')))
    return image_paths

# Get all image paths
all_image_paths = find_images(data_dir)
print(f"📁 Found {len(all_image_paths)} total images")

# Randomly sample 5000 images for speed
random.shuffle(all_image_paths)
selected_images = all_image_paths[:5000]
print(f"🎯 Selected {len(selected_images)} images for training")

# Train/Val split (80/20)
split_idx = int(0.8 * len(selected_images))
train_paths = selected_images[:split_idx]
val_paths = selected_images[split_idx:]

print(f"🚂 Training images: {len(train_paths)}")
print(f"🔍 Validation images: {len(val_paths)}")

# Create datasets
train_dataset = ColorizationDataset(train_paths, transform=transform)
val_dataset = ColorizationDataset(val_paths, transform=transform)

# Create data loaders with optimizations
batch_size = 64  # Larger batch size for speed
num_workers = 2  # Optimized for Colab

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=num_workers,
    pin_memory=True,  # Speed optimization
    drop_last=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers,
    pin_memory=True
)

print(f"\n⚡ DataLoaders ready!")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Test data loading
print("\n🧪 Testing data loading...")
sample_L, sample_AB = next(iter(train_loader))
print(f"L (grayscale) batch shape: {sample_L.shape}")
print(f"AB (color) batch shape: {sample_AB.shape}")
print(f"L range: [{sample_L.min():.3f}, {sample_L.max():.3f}]")
print(f"AB range: [{sample_AB.min():.3f}, {sample_AB.max():.3f}]")

## 🏋️‍♂️ Training Setup with Speed Optimizations

In [None]:
# Loss functions
def gan_loss(output, is_real):
    """GAN loss function"""
    target = torch.ones_like(output) if is_real else torch.zeros_like(output)
    return F.binary_cross_entropy(output, target)

def l1_loss(output, target):
    """L1 loss for pixel-wise comparison"""
    return F.l1_loss(output, target)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Learning rate schedulers
g_scheduler = optim.lr_scheduler.ReduceLROnPlateau(g_optimizer, 'min', patience=5, factor=0.5)
d_scheduler = optim.lr_scheduler.ReduceLROnPlateau(d_optimizer, 'min', patience=5, factor=0.5)

# Mixed precision training setup
scaler = GradScaler()

# Loss weight
lambda_l1 = 100  # Weight for L1 loss

print("✅ Training setup complete!")
print(f"Generator LR: {g_optimizer.param_groups[0]['lr']}")
print(f"Discriminator LR: {d_optimizer.param_groups[0]['lr']}")
print(f"L1 Loss weight: {lambda_l1}")
print(f"Mixed precision: Enabled")

## 📊 Visualization and Metrics Functions

In [None]:
def lab_to_rgb(L, AB):
    """Convert LAB to RGB for visualization"""
    # Denormalize
    L = (L + 1.0) * 50.0  # [-1, 1] -> [0, 100]
    AB = AB * 128.0       # [-1, 1] -> [-128, 127]
    
    # Combine L and AB
    Lab = torch.cat([L, AB], dim=1)
    
    # Convert to numpy and process each image in batch
    rgb_images = []
    for i in range(Lab.shape[0]):
        lab_img = Lab[i].permute(1, 2, 0).cpu().numpy()
        rgb_img = cv2.cvtColor(lab_img.astype(np.float32), cv2.COLOR_LAB2RGB)
        rgb_img = np.clip(rgb_img, 0, 1)
        rgb_images.append(torch.FloatTensor(rgb_img).permute(2, 0, 1))
    
    return torch.stack(rgb_images)

def show_results(generator, val_loader, device, num_samples=8, epoch=0):
    """Display colorization results"""
    generator.eval()
    
    with torch.no_grad():
        # Get a batch of validation data
        L_batch, AB_batch = next(iter(val_loader))
        L_batch = L_batch[:num_samples].to(device)
        AB_batch = AB_batch[:num_samples].to(device)
        
        # Generate colorized images
        with autocast():
            fake_AB = generator(L_batch)
        
        # Convert to RGB for visualization
        real_rgb = lab_to_rgb(L_batch, AB_batch)
        fake_rgb = lab_to_rgb(L_batch, fake_AB)
        grayscale = L_batch.repeat(1, 3, 1, 1)  # Convert to 3-channel grayscale
        
        # Create comparison grid
        comparison = torch.cat([
            grayscale.cpu(),
            fake_rgb.cpu(),
            real_rgb.cpu()
        ], dim=0)
        
        grid = make_grid(comparison, nrow=num_samples, normalize=True, padding=2)
        
        # Plot
        plt.figure(figsize=(15, 6))
        plt.imshow(grid.permute(1, 2, 0))
        plt.title(f'Epoch {epoch} - Top: Grayscale, Middle: Generated, Bottom: Ground Truth')
        plt.axis('off')
        plt.tight_layout()
        plt.show()
    
    generator.train()

def calculate_metrics(generator, val_loader, device, num_samples=100):
    """Calculate PSNR and SSIM metrics"""
    generator.eval()
    
    psnr_scores = []
    ssim_scores = []
    
    with torch.no_grad():
        count = 0
        for L_batch, AB_batch in val_loader:
            if count >= num_samples:
                break
                
            L_batch = L_batch.to(device)
            AB_batch = AB_batch.to(device)
            
            with autocast():
                fake_AB = generator(L_batch)
            
            # Convert to RGB
            real_rgb = lab_to_rgb(L_batch, AB_batch)
            fake_rgb = lab_to_rgb(L_batch, fake_AB)
            
            # Calculate metrics for each image in batch
            for i in range(real_rgb.shape[0]):
                if count >= num_samples:
                    break
                    
                real_img = real_rgb[i].permute(1, 2, 0).cpu().numpy()
                fake_img = fake_rgb[i].permute(1, 2, 0).cpu().numpy()
                
                # PSNR
                psnr_score = psnr(real_img, fake_img, data_range=1.0)
                psnr_scores.append(psnr_score)
                
                # SSIM
                ssim_score = ssim(real_img, fake_img, data_range=1.0, channel_axis=2)
                ssim_scores.append(ssim_score)
                
                count += 1
    
    generator.train()
    
    return np.mean(psnr_scores), np.mean(ssim_scores)

print("✅ Visualization functions ready!")

## 🚀 Memory-Efficient Training Loop

In [None]:
# Training configuration
num_epochs = 35  # 30-40 epochs for fast training
save_interval = 10  # Save checkpoints every 10 epochs
display_interval = 5  # Show results every 5 epochs

# Create output directories
os.makedirs('/content/checkpoints', exist_ok=True)
os.makedirs('/content/results', exist_ok=True)

# Training history
history = {
    'g_loss': [],
    'd_loss': [],
    'psnr': [],
    'ssim': []
}

print(f"🚀 Starting training for {num_epochs} epochs...")
print(f"📊 Training on {len(train_loader)} batches per epoch")
print(f"⚡ Using mixed precision training")
print(f"💾 Saving checkpoints every {save_interval} epochs")

# Training loop
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()

for epoch in range(num_epochs):
    epoch_start = torch.cuda.Event(enable_timing=True)
    epoch_end = torch.cuda.Event(enable_timing=True)
    epoch_start.record()
    
    generator.train()
    discriminator.train()
    
    epoch_g_loss = 0.0
    epoch_d_loss = 0.0
    
    # Progress bar
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
    
    for batch_idx, (L_batch, AB_batch) in enumerate(pbar):
        batch_size = L_batch.size(0)
        L_batch = L_batch.to(device)
        AB_batch = AB_batch.to(device)
        
        # ---------------------
        # Train Discriminator
        # ---------------------
        d_optimizer.zero_grad()
        
        with autocast():
            # Real images
            real_output = discriminator(L_batch, AB_batch)
            d_real_loss = gan_loss(real_output, True)
            
            # Fake images
            fake_AB = generator(L_batch)
            fake_output = discriminator(L_batch, fake_AB.detach())
            d_fake_loss = gan_loss(fake_output, False)
            
            d_loss = (d_real_loss + d_fake_loss) * 0.5
        
        scaler.scale(d_loss).backward()
        scaler.step(d_optimizer)
        
        # -----------------
        # Train Generator
        # -----------------
        g_optimizer.zero_grad()
        
        with autocast():
            # GAN loss
            fake_output = discriminator(L_batch, fake_AB)
            g_gan_loss = gan_loss(fake_output, True)
            
            # L1 loss
            g_l1_loss = l1_loss(fake_AB, AB_batch)
            
            # Total generator loss
            g_loss = g_gan_loss + lambda_l1 * g_l1_loss
        
        scaler.scale(g_loss).backward()
        scaler.step(g_optimizer)
        scaler.update()
        
        # Update progress
        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'G_loss': f'{g_loss.item():.4f}',
            'D_loss': f'{d_loss.item():.4f}',
            'L1': f'{g_l1_loss.item():.4f}'
        })
    
    # Calculate average losses
    avg_g_loss = epoch_g_loss / len(train_loader)
    avg_d_loss = epoch_d_loss / len(train_loader)
    
    # Update learning rate
    g_scheduler.step(avg_g_loss)
    d_scheduler.step(avg_d_loss)
    
    # Calculate epoch time
    epoch_end.record()
    torch.cuda.synchronize()
    epoch_time = epoch_start.elapsed_time(epoch_end) / 1000  # Convert to seconds
    
    # Estimate remaining time
    remaining_epochs = num_epochs - epoch - 1
    eta = remaining_epochs * epoch_time / 60  # Convert to minutes
    
    print(f'Epoch [{epoch+1}/{num_epochs}] - G_loss: {avg_g_loss:.4f}, D_loss: {avg_d_loss:.4f} - Time: {epoch_time:.1f}s - ETA: {eta:.1f}min')
    
    # Store history
    history['g_loss'].append(avg_g_loss)
    history['d_loss'].append(avg_d_loss)
    
    # Show results every few epochs
    if (epoch + 1) % display_interval == 0:
        print(f"\n📊 Epoch {epoch+1} Results:")
        show_results(generator, val_loader, device, num_samples=8, epoch=epoch+1)
        
        # Calculate metrics
        psnr_score, ssim_score = calculate_metrics(generator, val_loader, device, num_samples=50)
        history['psnr'].append(psnr_score)
        history['ssim'].append(ssim_score)
        print(f"PSNR: {psnr_score:.2f} dB, SSIM: {ssim_score:.4f}")
    
    # Save checkpoint
    if (epoch + 1) % save_interval == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'g_optimizer': g_optimizer.state_dict(),
            'd_optimizer': d_optimizer.state_dict(),
            'history': history
        }
        torch.save(checkpoint, f'/content/checkpoints/checkpoint_epoch_{epoch+1}.pth')
        print(f"💾 Checkpoint saved at epoch {epoch+1}")

# Calculate total training time
end_time.record()
torch.cuda.synchronize()
total_time = start_time.elapsed_time(end_time) / 1000 / 60  # Convert to minutes

print(f"\n🎉 Training completed in {total_time:.1f} minutes!")

# Save final model
torch.save({
    'generator': generator.state_dict(),
    'discriminator': discriminator.state_dict(),
    'history': history
}, '/content/final_model.pth')

print("💾 Final model saved as 'final_model.pth'")

## 📈 Results Analysis and Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
axes[0, 0].plot(history['g_loss'], label='Generator Loss', color='blue')
axes[0, 0].plot(history['d_loss'], label='Discriminator Loss', color='red')
axes[0, 0].set_title('Training Losses')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Generator loss only
axes[0, 1].plot(history['g_loss'], color='blue')
axes[0, 1].set_title('Generator Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True)

# PSNR
if history['psnr']:
    epochs_with_metrics = range(display_interval, len(history['psnr']) * display_interval + 1, display_interval)
    axes[1, 0].plot(epochs_with_metrics, history['psnr'], color='green', marker='o')
    axes[1, 0].set_title('PSNR Score')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('PSNR (dB)')
    axes[1, 0].grid(True)

# SSIM
if history['ssim']:
    axes[1, 1].plot(epochs_with_metrics, history['ssim'], color='orange', marker='o')
    axes[1, 1].set_title('SSIM Score')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('SSIM')
    axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig('/content/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

# Print final metrics
print("\n📊 Final Training Results:")
print(f"Final Generator Loss: {history['g_loss'][-1]:.4f}")
print(f"Final Discriminator Loss: {history['d_loss'][-1]:.4f}")
if history['psnr']:
    print(f"Best PSNR: {max(history['psnr']):.2f} dB")
    print(f"Final PSNR: {history['psnr'][-1]:.2f} dB")
if history['ssim']:
    print(f"Best SSIM: {max(history['ssim']):.4f}")
    print(f"Final SSIM: {history['ssim'][-1]:.4f}")
print(f"Total Training Time: {total_time:.1f} minutes")

In [None]:
# Generate final comparison images
print("🎨 Generating final comparison images...")

generator.eval()
num_final_samples = 10

with torch.no_grad():
    # Get validation samples
    L_batch, AB_batch = next(iter(val_loader))
    L_batch = L_batch[:num_final_samples].to(device)
    AB_batch = AB_batch[:num_final_samples].to(device)
    
    # Generate colorized images
    with autocast():
        fake_AB = generator(L_batch)
    
    # Convert to RGB
    real_rgb = lab_to_rgb(L_batch, AB_batch)
    fake_rgb = lab_to_rgb(L_batch, fake_AB)
    grayscale = (L_batch + 1.0) / 2.0  # Convert to [0, 1] grayscale
    grayscale = grayscale.repeat(1, 3, 1, 1)
    
    # Create a comprehensive comparison
    fig, axes = plt.subplots(3, num_final_samples, figsize=(20, 6))
    
    for i in range(num_final_samples):
        # Grayscale input
        axes[0, i].imshow(grayscale[i].permute(1, 2, 0).cpu().numpy())
        axes[0, i].set_title('Input (Grayscale)' if i == 0 else '')
        axes[0, i].axis('off')
        
        # Generated colorization
        axes[1, i].imshow(fake_rgb[i].permute(1, 2, 0).cpu().numpy())
        axes[1, i].set_title('Generated' if i == 0 else '')
        axes[1, i].axis('off')
        
        # Ground truth
        axes[2, i].imshow(real_rgb[i].permute(1, 2, 0).cpu().numpy())
        axes[2, i].set_title('Ground Truth' if i == 0 else '')
        axes[2, i].axis('off')
    
    plt.suptitle(f'Final Results - Fast Image Colorization with Pix2Pix GAN', fontsize=16)
    plt.tight_layout()
    plt.savefig('/content/final_results.png', dpi=300, bbox_inches='tight')
    plt.show()

# Calculate final comprehensive metrics
print("\n📊 Calculating final metrics on validation set...")
final_psnr, final_ssim = calculate_metrics(generator, val_loader, device, num_samples=200)

print(f"\n🎯 Final Model Performance:")
print(f"PSNR: {final_psnr:.2f} dB")
print(f"SSIM: {final_ssim:.4f}")
print(f"Model Size: {sum(p.numel() for p in generator.parameters()):,} parameters")
print(f"Training Time: {total_time:.1f} minutes")
print(f"Training Dataset: {len(train_dataset):,} images")
print(f"Image Resolution: 128x128 pixels")

## 🔮 Inference Function for New Images

In [None]:
def colorize_image(image_path, generator, device, output_path=None):
    """
    Colorize a single grayscale image
    
    Args:
        image_path: Path to input image
        generator: Trained generator model
        device: Device to run inference on
        output_path: Optional path to save colorized image
    
    Returns:
        Colorized image as PIL Image
    """
    generator.eval()
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    original_size = image.size
    
    # Transform for model input
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    
    image_tensor = transform(image).unsqueeze(0)
    
    # Convert to LAB
    dataset = ColorizationDataset([], transform=None)
    L, _ = dataset.rgb_to_lab(image_tensor.squeeze(0))
    L = L.unsqueeze(0).to(device)
    
    # Generate colorization
    with torch.no_grad():
        with autocast():
            fake_AB = generator(L)
        
        # Convert back to RGB
        colorized_rgb = lab_to_rgb(L, fake_AB)
        
        # Convert to PIL Image
        colorized_np = colorized_rgb[0].permute(1, 2, 0).cpu().numpy()
        colorized_np = np.clip(colorized_np, 0, 1)
        colorized_pil = Image.fromarray((colorized_np * 255).astype(np.uint8))
        
        # Resize back to original size
        colorized_pil = colorized_pil.resize(original_size, Image.LANCZOS)
        
        # Save if output path provided
        if output_path:
            colorized_pil.save(output_path)
            print(f"💾 Colorized image saved to: {output_path}")
        
        return colorized_pil

def batch_colorize(input_folder, output_folder, generator, device):
    """
    Colorize all images in a folder
    
    Args:
        input_folder: Path to folder containing input images
        output_folder: Path to folder to save colorized images
        generator: Trained generator model
        device: Device to run inference on
    """
    input_path = Path(input_folder)
    output_path = Path(output_folder)
    output_path.mkdir(exist_ok=True)
    
    # Find all image files
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
    image_files = []
    for ext in image_extensions:
        image_files.extend(list(input_path.glob(f'*{ext}')))
        image_files.extend(list(input_path.glob(f'*{ext.upper()}')))
    
    print(f"🎨 Found {len(image_files)} images to colorize")
    
    # Process each image
    for img_file in tqdm(image_files, desc="Colorizing images"):
        output_file = output_path / f"colorized_{img_file.name}"
        try:
            colorize_image(img_file, generator, device, output_file)
        except Exception as e:
            print(f"❌ Error processing {img_file}: {e}")
    
    print(f"✅ Batch colorization complete! Results saved to: {output_folder}")

# Example usage function
def demo_inference():
    """
    Demo function showing how to use the inference functions
    """
    print("🔮 Inference Functions Ready!")
    print("\n📝 Usage Examples:")
    print("\n1. Colorize a single image:")
    print("   colorized = colorize_image('path/to/image.jpg', generator, device, 'output.jpg')")
    print("\n2. Batch colorize a folder:")
    print("   batch_colorize('input_folder/', 'output_folder/', generator, device)")
    print("\n3. Load saved model for inference:")
    print("   checkpoint = torch.load('final_model.pth')")
    print("   generator.load_state_dict(checkpoint['generator'])")
    print("\n💡 The model works best on images similar to the training data")
    print("   (natural scenes, people, objects)")

demo_inference()

# Test inference on a validation image
print("\n🧪 Testing inference on a validation image...")
try:
    # Use a validation image for demonstration
    if val_paths:
        test_image_path = val_paths[0]
        colorized = colorize_image(test_image_path, generator, device, '/content/test_colorized.jpg')
        
        # Display comparison
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        original = Image.open(test_image_path)
        axes[0].imshow(original)
        axes[0].set_title('Original')
        axes[0].axis('off')
        
        # Grayscale version
        grayscale = original.convert('L')
        axes[1].imshow(grayscale, cmap='gray')
        axes[1].set_title('Grayscale Input')
        axes[1].axis('off')
        
        # Colorized
        axes[2].imshow(colorized)
        axes[2].set_title('AI Colorized')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig('/content/inference_demo.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print("✅ Inference test successful!")
except Exception as e:
    print(f"⚠️ Inference test failed: {e}")

## 📥 Download Results

In [None]:
# Create a zip file with all results
import zipfile

print("📦 Creating results package...")

with zipfile.ZipFile('/content/fast_colorization_results.zip', 'w') as zipf:
    # Add model file
    if os.path.exists('/content/final_model.pth'):
        zipf.write('/content/final_model.pth', 'final_model.pth')
    
    # Add result images
    if os.path.exists('/content/final_results.png'):
        zipf.write('/content/final_results.png', 'final_results.png')
    
    # Add training curves
    if os.path.exists('/content/training_curves.png'):
        zipf.write('/content/training_curves.png', 'training_curves.png')
    
    # Add inference demo
    if os.path.exists('/content/inference_demo.png'):
        zipf.write('/content/inference_demo.png', 'inference_demo.png')
    
    # Add test colorized image
    if os.path.exists('/content/test_colorized.jpg'):
        zipf.write('/content/test_colorized.jpg', 'test_colorized.jpg')
    
    # Add latest checkpoint
    checkpoint_files = [f for f in os.listdir('/content/checkpoints/') if f.endswith('.pth')]
    if checkpoint_files:
        latest_checkpoint = sorted(checkpoint_files)[-1]
        zipf.write(f'/content/checkpoints/{latest_checkpoint}', f'checkpoints/{latest_checkpoint}')

print("✅ Results package created: fast_colorization_results.zip")

# Download files (for Colab)
try:
    from google.colab import files
    print("📥 Downloading results...")
    files.download('/content/fast_colorization_results.zip')
    files.download('/content/final_model.pth')
    print("✅ Downloads complete!")
except ImportError:
    print("ℹ️ Not in Colab environment. Files saved locally.")

# Print summary
print(f"\n🎉 Fast Image Colorization Training Complete!")
print(f"\n📊 Summary:")
print(f"• Training Time: {total_time:.1f} minutes")
print(f"• Dataset Size: {len(train_dataset):,} training images")
print(f"• Model Size: {sum(p.numel() for p in generator.parameters()):,} parameters")
print(f"• Final PSNR: {final_psnr:.2f} dB")
print(f"• Final SSIM: {final_ssim:.4f}")
print(f"• Resolution: 128x128 pixels")
print(f"\n🚀 Ready for inference on new images!")
print(f"\n📁 Files created:")
print(f"• final_model.pth - Trained model weights")
print(f"• final_results.png - Sample colorization results")
print(f"• training_curves.png - Training progress plots")
print(f"• fast_colorization_results.zip - Complete results package")