In [None]:
"""
Nelson Farrell & Michael Massone
Image Enhancement: Colorization - cGAN
CS 7180 Advanced Perception
Bruce Maxwell, PhD.
09-28-2024
"""

In [1]:
import random
import glob
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
import matplotlib.pyplot as plt

import sys
import os
from pathlib import Path

In [2]:
path = Path(os.getcwd())
path_to_project_home = path.parent
path_to_project_home = str(path_to_project_home)
print(path_to_project_home)
sys.path.insert(1, path_to_project_home)

/Users/nelsonfarrell/Documents/Northeastern/7180/projects/color-GAN


In [3]:
#from src.utils.pretrain_utils import *
from src.utils.gan_utils import *

In [4]:
torch.cuda.is_available()

False

In [5]:
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

def build_res_unet(n_input=1, n_output=2, size=256):
    """
    Builds ResNet18 based U-Net
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet_model = resnet18(pretrained=True)
    body = create_body(resnet_model, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

In [6]:
# Image Preprocessing
def preprocess_image(image_path):
    # Load and resize the image
    img = Image.open(image_path).convert("RGB")
    img = img.resize((256, 256))

    # Convert to LAB color space
    img_lab = rgb2lab(np.array(img)).astype("float32")
    
    # Normalize L channel to range [-1, 1]
    L = img_lab[..., 0:1] / 50.0 - 1.0
    
    # ab channels should be between [-1, 1]
    ab = img_lab[..., 1:] / 110.0
    
    # Convert to tensors
    L = torch.tensor(L).permute(2, 0, 1).unsqueeze(0)  # (1, 1, 256, 256)
    ab = torch.tensor(ab).permute(2, 0, 1).unsqueeze(0)  # (1, 2, 256, 256)
    
    return L, ab, img_lab

# Inference on L channel
def run_inference_on_L(model, L):
    with torch.no_grad():
        ab_pred = model(L)
    return ab_pred

# Recompile the LAB image and convert back to RGB
def reassemble_and_convert_to_rgb(L, ab_pred):
    # Denormalize L and ab channels
    L = (L.squeeze(0).squeeze(0).cpu().numpy() + 1.0) * 50.0  # back to [0, 100] range
    ab_pred = ab_pred.squeeze(0).cpu().numpy() * 110.0  # back to [-110, 110] range
    
    # Reassemble LAB image
    lab_pred = np.concatenate([L[..., np.newaxis], ab_pred.transpose(1, 2, 0)], axis=-1)
    
    # Convert LAB to RGB
    rgb_pred = lab2rgb(lab_pred)
    return rgb_pred

# Visualize the images
def visualize_images(original_img, reconstructed_img):
    """
    Displays
    """
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(original_img)
    ax[0].set_title("Original Image")
    ax[1].imshow(reconstructed_img)
    ax[1].set_title("Predicted Image")
    plt.show()


# Visualize lists of images
def visualize_images_2(original_imgs:list, reconstructed_imgs:list, inputs:list, save_path:str) -> None:
    """
    Saves a figure of a set of example images: orignal, inputs, and generated.
    Adjust fig_size length as needed for the length of the list.

    Args:
     * original_images: (list)
     * reconstructed_images: (list)
     * inputs: (list)
     * save_path: (str)

     Returns: 
      * None
    """
    n = len(original_imgs)  # Number of images
    fig, axs = plt.subplots(n, 3, figsize=(10, 10))  # Create a grid of subplots
    
    for i in range(n):
        input_img = np.squeeze(inputs[i]) if inputs[i].shape[0] == 1 else inputs[i]

        # Display original images
        axs[i, 0].imshow(original_imgs[i])
        if i == 0:
            axs[i, 0].set_title(f"Original Image", weight = "bold")
        axs[i, 0].axis('off')

        # Display reconstructed images
        axs[i, 2].imshow(reconstructed_imgs[i])
        if i == 0:
            axs[i, 2].set_title(f"Generated Image", weight = "bold")
        axs[i, 2].axis('off')

        # Display input images
        axs[i, 1].imshow(input_img, cmap='gray')
        if i == 0:
            axs[i, 1].set_title(f"Input Image", weight = "bold")
        axs[i, 1].axis('off')
    plt.subplots_adjust(wspace=0.1, hspace=0.1) 
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()



In [7]:
model = build_res_unet()



In [8]:
# Load the pretrained model
model = build_res_unet()
path_to_weights = "/Users/nelsonfarrell/Documents/Northeastern/7180/projects/color-GAN/model_checkpoints/checkpoint_pretrain_gen.pth"
checkpoint = torch.load(path_to_weights, map_location=torch.device('cpu'))
checkpoint = checkpoint["generator_state_dict"]
model.load_state_dict(checkpoint)
model.eval()  # Set the model to inference mode

# Load data
coco_path = "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample"
paths = glob.glob(coco_path + "/*.jpg") 

# Get val data
num_imgs = 10000
split = 0.8
train_paths, val_paths = select_images(paths, num_imgs, split)


In [9]:
good_results_list = [   
                    "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000319579.jpg",
                    "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000100271.jpg",
                    "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000107846.jpg",
                    "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000064121.jpg",
                    "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000547471.jpg",
                    "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000411138.jpg"
                    ]

bad_results = [
                "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000092602.jpg",
                "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000450649.jpg",
                "/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000367853.jpg"
             ]

In [10]:
# Run the pipeline
original_image_list = []
reconstructed_img_list = []
grey_image_list = []
save_path = "figs/resGAN_bad.png"
for image_path in bad_results:

    # Preprocess the image
    L, ab, original_lab = preprocess_image(image_path)
    
    # Run inference
    ab_pred = run_inference_on_L(model, L)
    
    # Reassemble and convert to RGB
    reconstructed_img = reassemble_and_convert_to_rgb(L, ab_pred)
    
    # Convert original LAB back to RGB for comparison
    original_rgb = lab2rgb(original_lab)

    original_image_list.append(original_rgb)
    reconstructed_img_list.append(reconstructed_img)
    grey_image_list.append(L)
    
    # Visualize original and reconstructed images
    print(image_path)
    #visualize_images(original_rgb, reconstructed_img)


visualize_images_2(original_image_list, reconstructed_img_list, grey_image_list, save_path)

/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000092602.jpg
/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000450649.jpg
/Users/nelsonfarrell/.fastai/data/coco_sample/train_sample/000000367853.jpg
