In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import numpy as np
from skimage.color import rgb2lab
from skimage.transform import resize

# Custom Dataset for CIFAR10
class CIFAR10Colorization(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.input_size = 224  # Target size for pre-trained models

        # Define the transformation to resize images
        self.resize_transform = transforms.Resize((self.input_size, self.input_size))
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get the original image from the dataset (RGB format)
        img, _ = self.dataset[idx]

        # Convert PIL image to NumPy array and resize it to the required size (224x224)
        img_resized = np.array(self.resize_transform(img))

        # Convert the resized RGB image to CIE Lab color space
        lab_img = rgb2lab(img_resized).astype(np.float32)

        # Extract the L channel (luminance) and ab channels (chrominance)
        L = lab_img[:, :, 0]  # L channel (luminance)
        ab = lab_img[:, :, 1:]  # ab channels (chrominance)

        # Normalize the L channel to [0, 1]
        L = L / 100.0

        # Normalize the ab channels to [-1, 1]
        ab = (ab + 128) / 255.0 * 2.0 - 1.0

        # Convert L and ab to PyTorch tensors
        L = torch.from_numpy(L).unsqueeze(0)  # Add channel dimension for L
        ab = torch.from_numpy(ab).permute(2, 0, 1)  # Change from HxWx2 to 2xHxW

        return L, ab

# Load the CIFAR10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)

# Create an instance of the CIFAR10Colorization dataset
colorization_dataset = CIFAR10Colorization(train_dataset)

# Create a DataLoader for batching
data_loader = DataLoader(colorization_dataset, batch_size=32, shuffle=True)

# Iterate through the DataLoader (example)
for L, ab in data_loader:
    print(f"Luminance (L) shape: {L.shape}, Chrominance (ab) shape: {ab.shape}")
    break


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48927347.51it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Luminance (L) shape: torch.Size([32, 1, 224, 224]), Chrominance (ab) shape: torch.Size([32, 2, 224, 224])


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# Define the encoder model with max fusion and 1x1 convolutions
class EnsembleEncoder(nn.Module):
    def __init__(self):
        super(EnsembleEncoder, self).__init__()

        # Load pre-trained ResNet50 and DenseNet121 without the FC layers
        self.resnet50 = models.resnet50(pretrained=True)
        self.densenet121 = models.densenet121(pretrained=True)

        # Remove the FC layers from ResNet50 and DenseNet121
        self.resnet50 = nn.Sequential(*list(self.resnet50.children())[:-2])
        self.densenet121 = nn.Sequential(*list(self.densenet121.children())[:-1])

        # 1x1 convolution layers to match the number of channels before fusion
        self.conv_1x1_resnet = nn.ModuleList([
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1),
            nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1),
            nn.Conv2d(in_channels=2048, out_channels=1024, kernel_size=1)
        ])

        self.conv_1x1_densenet = nn.ModuleList([
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1),
            nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1),
            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1)
        ])

        # Specify which layers to extract features from (based on feature map sizes)
        self.resnet_layers = [3, 4, 5, 6]  # ResNet stages for 56x56, 28x28, 14x14, 7x7 feature maps
        self.densenet_layers = [6, 8, 10, 12]  # DenseNet blocks: Block 1, Block 2, Block 3, Block 4

    def forward(self, x):
        # Replicate grayscale image to 3 channels
        x = x.repeat(1, 3, 1, 1)

        # ResNet50 forward pass and feature extraction
        resnet_features = []
        for i, layer in enumerate(self.resnet50):
            x = layer(x)
            if i in self.resnet_layers:
                resnet_features.append(x)

        # DenseNet121 forward pass and feature extraction
        densenet_features = []
        for i, layer in enumerate(self.densenet121):
            x = layer(x)
            if i in self.densenet_layers:
                densenet_features.append(x)

        # Apply 1x1 convolutions and fuse the features using max pooling
        fused_features = []
        for i, (resnet_f, densenet_f) in enumerate(zip(resnet_features, densenet_features)):
            # Apply 1x1 convolution to match the channel dimensions
            resnet_f_conv = self.conv_1x1_resnet[i](resnet_f)
            densenet_f_conv = self.conv_1x1_densenet[i](densenet_f)

            # Max fusion of the feature maps at this level
            fused_f = torch.max(resnet_f_conv, densenet_f_conv)
            fused_features.append(fused_f)

        return fused_features

# Example usage
encoder = EnsembleEncoder()
# input_image = torch.randn(1, 1, 224, 224)  # Example grayscale input (batch_size=1, channels=1, H=224, W=224)
# fused_features = encoder(input_image)

# Check feature sizes
# for i, feature in enumerate(fused_features):
#     print(f"Fused feature map {i+1} shape: {feature.shape}")


In [None]:
import torch
import torch.nn as nn

# Define the decoder network
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        # Decoder block 1: Receives 7x7 features from encoder and previous decoder output
        self.decoder_block1 = self._make_decoder_block(1024, 512)

        # Decoder block 2: Receives 14x14 features from encoder and previous decoder output
        self.decoder_block2 = self._make_decoder_block(512, 256)

        # Decoder block 3: Receives 28x28 features from encoder and previous decoder output
        self.decoder_block3 = self._make_decoder_block(256, 128)

        # Decoder block 4: Receives 56x56 features from encoder and previous decoder output
        self.decoder_block4 = self._make_decoder_block(128, 64)

        # Final output layer to predict ab channels
        self.final_conv = nn.Conv2d(64, 2, kernel_size=1)  # 2 channels for ab output (CIE-Lab color space)

    def _make_decoder_block(self, in_channels, out_channels):
        """
        Helper function to create a decoder block.
        Each block consists of Conv2D, BatchNorm, and Bilinear Upsampling.
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )

    def forward(self, fused_features):
        """
        Forward pass of the decoder.
        Fused features are a list containing the multi-level fused features from the encoder.
        """
        # Starting from the coarsest features (7x7) and moving up
        x = self.decoder_block1(fused_features[-1])  # 7x7 -> 14x14

        # Decoder block 2 with skip connection (fused_features[2] is 14x14)
        x = torch.cat((x, fused_features[-2]), dim=1)
        x = self.decoder_block2(x)  # 14x14 -> 28x28

        # Decoder block 3 with skip connection (fused_features[1] is 28x28)
        x = torch.cat((x, fused_features[-3]), dim=1)
        x = self.decoder_block3(x)  # 28x28 -> 56x56

        # Decoder block 4 with skip connection (fused_features[0] is 56x56)
        x = torch.cat((x, fused_features[-4]), dim=1)
        x = self.decoder_block4(x)  # 56x56 -> 112x112

        # Final output layer to predict ab channels
        ab_output = self.final_conv(x)

        return ab_output

# Example Usage
decoder = Decoder()
# Assuming fused_features from the encoder, which are a list of feature maps
# Here we simulate with random tensors for illustration (shape matching feature maps)
# fused_features = [
#     torch.randn(1, 128, 56, 56),  # 56x56
#     torch.randn(1, 256, 28, 28),  # 28x28
#     torch.randn(1, 512, 14, 14),  # 14x14
#     torch.randn(1, 1024, 7, 7),   # 7x7
# ]

# # Forward pass through decoder
# ab_output = decoder(fused_features)
# print(f"Decoder output shape (ab channels): {ab_output.shape}")  # Should be (1, 2, 112, 112)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np
from skimage import color

# Custom dataset to convert RGB to Lab and process only L as input
class CIFAR10_LabColorization(CIFAR10):
    def __getitem__(self, index):
        img, target = super(CIFAR10_LabColorization, self).__getitem__(index)

        # Convert the RGB image to Lab color space
        img_lab = color.rgb2lab(np.array(img) / 255.0)  # Normalize to [0, 1] before conversion

        # Normalize L channel to range [-1, 1]
        L = img_lab[:, :, 0] / 50.0 - 1.0  # L channel range: 0-100 normalized to [-1, 1]

        # Normalize ab channels to range [-1, 1]
        ab = img_lab[:, :, 1:] / 128.0  # a and b channels range: -128 to 127 normalized to [-1, 1]

        # Return L as input (1 channel), ab as target (2 channels)
        return torch.FloatTensor(L).unsqueeze(0), torch.FloatTensor(ab).permute(2, 0, 1)

# Training configuration
batch_size = 64
learning_rate = 0.001
num_epochs = 25

# Data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to match the input size for the encoder
    transforms.ToTensor(),
])

# Load CIFAR10 dataset (train and test sets)
train_dataset = CIFAR10_LabColorization(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = CIFAR10_LabColorization(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the ensemble encoder and decoder
encoder = EnsembleEncoder()
decoder = Decoder()

# Combine encoder and decoder into a single model
class ColorizationModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ColorizationModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        # Pass the grayscale input through the encoder
        fused_features = self.encoder(x)

        # Pass the fused features through the decoder to get the ab channels
        ab_output = self.decoder(fused_features)

        return ab_output

# Initialize the complete colorization model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ColorizationModel(encoder, decoder).to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()  # Mean Squared Error (MSE) loss for ab channel prediction
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (L, ab) in enumerate(train_loader):
        # Move inputs and targets to GPU if available
        L, ab = L.to(device), ab.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass: grayscale L channel -> ab color channels
        ab_pred = model(L)

        # Compute the loss
        loss = criterion(ab_pred, ab)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if batch_idx % 100 == 99:  # Print every 100 mini-batches
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
            running_loss = 0.0

# Save the trained model
# torch.save(model.state_dict(), 'colorization_model.pth')

print('Finished Training')


Files already downloaded and verified
Files already downloaded and verified


ValueError: the input array must have size 3 along `channel_axis`, got (3, 224, 224)