In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
bardiaardakanian_mmsample_path = kagglehub.dataset_download('bardiaardakanian/mmsample')

print('Data source import complete.')


#**IMAGE COLOURIZATION**

###**Here is the outline of the project**



1.Download the dataset

2.Explore & analyse the dataset

3.prepare the dataset for the ML training

4.Train the hardcoded & baseline models

5.make the predictions

6.perform feature engineering

7.Train & evaluate different models

8.Train on a GPU with the entire dataset

10.Document & publish the project online




##1.Download the Dataset

#####steps:


*   install required libraries
*   Download data from kaggle
*   view dataset files
*   load the training set with pandas
*   load test set with pandas



In [None]:
data_dir="/kaggle/input/mmsample"

###view dataset files
Lets look at the size of the files

In [None]:
import os
print(os.listdir(data_dir))

###✅ 1. Install & Import Libraries

In [None]:
import os
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from skimage.color import rgb2lab


###✅ 2. Check for GPU

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


###✅ 3. Load Image Paths

In [None]:
# This should point to the folder containing 'train2017' and 'val2017'
dataset_path = data_dir

# Get all image file paths from train and val folders
image_paths = glob.glob(os.path.join(dataset_path, "train2017", "*.jpg"))
image_paths += glob.glob(os.path.join(dataset_path, "val2017", "*.jpg"))

print(f"Total images found: {len(image_paths)}")  # Should be 10,000


###✅ 4. Split into Train and Validation

In [None]:
# Shuffle and pick 8000 for training, 2000 for validation
np.random.seed(42)
chosen_paths = np.random.choice(image_paths, 10000, replace=False)

train_paths = chosen_paths[:8000]
val_paths = chosen_paths[8000:]

print("Train images:", len(train_paths))
print("Validation images:", len(val_paths))


###✅ 5. Visualize Sample Images

In [None]:
# Plot 4 training images
plt.figure(figsize=(10, 10))
for i in range(4):
    img = Image.open(train_paths[i])
    plt.subplot(4, 4, i+1)
    plt.imshow(img)
    plt.axis("off")
plt.tight_layout()
plt.show()


### 2.Prepare the dataset for the ml model

###✅ 1. Define Custom Dataset Class

In [None]:
# Desired size to resize all images to
IMAGE_SIZE = 256

class ImageColorizationDataset(Dataset):
    """
    Custom PyTorch Dataset for image colorization tasks.

    - Loads RGB images from a list of file paths
    - Converts them to LAB color space
    - Normalizes L channel (input) and ab channels (target)
    """

    def __init__(self, image_list, is_train=True):
        # Define image preprocessing and optional data augmentation
        self.transforms = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),  # Resize to 256x256
            transforms.RandomHorizontalFlip() if is_train else transforms.Lambda(lambda x: x)
            # Randomly flip images only during training
        ])
        self.image_list = image_list      # List of file paths to images
        self.is_train = is_train          # True if training mode, False for validation/test

    def __len__(self):
        # Return total number of images in dataset
        return len(self.image_list)

    def __getitem__(self, idx):
        # Load the image and ensure it is in RGB mode
        img = Image.open(self.image_list[idx]).convert("RGB")

        # Apply resizing and flipping transforms
        img = self.transforms(img)

        # Convert PIL Image to NumPy array (H x W x C)
        img = np.array(img)

        # Convert the RGB image to LAB color space
        # LAB separates brightness (L) from color (a, b), helpful for colorization
        lab_img = rgb2lab(img).astype("float32")  # Shape: [H, W, 3]

        # Convert to PyTorch tensor and change shape to [C, H, W]
        lab_img = torch.from_numpy(lab_img).permute(2, 0, 1)  # Shape: [3, H, W]

        # Normalize the L (lightness) channel to [-1, 1]
        # Original range is [0, 100] → (L / 50) - 1 maps it to [-1, 1]
        L = lab_img[[0]] / 50.0 - 1.0  # Shape: [1, H, W]

        # Normalize the ab (color) channels to [-1, 1]
        # Original range is roughly [-110, 110] → divide by 110
        ab = lab_img[1:] / 110.0       # Shape: [2, H, W]

        # Return both the input (L) and target (ab) as a dictionary
        return {'L': L, 'ab': ab}


###✅ 2. Create DataLoaders

In [None]:
from torch.utils.data import DataLoader

def get_dataloader(image_list, is_train=True, batch_size=16):
    """
    Creates and returns a DataLoader for the ImageColorizationDataset.

    Parameters:
    - image_list (list): List of image file paths to load.
    - is_train (bool): Whether the loader is for training (enables shuffling and augmentation).
    - batch_size (int): Number of samples per batch.

    Returns:
    - DataLoader: PyTorch DataLoader object that yields batches of {'L': ..., 'ab': ...}.
    """

    # Create the custom dataset
    dataset = ImageColorizationDataset(image_list, is_train=is_train)

    # Create and return a DataLoader with:
    # - shuffling for training
    # - no shuffling for validation/testing
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_train,      # Only shuffle during training
        num_workers=2,         # Number of parallel data loading threads
        pin_memory=True        # Faster data transfer to CUDA (if using GPU)
    )

    return loader


In [None]:
# Create training and validation DataLoaders using the helper function
# These will load batches of L and ab values from the dataset
train_dl = get_dataloader(image_list=train_paths, is_train=True)
val_dl = get_dataloader(image_list=val_paths, is_train=False)


In [None]:
# 🧪 Test the DataLoader by retrieving one batch from the training set
data = next(iter(train_dl))  # Get the first batch from train DataLoader

# Separate the L and ab channels from the batch
Ls = data['L']    # Grayscale input images, shape: [batch_size, 1, 256, 256]
abs_ = data['ab'] # Color channels, shape: [batch_size, 2, 256, 256]

# Print the shape of L and ab to confirm
print("L channel shape:", Ls.shape)   # Expected: torch.Size([16, 1, 256, 256])
print("ab channel shape:", abs_.shape)  # Expected: torch.Size([16, 2, 256, 256])

# Print how many batches are in each DataLoader
print("Batches: train =", len(train_dl), ", val =", len(val_dl))


##2.Unet generator

In [None]:
import torch
import torch.nn as nn

class UNetBlock(nn.Module):
    def __init__(self, outer_channels, inner_channels, submodule=None, input_channels=None,
                 use_dropout=False, innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost

        if input_channels is None:
            input_channels = outer_channels

        downconv = nn.Conv2d(input_channels, inner_channels, kernel_size=4, stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, inplace=True)
        downnorm = nn.BatchNorm2d(inner_channels)

        uprelu = nn.ReLU(inplace=True)
        upnorm = nn.BatchNorm2d(outer_channels)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_channels * 2, outer_channels, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_channels, outer_channels, kernel_size=4, stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_channels * 2, outer_channels, kernel_size=4, stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if use_dropout:
                up.append(nn.Dropout(0.5))
            model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        return torch.cat([x, self.model(x)], dim=1)


### 🔧 `UNetBlock`: Recursive U-Net Encoder-Decoder Unit

This class defines a **modular encoder-decoder block** with skip connections. It's the building block of the full U-Net model.

#### 🧠 Key Features:
- Uses **Conv2D + BatchNorm + LeakyReLU** for downsampling.
- Uses **ConvTranspose2D + BatchNorm + ReLU** for upsampling.
- **Recursive definition** allows nesting submodules to form the U-Net.
- Uses **`torch.cat(...)`** to apply skip connections between encoder and decoder layers.

#### 🔁 Modes:
- `outermost=True`: Last layer, applies `Tanh` activation, no skip connection.
- `innermost=True`: Deepest bottleneck layer, no submodule inside.
- Intermediate: Includes `Dropout` optionally, has submodules and skip connections.

#### 🔁 Skip Connections:
Skip connections are implemented via:
```python
torch.cat([x, self.model(x)], dim=1)


In [None]:



### ✅ **🧠 U-Net Generator**

class UNet(nn.Module):
    def __init__(self, input_channels=1, output_channels=2, num_downs=8, base_filters=64):
        super().__init__()

        # Start with innermost layer
        unet_block = UNetBlock(base_filters * 8, base_filters * 8, innermost=True)

        # Add intermediate layers with dropout
        for _ in range(num_downs - 5):
            unet_block = UNetBlock(base_filters * 8, base_filters * 8, submodule=unet_block, use_dropout=True)

        # Gradually reduce the number of filters in shallower layers
        filters = base_filters * 8
        for _ in range(3):
            unet_block = UNetBlock(filters // 2, filters, submodule=unet_block)
            filters //= 2

        # Outermost block
        self.model = UNetBlock(output_channels, filters, input_channels=input_channels, submodule=unet_block, outermost=True)

    def forward(self, x):
        return self.model(x)


### 🧠 `UNet`: Full U-Net Generator for Image-to-Image Tasks

Constructs a **U-Net architecture** by stacking multiple `UNetBlock`s.

#### 🧱 Parameters:
- `input_channels`: Input image channels (e.g., 1 for grayscale).
- `output_channels`: Output image channels (e.g., 2 for ab color components).
- `num_downs`: Total number of downsampling steps (controls depth of network).
- `base_filters`: Number of filters in the first layer (default: 64).

#### 🧬 Architecture Flow:
1. Starts with an **innermost bottleneck block**.
2. Adds several **intermediate blocks with dropout**.
3. Adds **outer blocks** with reducing filter size.
4. Finally wraps everything inside an **outermost block** (no skip connection).

#### 📝 Example:
```python
netG = UNet(input_channels=1, output_channels=2, num_downs=8).to(device)
output = netG(grayscale_image)


[for depth understanding](https://www.youtube.com/watch?v=EHuACSjijbI&list=PLyMom0n-MBroupZiLfVSZqK5asX8KfoHL&index=5)

In [None]:
import torch
import torch.nn as nn

class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels, base_filters=64, num_downs=3):
        super().__init__()

        layers = []

        # Initial layer (no normalization)
        layers.append(self._block(input_channels, base_filters, normalize=False))

        # Downsampling layers
        for i in range(num_downs):
            in_filters = base_filters * (2 ** i)
            out_filters = base_filters * (2 ** (i + 1))
            stride = 1 if i == (num_downs - 1) else 2  # No downsampling on the last one
            layers.append(self._block(in_filters, out_filters, stride=stride))

        # Final output layer (1 channel, no activation or normalization)
        layers.append(self._block(out_filters, 1, stride=1, normalize=False, activate=False))

        self.model = nn.Sequential(*layers)

    def _block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, normalize=True, activate=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not normalize)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        if activate:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


### 🔍 PatchGAN Discriminator — `PatchDiscriminator`

This is a convolutional discriminator based on **PatchGAN**, which classifies overlapping patches (instead of the entire image) as real or fake.

#### 📐 Architecture Overview:
- Input: An image with `input_channels` (e.g., 3 for RGB, or 1 for grayscale).
- Output: A **feature map** (not a scalar) where each value represents whether the corresponding patch is real or fake.

#### ⚙️ Constructor Parameters:
- `input_channels`: Number of input channels.
- `base_filters`: Number of filters in the first layer (default: 64).
- `num_downs`: Number of downsampling layers (default: 3).

#### 🧱 Block Structure (`_block` function):
- **Conv2D** with 4x4 kernel, stride (usually 2), and padding 1.
- Optional **BatchNorm2d** (except first and last layers).
- Optional **LeakyReLU** (disabled for last output layer).

#### 🧠 Forward Method:
```python
def forward(self, x):
    return self.model(x)


In [None]:
discriminator = PatchDiscriminator(3)
dummy_input = torch.randn(16, 3, 256, 256) # batch_size, channels, size, size
out = discriminator(dummy_input)
out.shape

In [None]:
import torch
import torch.nn as nn

class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()

        # Store the real and fake labels as buffers (so they move with model's device)
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))

        # Choose the type of GAN loss
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy with logits
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()  # Least Squares GAN (more stable)
        else:
            raise NotImplementedError(f"GAN mode '{gan_mode}' is not supported")

    def get_labels(self, preds, target_is_real):
        # Return a tensor of the same shape as preds, filled with real or fake label
        return self.real_label.expand_as(preds) if target_is_real else self.fake_label.expand_as(preds)

    def forward(self, preds, target_is_real):
        # Compute the GAN loss between predictions and target labels
        labels = self.get_labels(preds, target_is_real)
        return self.loss(preds, labels)


### 🎯 Custom GAN Loss Class — `GANLoss`

This is a reusable loss module designed for training GANs (Generative Adversarial Networks). It supports two types of loss functions:

#### 💡 Supported Modes:
- `'vanilla'`: Uses **Binary Cross Entropy with Logits** (BCEWithLogitsLoss).
- `'lsgan'`: Uses **Least Squares Error** (MSELoss) — typically more stable.

#### ⚙️ Key Features:
- Automatically handles label generation for real vs. fake predictions.
- Registers `real_label` and `fake_label` as buffers, so they're correctly moved across devices (e.g., CPU ↔ GPU).
- Can be plugged directly into a GAN training loop for both generator and discriminator losses.

#### 🧠 How It Works:
1. **Initialization (`__init__`)**: Sets loss type and stores real/fake label values.
2. **`get_labels(...)`**: Creates a label tensor matching prediction shape.
3. **`forward(...)`**: Applies the selected loss function between predictions and ground-truth labels (real or fake).

#### 📝 Example Usage:
```python
loss_fn = GANLoss(gan_mode='lsgan')
pred_fake = discriminator(fake_img)
loss_D_fake = loss_fn(pred_fake, target_is_real=False)


In [None]:
import torch.nn as nn

def init_weights(net, init_type='norm', gain=0.02):
    """
    Initializes weights of a PyTorch model using the specified method.

    Parameters:
    - net (nn.Module): The model whose weights are to be initialized.
    - init_type (str): The initialization strategy ('norm', 'xavier', or 'kaiming').
    - gain (float): Scaling factor used by some initialization methods.

    Returns:
    - The model with initialized weights.
    """

    def init_func(m):
        classname = m.__class__.__name__

        # For Conv layers
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init_type == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            else:
                raise NotImplementedError(f"Initialization method '{init_type}' not supported.")

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)

        # For BatchNorm2d layers
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1.0, gain)
            nn.init.constant_(m.bias.data, 0.0)

    net.apply(init_func)
    print(f"✅ Model initialized with '{init_type}' weights.")
    return net


### 🧠 `init_weights()`: Initialize Model Weights

This function initializes weights of a PyTorch model layer by layer.

#### 🔧 Supported Initialization Methods:
- `'norm'`: Normal distribution (mean=0, std=gain)
- `'xavier'`: Xavier (Glorot) normal initialization
- `'kaiming'`: Kaiming He initialization

#### ⚙️ Layer-wise Logic:
- **Conv Layers**: Initialized based on the selected method
- **BatchNorm2d**: Weights ~ N(1.0, gain), Bias = 0

#### 📝 Usage:
Call `init_weights(model)` after defining your model architecture but before training.


In [None]:
def init_model(model, device):
    """
    Moves the model to the specified device and initializes its weights.

    Parameters:
    - model (nn.Module): The model to initialize.
    - device (torch.device): The device (CPU or GPU) to move the model to.

    Returns:
    - Initialized model on the specified device.
    """
    model = model.to(device)
    model = init_weights(model)
    return model


### ⚙️ `init_model()`: Prepare Model for Training

Combines two tasks in one:

1. **Moves the model to the specified device** (CPU or GPU).
2. **Initializes weights** using the default `'norm'` strategy.

#### 🔁 Returns:
- A ready-to-train PyTorch model, correctly initialized and placed on the chosen device.

#### 📝 Example:
```python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = build_res_unet(...)
net = init_model(net, device)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1

        # Generator
        if net_G is None:
            self.net_G = init_model(UNet(input_channels=1, output_channels=2), self.device)
        else:
            self.net_G = net_G.to(self.device)

        # Discriminator
        self.net_D = init_model(PatchDiscriminator(input_channels=3), self.device)

        # Loss functions
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()

        # Optimizers
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        fake_input = torch.cat([self.L, self.fake_color.detach()], dim=1)
        real_input = torch.cat([self.L, self.ab], dim=1)

        pred_fake = self.net_D(fake_input)
        pred_real = self.net_D(real_input)

        self.loss_D_fake = self.GANcriterion(pred_fake, False)
        self.loss_D_real = self.GANcriterion(pred_real, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        fake_input = torch.cat([self.L, self.fake_color], dim=1)
        pred_fake = self.net_D(fake_input)

        self.loss_G_GAN = self.GANcriterion(pred_fake, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self):
        self.forward()

        # Optimize Discriminator
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        # Optimize Generator
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()


## 🧠 MainModel Class: Full GAN Training Logic

This class handles everything needed to train our GAN-based image colorization model.

### ✅ Key Components:

- **Generator (`net_G`)**: A U-Net that takes in grayscale (L channel) and predicts ab color channels.
- **Discriminator (`net_D`)**: A PatchGAN that evaluates whether colorized images are real or fake.
- **Loss Functions**:
  - `GANLoss`: Measures how well the generator is fooling the discriminator.
  - `L1Loss`: Measures how close the predicted colors are to the ground truth.
- **Optimizers**:
  - `opt_G` for the generator.
  - `opt_D` for the discriminator.

### 🔁 Training Steps:

- `setup_input(data)`: Loads one batch of data (L and ab) to the device.
- `forward()`: Uses the generator to predict ab from L.
- `backward_D()`: Updates the discriminator by comparing real vs. fake colorizations.
- `backward_G()`: Updates the generator using adversarial + L1 losses.
- `optimize()`: Runs both backward passes and steps the optimizers.

Each batch, we:
1. Train the **discriminator** to classify real vs. fake images.
2. Train the **generator** to fool the discriminator and produce realistic colorization.

This setup follows the [Pix2Pix](https://arxiv.org/abs/1611.07004) framework and is optimized for fast convergence and stable training.


In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import lab2rgb

# 📊 Keeps track of the average of any value (used for losses)
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = 0., 0., 0.

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count


# 🧰 Creates AverageMeters for each loss
def create_loss_meters():
    return {
        'loss_D_fake': AverageMeter(),
        'loss_D_real': AverageMeter(),
        'loss_D': AverageMeter(),
        'loss_G_GAN': AverageMeter(),
        'loss_G_L1': AverageMeter(),
        'loss_G': AverageMeter()
    }


# 🔄 Updates meters with current loss values from the model
def update_losses(model, loss_meter_dict, count):
    for loss_name, meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        meter.update(loss.item(), count)


# 🔁 Converts LAB image tensors back to RGB for visualization
def lab_to_rgb(L, ab):
    L = (L + 1.) * 50.       # Denormalize L to [0, 100]
    ab = ab * 110.           # Denormalize ab to [-110, 110]
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()

    rgb_images = []
    for img in Lab:
        rgb_images.append(lab2rgb(img))
    return np.stack(rgb_images, axis=0)


# 🖼️ Visualizes 5 examples: input, fake output, and ground truth
def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()

    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L

    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)

    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        # Grayscale input
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")

        # Fake colorized output
        ax = plt.subplot(3, 5, i + 6)
        ax.imshow(fake_imgs[i])
        ax.axis("off")

        # Ground truth color image
        ax = plt.subplot(3, 5, i + 11)
        ax.imshow(real_imgs[i])
        ax.axis("off")

    plt.tight_layout()
    plt.show()

    if save:
        fig.savefig(f"colorization_{int(time.time())}.png")


# 📝 Prints out average values of all tracked losses
def log_results(loss_meter_dict):
    print("🔔 Average Losses:")
    for loss_name, meter in loss_meter_dict.items():
        print(f"{loss_name}: {meter.avg:.5f}")


## 📉 Loss Tracking and 🖼️ Image Visualization Utilities

To effectively monitor and evaluate our GAN-based image colorization model during training, we define several helper functions and classes. Here's what each one does:

---

### ✅ 1. `AverageMeter`

Tracks the **running average of any metric**, commonly used for losses.

- `reset()` – Resets all values to zero.
- `update(val, count)` – Updates the internal sum and average.
- Useful for displaying average loss at the end of each epoch.

---

### 🧮 2. `create_loss_meters()`

Returns a dictionary of `AverageMeter`s for:
- Discriminator losses:
  - `loss_D_fake`: Loss for fake samples.
  - `loss_D_real`: Loss for real samples.
  - `loss_D`: Combined average.
- Generator losses:
  - `loss_G_GAN`: GAN loss (how well it fools the discriminator).
  - `loss_G_L1`: L1 loss (difference from true image).
  - `loss_G`: Total generator loss (GAN + L1).

---

### 🔄 3. `update_losses(model, loss_meter_dict, count)`

Automatically pulls loss values from the model and updates the corresponding average meters.

- Used during each batch to accumulate loss statistics.
- `model.loss_G`, `model.loss_D_fake`, etc. are accessed dynamically.

---

### 🎨 4. `lab_to_rgb(L, ab)`

Converts LAB color space tensors back to **RGB images** for visualization.

- The model works in LAB format for better color separation.
- We denormalize the `L` and `ab` channels to their original range.
- Then convert them to RGB using `skimage.color.lab2rgb()`.

---

### 🖼️ 5. `visualize(model, data, save=True)`

Displays **side-by-side comparison** of model output and ground truth:

- Top row: Grayscale input (L)
- Middle row: Generated color image
- Bottom row: Real color image

If `save=True`, it also saves the visualization as a PNG image with a timestamp.

---

### 📝 6. `log_results(loss_meter_dict)`

Prints out all tracked average losses (from the `AverageMeter`s) at the end of an epoch.

Example output:


In [None]:
from tqdm import tqdm

def train_model(model, train_dl, epochs, display_every=200, val_dl=None):
    # Use val_dl if available for visualizing
    val_data = next(iter(val_dl)) if val_dl is not None else None

    epoch_loss_log = []  # ⏺️ List of dicts: one per epoch

    for epoch in range(epochs):
        loss_meter_dict = create_loss_meters()  # e.g., G_GAN, G_L1, D_real, D_fake
        step = 0

        for data in tqdm(train_dl):
            model.setup_input(data)
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0))

            step += 1
            if step % display_every == 0:
                print(f"\nEpoch {epoch+1}/{epochs} | Step {step}/{len(train_dl)}")
                log_results(loss_meter_dict)
                if val_data is not None:
                    visualize(model, val_data, save=False)

        # After epoch: log average loss values
        epoch_losses = {key: meter.avg for key, meter in loss_meter_dict.items()}
        epoch_loss_log.append(epoch_losses)
        print(f"\n✅ Epoch {epoch+1} completed. Average losses:")
        for k, v in epoch_losses.items():
            print(f"  {k}: {v:.4f}")

    return epoch_loss_log  # <-- for saving loss logs


## 🚀 Training Loop: `train_model()`

This function handles the **entire GAN training process**. It runs the model for multiple epochs, logs losses, and periodically visualizes the outputs. Here's how it works:

---

### 🔁 Parameters

- `model`: The `MainModel` instance (contains Generator + Discriminator).
- `train_dl`: PyTorch DataLoader for training data.
- `epochs`: Number of full passes through the dataset.
- `display_every`: How often (in steps) to show progress and visualize output.

---

### 🧠 Workflow Explained

#### 1. `val_data = next(iter(val_dl))`

- Takes a single batch from the validation set.
- This batch is used repeatedly for visualizing model performance during training.

#### 2. `loss_meter_dict = create_loss_meters()`

- Initializes tracking meters for all loss components (`D_fake`, `D_real`, `G_GAN`, `G_L1`, etc.).
- Tracks **average loss per epoch**.

#### 3. `model.setup_input(data)`  
- Sends current batch to GPU and splits it into L (grayscale) and ab (color channels).

#### 4. `model.optimize()`

- Calls:
  - `model.forward()` → Generates fake color
  - `model.backward_D()` → Updates Discriminator
  - `model.backward_G()` → Updates Generator

#### 5. `update_losses(...)`

- Updates all tracked loss meters using the latest loss values from the model.

#### 6. Every `display_every` steps:

- 📢 Prints progress and current epoch/iteration.
- 📊 Calls `log_results()` to print average losses.
- 🖼️ Calls `visualize()` to show:
  - Grayscale input
  - Generated color image
  - Real ground truth

---

This function is essential for:
- Repeatedly improving the model through training
- Logging the generator and discriminator performance
- Visually monitoring how well the model is learning to col


In [None]:
# ✅ Install FastAI if not already installed
# !pip install fastai==2.4

from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
from torchvision.models.resnet import resnet18

# ✅ Create a U-Net Generator from a pretrained ResNet-18 encoder
encoder = create_body(resnet18(), n_in=1, pretrained=True)
net_G = DynamicUnet(encoder, n_out=2, img_size=(256, 256))  # ✅ use n_out instead of out_channels


## ⚙️ Building the U-Net Generator with FastAI's DynamicUnet

To simplify the creation of a powerful U-Net architecture, we use FastAI’s `DynamicUnet`, which builds a U-Net using any backbone like ResNet, EfficientNet, etc.

Here’s what each line of code does:

---

### 🔽 1. `from fastai.vision.learner import create_body`

- Extracts the **encoder part** (feature extractor) from a pretrained model like ResNet.
- Removes the classification head and keeps convolutional layers.

---

### 🧠 2. `from torchvision.models.resnet import resnet18`

- Loads the **ResNet-18 architecture** from PyTorch’s model zoo.
- We use it as the encoder (downsampling path) of our U-Net.
- Lightweight yet effective for image colorization tasks.

---

### 🔁 3. `from fastai.vision.models.unet import DynamicUnet`

- `DynamicUnet` automatically builds the **decoder path** for U-Net.
- It adds:
  - Upsampling layers
  - Skip connections from encoder
  - A final convolution layer for output

---

### 🧱 4. `create_body(resnet18(), n_in=1, pretrained=True)`

- Creates a ResNet-18 backbone.
- `n_in=1` means we’re using grayscale (L channel) images as input (1 channel).
- `pretrained=True` loads pretrained ImageNet weights, improving performance with transfer learning.

---

### 🎯 5. `DynamicUnet(encoder, out_channels=2, img_size=(256, 256))`

- Builds a U-Net by attaching a decoder to the encoder.
- `out_channels=2`: We predict the a and b channels of the LAB color space.
- `img_size=(256, 256)`: Necessary for building the correct skip connections and upsampling shapes.

---

💡 Using FastAI here saves you time and complexity while still leveraging powerful pretrained networks for colorization!


### ⚙️ Create U-Net Generator with Pretrained ResNet-18 Encoder (FastAI)

- We use FastAI’s `DynamicUnet` to create a U-Net generator with a pretrained ResNet-18 encoder.
- `create_body(resnet18(), n_in=1, pretrained=True)` builds the encoder with 1 input channel (grayscale).
- `DynamicUnet(...)` wraps the encoder into a U-Net decoder:
  - `n_out=2`: output channels (e.g., a and b in LAB color space).
  - `img_size=(256, 256)`: target input size.

> ⚠️ Make sure to use `n_out` instead of `out_channels`, as `DynamicUnet` does not accept `out_channels`.


In [None]:
import torch
from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
from torchvision.models.resnet import resnet18

def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create encoder (feature extractor) from a pretrained ResNet18
    body = create_body(resnet18(pretrained=True), n_in=n_input, cut=-2)

    # Attach a decoder to make a full U-Net
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)

    return net_G


## 🏗️ `build_res_unet`: Constructing a ResNet-based U-Net Generator

This function creates a U-Net model for image colorization using FastAI and PyTorch. It builds the generator by combining a **pretrained ResNet18 encoder** with a **decoder** generated by `DynamicUnet`.

---

### 🧠 What Each Parameter Does

- `n_input`: Number of input channels (default is 1 for grayscale **L channel**).
- `n_output`: Number of output channels (default is 2 for predicted **a and b channels** in LAB space).
- `size`: Image size (used to define output shape and skip connections).

---

### ⚙️ Step-by-Step Explanation

#### 1. `device = torch.device(...)`

- Automatically selects GPU (if available), otherwise uses CPU.

---

#### 2. `create_body(...)`

- Loads the convolutional layers from a pretrained `resnet18`.
- `n_in=n_input` lets you control the number of input channels.
  - Set to `1` for grayscale input.
- `cut=-2` removes the final two layers (usually average pooling and FC), leaving only the feature extractor.

---

#### 3. `DynamicUnet(...)`

- Takes the encoder (`body`) and attaches a **decoder path** to build a complete U-Net.
- Automatically adds skip connections and upsampling layers.
- `n_output`: Sets the number of output channels (usually 2 for a, b channels).
- `(size, size)`: Required to build the decoder with the correct spatial dimensions.

---

#### 4. `.to(device)`

- Moves the model to the appropriate device (GPU or CPU).

---

### ✅ Return

- Returns a ready-to-train U-Net model (`net_G`) that takes in grayscale images and predicts color components.

---

💡 This function makes your generator modular and reusable — perfect for training GANs or using separately for pretraining!


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import json
from tqdm import tqdm
# Removed Colab-specific file upload
import os

# 📂 Create save folder
save_path = "/kaggle/working/saved_model"
os.makedirs(save_path, exist_ok=True)

# ✅ AverageMeter to track loss
class AverageMeter:
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.sum = 0
        self.count = 0
        self.avg = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# ✅ Training function with saving and logging
def pretrain_generator(net_G, train_dl, opt, criterion, epochs):
    loss_history = []

    for epoch in range(epochs):
        loss_meter = AverageMeter()

        for data in tqdm(train_dl):
            L = data['L'].to(device)
            ab = data['ab'].to(device)

            preds = net_G(L)
            loss = criterion(preds, ab)

            opt.zero_grad()
            loss.backward()
            opt.step()

            loss_meter.update(loss.item(), L.size(0))

        avg_loss = loss_meter.avg
        loss_history.append(avg_loss)
        print(f"Epoch {epoch + 1}/{epochs} - L1 Loss: {avg_loss:.5f}")

        # ✅ Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            torch.save(net_G.state_dict(), f"{save_path}/checkpoint_epoch_{epoch+1}.pth")

    # ✅ Save final model weights
    torch.save(net_G.state_dict(), f"{save_path}/res18-unet.pt")
    torch.save(opt.state_dict(), f"{save_path}/opt_res18-unet.pth")

    # ✅ Save loss history
    with open(f"{save_path}/loss_log.json", "w") as f:
        json.dump(loss_history, f)

    # ✅ Save config and epoch
    config = {
        "architecture": "ResNet18-UNet",
        "input_channels": 1,
        "output_channels": 2,
        "img_size": [256, 256],
        "loss": "L1Loss",
        "optimizer": "Adam",
        "learning_rate": 1e-4,
        "epochs_trained": epochs
    }
    with open(f"{save_path}/config.json", "w") as f:
        json.dump(config, f)

    with open(f"{save_path}/last_epoch.txt", "w") as f:
        f.write(str(epochs))

    print("✅ All files saved.")

    # 📥 Download all important files
    files.download(f"{save_path}/res18-unet.pt")
    files.download(f"{save_path}/opt_res18-unet.pth")
    files.download(f"{save_path}/loss_log.json")
    files.download(f"{save_path}/config.json")
    files.download(f"{save_path}/last_epoch.txt")


## 🧪 Pretraining the Generator with L1 Loss (Before GAN Training)

Before using the generator (`net_G`) in a GAN setup, we pretrain it using **only L1 loss** (also called pixel-wise loss). This helps the model learn **basic colorization** before introducing adversarial training with the discriminator.

---

### 🎯 Objective

- Minimize the **difference between predicted (fake) and true (real)** ab channels.
- Use **L1 loss** for better sharpness and stable convergence.
- Avoids the GAN from starting from random weights.

---

### 🧠 Function: `pretrain_generator(...)`

#### Inputs:
- `net_G`: The U-Net generator (ResNet18 + decoder).
- `train_dl`: Dataloader for training images.
- `opt`: Optimizer (Adam in this case).
- `criterion`: Loss function (L1).
- `epochs`: How many full passes over the dataset.

#### Steps:
1. Loop over each epoch.
2. For each batch:
   - Move L and ab channels to GPU/CPU.
   - Get predictions from the generator.
   - Compute L1 loss: `loss = criterion(preds, ab)`
   - Backpropagate and update weights using Adam.
   - Update the running average loss using `AverageMeter`.

#### Output:
Prints the average L1 loss at the end of each epoch.

---

### 💾 Saving the Pretrained Model

```python
torch.save(net_G.state_dict(), "/kaggle/working/res18-unet.pt")


### 🧠 Build and Load the Generator (U-Net based on ResNet-18)

- `build_res_unet(...)` creates a U-Net model using ResNet-18 as the encoder backbone.
  - `n_input=1`: grayscale image input (L channel in LAB color space).
  - `n_output=2`: outputs two channels (a and b from LAB color space).
  - `size=256`: input image size is 256x256.

- The pretrained weights are loaded from `res18-unet.pt` using `load_state_dict(...)`.
- `map_location=device` ensures the model loads correctly on the current hardware (CPU or GPU).


### 🧩 Wrap Generator into MainModel

- `MainModel` is a custom training wrapper (likely includes loss functions, optimizers, etc.).
- We pass the prebuilt U-Net generator (`net_G`) to it.
- This abstraction helps us train, validate, and test using one interface.


### 🎯 Train the Model for 20 Epochs

- `train_model(...)` begins training the model.
  - `model`: instance of `MainModel` containing the generator.
  - `train_dl`: the DataLoader providing batches of training data.
  - `20`: the number of training epochs.
- This step will optimize the model weights to colorize grayscale images.


### 💾 Save Pretrained ResNet-UNet Generator (Optional)

- This step saves the pretrained generator weights to a file (`res18-unet.pt`) for reuse later.
- It also optionally saves the optimizer state (`opt_res18-unet.pth`), which is helpful for resuming training.
- Use this step if you’re training the generator separately before integrating it into the full GAN model.


### 🧠 Save Final Trained Generator and Discriminator

- After full GAN training, save the final state of both models:
  - `final_generator.pth` — trained U-Net generator.
  - `final_discriminator.pth` — trained PatchGAN discriminator.
- This lets us use the trained models later for inference or fine-tuning.


### 💡 Save Full MainModel (Convenience)

- Saves the full `MainModel` wrapper's state dict (`final_model.pth`), which includes:
  - Generator
  - Discriminator
  - Loss functions
  - Any additional internal state
- Convenient when restoring the full training setup directly without separately loading G & D.


In [None]:
net_G = build_res_unet(n_input=1, n_output=2, size=256).to(device)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()
pretrain_generator(net_G, train_dl, opt, criterion, epochs=20)


In [None]:
import torch
import json
from tqdm import tqdm
# Removed Colab-specific file upload
import os

# 📂 Save folder
gan_save_path = "/kaggle/working/saved_gan"
os.makedirs(gan_save_path, exist_ok=True)

# ✅ Load pretrained generator
net_G = build_res_unet(n_input=1, n_output=2, size=256).to(device)
net_G.load_state_dict(torch.load("/kaggle/working/saved_model/res18-unet.pt", map_location=device))

# ✅ Build MainModel and train
model = MainModel(net_G=net_G)
gan_loss_log = train_model(model, train_dl, epochs=20, val_dl=val_dl)

# Save this log as a JSON file
with open("/kaggle/working/saved_gan/gan_loss_log.json", "w") as f:
    json.dump(gan_loss_log, f)

# ✅ Save final GAN weights
torch.save(model.net_G.state_dict(), f"{gan_save_path}/final_generator.pth")
torch.save(model.net_D.state_dict(), f"{gan_save_path}/final_discriminator.pth")
torch.save(model.state_dict(), f"{gan_save_path}/final_model.pth")  # optional

# ✅ Save GAN loss history
with open(f"{gan_save_path}/gan_loss_log.json", "w") as f:
    json.dump(gan_loss_log, f)

# ✅ Save GAN config
gan_config = {
    "type": "Conditional GAN",
    "epochs_trained": 20,
    "losses": ["GANLoss", "L1Loss"],
    "generator": "ResNet18-UNet",
    "discriminator": "PatchGAN"
}
with open(f"{gan_save_path}/gan_config.json", "w") as f:
    json.dump(gan_config, f)

# ✅ Save epoch info
with open(f"{gan_save_path}/gan_last_epoch.txt", "w") as f:
    f.write("20")

# 📥 Download everything
files.download(f"{gan_save_path}/final_generator.pth")
files.download(f"{gan_save_path}/final_discriminator.pth")
files.download(f"{gan_save_path}/final_model.pth")
files.download(f"{gan_save_path}/gan_loss_log.json")
files.download(f"{gan_save_path}/gan_config.json")
files.download(f"{gan_save_path}/gan_last_epoch.txt")


In [None]:
print("✅ All model weights saved to /kaggle/working/")
