In [None]:
!pip install lpips
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models
import torchvision.datasets as datasets
from torchvision import datasets
from torch.utils.data import DataLoader
from skimage import color, io
import numpy as np
import torch.nn.functional as F
import os
from glob import glob
from lpips import LPIPS  # Learned Perceptual Image Patch Similarity
from skimage.metrics import peak_signal_noise_ratio, structural_similarity



In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# Define a single Fusion Block
class FusionBlock(nn.Module):
    def __init__(self, res_channels, dense_channels, out_channels):
        super(FusionBlock, self).__init__()
        self.conv = nn.Conv2d(res_channels + dense_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, resnet_features, densenet_features):
        # Concatenate along the channel dimension
        fused_features = torch.cat([resnet_features, densenet_features], dim=1)
        return self.relu(self.bn(self.conv(fused_features)))

# Define Encoder Model using ResNet50 and DenseNet121
class Encoder(nn.Module):
    def __init__(self, pretrained=True):
        super(Encoder, self).__init__()
        # Load pretrained ResNet50 and DenseNet121
        self.resnet = models.resnet50(pretrained=pretrained)
        self.densenet = models.densenet121(pretrained=pretrained)

        # Extract feature maps at different resolutions
        self.resnet_layers = nn.ModuleList([nn.Sequential(*list(self.resnet.children())[:i]) for i in range(6)])  # 0 to 5
        self.densenet_layers = nn.ModuleList([nn.Sequential(*list(self.densenet.features.children())[:i]) for i in range(6)])  # 0 to 5

    def forward(self, x):
        res_features = [layer(x) for layer in self.resnet_layers]
        dense_features = [layer(x) for layer in self.densenet_layers]
        return res_features, dense_features

# Define Decoder Block with variable input channels
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        x = self.upsample(x)
        x = self.relu(self.bn(self.conv(x)))
        return x

# Define Full Encoder-Decoder Model with Fusion Blocks
class ColorizationModel(nn.Module):
    def __init__(self):
        super(ColorizationModel, self).__init__()
        self.encoder = Encoder()

        # Define Fusion Blocks
        self.fusion1 = FusionBlock(512, 256, 512)  # ResNet output 512, DenseNet output 256
        self.fusion2 = FusionBlock(1024, 512, 768)  # Adjust channels according to layer depths
        self.fusion3 = FusionBlock(1024, 640, 1024)  # Adjust channels according to layer depths
        self.fusion4 = FusionBlock(2048, 1024, 1024)  # Final fusion before decoder

        # Define Decoder Blocks
        self.decoder1 = DecoderBlock(1024, 512)  # Fusion Block 4
        self.decoder2 = DecoderBlock(1024, 256)  # Fusion Block 3 + output from decoder1
        self.decoder3 = DecoderBlock(512, 128)   # Fusion Block 2 + output from decoder2
        self.decoder4 = DecoderBlock(256, 64)    # Fusion Block 1 + output from decoder3
        self.final_conv = nn.Conv2d(64, 2, kernel_size=3, padding=1)  # Output L/a,b

    def forward(self, x):
        # Pass through encoder
        res_features, dense_features = self.encoder(x)

        # Fusion connections
        fusion1_output = self.fusion1(res_features[3], dense_features[3])  # ResNet block 3 & DenseNet block 3
        fusion2_output = self.fusion2(res_features[4], dense_features[4])  # ResNet block 4 & DenseNet block 4
        fusion3_output = self.fusion3(res_features[5], dense_features[5])  # ResNet block 5 & DenseNet block 5
        fusion4_output = self.fusion4(fusion1_output, fusion2_output)     # Combine outputs of Fusion Block 1 and 2

        # Decode with skip connections
        x = self.decoder1(fusion4_output)  # From FusionBlock 4
        x = self.decoder2(torch.cat([x, fusion3_output], dim=1))  # Skip connection from FusionBlock 3
        x = self.decoder3(torch.cat([x, fusion2_output], dim=1))  # Skip connection from FusionBlock 2
        x = self.decoder4(torch.cat([x, fusion1_output], dim=1))  # Skip connection from FusionBlock 1
        x = self.final_conv(x)  # Final output layer

        return x  # Output L/a,b channels

# Instantiate and print model summary
model = ColorizationModel()
print(model)




ColorizationModel(
  (encoder): Encoder(
    (resnet): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU

In [None]:
# Function to convert image to range [0, 255]
def scale_to_255(x):
    return x * 255

In [None]:
# Preprocessing function to convert RGB to Lab
def rgb2lab(image):
    lab_image = color.rgb2lab(image)  # Convert to Lab color space
    L = lab_image[:, :, 0] / 100.0  # Normalize L channel to [0, 1]
    ab = lab_image[:, :, 1:] / 128.0  # Normalize ab channels to [-1, 1]
    return L, ab

# Convert Lab back to RGB
def lab2rgb(L, ab):
    lab_image = np.zeros((L.shape[0], L.shape[1], 3))  # Create empty Lab image
    lab_image[:, :, 0] = L * 100  # Rescale L channel to [0, 100]
    lab_image[:, :, 1:] = ab * 128  # Rescale ab channels to [-128, 128]
    return color.lab2rgb(lab_image)

# Preprocess the input image for training
def preprocess_image(image, size=224):
    """
    Preprocess input image: resize and convert to Lab color space.
    """
    image = np.array(image)
    L, ab = rgb2lab(image)
    return torch.from_numpy(L).unsqueeze(0).float(), torch.from_numpy(ab).permute(2, 0, 1).float()  # Return L and ab channels as tensors

# CIFAR10 Dataset Loader
def load_cifar10_dataset(batch_size=8, num_workers=2):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to match ResNet/DenseNet input size
        transforms.ToTensor(),
        transforms.Lambda(scale_to_255),  # Replace lambda with a named function
    ])

    train_set = datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform)
    test_set = datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader




In [None]:
# PSNR Calculation
def calculate_psnr(original, generated):
    psnr_value = peak_signal_noise_ratio(original, generated)
    return psnr_value

# SSIM Calculation
def calculate_ssim(original, generated):
    ssim_value = structural_similarity(original, generated, multichannel=True)
    return ssim_value

# LPIPS (Learned Perceptual Image Patch Similarity)
def calculate_lpips(original, generated, lpips_model):
    original_tensor = torch.from_numpy(original).permute(2, 0, 1).unsqueeze(0).float()  # Convert to tensor
    generated_tensor = torch.from_numpy(generated).permute(2, 0, 1).unsqueeze(0).float()
    lpips_value = lpips_model(original_tensor, generated_tensor)
    return lpips_value.item()

# MAE Calculation
def calculate_mae(true_ab, predicted_ab):
    mae_value = torch.nn.functional.l1_loss(torch.from_numpy(true_ab), torch.from_numpy(predicted_ab)).item()
    return mae_value

# LPIPS Model for Perceptual Similarity
lpips_model = LPIPS(net='vgg')  # Load LPIPS model (VGG backbone)


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/vgg.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


In [None]:
# Define Loss Function
def colorization_loss(predicted_ab, true_ab):
    """
    Calculate Mean Squared Error (MSE) between predicted and true ab channels.
    """
    return torch.nn.MSELoss()(predicted_ab, true_ab)

# Training Function
def train_model(model, dataloader, optimizer, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (L, ab) in enumerate(dataloader):
            L, ab = L.to(device), ab.to(device)
            optimizer.zero_grad()
            predicted_ab = model(L)
            loss = torch.nn.MSELoss()(predicted_ab, ab)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")



In [None]:
def test_and_evaluate(model, dataloader, lpips_model):
    """
    Test the model and evaluate using PSNR, SSIM, LPIPS, and MAE.
    """
    model.eval()  # Set the model to evaluation mode
    psnr_scores, ssim_scores, lpips_scores, mae_scores = [], [], [], []

    with torch.no_grad():  # Disable gradient calculations for testing
        for i, (L, ab_true) in enumerate(dataloader):
            L = L.to('cuda' if torch.cuda.is_available() else 'cpu')  # Move to GPU if available
            ab_true = ab_true.to('cuda' if torch.cuda.is_available() else 'cpu')

            predicted_ab = model(L).cpu().numpy()  # Predict ab channels
            ab_true = ab_true.cpu().numpy()  # Get the true ab channels
            L = L.cpu().numpy()  # Convert L channel to NumPy array

            # Convert Lab to RGB for both true and generated images
            true_rgb = lab2rgb(L[0, 0], ab_true[0].transpose(1, 2, 0))
            generated_rgb = lab2rgb(L[0, 0], predicted_ab[0].transpose(1, 2, 0))

            # Calculate metrics
            psnr_value = calculate_psnr(true_rgb, generated_rgb)
            ssim_value = calculate_ssim(true_rgb, generated_rgb)
            lpips_value = calculate_lpips(true_rgb, generated_rgb, lpips_model)
            mae_value = calculate_mae(ab_true, predicted_ab)

            psnr_scores.append(psnr_value)
            ssim_scores.append(ssim_value)
            lpips_scores.append(lpips_value)
            mae_scores.append(mae_value)

    # Compute average scores across all test images
    avg_psnr = sum(psnr_scores) / len(psnr_scores)
    avg_ssim = sum(ssim_scores) / len(ssim_scores)
    avg_lpips = sum(lpips_scores) / len(lpips_scores)
    avg_mae = sum(mae_scores) / len(mae_scores)

    print(f'Average PSNR: {avg_psnr:.4f}')
    print(f'Average SSIM: {avg_ssim:.4f}')
    print(f'Average LPIPS: {avg_lpips:.4f}')
    print(f'Average MAE: {avg_mae:.4f}')


In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load dataset
train_loader, test_loader = load_cifar10_dataset(batch_size=8, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
from torchvision import models

# Initialize model and optimizer
model = ColorizationModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Train the model
train_model(model, train_loader, optimizer, num_epochs=25)

RuntimeError: Given groups=1, weight of size [512, 768, 1, 1], expected input[8, 128, 112, 112] to have 768 channels, but got 128 channels instead