In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
import matplotlib.pyplot as plt
from tqdm import tqdm
import lpips
import os
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import torch.serialization

# Configuration
NUM_BINS = 2  # Use 2 bins for accuracy calculation

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define U-Net for colorization
class UNetColorization(nn.Module):
    def __init__(self):
        super(UNetColorization, self).__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        # Decoder
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.dec3 = nn.Conv2d(256, 3, 3, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        d1 = self.dec1(e3)
        d2 = self.dec2(torch.cat([d1, e2], dim=1))
        d3 = self.dec3(torch.cat([d2, e1], dim=1))
        return self.tanh(d3)

# Allowlist necessary modules for safe model loading
torch.serialization.add_safe_globals([
    UNetColorization,
    torch.nn.Sequential,
    torch.nn.Conv2d,
    torch.nn.BatchNorm2d,
    torch.nn.Dropout,
    torch.nn.ReLU,
    torch.nn.ConvTranspose2d,
    torch.nn.Tanh
])

# Make sure models folder exists in current directory
os.makedirs("models", exist_ok=True)

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

# Function to convert RGB to grayscale
def rgb_to_gray(img):
    return img.mean(dim=1, keepdim=True)

# Perceptual Loss using LPIPS
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)

# Training function
def train_model(model, train_loader, loss_type, epochs=30):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    mse_loss = nn.MSELoss()
    train_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, (images, _) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')):
            images = images.to(device)
            grayscale_images = rgb_to_gray(images).to(device)

            optimizer.zero_grad()
            outputs = model(grayscale_images)
            if loss_type == 'mse':
                loss = mse_loss(outputs, images)
            elif loss_type == 'perceptual':
                loss = loss_fn_vgg(outputs, images).mean()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        print(f'Epoch {epoch+1}, {loss_type.upper()} Loss: {epoch_loss:.4f}')

    return train_losses

# Evaluation function with per-channel accuracy
def evaluate_model(model, test_loader, loss_type):
    model.eval()
    mse_loss = nn.MSELoss()
    total_mse = 0.0
    total_perceptual = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    all_accuracies = []
    num_samples = 0

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f'Evaluating {loss_type}'):
            images = images.to(device)
            grayscale = rgb_to_gray(images).to(device)
            outputs = model(grayscale)

            # Compute MSE and Perceptual Loss
            mse = mse_loss(outputs, images).item()
            perceptual = loss_fn_vgg(outputs, images).mean().item()
            total_mse += mse * images.size(0)
            total_perceptual += perceptual * images.size(0)

            # Compute PSNR and SSIM
            images_np = images.cpu().numpy().transpose(0, 2, 3, 1)
            outputs_np = outputs.cpu().numpy().transpose(0, 2, 3, 1)
            for i in range(images_np.shape[0]):
                total_psnr += psnr(images_np[i], outputs_np[i], data_range=1.0)
                total_ssim += ssim(images_np[i], outputs_np[i], data_range=1.0, channel_axis=-1, win_size=3)
            num_samples += images.size(0)

            # Per-channel accuracy
            bins = np.linspace(0, 1, NUM_BINS + 1)  # Adjusted for [0, 1] range
            for c in range(3):  # R, G, B channels
                preds_c = outputs[:, c, :, :].cpu().numpy().flatten()
                targets_c = images[:, c, :, :].cpu().numpy().flatten()
                preds_c = np.clip(preds_c, 0, 1)
                targets_c = np.clip(targets_c, 0, 1)
                preds_binned_c = np.digitize(preds_c, bins)
                targets_binned_c = np.digitize(targets_c, bins)
                cm_c = confusion_matrix(targets_binned_c, preds_binned_c)
                all_accuracies.append(np.sum(np.diag(cm_c)) / np.sum(cm_c))

    avg_mse = total_mse / num_samples
    avg_perceptual = total_perceptual / num_samples
    avg_psnr = total_psnr / num_samples
    avg_ssim = total_ssim / num_samples
    avg_accuracy = np.mean(all_accuracies)
    cm = confusion_matrix(np.digitize(np.clip(outputs_np.flatten(), 0, 1), bins), np.digitize(np.clip(images_np.flatten(), 0, 1), bins))
    precision, recall, f1, _ = precision_recall_fscore_support(
        np.digitize(np.clip(images_np.flatten(), 0, 1), bins), np.digitize(np.clip(outputs_np.flatten(), 0, 1), bins), average='weighted', zero_division=0
    )

    return avg_mse, avg_perceptual, avg_psnr, avg_ssim, cm, precision, recall, f1, avg_accuracy

# Visualize samples
def visualize_samples(model, test_loader, num_samples=5, model_name='model'):
    model.eval()
    with torch.no_grad():
        images, _ = next(iter(test_loader))
        grayscale = rgb_to_gray(images).to(device)
        outputs = model(grayscale)
        for i in range(num_samples):
            plt.figure(figsize=(12, 3))
            plt.subplot(1, 3, 1)
            plt.imshow(grayscale[i, 0].cpu().numpy(), cmap='gray')
            plt.title('Grayscale')
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow((images[i].cpu().numpy().transpose(1, 2, 0)))
            plt.title('Ground Truth')
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow((outputs[i].cpu().numpy().transpose(1, 2, 0)))
            plt.title('Predicted')
            plt.axis('off')
            plt.savefig(f'models/sample_{i}_{model_name}.png')
            plt.show()

# Check if models exist, load or train
def load_or_train_model(model, model_path, weights_path, train_loader, loss_type):
    if os.path.exists(model_path):
        print(f'Loading existing {loss_type.upper()} model...')
        try:
            model = torch.load(model_path)
            return model, []
        except Exception as e:
            print(f'Failed to load model: {e}. Loading weights instead...')
            model.load_state_dict(torch.load(weights_path))
            return model, []
    else:
        print(f'Training {loss_type.upper()} model...')
        losses = train_model(model, train_loader, loss_type)
        torch.save(model.state_dict(), weights_path)
        torch.save(model, model_path)
        return model, losses

# Paths to save models inside ./models folder
mse_model_path = '../models/model_mse.pth'
mse_weights_path = '../models/model_mse_weights.pth'
perceptual_model_path = '../models/model_perceptual.pth'
perceptual_weights_path = '../models/model_perceptual_weights.pth'

# MSE model
model_mse = UNetColorization()
model_mse, mse_losses = load_or_train_model(model_mse, mse_model_path, mse_weights_path, train_loader, 'mse')
# Evaluate MSE model
mse_metrics = evaluate_model(model_mse, test_loader, 'mse')
print(f'MSE Model - MSE Loss: {mse_metrics[0]:.4f}, Perceptual Loss: {mse_metrics[1]:.4f}, PSNR: {mse_metrics[2]:.4f}, SSIM: {mse_metrics[3]:.4f}')
print(f'Accuracy: {mse_metrics[8]:.4f}, Precision: {mse_metrics[5]:.4f}, Recall: {mse_metrics[6]:.4f}, F1: {mse_metrics[7]:.4f}')
visualize_samples(model_mse, test_loader, model_name='mse')

# Perceptual model
model_perceptual = UNetColorization()
model_perceptual, perceptual_losses = load_or_train_model(model_perceptual, perceptual_model_path, perceptual_weights_path, train_loader, 'perceptual')
# Evaluate Perceptual model
perceptual_metrics = evaluate_model(model_perceptual, test_loader, 'perceptual')
print(f'Perceptual Model - MSE Loss: {perceptual_metrics[0]:.4f}, Perceptual Loss: {perceptual_metrics[1]:.4f}, PSNR: {perceptual_metrics[2]:.4f}, SSIM: {perceptual_metrics[3]:.4f}')
print(f'Accuracy: {perceptual_metrics[8]:.4f}, Precision: {perceptual_metrics[5]:.4f}, Recall: {perceptual_metrics[6]:.4f}, F1: {perceptual_metrics[7]:.4f}')
visualize_samples(model_perceptual, test_loader, model_name='perceptual')

# Plot training losses if training occurred
if mse_losses or perceptual_losses:
    plt.figure(figsize=(10, 5))
    if mse_losses:
        plt.plot(mse_losses, label='MSE Loss')
    if perceptual_losses:
        plt.plot(perceptual_losses, label='Perceptual Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Comparison')
    plt.legend()
    plt.savefig('models/loss_comparison.png')
    plt.show()

# Plot confusion matrices
plt.figure(figsize=(8, 6))
plt.imshow(mse_metrics[4], cmap='Blues')
plt.title('Confusion Matrix - MSE Model')
plt.colorbar()
plt.savefig('models/cm_mse.png')
plt.show()

plt.figure(figsize=(8, 6))
plt.imshow(perceptual_metrics[4], cmap='Blues')
plt.title('Confusion Matrix - Perceptual Model')
plt.colorbar()
plt.savefig('models/cm_perceptual.png')
plt.show()

# Comparison and Explanation
print('Comparison of Loss Functions:')
print('MSE Loss: Optimizes pixel-wise accuracy, leading to smoother outputs but potentially less vibrant colors.')
print('Perceptual Loss: Emphasizes high-level features, improving visual quality but possibly sacrificing pixel-wise accuracy.')
print('U-Net with batch normalization and dropout enhances detail preservation. PSNR/SSIM are key for quality assessment.')


  3%|▎         | 5.57M/170M [00:04<02:13, 1.24MB/s]


KeyboardInterrupt: 