In [1]:
import os

In [2]:
import sys
sys.path.insert(0, f'../../')
sys.path.insert(0, f'../')

In [3]:
import torch
import cv2
import matplotlib.pyplot as plt
import numpy as np
from torchmetrics.functional import (
    structural_similarity_index_measure,
    peak_signal_noise_ratio,
)

import clip

import pandas as pd
from PIL import Image

from tqdm import tqdm
%matplotlib inline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# clip model setup
model, preprocess = clip.load("ViT-B/32", device=device)

In [5]:
from waternet.data import transform as preprocess_transform
from waternet.training_utils import arr2ten
from waternet.net import WaterNet
from configs.constants import contrastive_pairs

In [6]:
def arr2ten_noeinops(arr):
    """Converts (N)HWC numpy array into torch Tensor:
    1. Divide by 255
    2. Rearrange dims: HWC -> 1CHW or NHWC -> NCHW
    """
    ten = torch.from_numpy(arr) / 255
    if len(ten.shape) == 3:
        # ten = rearrange(ten, "h w c -> 1 c h w")
        ten = torch.permute(ten, (2, 0, 1))
        ten = torch.unsqueeze(ten, dim=0)
    elif len(ten.shape) == 4:
        # ten = rearrange(ten, "n h w c -> n c h w")
        ten = torch.permute(ten, (0, 3, 1, 2))
    return ten

def pre_process(rgb_arr, ref):
    wb, gc, he = preprocess_transform(rgb_arr)
    rgb_ten = arr2ten_noeinops(rgb_arr)
    wb_ten = arr2ten_noeinops(wb)
    gc_ten = arr2ten_noeinops(gc)
    he_ten = arr2ten_noeinops(he)
    ref_ten = arr2ten_noeinops(ref)
    return rgb_ten, wb_ten, he_ten, gc_ten, ref_ten
    
def post_process(ten):
    arr = ten.cpu().detach().numpy()
    arr = np.clip(arr, 0, 1)
    # arr = arr - np.min(arr)
    # arr = arr / np.max(arr)
    arr = (arr * 255).astype(np.uint8)
    # arr = rearrange(arr, "n c h w -> n h w c")
    arr = np.transpose(arr, (0, 2, 3, 1))
    return arr

In [8]:
flatten_pairs = np.ravel(contrastive_pairs)

kinds = {
    "base": "weights/pretrained/waternet.pt",
    "vivid_mid": "weights/color-enhanced.pt",
    "color_cast": "weights/wb-enhanced.pt",
    "exposure": "weights/expo-enhanced.pt",
    "all": "weights/all-enhanced.pt",
}

waternets = []
for _, key in enumerate(kinds):
    waternet = WaterNet()
    check_point = torch.load(f'../{kinds[key]}')
    waternet.load_state_dict(check_point)
    waternet.eval()
    waternet = waternet.to(device)
    waternets.append(waternet)

In [9]:
# need setup lsui data set first: get_data("lsui")
lsui_files = os.listdir("./lsui/GT")
lsui_files.sort(key=lambda x:int(x[:-4]))
lsui_gts = [os.path.join("./lsui/GT", _) for _ in lsui_files]
lsui_raws = [os.path.join("./lsui/input", _) for _ in lsui_files]

In [10]:
def clip_process_one(output_image_path):
    # clip evaluations:
    image = preprocess(Image.open(output_image_path)).unsqueeze(0).to(device)
    text = clip.tokenize(flatten_pairs).to(device)

    logits_per_image, _ = model(image, text)
    logits_per_image = logits_per_image.view(len(contrastive_pairs), 2)
    prob = logits_per_image.softmax(dim=-1).cpu().numpy()
    return prob.flatten()

def process_one(raw, gt):
    basename = os.path.basename(raw)
    # preprocessing
    im = cv2.imread(raw)
    rgb_im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    # rgb_im = cv2.resize(rgb_im, (320, 240))
    
    # gt load
    gt_im = cv2.imread(gt)
    gt_im = cv2.cvtColor(gt_im, cv2.COLOR_BGR2RGB)
    # Resize image
    # gt_im = cv2.resize(gt_im, (320, 240))
    
    # preprocess:
    rgb_ten, wb_ten, he_ten, gc_ten, ref_ten = pre_process(rgb_im, gt_im)
    rgb_ten, wb_ten, he_ten, gc_ten, ref_ten = rgb_ten.to(device), wb_ten.to(device), he_ten.to(device), gc_ten.to(device), ref_ten.to(device)
    
    # generations: psnr + ssim
    enhanced_imgs = []
    psnrs = []
    ssims = []
    probs = []
    for ind, key in enumerate(kinds):
        waternet = waternets[ind]
        out_ten = waternet(rgb_ten, wb_ten, he_ten, gc_ten)
        
        ssim = structural_similarity_index_measure(preds=out_ten, target=ref_ten).item()
        ssims.append(ssim)
        
        psnr = peak_signal_noise_ratio(preds=out_ten, target=ref_ten, data_range=1 - 0).item()
        psnrs.append(psnr)

        out_im = post_process(out_ten)[0]
        path = f"./lsui/{key}"
        if not os.path.exists(path):
            os.makedirs(path)
            print(f"make dir path: {path}")
            
        output_image_path = f"{path}/{basename}"
        cv2.imwrite(output_image_path, cv2.cvtColor(out_im, cv2.COLOR_RGB2BGR))
    
        prob = clip_process_one(output_image_path)
        probs.append(prob)
        
    raw_prob = clip_process_one(raw)
    probs.append(raw_prob)
        
    gt_prob = clip_process_one(gt)
    probs.append(gt_prob)
    
    return psnrs, ssims, probs

In [11]:
def process_all():
    # Create empty DataFrames
    columns = ['image', 'psnr', 'ssim'] + [i for i in flatten_pairs]
    dfs = {kind: pd.DataFrame(columns=columns) for kind in kinds}
    dfs["RAW"] = pd.DataFrame(columns=columns)
    dfs["GT"] = pd.DataFrame(columns=columns)

    # Process each image pair
    with torch.no_grad():
        for raw_path, ref_path in tqdm(zip(lsui_raws, lsui_gts), total=len(lsui_raws)):
            psnrs, ssims, probs = process_one(raw_path, ref_path)
            
            # Extract image file name or identifier from raw_path
            image_name = os.path.basename(raw_path)

            # Fill DataFrames for each kind
            for kind_index, kind in enumerate(kinds):
                data = {
                    'image': image_name,
                    'psnr': psnrs[kind_index],
                    'ssim': ssims[kind_index]
                }
                
                # Add probability values to the data dictionary
                for pair_index, pair in enumerate(flatten_pairs):
                    data[pair] = probs[kind_index][pair_index]

                # Append the data as a new row in the corresponding DataFrame
                dfs[kind] = dfs[kind].append(data, ignore_index=True)

                
            # Assuming the last element in probs is for GT
            raw_data = {
                'image': image_name,
                'psnr': None,  # or appropriate value
                'ssim': None,  # or appropriate value
            }
            for pair_index, pair in enumerate(flatten_pairs):
                raw_data[pair] = probs[-2][pair_index]
            dfs["RAW"] = dfs["RAW"].append(raw_data, ignore_index=True)
            
            # Assuming the last element in probs is for GT
            gt_data = {
                'image': image_name,
                'psnr': None,  # or appropriate value
                'ssim': None,  # or appropriate value
            }
            for pair_index, pair in enumerate(flatten_pairs):
                gt_data[pair] = probs[-1][pair_index]
            dfs["GT"] = dfs["GT"].append(gt_data, ignore_index=True)

    return dfs

In [12]:
dfs = process_all()



make dir path: ./lsui/all0_best


  0%|          | 1/4279 [00:01<1:16:27,  1.07s/it]

make dir path: ./lsui/all0_last


100%|██████████| 4279/4279 [12:28<00:00,  5.72it/s]


In [13]:
def save_dfs_to_csv(dfs, output_directory):
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    for kind, df in dfs.items():
        csv_file_path = os.path.join(output_directory, f"{kind}_results.csv")
        df.to_csv(csv_file_path, index=False)
        print(f"Saved DataFrame for {kind} to {csv_file_path}")

# Usage example
output_dir = './lsui/'
save_dfs_to_csv(dfs, output_dir)

Saved DataFrame for all0_best to ./lsui/all0_best_results.csv
Saved DataFrame for all0_last to ./lsui/all0_last_results.csv
Saved DataFrame for RAW to ./lsui/RAW_results.csv
Saved DataFrame for GT to ./lsui/GT_results.csv


In [14]:
wbs = ['Accurate Color Representation']
colors = ['Vibrant and vivid', 'Bright, Colorful Underwater Scene',]
exps = ['Crystal-clear and Unobstructed Scene', 'Balanced and Well-Lit view']
cols = ["psnr", "ssim"]

In [15]:
base_df = dfs["base"]
cl_df = dfs["vivid_mid"]
ex_df = dfs["exposure"]
wb_df = dfs["color_cast"]
selected_dfs = {
    "base": base_df,
    "cl": cl_df,
    "wb": wb_df,
    "exp": ex_df,
    "all_last": dfs['all0_best'],
    "all_best": dfs['all0_last'],
}

In [16]:
# Initialize a dictionary to hold the aggregated data
aggregated_data = {
    "model": [],
    "psnr": [],
    "ssim": [],
    "colorfulness": [],
    "white balance": [],
    "exposure": []
}

# Iterate over each model to compute the scores
for model_name, df in selected_dfs.items():
    aggregated_data["model"].append(model_name)
    aggregated_data["psnr"].append(df["psnr"].mean())
    aggregated_data["ssim"].append(df["ssim"].mean())
    aggregated_data["colorfulness"].append(df[colors].mean(axis=1).mean())
    aggregated_data["white balance"].append(df[wbs].mean(axis=1).mean())
    aggregated_data["exposure"].append(df[exps].mean(axis=1).mean())

# Create the final DataFrame
summary_df = pd.DataFrame(aggregated_data)
# summary_df = summary_df.set_index("model")

In [17]:
# Performances: models trained with all prompts for 50 epochs
summary_df.round(3)

Unnamed: 0,model,psnr,ssim,colorfulness,white balance,exposure
0,all_last,18.985,0.771,0.62207,0.704102,0.573242
1,all_best,18.918,0.772,0.618164,0.690918,0.576172


In [56]:
# Performances: models trained with cl/wb/exp prompts for 20 epochs
summary_df.round(3)

Unnamed: 0,model,psnr,ssim,colorfulness,white balance,exposure
0,base,22.154,0.85,0.489014,0.354004,0.547852
1,cl,20.421,0.822,0.682129,0.310059,0.558105
2,wb,20.49,0.839,0.332031,0.548828,0.559082
3,exp,21.202,0.83,0.406982,0.22998,0.618164


In [21]:
def calculate_means(df, columns, skip_columns=None):
    skip_columns = skip_columns or []
    means = {}
    for col in columns:
        if col in skip_columns:
            means[col] = "--"
        else:
            means[col] = df[col].mean()
    return means

# Columns to include PSNR and SSIM
extended_columns = ['psnr', 'ssim'] + ['Accurate Color Representation',
                                       # 'Accurate and Natural Color Representation', 
                                       # 'Harmonious, Suitable, and Aesthetically Pleasing Color Representation',
                                       'Vibrant and vivid', 'Bright, Colorful Underwater Scene',
                                       'Crystal-clear and Unobstructed Scene',
                                       'Balanced and Well-Lit view', 
                                       # 'Richly detailed',
                                       # 'Sharp Aquatic Details'
                                      ]

# Initialize an empty DataFrame for aggregated means
aggregated_means = pd.DataFrame()

for kind, df in dfs.items():
    # Specify columns to skip for RAW and GT
    skip_cols = ['psnr', 'ssim'] if kind in ['RAW', 'GT'] else []

    # Calculate means for the specified columns
    means = calculate_means(df, extended_columns, skip_columns=skip_cols)
    means_row = pd.Series(means, name=kind)

    # Append the data as a new row in the DataFrame
    aggregated_means = aggregated_means.append(means_row)

# Reset the index to have 'kind' as a column
aggregated_means.reset_index(inplace=True)
aggregated_means.rename(columns={'index': 'kind'}, inplace=True)

In [22]:
# # Performances: models trained with all prompts for 50 epochs
aggregated_means

Unnamed: 0,kind,psnr,ssim,Accurate Color Representation,Vibrant and vivid,"Bright, Colorful Underwater Scene",Crystal-clear and Unobstructed Scene,Balanced and Well-Lit view
0,all0_best,18.984898,0.770592,0.70459,0.634277,0.610352,0.859863,0.285645
1,all0_last,18.918308,0.772224,0.690918,0.625488,0.609863,0.855469,0.295898
2,RAW,--,--,0.236816,0.509766,0.537598,0.813477,0.219604
3,GT,--,--,0.393555,0.542969,0.560059,0.779297,0.303711


In [18]:
# # Performances: models trained with individual prompts for 20 epochs
aggregated_means

Unnamed: 0,kind,psnr,ssim,Accurate Color Representation,Accurate and Natural Color Representation,"Harmonious, Suitable, and Aesthetically Pleasing Color Representation",Vibrant and vivid,"Bright, Colorful Underwater Scene",Crystal-clear and Unobstructed Scene,Balanced and Well-Lit view,Richly detailed,Sharp Aquatic Details
0,base,22.153838,0.849867,0.354248,0.784668,0.570801,0.460449,0.516602,0.838867,0.256836,0.271484,0.869141
1,vivid_mid,20.421405,0.822183,0.310059,0.763184,0.578613,0.658203,0.705566,0.859863,0.256348,0.337402,0.885742
2,vivid_final,18.965349,0.794694,0.178711,0.688965,0.476074,0.649902,0.711914,0.86084,0.227905,0.318359,0.847656
3,color_comfort,20.858517,0.832259,0.341064,0.765137,0.5625,0.335449,0.439697,0.875488,0.226196,0.324463,0.918945
4,color_cast,20.489699,0.83875,0.548828,0.711914,0.505371,0.323486,0.340332,0.855469,0.261963,0.278564,0.904785
5,color_final,21.753921,0.867708,0.360107,0.827148,0.69043,0.417236,0.504395,0.838867,0.279785,0.342773,0.919434
6,exposure,21.202268,0.829584,0.229614,0.731445,0.589355,0.372803,0.441162,0.904297,0.333008,0.425537,0.938965
7,vivid_last,20.357266,0.829908,0.266602,0.773926,0.598633,0.621582,0.663086,0.872559,0.261719,0.322021,0.893555
8,sharpness,21.010939,0.841984,0.243652,0.726074,0.626953,0.359131,0.395752,0.861816,0.266846,0.507812,0.945801
9,all,19.58167,0.808633,0.160278,0.769531,0.670898,0.448975,0.491943,0.88916,0.423584,0.569336,0.957031
