In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import vgg16
try:
    from torchvision.models import VGG16_Weights
    vgg_weights = VGG16_Weights.IMAGENET1K_V1
except ImportError:
    vgg_weights = None

import numpy as np
from PIL import Image
import os
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import lpips
from colormath.color_objects import LabColor, sRGBColor
from colormath.color_conversions import convert_color
from colormath.color_diff import delta_e_cie2000
import matplotlib.pyplot as plt
import pandas as pd
import warnings

# Suppress warnings for LAB conversion edge cases
warnings.filterwarnings('ignore', category=RuntimeWarning)

# Fix for numpy.asscalar removal in numpy 1.23+
if not hasattr(np, 'asscalar'):
    np.asscalar = lambda x: x.item()

In [None]:
# Configuration
MODEL_ID = 'H1'
USE_GAN = True
USE_L1 = True
USE_PERCEPTUAL = False
LAMBDA_L1 = 100.0

BATCH_SIZE = 64
NUM_EPOCHS = 50
LEARNING_RATE = 2e-4
IMAGE_SIZE = 256
NUM_WORKERS = 4

DATA_DIR = '../data/colorize_dataset/data'
MODEL_DIR = f'../models/{MODEL_ID}'
RESULTS_DIR = '../results'

os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Device configuration
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

In [None]:
class ColorizeDataset(Dataset):
    def __init__(self, color_dir, size=256, split='train'):
        self.color_dir = color_dir
        self.size = size
        self.split = split
        self.images = [f for f in os.listdir(color_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
        
        if split == 'train':
            self.transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomRotation(15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                transforms.ToTensor()
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor()
            ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.color_dir, self.images[idx])
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        
        # Convert to LAB
        img_np = img.permute(1, 2, 0).numpy()
        lab = rgb2lab(img_np)
        
        # Normalize
        L = lab[:, :, 0] / 50.0 - 1.0  # [-1, 1]
        ab = lab[:, :, 1:] / 110.0      # [-1, 1]
        
        L = torch.from_numpy(L).float().unsqueeze(0)
        ab = torch.from_numpy(ab).float().permute(2, 0, 1)
        
        return L, ab


In [None]:
class ResidualBlock(nn.Module):
    """Residual block for ResNet generator"""
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.BatchNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.conv_block(x)

class ResNetGenerator(nn.Module):
    """ResNet-based generator with 9 residual blocks"""
    def __init__(self, in_channels=1, out_channels=2, num_residual_blocks=9):
        super(ResNetGenerator, self).__init__()
        
        # Initial convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, 7),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        
        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.BatchNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        
        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]
        
        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        
        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, 7),
            nn.Tanh()
        ]
        
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)

In [None]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(PatchGANDiscriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )
    
    def forward(self, img_L, img_ab):
        img_input = torch.cat((img_L, img_ab), 1)
        return self.model(img_input)

In [None]:
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True):
        super(GANLoss, self).__init__()
        self.loss = nn.MSELoss() if use_lsgan else nn.BCEWithLogitsLoss()
    
    def __call__(self, prediction, target_is_real):
        target = torch.ones_like(prediction) if target_is_real else torch.zeros_like(prediction)
        return self.loss(prediction, target)

In [None]:
# Initialize models
generator = ResNetGenerator().to(device)
discriminator = PatchGANDiscriminator().to(device)

# Loss functions
criterion_L1 = nn.L1Loss().to(device)
gan_loss = GANLoss().to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# Data loaders
train_dataset = ColorizeDataset(os.path.join(DATA_DIR, 'train_color'), IMAGE_SIZE, split='train')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

val_dataset = ColorizeDataset(os.path.join(DATA_DIR, 'val_color'), IMAGE_SIZE, split='val')
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f'Generator parameters: {sum(p.numel() for p in generator.parameters()):,}')
print(f'Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}')


In [None]:
def validate_epoch(generator, val_loader, device):
    """Validate generator on validation set"""
    generator.eval()
    total_val_loss = 0.0
    
    with torch.no_grad():
        for L, ab in val_loader:
            L, ab = L.to(device), ab.to(device)
            fake_ab = generator(L)
            
            # Only L1 loss for validation (no GAN, no perceptual)
            l1_loss = criterion_L1(fake_ab, ab)
            val_loss = l1_loss * LAMBDA_L1
            
            total_val_loss += val_loss.item()
    
    avg_val_loss = total_val_loss / len(val_loader)
    return avg_val_loss

def visualize_validation_samples(generator, val_loader, device, epoch, results_dir):
    """Visualize validation samples"""
    generator.eval()
    with torch.no_grad():
        sample_L, sample_ab = next(iter(val_loader))
        sample_L, sample_ab = sample_L[:8].to(device), sample_ab[:8].to(device)
        fake_ab = generator(sample_L)
        
        # Convert to RGB
        def lab_to_rgb_viz(L, ab):
            L = (L + 1.0) * 50.0
            ab = ab * 110.0
            Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
            rgb_imgs = []
            for lab_img in Lab:
                rgb_img = lab2rgb(lab_img)
                rgb_imgs.append(rgb_img)
            return np.array(rgb_imgs)
        
        pred_rgb = lab_to_rgb_viz(sample_L, fake_ab)
        target_rgb = lab_to_rgb_viz(sample_L, sample_ab)
        grayscale = sample_L.cpu().numpy()
    
    fig, axes = plt.subplots(3, 8, figsize=(20, 8))
    for i in range(8):
        axes[0, i].imshow(grayscale[i, 0], cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Grayscale', fontsize=10)
        
        axes[1, i].imshow(np.clip(pred_rgb[i], 0, 1))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Colorized', fontsize=10)
        
        axes[2, i].imshow(np.clip(target_rgb[i], 0, 1))
        axes[2, i].axis('off')
        if i == 0:
            axes[2, i].set_title('Ground Truth', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, f'val_samples_epoch_{epoch}.png'), dpi=150, bbox_inches='tight')
    plt.close()


In [None]:
# Training loop
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'gen_loss': [], 'disc_loss': []}
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    generator.train()
    discriminator.train()
    epoch_g_loss = 0.0
    epoch_l1_loss = 0.0
    epoch_gan_loss = 0.0
    epoch_d_loss = 0.0
    
    for i, (L, ab) in enumerate(train_loader):
        L, ab = L.to(device), ab.to(device)
        
        # Train Discriminator
        optimizer_D.zero_grad()
        
        fake_ab = generator(L)
        fake_pred = discriminator(L, fake_ab.detach())
        real_pred = discriminator(L, ab)
        
        d_loss_fake = gan_loss(fake_pred, False)
        d_loss_real = gan_loss(real_pred, True)
        d_loss = (d_loss_fake + d_loss_real) * 0.5
        
        d_loss.backward()
        optimizer_D.step()
        
        # Train Generator
        optimizer_G.zero_grad()
        
        fake_ab = generator(L)
        fake_pred = discriminator(L, fake_ab)
        
        # L1 loss
        loss_l1 = criterion_L1(fake_ab, ab)
        
        # GAN loss
        loss_gan_g = gan_loss(fake_pred, True)
        
        # Total generator loss (GAN + L1 only, no Perceptual)
        g_loss = loss_gan_g + LAMBDA_L1 * loss_l1
        
        g_loss.backward()
        optimizer_G.step()
        
        epoch_g_loss += g_loss.item()
        epoch_l1_loss += loss_l1.item()
        epoch_gan_loss += loss_gan_g.item()
        epoch_d_loss += d_loss.item()
    
    # Average losses
    avg_g_loss = epoch_g_loss / len(train_loader)
    avg_l1_loss = epoch_l1_loss / len(train_loader)
    avg_gan_loss = epoch_gan_loss / len(train_loader)
    avg_d_loss = epoch_d_loss / len(train_loader)
    
    # Validation
    avg_val_loss = validate_epoch(generator, val_loader, device)
    
    # Update history
    history['epoch'].append(epoch + 1)
    history['train_loss'].append(avg_g_loss)
    history['val_loss'].append(avg_val_loss)
    history['gen_loss'].append(avg_g_loss)
    history['disc_loss'].append(avg_d_loss)
    
    print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] Train_Loss: {avg_g_loss:.4f} Val_Loss: {avg_val_loss:.4f} (GAN: {avg_gan_loss:.4f}, L1: {avg_l1_loss:.4f}) D_Loss: {avg_d_loss:.4f}')
    
    # Save best model based on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            'loss': avg_val_loss,
        }, os.path.join(MODEL_DIR, 'best_model.pt'))
        print(f'Saved best model at epoch {epoch+1} with val_loss {avg_val_loss:.4f}')
    
    # Visualize validation samples every 10 epochs
    if (epoch + 1) % 10 == 0:
        visualize_validation_samples(generator, val_loader, device, epoch + 1, RESULTS_DIR)
        torch.save({
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            'loss': avg_val_loss,
        }, os.path.join(MODEL_DIR, f'checkpoint_epoch_{epoch+1}.pt'))
    
    # Save training history after each epoch
    history_df = pd.DataFrame(history)
    history_df.to_csv(os.path.join(RESULTS_DIR, f'training_history_{MODEL_ID}.csv'), index=False)

print('Training completed!')


In [None]:
class ColorimetricEvaluator:
    def __init__(self, device):
        self.device = device
        self.lpips_fn = lpips.LPIPS(net='alex').to(device)
    
    def lab_to_rgb(self, L, ab):
        L = (L + 1.0) * 50.0
        ab = ab * 110.0
        Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
        rgb_imgs = []
        for lab_img in Lab:
            rgb_img = lab2rgb(lab_img)
            rgb_imgs.append(rgb_img)
        return np.array(rgb_imgs)
    
    def calculate_psnr(self, pred, target):
        pred = np.clip(pred, 0, 1)
        target = np.clip(target, 0, 1)
        psnr_values = [psnr(t, p, data_range=1.0) for p, t in zip(pred, target)]
        return np.mean(psnr_values)
    
    def calculate_ssim(self, pred, target):
        pred = np.clip(pred, 0, 1)
        target = np.clip(target, 0, 1)
        ssim_values = [ssim(t, p, data_range=1.0, channel_axis=2) for p, t in zip(pred, target)]
        return np.mean(ssim_values)
    
    def calculate_ciede2000(self, pred, target):
        pred = np.clip(pred * 255, 0, 255).astype(np.uint8)
        target = np.clip(target * 255, 0, 255).astype(np.uint8)
        
        delta_e_values = []
        for p, t in zip(pred, target):
            p_lab = rgb2lab(p / 255.0)
            t_lab = rgb2lab(t / 255.0)
            
            p_mean = np.mean(p_lab, axis=(0, 1))
            t_mean = np.mean(t_lab, axis=(0, 1))
            
            p_color = LabColor(p_mean[0], p_mean[1], p_mean[2])
            t_color = LabColor(t_mean[0], t_mean[1], t_mean[2])
            
            delta_e = delta_e_cie2000(p_color, t_color)
            delta_e_values.append(delta_e)
        
        return np.mean(delta_e_values)
    
    def calculate_lpips(self, pred, target):
        pred_tensor = torch.from_numpy(pred).permute(0, 3, 1, 2).float().to(self.device) * 2 - 1
        target_tensor = torch.from_numpy(target).permute(0, 3, 1, 2).float().to(self.device) * 2 - 1
        
        with torch.no_grad():
            lpips_values = self.lpips_fn(pred_tensor, target_tensor)
        
        return lpips_values.mean().item()
    
    def evaluate(self, generator, dataloader):
        generator.eval()
        all_pred_rgb = []
        all_target_rgb = []
        
        with torch.no_grad():
            for L, ab in dataloader:
                L, ab = L.to(self.device), ab.to(self.device)
                fake_ab = generator(L)
                
                pred_rgb = self.lab_to_rgb(L, fake_ab)
                target_rgb = self.lab_to_rgb(L, ab)
                
                all_pred_rgb.append(pred_rgb)
                all_target_rgb.append(target_rgb)
        
        all_pred_rgb = np.concatenate(all_pred_rgb, axis=0)
        all_target_rgb = np.concatenate(all_target_rgb, axis=0)
        
        metrics = {
            'PSNR': self.calculate_psnr(all_pred_rgb, all_target_rgb),
            'SSIM': self.calculate_ssim(all_pred_rgb, all_target_rgb),
            'CIEDE2000': self.calculate_ciede2000(all_pred_rgb, all_target_rgb),
            'LPIPS': self.calculate_lpips(all_pred_rgb, all_target_rgb)
        }
        
        return metrics

In [None]:
# Evaluation
test_dataset = ColorizeDataset(os.path.join(DATA_DIR, 'test_color'), IMAGE_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

evaluator = ColorimetricEvaluator(device)

# Load best model
checkpoint = torch.load(os.path.join(MODEL_DIR, 'best_model.pt'), map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])

print('Evaluating model...')
metrics = evaluator.evaluate(generator, test_loader)

print('\nEvaluation Results:')
for metric, value in metrics.items():
    print(f'{metric}: {value:.4f}')

In [None]:
# Save metrics
metrics_df = pd.DataFrame([metrics])
metrics_df['model'] = MODEL_ID
metrics_df['Best_Val_Loss'] = best_val_loss
metrics_df.to_csv(os.path.join(RESULTS_DIR, f'metrics_{MODEL_ID}.csv'), index=False)
print(f'Metrics saved to {RESULTS_DIR}/metrics_{MODEL_ID}.csv')


In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Generator loss (train and val)
axes[0].plot(history['epoch'], history['train_loss'], label='Train Loss')
axes[0].plot(history['epoch'], history['val_loss'], label='Val Loss')
axes[0].set_title('Generator Loss (Train vs Val)')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True)

# Discriminator loss
axes[1].plot(history['epoch'], history['disc_loss'], label='Discriminator Loss')
axes[1].set_title('Discriminator Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, f'training_history_{MODEL_ID}.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f'Training history plot saved to {RESULTS_DIR}/training_history_{MODEL_ID}.png')


In [None]:
# Visualize results
generator.eval()
with torch.no_grad():
    sample_L, sample_ab = next(iter(test_loader))
    sample_L, sample_ab = sample_L[:8].to(device), sample_ab[:8].to(device)
    fake_ab = generator(sample_L)
    
    evaluator_viz = ColorimetricEvaluator(device)
    pred_rgb = evaluator_viz.lab_to_rgb(sample_L, fake_ab)
    target_rgb = evaluator_viz.lab_to_rgb(sample_L, sample_ab)
    grayscale = sample_L.cpu().numpy()

fig, axes = plt.subplots(3, 8, figsize=(20, 8))
for i in range(8):
    axes[0, i].imshow(grayscale[i, 0], cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Grayscale', fontsize=10)
    
    axes[1, i].imshow(np.clip(pred_rgb[i], 0, 1))
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Colorized', fontsize=10)
    
    axes[2, i].imshow(np.clip(target_rgb[i], 0, 1))
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_title('Ground Truth', fontsize=10)

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, f'colorization_results_{MODEL_ID}.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f'Colorization results saved to {RESULTS_DIR}/colorization_results_{MODEL_ID}.png')