In [1]:
from skimage import color
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


In [2]:
import os
import glob
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader

In [3]:
# Dataset class for DIV2k
class ColorizationDataset(Dataset):
    def __init__(self, image_list):
        self.image_list = image_list
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),        # Resize to 224x224
            transforms.ToTensor()                 # Convert to Tensor
        ])

    def __len__(self):
        return len(self.image_list)



    def __getitem__(self, idx):
        # Load image from file path and ensure it's in RGB format
        img = Image.open(self.image_list[idx]).convert("RGB")

        # Apply the transformation
        img = self.transform(img)

        # Convert the resized RGB image to Lab color space
        img_lab = color.rgb2lab(img.permute(1, 2, 0).numpy()).astype(np.float32)

        # Normalize L channel to [-1, 1] and ab channels to [-1, 1]
        img_lab[:, :, 0] = img_lab[:, :, 0] / 50.0 - 1  # Normalize L channel to [-1, 1]
        img_lab[:, :, 1:] = img_lab[:, :, 1:] / 128.0  # Normalize a and b channels to [-1, 1]

        # Separate L and ab channels
        L = img_lab[:, :, 0:1]  # Input: L channel
        ab = img_lab[:, :, 1:]  # Target: ab channels

        # Convert to PyTorch tensor
        L = torch.from_numpy(L).permute(2, 0, 1)  # HxWx1 -> 1xHxW
        ab = torch.from_numpy(ab).permute(2, 0, 1)  # HxWx2 -> 2xHxW

        return L, ab


In [4]:
import os
import random
from torch.utils.data import DataLoader

def load_data(batch_size):
    def get_limited_images(root_path, limit=50):
        all_images = []
        for subdir, _, files in os.walk(root_path):
            # Filter only .png files
            png_files = [os.path.join(subdir, f) for f in files if f.endswith(".png")]
            # Randomly select 'limit' images from each subfolder
            all_images.extend(random.sample(png_files, min(len(png_files), limit)))
        return all_images

    train_path = "/content/drive/MyDrive/Colab Notebooks/Machine Learning Lab/Project/CIFAR10_train"
    test_path = "/content/drive/MyDrive/Colab Notebooks/Machine Learning Lab/Project/CIFAR10_test"

    # Get limited number of image file paths from train and test directories
    train_images = get_limited_images(train_path, limit=50)
    test_images = get_limited_images(test_path, limit=50)

    # Check if images were collected successfully
    print(f"Number of training images: {len(train_images)}")
    print(f"Number of testing images: {len(test_images)}")

    # Create custom ColorizationDataset
    train_data = ColorizationDataset(train_images)
    test_data = ColorizationDataset(test_images)

    # Create DataLoaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


In [5]:

# Fusion block to combine features from ResNet and DenseNet
class FusionBlock(nn.Module):
    def __init__(self, in_channels_1, in_channels_2):
        super(FusionBlock, self).__init__()
        # 1x1 convolution to unify the channel size to 256 for both feature maps
        self.conv1 = nn.Conv2d(in_channels_1, 256, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels_2, 256, kernel_size=1)
        # Another 1x1 convolution to reduce the concatenated output back to 256 channels
        self.reduce_channels = nn.Conv2d(512, 256, kernel_size=1)

    def forward(self, x1, x2):
        # print(f'FusionBlock - Input x1 shape: {x1.shape}, Input x2 shape: {x2.shape}')
        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        # Concatenate the two feature maps along the channel dimension
        x = torch.cat([x1, x2], dim=1)
        # print(f'FusionBlock - After concat shape: {x.shape}')
        # Reduce the concatenated output back to 256 channels
        x = self.reduce_channels(x)
        # print(f'FusionBlock - After reducing channels shape: {x.shape}')
        return x

# Decoder block with upsampling and unified output to 256 channels
class DecoderBlock(nn.Module):
    def __init__(self, in_channels=256, out_channels=256):
        super(DecoderBlock, self).__init__()
        # Expecting 512 channels from the concatenated feature maps, reducing to 256
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x, skip=None):
        # print(f'DecoderBlock - Input x shape: {x.shape}')
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        # print(f'DecoderBlock - After conv shape: {x.shape}')
        x = self.upsample(x)
        # print(f'DecoderBlock - After upsample shape: {x.shape}')

        if skip is not None:
            # print(f'DecoderBlock - Skip connection shape: {skip.shape}')
            # Upsample skip connection if needed to match spatial size
            if skip.shape[2:] != x.shape[2:]:
                skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=False)
                # print(f'DecoderBlock - After skip upsample shape: {skip.shape}')
            x = x + skip
            # print(f'DecoderBlock - After adding skip shape: {x.shape}')

        return x

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision import models



In [7]:

# Colorization Model using ResNet50 and DenseNet121
class ColorizationModel(nn.Module):
    def __init__(self):
        super(ColorizationModel, self).__init__()

        # Pretrained ResNet50 and DenseNet121 as encoders
        self.resnet = models.resnet50(pretrained=True)
        self.densenet = models.densenet121(pretrained=True)

        # Define the layers from which you want to extract features
        return_nodes_resnet = {
            'layer1': 'resnet_feats_56',   # Feature map size 56x56
            'layer2': 'resnet_feats_28',   # Feature map size 28x28
            'layer3': 'resnet_feats_14',   # Feature map size 14x14
            'layer4': 'resnet_feats_7'     # Feature map size 7x7
        }

        return_nodes_densenet = {
            'features.denseblock1': 'densenet_feats_56',   # Feature map size 56x56
            'features.denseblock2': 'densenet_feats_28',   # Feature map size 28x28
            'features.denseblock3': 'densenet_feats_14',   # Feature map size 14x14
            'features.denseblock4': 'densenet_feats_7'     # Feature map size 7x7
        }

        # Create feature extractors
        self.resnet_extractor = create_feature_extractor(self.resnet, return_nodes=return_nodes_resnet)
        self.densenet_extractor = create_feature_extractor(self.densenet, return_nodes=return_nodes_densenet)

        # Fusion blocks for multi-level features (each output after concatenation is 512 channels)
        self.fusion_56 = FusionBlock(256, 256)  # Concatenate to get 512 channels
        self.fusion_28 = FusionBlock(512, 512)  # Concatenate to get 512 channels
        self.fusion_14 = FusionBlock(1024, 1024)  # Concatenate to get 512 channels
        self.fusion_7 = FusionBlock(2048, 1024)  # Concatenate to get 512 channels

        # Decoder blocks with upsampling
        self.decoder_7 = DecoderBlock(256)   # Input 512 from fusion_7
        self.decoder_14 = DecoderBlock(256)  # Input 512 from fusion_14
        self.decoder_28 = DecoderBlock(256)  # Input 512 from fusion_28
        self.decoder_56 = DecoderBlock(256)  # Input 512 from fusion_56

        # Final output layer (predict ab channels)
        self.final_conv = nn.Conv2d(256, 2, kernel_size=3, padding=1)
        self.upsample_final = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x):
        # Replicate grayscale input to 3 channels for ResNet and DenseNet
        x_rgb = x.repeat(1, 3, 1, 1)
        # print(f'Forward Pass - Initial input shape: {x_rgb.shape}')

        # Extract features from different stages of ResNet and DenseNet
        resnet_feats = self.resnet_extractor(x_rgb)
        densenet_feats = self.densenet_extractor(x_rgb)

        # Get features for each level
        resnet_feats_56 = resnet_feats['resnet_feats_56']
        resnet_feats_28 = resnet_feats['resnet_feats_28']
        resnet_feats_14 = resnet_feats['resnet_feats_14']
        resnet_feats_7 = resnet_feats['resnet_feats_7']

        densenet_feats_56 = densenet_feats['densenet_feats_56']
        densenet_feats_28 = densenet_feats['densenet_feats_28']
        densenet_feats_14 = densenet_feats['densenet_feats_14']
        densenet_feats_7 = densenet_feats['densenet_feats_7']

        # Fusion of multi-level features
        fusion_56 = self.fusion_56(resnet_feats_56, densenet_feats_56)
        fusion_28 = self.fusion_28(resnet_feats_28, densenet_feats_28)
        fusion_14 = self.fusion_14(resnet_feats_14, densenet_feats_14)
        fusion_7 = self.fusion_7(resnet_feats_7, densenet_feats_7)

        # Decoder with skip connections and unified channels
        decoded_7 = self.decoder_7(fusion_7)            # 7x7 -> 14x14
        decoded_14 = self.decoder_14(decoded_7, fusion_14)  # 14x14 -> 28x28
        decoded_28 = self.decoder_28(decoded_14, fusion_28)  # 28x28 -> 56x56
        decoded_56 = self.decoder_56(decoded_28, fusion_56)  # 56x56 -> Final output

        # Final prediction for ab channels
        ab_pred = self.final_conv(decoded_56)
        ab_pred = self.upsample_final(ab_pred)
        # print(f'Forward Pass - Final output shape: {ab_pred.shape}')

        return ab_pred


In [8]:
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

In [9]:
def train_model(model, train_loader, test_loader, num_epochs=2, lr=0.001):
    # Define optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()  # Assuming you're using MSE loss for the ab channels

    model.train()  # Set model to training mode
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (L, ab) in enumerate(tqdm(train_loader)):
            # Move data to the appropriate device (GPU or CPU)
            L = L.to(device)
            ab = ab.to(device)

            # Forward pass
            ab_pred = model(L)

            # Compute loss
            loss = criterion(ab_pred, ab)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accumulate loss for reporting
            running_loss += loss.item()

            # Print the loss every 500 batches
            if (i + 1) % 8 == 0:
                avg_loss = running_loss / 8
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {avg_loss:.4f}')
                running_loss = 0.0

        # Validation at the end of each epoch
        validate_model(model, test_loader)

def validate_model(model, test_loader):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    criterion = nn.MSELoss()

    with torch.no_grad():
        for L, ab in test_loader:
            L = L.to(device)
            ab = ab.to(device)

            # Forward pass
            ab_pred = model(L)

            # Compute loss
            loss = criterion(ab_pred, ab)
            running_loss += loss.item()

    avg_loss = running_loss / len(test_loader)
    print(f'Validation Loss: {avg_loss:.4f}')
    model.train()  # Set model back to training mode after validation


In [10]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load CIFAR-10 data
    batch_size = 16  # You can adjust this based on your system
    train_loader, test_loader = load_data(batch_size)

    # Initialize the model
    model = ColorizationModel().to(device)

    # Train the model
    train_model(model, train_loader, test_loader, num_epochs=3, lr=0.001)


Number of training images: 500
Number of testing images: 500


 25%|██▌       | 8/32 [00:41<02:02,  5.10s/it]

Epoch [1/3], Step [8/32], Loss: 0.8121


 50%|█████     | 16/32 [01:22<01:14,  4.66s/it]

Epoch [1/3], Step [16/32], Loss: 0.0358


 75%|███████▌  | 24/32 [02:04<00:40,  5.08s/it]

Epoch [1/3], Step [24/32], Loss: 0.0242


100%|██████████| 32/32 [02:39<00:00,  4.98s/it]

Epoch [1/3], Step [32/32], Loss: 0.0259





Validation Loss: 0.2419


 25%|██▌       | 8/32 [00:05<00:15,  1.55it/s]

Epoch [2/3], Step [8/32], Loss: 0.0234


 50%|█████     | 16/32 [00:10<00:09,  1.62it/s]

Epoch [2/3], Step [16/32], Loss: 0.0166


 75%|███████▌  | 24/32 [00:15<00:05,  1.56it/s]

Epoch [2/3], Step [24/32], Loss: 0.0127


100%|██████████| 32/32 [00:19<00:00,  1.60it/s]


Epoch [2/3], Step [32/32], Loss: 0.0118
Validation Loss: 0.0280


 25%|██▌       | 8/32 [00:05<00:14,  1.64it/s]

Epoch [3/3], Step [8/32], Loss: 0.0115


 50%|█████     | 16/32 [00:09<00:09,  1.63it/s]

Epoch [3/3], Step [16/32], Loss: 0.0129


 75%|███████▌  | 24/32 [00:15<00:05,  1.56it/s]

Epoch [3/3], Step [24/32], Loss: 0.0127


100%|██████████| 32/32 [00:19<00:00,  1.65it/s]


Epoch [3/3], Step [32/32], Loss: 0.0122
Validation Loss: 0.0157


In [11]:
torch.save(model.state_dict(), '/content/drive/MyDrive/Colab Notebooks/Machine Learning Lab/Project/final_intnskip_cat_cifar10_clarity.pth')