In [49]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from skimage import io, color
from skimage.color import rgb2lab

In [50]:
# transforming RGB into L*a*b* color space: we use the L channel as input and the a, b channels as target output
class colorize_dataset(Dataset):
    def __init__(self, transforms=None):
        self.transforms = transforms
        
    def __getitem__(self, index):
        img_rgb = Image.open('IMG_3965.JPG').convert("RGB") #self.paths[index]
        
        img_rgb = self.transforms(img_rgb)
        img_rgb = np.asarray(img_rgb)
        
        img_lab = rgb2lab(img_rgb)
        
        L = img_lab[:, :, 0] / 100.0  # Normalize L channel to [0,1]
        ab = (img_lab[:, :, 1:3] + 128) / 255.0  # Normalize a and b to [0,1]

        L = torch.from_numpy(L).unsqueeze(0).float()         # shape: 1 x H x W
        ab = torch.from_numpy(ab).permute(2, 0, 1).float()   # shape: 2 x H x W

        return {'L': L, 'ab': ab}

In [51]:
#While the model works on any size image, we trained it on 224x224 pixel images and thus it works best on small images. 
#Note that you can process a small imageto obtain the chrominance map and then rescale it and combine it with the original 
#grayscale image for higher quality.
#Larger image sizes can give uneven colorings (limited by spatial support of the network).

In [65]:
# loading image paths
import os
import glob #glob finds all the file paths that match a specified pattern
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image

#train_path = "\C:\Users\batpot\Desktop\RI\datasets\coco\train\/"
#test_path = "/RI/datasets/coco/test/"
#val_path = "/RI/datasets/coco/val/"

#def get_images_paths(directory):
#    print(os.path.join(directory, "*.jpg"))
#    return glob(os.path.join(directory, "*.jpg"))

train_paths = glob.glob(r"C:\Users\batpot\Desktop\RI\datasets\coco\train\*.jpg")
val_paths = glob.glob(r"C:\Users\batpot\Desktop\RI\datasets\coco\val\*.jpg")

#print(train_paths[:5])
    
#train_paths = get_images_paths(train_path)
#val_paths = get_images_paths(val_path)
print(len(train_paths), len(val_paths))

#_, axes = plt.subplots(4, 4, figsize=(10, 10))
#for ax, img_path in zip(axes.flatten(), train_paths):
#    ax.imshow(Image.open(img_path))
#    ax.axis("off")

37 37


In [66]:
# preprocess images
import PIL
import PIL.Image as PILI
from skimage.color import rgb2lab, lab2rgb

def preprocess_image(image_path, size=(224, 224)):
    img_path = Path(image_path)
    #print(img_path.exists())
    img = PIL.Image.open(img_path).convert('RGB')
      
    img = img.resize(size, Image.BICUBIC)

    img_rgb = np.asarray(img)
    img_lab = rgb2lab(img)

    L = img_lab[:, :, 0] / 100.0  # Normalize L channel to [0,1]
    ab = (img_lab[:, :, 1:3] + 128) / 255.0  # Normalize a and b to [0,1]

    L = torch.from_numpy(L).unsqueeze(0).float()         # shape: 1 x H x W
    ab = torch.from_numpy(ab).permute(2, 0, 1).float()   # shape: 2 x H x W

    #print(img)
    #print(L)
    #print(ab)
    return L, ab

In [67]:
import numpy as np
def load_dataset(image_paths, max_images=None):
    L_list, ab_list = [], []
    for i, path in enumerate(image_paths):
        if max_images and i >= max_images:
            break
        L, ab = preprocess_image(path)
        L_list.append(L)
        ab_list.append(ab)
    return np.array(L_list), np.array(ab_list)

In [68]:
X_train, Y_train = load_dataset(train_paths, max_images=10000)
X_val, Y_val = load_dataset(val_paths, max_images=2000)

In [69]:
X_train, Y_train

(array([[[[0.51636994, 0.5215943 , 0.5515568 , ..., 0.54420596,
           0.5176947 , 0.49350175],
          [0.46812505, 0.4678316 , 0.48585814, ..., 0.59008056,
           0.5981047 , 0.61107075],
          [0.40980378, 0.4571173 , 0.48410675, ..., 0.5247138 ,
           0.5272336 , 0.59627753],
          ...,
          [0.72327995, 0.47078276, 0.46527138, ..., 0.62756383,
           0.6200157 , 0.59252214],
          [0.63378745, 0.5806199 , 0.56597507, ..., 0.78682524,
           0.6598563 , 0.6431522 ],
          [0.55554867, 0.54866076, 0.5954175 , ..., 0.7843784 ,
           0.697508  , 0.7629302 ]]],
 
 
        [[[0.11628729, 0.12981842, 0.13115443, ..., 0.20624226,
           0.20648056, 0.20401245],
          [0.12279387, 0.12208544, 0.12178253, ..., 0.20929235,
           0.19975048, 0.19667488],
          [0.11900228, 0.10458957, 0.10171323, ..., 0.21726531,
           0.21089233, 0.20205145],
          ...,
          [0.00313747, 0.00274173, 0.00274173, ..., 0.3023552 ,
