# CNN Model

This notebook contains the construction and the evaluations of CNN models for colorization purposes.

In [1]:
import torch
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
from utils.dataset import CocoDataset
from utils.plots import plot_l, plot_a, plot_b, plot_rgb, reconstruct_lab, plot_predicted_image, plot_ab
from utils.models import BaselineCNN, save_model, load_model

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

## Data import

In [3]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(), # tensorization brings image in range [0,1] and space CxHxW
])

In [5]:
path_sandro = "coco/images/train2014"
path_diego = "C:/Users/diego/coco/images/train2014/train2014"
dataset = CocoDataset(root=path_diego, transform=transform)

Found 82783 images.


## Data preparation

We split our dataset in train and test data.

In [8]:
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - test_size

In [9]:
torch.manual_seed(42)
train, test = random_split(dataset, [train_size, test_size], torch.Generator().manual_seed(42))

train_loader = DataLoader(train, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=64, shuffle=True)

## CNN Model

We will now build a network of 9 convolutional layers, each followed by a ReLU activation and a BatchNorm layer, with vanishing learning rate.

### Training

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=5, stride=1, padding=2),  # Conv1
            nn.ReLU(),
            nn.BatchNorm2d(4),
            nn.Conv2d(4, 8, kernel_size=5, stride=1, padding=2),  # Conv2
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, kernel_size=5, stride=1, padding=2),  # Conv3
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),  # Conv4
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),  # Conv5
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2),  # Conv6
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),  # Conv7
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=5, stride=1, padding=2),  # Conv8
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=5, stride=1, padding=2),  # Conv9
            nn.ReLU(),
            nn.BatchNorm2d(128),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128 * 64 * 64, 2 * 64 * 64)  # Assuming the input size 64x64, adjust accordingly

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.fc(x)
        x = x.view(-1, 2, 64, 64)  # Reshape back to expected output dimensions
        return x

In [None]:
model = CNN().to(device)

criterion1 = nn.MSELoss()
criterion2 = nn.L1Loss
criterion3 = nn.SmoothL1Loss

optimizer = optim.Adam(model.parameters(), lr=0.01)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 100)

### MSE Loss

In [None]:
NUM_EPOCHS = 100
train_losses = []
test_losses = []

for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    test_loss = 0.0

    # Training loop
    model.train()
    for l_channels, ab_channels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}', leave=True):
        l_channels, ab_channels = l_channels.to(device), ab_channels.to(device)

        optimizer.zero_grad()
        outputs = model(l_channels)
        loss = criterion1(outputs, ab_channels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()  # Update the scheduler
    train_losses.append(running_loss / len(train_loader))

    # Testing loop
    model.eval()
    with torch.no_grad():
        for l_channels, ab_channels in tqdm(test_loader, desc='Testing', leave=True):
            l_channels, ab_channels = l_channels.to(device), ab_channels.to(device)
            outputs = model(l_channels)
            loss = criterion1(outputs, ab_channels)
            test_loss += loss.item()

    test_losses.append(test_loss / len(test_loader))

    # Print training and validation losses
    print(f"Epoch {epoch + 1}, Train Loss: {train_losses[-1]}, Validation Loss: {test_losses[-1]}")

print('Finished Training')

In [10]:
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='train')
plt.plot(test_losses, label='test')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(linestyle = "--")
plt.legend()
plt.show()

NameError: name 'train_losses' is not defined

<Figure size 1000x600 with 0 Axes>

In [None]:
plot_predicted_image(model, dataset[76543])

In [None]:
save_model(model, "CNN_MSELoss")

In [None]:
loaded_model1 = load_model(CNN(), "models/CNN_MSELoss.pth").to(device)

In [None]:
next(loaded_model1.parameters()).device # check model is on the right device

In [None]:
loaded_model1.eval()

total_mse = 0
total_psnr = 0
num_samples = 0

with torch.no_grad():
    for _, l_channels, _, _, ab_channels in test_loader:
        l_channels = l_channels.to(device)
        ab_channels = ab_channels.to(device)

        ab_preds = loaded_model1(l_channels)

        mse = F.mse_loss(ab_preds, ab_channels)
        total_mse += mse.item() * ab_channels.size(0)

        max_pixel_value = 1.0  # image values are between 0 and 1
        psnr = 20 * torch.log10(max_pixel_value**2 / mse)
        total_psnr += psnr.item() * ab_channels.size(0)

        num_samples += ab_channels.size(0)

avg_mse = total_mse / num_samples
avg_psnr = total_psnr / num_samples

print(f"Average MSE: {avg_mse:.4f}")
print(f"Average PSNR: {avg_psnr:.4f}")

### L1 Loss

In [None]:
NUM_EPOCHS = 100
train_losses = []
test_losses = []

for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    test_loss = 0.0

    # Training loop
    model.train()
    for l_channels, ab_channels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}', leave=True):
        l_channels, ab_channels = l_channels.to(device), ab_channels.to(device)

        optimizer.zero_grad()
        outputs = model(l_channels)
        loss = criterion2(outputs, ab_channels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()  # Update the scheduler
    train_losses.append(running_loss / len(train_loader))

    # Testing loop
    model.eval()
    with torch.no_grad():
        for l_channels, ab_channels in tqdm(test_loader, desc='Testing', leave=True):
            l_channels, ab_channels = l_channels.to(device), ab_channels.to(device)
            outputs = model(l_channels)
            loss = criterion2(outputs, ab_channels)
            test_loss += loss.item()

    test_losses.append(test_loss / len(test_loader))

    # Print training and validation losses
    print(f"Epoch {epoch + 1}, Train Loss: {train_losses[-1]}, Validation Loss: {test_losses[-1]}")

print('Finished Training')

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='train')
plt.plot(test_losses, label='test')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(linestyle = "--")
plt.legend()
plt.show()

In [None]:
plot_predicted_image(model, dataset[76543])

In [None]:
save_model(model, "CNN_L1Loss")

In [None]:
loaded_model2 = load_model(CNN(), "models/L1Loss.pth").to(device)

In [None]:
next(loaded_model2.parameters()).device # check model is on the right device

In [None]:
loaded_model2.eval()

total_mse = 0
total_psnr = 0
num_samples = 0

with torch.no_grad():
    for _, l_channels, _, _, ab_channels in test_loader:
        l_channels = l_channels.to(device)
        ab_channels = ab_channels.to(device)

        ab_preds = loaded_model2(l_channels)

        mse = F.mse_loss(ab_preds, ab_channels)
        total_mse += mse.item() * ab_channels.size(0)

        max_pixel_value = 1.0  # image values are between 0 and 1
        psnr = 20 * torch.log10(max_pixel_value**2 / mse)
        total_psnr += psnr.item() * ab_channels.size(0)

        num_samples += ab_channels.size(0)

avg_mse = total_mse / num_samples
avg_psnr = total_psnr / num_samples

print(f"Average MSE: {avg_mse:.4f}")
print(f"Average PSNR: {avg_psnr:.4f}")

### Smooth L1 loss

In [None]:
NUM_EPOCHS = 100
train_losses = []
test_losses = []

for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    test_loss = 0.0

    # Training loop
    model.train()
    for l_channels, ab_channels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}', leave=True):
        l_channels, ab_channels = l_channels.to(device), ab_channels.to(device)

        optimizer.zero_grad()
        outputs = model(l_channels)
        loss = criterion3(outputs, ab_channels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()  # Update the scheduler
    train_losses.append(running_loss / len(train_loader))

    # Testing loop
    model.eval()
    with torch.no_grad():
        for l_channels, ab_channels in tqdm(test_loader, desc='Testing', leave=True):
            l_channels, ab_channels = l_channels.to(device), ab_channels.to(device)
            outputs = model(l_channels)
            loss = criterion3(outputs, ab_channels)
            test_loss += loss.item()

    test_losses.append(test_loss / len(test_loader))

    # Print training and validation losses
    print(f"Epoch {epoch + 1}, Train Loss: {train_losses[-1]}, Validation Loss: {test_losses[-1]}")

print('Finished Training')

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='train')
plt.plot(test_losses, label='test')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(linestyle = "--")
plt.legend()
plt.show()

In [None]:
plot_predicted_image(model, dataset[76543])

In [None]:
save_model(model, "CNN_SmoothL1")

In [None]:
loaded_model3 = load_model(CNN(), "models/CNN_SmoothL1.pth").to(device)

In [None]:
next(loaded_model3.parameters()).device # check model is on the right device

In [None]:
loaded_model3.eval()

total_mse = 0
total_psnr = 0
num_samples = 0

with torch.no_grad():
    for _, l_channels, _, _, ab_channels in test_loader:
        l_channels = l_channels.to(device)
        ab_channels = ab_channels.to(device)

        ab_preds = loaded_model3(l_channels)

        mse = F.mse_loss(ab_preds, ab_channels)
        total_mse += mse.item() * ab_channels.size(0)

        max_pixel_value = 1.0  # image values are between 0 and 1
        psnr = 20 * torch.log10(max_pixel_value**2 / mse)
        total_psnr += psnr.item() * ab_channels.size(0)

        num_samples += ab_channels.size(0)

avg_mse = total_mse / num_samples
avg_psnr = total_psnr / num_samples

print(f"Average MSE: {avg_mse:.4f}")
print(f"Average PSNR: {avg_psnr:.4f}")

## Eval

plot_rgb(img[0])
reconstruct_lab(img[1].detach().cpu(), loaded_model1(img[1].to(device)).detach().cpu())
reconstruct_lab(img[1].detach().cpu(), loaded_model2(img[1].to(device)).detach().cpu())
reconstruct_lab(img[1].detach().cpu(), loaded_model3(img[1].to(device)).detach().cpu())