In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from sklearn.preprocessing import LabelEncoder
import numpy as np
from skimage.color import rgb2lab

In [None]:
# Determine device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
import numpy as np
import torch
from skimage.color import rgb2lab

# Function to scale images to [0, 255]
def scale_to_255(x):
    return x * 255

def rgb_to_ab(images):
    a_channels = []
    b_channels = []
    for img in images:  # Iterate through each image in the batch
        img = img.permute(1, 2, 0).cpu().numpy()  # Change to HWC format and move to CPU
        lab_image = rgb2lab(img)  # Convert to CIE-Lab

        # Normalize L, a, and b channels
        L_channel = lab_image[:, :, 0] / 100.0  # Normalize L channel
        a_channel = (lab_image[:, :, 1] + 128) / 255.0  # Normalize a channel
        b_channel = (lab_image[:, :, 2] + 128) / 255.0  # Normalize b channel

        # Collect normalized channels
        a_channels.append(a_channel)
        b_channels.append(b_channel)

    # Stack the a and b channels and convert to tensors
    return (
        torch.tensor(np.stack(a_channels), dtype=torch.float32).to(device),  # Move to device
        torch.tensor(np.stack(b_channels), dtype=torch.float32).to(device)   # Move to device
    )

In [None]:
# CIFAR10 Dataset Loader
def load_cifar10_dataset(batch_size=8, num_workers=2):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to match ResNet/DenseNet input size
        transforms.ToTensor(),  # Convert to tensor and scale to [0, 1]
        transforms.Lambda(scale_to_255),  # Scale to [0, 255]
    ])

    # Load the CIFAR-10 dataset
    train_set = datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform)
    test_set = datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=transform)

    # Create DataLoaders for train and test sets
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class ColorizationModel(nn.Module):
    def __init__(self):
        super(ColorizationModel, self).__init__()

        # Pre-trained ResNet50 encoder
        self.encoder_resnet = nn.Sequential(
            *list(models.resnet50(weights='IMAGENET1K_V1').children())[:-2]
        )

        # Pre-trained DenseNet121 encoder
        self.encoder_densenet = nn.Sequential(
            *list(models.densenet121(weights='IMAGENET1K_V1').children())[:-1]  # Use all layers except the classifier
        )

        # Pooling layer to downsample DenseNet output to 7x7
        self.downsample_densenet = nn.AdaptiveAvgPool2d((7, 7))

        # 1x1 convolutions to match the channels
        self.resnet_conv1x1 = nn.Conv2d(2048, 1024, kernel_size=1)  # Reduce ResNet output channels from 2048 to 1024
        self.densenet_conv1x1 = nn.Conv2d(1024, 1024, kernel_size=1)  # Keep DenseNet output channels at 1024

        # Fusion Blocks (adjust input channels after max pooling)
        self.fusion_block1 = nn.Sequential(
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.fusion_block2 = nn.Sequential(  # Adjust input to 256 instead of 512
            nn.Conv2d(256, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.fusion_block3 = nn.Sequential(  # Adjust input to 256 instead of 512
            nn.Conv2d(256, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.fusion_block4 = nn.Sequential(  # Adjust input to 256 instead of 512
            nn.Conv2d(256, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        # Decoder Blocks
        self.decoder_block1 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )

        self.decoder_block2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )

        self.decoder_block3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )

        self.decoder_block4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )

        self.decoder_block5 = nn.Sequential(
            nn.Conv2d(256, 2, kernel_size=3, padding=1),
            nn.Tanh()  # Use Tanh to match output range [-1, 1]
        )

    def forward(self, x):
        # Encoder
        x_resnet = self.encoder_resnet(x)  # ResNet output
        x_densenet = self.encoder_densenet(x)  # DenseNet output
        x_densenet = self.downsample_densenet(x_densenet)  # Downsample DenseNet output

        # Apply 1x1 convolution to match the channel sizes
        x_resnet = self.resnet_conv1x1(x_resnet)
        x_densenet = self.densenet_conv1x1(x_densenet)

        # Fusion Blocks - (Element-wise Averaging)
        fb1_input = (x_resnet + x_densenet) / 2  # Element-wise averaging (1024 channels)
        fb1_output = self.fusion_block1(fb1_input)

        fb2_input = (fb1_output + fb1_output) / 2  # Use previous output only (256 channels)
        fb2_output = self.fusion_block2(fb2_input)

        fb3_input = (fb2_output + fb2_output) / 2  # Use previous output only (256 channels)
        fb3_output = self.fusion_block3(fb3_input)

        fb4_input = (fb3_output + fb3_output) / 2  # Use previous output only (256 channels)
        fb4_output = self.fusion_block4(fb4_input)

        # Decoder
        db1_output = self.decoder_block1(fb4_output)
        db2_output = self.decoder_block2(db1_output)
        db3_output = self.decoder_block3(db2_output)
        db4_output = self.decoder_block4(db3_output)

        output = self.decoder_block5(db4_output)

        return output


In [None]:
import torch.nn.functional as F

def train(model, train_loader, criterion, optimizer, num_epochs):
    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        for batch_idx, batch in enumerate(train_loader):  # Use enumerate to track batch index
            images = batch[0]  # Assuming the first element is the images
            images = images / 255.0  # Normalize images to [0, 1]
            images = images.to(device)  # Move images to the correct device

            # Convert images to 'a' and 'b' channels
            a_channel, b_channel = rgb_to_ab(images)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)  # Forward pass

            # Combine a and b channels
            target = torch.stack((a_channel, b_channel), dim=1)  # Combine a and b channels

            # Resize target to match model's output size
            target_resized = F.interpolate(target, size=(112, 112), mode='bilinear', align_corners=False)

            # Compute loss
            loss = criterion(outputs, target_resized)  # Ensure shapes match
            loss.backward()  # Backward pass
            optimizer.step()  # Update weights

            # Print loss every 10 batches
            if batch_idx % 500 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')  # Print loss for monitoring

In [None]:
# Example usage
if __name__ == "__main__":
    # Load the CIFAR10 dataset
    batch_size = 8
    train_loader, test_loader = load_cifar10_dataset(batch_size=batch_size)

    # Initialize the model
    model = ColorizationModel().to(device)  # Move model to GPU if available


    # Define loss function and optimizer
    criterion = nn.MSELoss()  # Mean Squared Error Loss
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    # Train the model
    train(model, train_loader, criterion, optimizer, num_epochs=2)

Files already downloaded and verified
Files already downloaded and verified
Epoch [1/2], Batch [0/6250], Loss: 0.3217
Epoch [1/2], Batch [500/6250], Loss: 0.0064
Epoch [1/2], Batch [1000/6250], Loss: 0.0045
Epoch [1/2], Batch [1500/6250], Loss: 0.0063
Epoch [1/2], Batch [2000/6250], Loss: 0.0016
Epoch [1/2], Batch [2500/6250], Loss: 0.0012
Epoch [1/2], Batch [3000/6250], Loss: 0.0022
Epoch [1/2], Batch [3500/6250], Loss: 0.0025
Epoch [1/2], Batch [4000/6250], Loss: 0.0024
Epoch [1/2], Batch [4500/6250], Loss: 0.0004
Epoch [1/2], Batch [5000/6250], Loss: 0.0016
Epoch [1/2], Batch [5500/6250], Loss: 0.0006
Epoch [1/2], Batch [6000/6250], Loss: 0.0022
Epoch [2/2], Batch [0/6250], Loss: 0.0010
Epoch [2/2], Batch [500/6250], Loss: 0.0006
Epoch [2/2], Batch [1000/6250], Loss: 0.0005
Epoch [2/2], Batch [1500/6250], Loss: 0.0020
Epoch [2/2], Batch [2000/6250], Loss: 0.0010
Epoch [2/2], Batch [2500/6250], Loss: 0.0026
Epoch [2/2], Batch [3000/6250], Loss: 0.0011
Epoch [2/2], Batch [3500/6250], 

In [None]:
import torch

# After training is complete
model_save_path = "colorization_model_avg.pth"

# Save the model's state_dict (recommended way to save models in PyTorch)
torch.save(model.state_dict(), model_save_path)

print(f"Model saved to {model_save_path}")

Model saved to colorization_model_avg.pth
