In [None]:
import random
import glob
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
import matplotlib.pyplot as plt

from generator import UNet
from gan_utils_new import *

In [None]:
torch.cuda.is_available()

In [30]:
# Image Preprocessing
def preprocess_image(image_path):
    # Load and resize the image
    img = Image.open(image_path).convert("RGB")
    img = img.resize((256, 256))

    # Convert to LAB color space
    img_lab = rgb2lab(np.array(img)).astype("float32")
    
    # Normalize L channel to range [-1, 1]
    L = img_lab[..., 0:1] / 50.0 - 1.0
    
    # ab channels should be between [-1, 1]
    ab = img_lab[..., 1:] / 110.0
    
    # Convert to tensors
    L = torch.tensor(L).permute(2, 0, 1).unsqueeze(0)  # (1, 1, 256, 256)
    ab = torch.tensor(ab).permute(2, 0, 1).unsqueeze(0)  # (1, 2, 256, 256)
    
    return L, ab, img_lab

# Inference on L channel
def run_inference_on_L(model, L):
    with torch.no_grad():
        ab_pred = model(L)
    return ab_pred

# Recompile the LAB image and convert back to RGB
def reassemble_and_convert_to_rgb(L, ab_pred):
    # Denormalize L and ab channels
    L = (L.squeeze(0).squeeze(0).cpu().numpy() + 1.0) * 50.0  # back to [0, 100] range
    ab_pred = ab_pred.squeeze(0).cpu().numpy() * 110.0  # back to [-110, 110] range
    
    # Reassemble LAB image
    lab_pred = np.concatenate([L[..., np.newaxis], ab_pred.transpose(1, 2, 0)], axis=-1)
    
    # Convert LAB to RGB
    rgb_pred = lab2rgb(lab_pred)
    return rgb_pred

# Visualize the images
def visualize_images(original_img, reconstructed_img):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(original_img)
    ax[0].set_title("Original Image")
    ax[1].imshow(reconstructed_img)
    ax[1].set_title("Reconstructed Image")
    plt.show()

In [35]:
# Load the pretrained model
model = UNet()
path_to_weights = "/home/massone.m/image_enhancement/training_runs/test_17/model_weights/generator_weights.pth"
checkpoint = torch.load(path_to_weights)
model.load_state_dict(checkpoint)
model.eval()  # Set the model to inference mode

# Load data
coco_path = "/home/massone.m/image_enhancement/train_sample"
paths = glob.glob(coco_path + "/*.jpg") 

# Get val data
num_imgs = 1000
split = 0.8
train_paths, val_paths = select_images(paths, num_imgs, split)


In [None]:
# Run the pipeline
for image_path in val_paths:

    # Preprocess the image
    L, ab, original_lab = preprocess_image(image_path)
    
    # Run inference
    ab_pred = run_inference_on_L(model, L)
    
    # Reassemble and convert to RGB
    reconstructed_img = reassemble_and_convert_to_rgb(L, ab_pred)
    
    # Convert original LAB back to RGB for comparison
    original_rgb = lab2rgb(original_lab)
    
    # Visualize original and reconstructed images
    visualize_images(original_rgb, reconstructed_img)


