In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import cv2
from plotting import imshow


In [None]:
class CustomDataset(Dataset):
    def __init__(self, filenames, images_data, transform=None):
        self.filenames = filenames # image name
        self.images_data = images_data # image data
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # Creating image from the dataset
        img_data = self.images_data[idx]
        img = np.array(img_data, dtype=np.uint8)

        R_channel = img[0:1024].reshape(32, 32, 1)
        G_channel = img[1024:2048].reshape(32, 32, 1)
        B_channel = img[2048:].reshape(32, 32, 1)

        image = np.concatenate([R_channel,G_channel,B_channel], axis = 2)
    
        if self.transform:
            image = self.transform(image)

        return image, self.filenames[idx]

# Normalization and transforming to a tensor
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
train_set = unpickle('train')

images_data = train_set[b'data']

filenames = [f.decode('utf-8') for f in train_set[b'filenames']]

dataset = CustomDataset(filenames=filenames, images_data=images_data, transform=transform)
train_loader = DataLoader(dataset, batch_size=8, shuffle = False)

In [None]:
for batch in train_loader:
    image_tensor, filename = batch  # Extract data from the batch
    for idx, image_rgb in enumerate(image_tensor):
        resized_img = F.interpolate(image_rgb.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False).squeeze(0)

        # Convert to NumPy for visualization
        resized_img_numpy = resized_img.permute(1, 2, 0).cpu().numpy()

        # Plot the original and resized images
        fig, ax = plt.subplots(1, 2, figsize=(8, 4))

        ax[0].imshow(image_rgb.permute(1, 2, 0).cpu().numpy())
        ax[0].set_title("Original 32x32 Image")
        ax[0].axis("off")

        ax[1].imshow(resized_img_numpy)
        ax[1].set_title("Resized 224x224 Image")
        ax[1].axis("off")

        plt.show()



        
    break

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class ColorizationUNet(nn.Module):
    def __init__(self):
        super(ColorizationUNet, self).__init__()
        
        # Use a pretrained MobileNetV2 as an encoder
        mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
        self.encoder = mobilenet.features  # Extract feature layers
        
        # Decoder layers to upscale to 32x32
        self.upconv1 = nn.ConvTranspose2d(1280, 256, kernel_size=4, stride=2, padding=1)  # 8x8 -> 16x16
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)  # 16x16 -> 32x32
        self.final_conv = nn.Conv2d(128, 2, kernel_size=3, padding=1)  # Output AB channels (Lab space)
        
    def forward(self, x):
        x = self.encoder(x)
        x = F.relu(self.upconv1(x))
        x = F.relu(self.upconv2(x))
        x = torch.tanh(self.final_conv(x))  # Output in range [-1,1] for AB channels
        return x

# Instantiate model
model = ColorizationUNet()

# Test on a dummy grayscale image (batch_size=1, channels=3, height=32, width=32)
dummy_input = torch.randn(1, 3, 32, 32)  # MobileNet expects 3-channel input
output = model(dummy_input)
print("Output shape:", output.shape)  # Should be (1, 2, 32, 32)