In [None]:
import cv2
import numpy as np
import torch
from skimage.color import lab2rgb, rgb2gray
import torchvision.transforms as T
from basic_model import ColorizationNet
from PIL import Image
import matplotlib.pyplot as plt

if __name__ == '__main__':
    # Read the two input parameters, which is the model checkpoint and grayscale image
    model_checkpoint, gray_image = '.\\vae-best-model.pth', '.\\Places365_val_00000091.jpg'
    # Load model from basic_model.py by calling the class constructor
    model = ColorizationNet()
    # If GPU available, set current execution to the CUDA instance, else, use CPU
    if torch.cuda.is_available():
        model.cuda()
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Load checkpoint
    checkpoint = torch.load(model_checkpoint)
    model.load_state_dict(checkpoint, strict=False)

    # Open the grayscale image
    o = Image.open(gray_image)
    image = Image.open(gray_image)
    x = Image.open(gray_image).convert("L")
    arr = np.asarray(x)
    gray = rgb2gray(np.array(image))
    gray = T.ToTensor()(gray).float()
    gray = T.Resize(size=(224, 224))(gray)
    gray = gray.unsqueeze_(0)

    # Evaluate the image from the model
    model.eval()

    with torch.no_grad():
        preds = model(gray.to(device))

    ab_output = preds[0].cpu()
    color_image = torch.cat((gray[0], ab_output), 0).numpy()
    color_image = color_image.transpose((1, 2, 0))
    color_image[:, :, 0] = color_image[:, :, 0] * 100
    color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128
    final = lab2rgb(color_image)

In [None]:
fig, axs = plt.subplots(1,3, figsize=(10, 10))
axs[0].set_title('Grayscale Image')
axs[0].imshow(arr, cmap='gray', vmin=0, vmax=255)
axs[1].set_title('Colorized Image')
axs[1].imshow(final)
axs[2].set_title('Original Image')
axs[2].imshow(o)