In [None]:
import os

In [None]:
import sys
sys.path.insert(0, f'../../')
sys.path.insert(0, f'../')

In [None]:
import torch
import cv2
import matplotlib.pyplot as plt
import numpy as np
from torchmetrics.functional import (
    structural_similarity_index_measure,
    peak_signal_noise_ratio,
)

import clip

import pandas as pd
from PIL import Image

from tqdm import tqdm
%matplotlib inline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# clip model setup
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
image = preprocess(Image.open("./outputs/base_ScubaDiver.jpg")).unsqueeze(0).to(device)
text = clip.tokenize(    ["Vibrant and vivid", "Dull and washed-out"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs) 

In [None]:
from waternet.data import transform as preprocess_transform
from waternet.training_utils import arr2ten
from waternet.net import WaterNet
from configs.constants import contrastive_pairs

In [None]:
def arr2ten_noeinops(arr):
    """Converts (N)HWC numpy array into torch Tensor:
    1. Divide by 255
    2. Rearrange dims: HWC -> 1CHW or NHWC -> NCHW
    """
    ten = torch.from_numpy(arr) / 255
    if len(ten.shape) == 3:
        # ten = rearrange(ten, "h w c -> 1 c h w")
        ten = torch.permute(ten, (2, 0, 1))
        ten = torch.unsqueeze(ten, dim=0)
    elif len(ten.shape) == 4:
        # ten = rearrange(ten, "n h w c -> n c h w")
        ten = torch.permute(ten, (0, 3, 1, 2))
    return ten

def pre_process(rgb_arr, ref):
    wb, gc, he = preprocess_transform(rgb_arr)
    rgb_ten = arr2ten_noeinops(rgb_arr)
    wb_ten = arr2ten_noeinops(wb)
    gc_ten = arr2ten_noeinops(gc)
    he_ten = arr2ten_noeinops(he)
    ref_ten = arr2ten_noeinops(ref)
    return rgb_ten, wb_ten, he_ten, gc_ten, ref_ten
    
def post_process(ten):
    arr = ten.cpu().detach().numpy()
    arr = np.clip(arr, 0, 1)
    # arr = arr - np.min(arr)
    # arr = arr / np.max(arr)
    arr = (arr * 255).astype(np.uint8)
    # arr = rearrange(arr, "n c h w -> n h w c")
    arr = np.transpose(arr, (0, 2, 3, 1))
    return arr

def process(img, waternet):
    rgb_ten, wb_ten, he_ten, gc_ten, _ = pre_process(img, img)
    rgb_ten, wb_ten, he_ten, gc_ten = rgb_ten.to(device), wb_ten.to(device), he_ten.to(device), gc_ten.to(device)
    out_ten = waternet(rgb_ten, wb_ten, he_ten, gc_ten)
    return post_process(out_ten)[0]

In [None]:
# WB, CL, LIT
WB = [0, 1, 2]
CL = [3, 4]
LIT = [5, 6]

flatten_pairs = np.ravel(contrastive_pairs)
good_prompts = flatten_pairs[::2]
WB_prompts = good_prompts[WB]
CL_prompts = good_prompts[CL]
LIT_prompts = good_prompts[LIT]

print(WB_prompts)
print(CL_prompts)
print(LIT_prompts)

In [None]:
kinds = {
    "base": "weights/pretrained/waternet.pt",
    "vivid_mid": "weights/color-enhanced.pt",
    "color_cast": "weights/wb-enhanced.pt",
    "exposure": "weights/expo-enhanced.pt",
    "all": "weights/all-enhanced.pt",
}

waternets = []
for _, key in enumerate(kinds):
    waternet = WaterNet()
    check_point = torch.load(f'../{kinds[key]}')
    waternet.load_state_dict(check_point)
    waternet.eval()
    waternet = waternet.to(device)
    waternets.append(waternet)

## LSUI Dataset Evaluation

In [None]:
# need setup lsui data set first: get_data("lsui")
lsui_files = os.listdir("./lsui/GT")
lsui_files.sort(key=lambda x:int(x[:-4]))
lsui_gts = [os.path.join("./lsui/GT", _) for _ in lsui_files]
lsui_raws = [os.path.join("./lsui/input", _) for _ in lsui_files]

In [None]:
dfs = {}
output_directory = './lsui/'
for kind, _ in kinds.items():
    csv_file_path = os.path.join(output_directory, f"{kind}_results.csv")
    df = pd.read_csv(csv_file_path)
    dfs[kind] = df

In [None]:
def load_image(kind, raw_path):
    basename = os.path.basename(raw_path)
    path = f'./lsui/{kind}/{basename}'
    img = cv2.imread(path)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def display_images(raw_path, gt_path=None, _load_image=True, _save_image=False):
    basename = os.path.basename(raw_path)
    raw = cv2.imread(raw_path)
    raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
    
    # Prepare the figure
    fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(25, 18))
    axs = axs.flatten()  # Flatten the array for easy indexing

    # Display the source image
    axs[0].imshow(raw)
    axs[0].axis("off")
    axs[0].set_title("Source")
    
    offset = 1
    
    if gt_path is not None:
        gt = cv2.imread(gt_path)
        gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
        
        axs[1].imshow(gt)
        axs[1].axis("off")
        axs[1].set_title("ref")
        offset += 1
    
    # Process and display each kind
    for index, kind in enumerate(kinds):
        if _load_image:
            img = load_image(kind, raw_path)
        else:
            img = process(raw, waternets[index])
        axs[index+offset].imshow(img)
        axs[index+offset].axis("off")
        axs[index+offset].set_title(kind)
        
        if _save_image:
            path = f"./outputs/"
            if not os.path.exists(path):
                os.makedirs(path)
                print(f"make dir path: {path}")

            output_image_path = f"{path}/{kind}_{basename}"
            cv2.imwrite(output_image_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        
    plt.tight_layout()
    plt.show()
    
def save_image(raw_path, _load_image=True):
    basename = os.path.basename(raw_path)
    path = f"./outputs/"
    img = cv2.imread(raw_path)
    cv2.imwrite(path+f"raw_{basename}", img)
    for index, kind in enumerate(kinds):
        if _load_image:
            img = load_image(kind, raw_path)
        else:
            img = process(img, waternets[index])

        if not os.path.exists(path):
            os.makedirs(path)
            print(f"make dir path: {path}")

        output_image_path = f"{path}/{kind}_{basename}"
        cv2.imwrite(output_image_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

In [None]:
to_show_images = [1211, 1311, 1395, 1636, 2902, 3090, 3151, 3198, 3222, 3331, 3357, 3456, 3645, 3937, 4008, 4992]
limited = [2054, 2055, 2057]

In [None]:
for ind in to_show_images:
    raw_path = f"./lsui/input/{ind}.jpg"
    save_image(raw_path)

In [None]:
# visualize the results
picked_ = 4992
raw_path = f"./lsui/input/{picked_}.jpg"
gt_path = f"./lsui/GT/{picked_}.jpg"
# raw_path = "./ScubaDiver.jpg"
display_images(raw_path, _load_image=False, _save_image=False)#, gt_path)