In [1]:
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import os
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data import DataLoader, Dataset

In [2]:
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet34 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 206MB/s]


In [3]:
def build_res_unet(n_input=1, n_output=2, size=256):
    body = create_body(resnet34, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size))
    return net_G

In [4]:
IMG_DIM = 256
model_path = hf_hub_download(repo_id="dhairya-1105/image-colorization", filename="net_G_epoch_20.pth")
net_G = build_res_unet(n_input=1, n_output=2, size=IMG_DIM)
net_G.load_state_dict(torch.load(model_path))  # Adjust path
net_G.eval().to(device)

net_G_epoch_20.pth:   0%|          | 0.00/165M [00:00<?, ?B/s]

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05

In [5]:
# Preprocessing
transform = transforms.Compose([
    transforms.Resize((IMG_DIM, IMG_DIM)),
])

In [6]:
# Dataset
class CocoSubset(Dataset):
    def __init__(self, img_dir, transform=None, limit=1000):
        self.paths = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')]
        self.paths = random.sample(self.paths, min(limit, len(self.paths)))
        self.transform = transform

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        lab = rgb2lab(np.array(img).astype(np.float32) / 255.0)
        L = (lab[:, :, 0] / 50.0 - 1.0)[np.newaxis, ...]
        ab = (lab[:, :, 1:] / 128.0).transpose(2, 0, 1)
        return torch.tensor(L).float(), torch.tensor(ab).float(), lab

# Dataloader
dataset = CocoSubset('/kaggle/input/coco25k/images', transform=transform, limit=5000)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [7]:
# Evaluation
psnr_list, ssim_list, delta_e_list = [], [], []

for L, ab_gt, lab_gt in tqdm(dataloader):
    L = L.to(device)
    with torch.no_grad():
        ab_pred = net_G(L).cpu()

    # Denormalize
    L_denorm = (L.cpu().numpy()[0, 0] + 1.0) * 50.0
    ab_gt_denorm = ab_gt.numpy()[0].transpose(1, 2, 0) * 128.0
    ab_pred_denorm = ab_pred.numpy()[0].transpose(1, 2, 0) * 128.0

    lab_pred = np.concatenate([L_denorm[..., None], ab_pred_denorm], axis=2)
    lab_true = lab_gt[0].numpy()

    rgb_pred = lab2rgb(lab_pred.clip(0, 100))
    rgb_true = lab2rgb(lab_true.clip(0, 100))

    psnr_list.append(psnr(rgb_true, rgb_pred, data_range=1.0))
    ssim_list.append(ssim(rgb_true, rgb_pred, channel_axis=2, data_range=1.0))
    delta_e_list.append(np.mean(np.linalg.norm(lab_true - lab_pred, axis=2)))


  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
100%|██████████| 5000/5000 [06:12<00:00, 13.44it/s]


In [8]:
# Results
print({
    "PSNR": np.mean(psnr_list),
    "SSIM": np.mean(ssim_list),
    "DeltaE": np.mean(delta_e_list)
})

{'PSNR': 26.59493642863726, 'SSIM': 0.92430437, 'DeltaE': 12.664919}
