In [4]:
import os

os.chdir("/Users/esti/Documents/PROYECTOS/TFM-IVANHIDALGO/MicroscoPy/")

import numpy as np
import torch
from tqdm import tqdm

from skimage import metrics as skimage_metrics
from skimage.util import img_as_ubyte
import copy
from skimage import io

# LPIPS metrics with AlexNet and VGG
import lpips
lpips_alex = lpips.LPIPS(net="alex", version="0.1")
lpips_vgg = lpips.LPIPS(net="vgg", version="0.1")

# Nanopyx metrics: Error map (RSE and RSP) and decorrelation analysis 
from nanopyx.core.transform.error_map import ErrorMap
from nanopyx.core.analysis.decorr import DecorrAnalysis

# ILNIQE (in a local file)
import sys
sys.path.append("/Users/esti/Documents/PROYECTOS/TFM-IVANHIDALGO/MicroscoPy")
sys.path.append("/Users/esti/Documents/PROYECTOS/TFM-IVANHIDALGO/MicroscoPy/microscopy/")
from ILNIQE import calculate_ilniqe
#from IL-NIQUE import calculate_ilniqe


def merge_lists(list1, list2):
    return list1 + list2

def merge_dict(dict1, dict2):
    # Merge dictionaries with complex data structures
    merged_dict = {}
    for key in set(dict1) | set(dict2):
        if key in dict1 and key in dict2:
            if isinstance(dict1[key], list) and isinstance(dict2[key], list):
                merged_dict[key] = merge_lists(dict1[key], dict2[key])
            elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
                merged_dict[key] = merge_nested_dicts(copy.deepcopy(dict1[key]), dict2[key])
            else:
                merged_dict[key] = dict2[key]
        elif key in dict1:
            merged_dict[key] = copy.deepcopy(dict1[key])
        else:
            merged_dict[key] = copy.deepcopy(dict2[key])
    
    # Output the merged dictionary
    return(merged_dict)
def unint16touint8(im):
    im = im.astype(np.float32)
    im = (im*255)/((2^16)-1)
    return im.astype(np.uint8)

def evaluate_dataset(gt_dir, predict_dir, input_dir):
    im_list = os.listdir(gt_dir)
    for i in tqdm(range(len(im_list))):
        
        gt_image = io.imread(os.path.join(gt_dir, im_list[i]))
        gt_image = [min_max_normalization(gt_image)]
        
        predicted_image = io.imread(os.path.join(predict_dir, im_list[i]))
        predicted_image = [min_max_normalization(predicted_image)]
        
        input_image = io.imread(os.path.join(input_dir, im_list[i]))
        input_image = [min_max_normalization(input_image)]
        
        im_metrics = obtain_metrics(gt_image, predicted_image, input_image)
        im_metrics["files"] = [im_list[i]]
        if i == 0:
            metrics_dict = im_metrics
        else:
            metrics_dict = merge_dict(metrics_dict, im_metrics)

    return metrics_dict
#
# Functions that define normalization tecniques 
# TODO: to import from utils
def min_max_normalization(data, desired_accuracy=np.float32):
    """
    Normalize the given data using min-max normalization.

    Parameters:
        data (ndarray): The data to be normalized.
        desired_accuracy (type): The desired accuracy of the normalized data. Defaults to np.float32.

    Returns:
        ndarray: The normalized data.
    """
    return (data - data.min()) / (data.max() - data.min() + 1e-10).astype(
        desired_accuracy
    )
    

def obtain_metrics(gt_image_list, predicted_image_list, wf_image_list, test_metric_indexes=[]):
    """
    Calculate various metrics for evaluating the performance of an image prediction model.

    Args:
        gt_image_list (List[np.ndarray]): A list of ground truth images.
        predicted_image_list (List[np.ndarray]): A list of predicted images with the same name as the ground truth images.
        wf_image_list (List[np.ndarray]): A list of input images with the same name as the ground truth and predicted images.
        test_metric_indexes (List[int]): A list of indexes to calculate additional metrics.

    Returns:
        dict: A dictionary containing different metrics as keys and their corresponding values as lists.

    Raises:
        AssertionError: If the minimum value of the wavefront image is greater than 0 or the maximum value is less than 0.

    Note:
        This function uses various image metrics including MSE, SSIM, PSNR, GT RSE, GT RSP, Pred RSE, Pred RSP, and Decorrelation.
        It also calculates metrics using the LPIPS (Learned Perceptual Image Patch Similarity) model, ILNIQE (Image Lab Non-Reference Image Quality Evaluation), and other metrics.
        The calculated metrics are stored in a dictionary with the metric names as keys and lists of values as their corresponding values.
    """
    metrics_dict = {
        "ssim": [],
        "psnr": [],
        "mse": [],
        "alex": [],
        "vgg": [],
        "ilniqe": [],
        "fsim": [],
        "gmsd": [],
        "vsi": [],
        "haarpsi": [],
        "mdsi": [],
        "pieapp": [],
        "dists": [],
        "brisqe": [],
        "fid": [],
        "gt_rse":[],
        "gt_rsp":[],
        "pred_rse":[],
        "pred_rsp":[],
        "decor":[]
    }

    test_data_length = len(gt_image_list)
    for i in tqdm(range(test_data_length)):
        
        # Load the widefield image, ground truth image, and predicted image
        gt_image = np.squeeze(gt_image_list[i]) # gt_image_list[i][:, :, 0]
        predicted_image = np.squeeze(predicted_image_list[i]) # predicted_image_list[i][:, :, 0]
        wf_image = np.squeeze(wf_image_list[i]) # wf_image_list[i][:, :, 0]


        # Print info about the images
        print(
            f"gt_image shape: {gt_image.shape} - intensity range: {gt_image.min()} {gt_image.max()} - data type {gt_image.dtype}"
        )
        print(
            f"predicted_image shape: {predicted_image.shape} - intensity range: {predicted_image.min()} {predicted_image.max()} - data type {predicted_image.dtype}"
        )
        print(
            f"wf_image shape: {wf_image.shape} - intensity range: {wf_image.min()} {wf_image.max()} - data type {wf_image.dtype}"
        )


        # Convert the Numpy images into Pytorch tensors
        # Pass the images into Pytorch format (1, 1, X, X)
        gt_image_piq = np.expand_dims(gt_image, axis=[0, 1])
        predicted_image_piq = np.expand_dims(predicted_image, axis=[0, 1])
        
        # Pytorch does not support uint16
        if gt_image_piq.dtype == np.uint16:
            gt_image_piq = unint16touint8(gt_image_piq)
        if predicted_image_piq.dtype == np.uint16:
            predicted_image_piq = unint16touint8(predicted_image_piq)
            
        # Convert the images into Pytorch tensors
        gt_image_piq = torch.from_numpy(gt_image_piq)
        predicted_image_piq = torch.from_numpy(predicted_image_piq)

        
        # Assert that there are no negative values
        assert wf_image.min() >= 0. and wf_image.max() >= 0.

        # In case all the predicted values are equal (all zeros for example)
        all_equals = np.all(predicted_image==np.ravel(predicted_image)[0])

    
        #####################################
        #
        # Calculate the skimage metrics
        print("Calculating standard pixel based metrics")
        print("__________________________________________")
        metrics_dict["mse"].append(
            skimage_metrics.mean_squared_error(gt_image, predicted_image)
        )

        metrics_dict["ssim"].append(
            skimage_metrics.structural_similarity(
                predicted_image, gt_image, data_range=1.0
            )
        )
        metrics_dict["psnr"].append(
            skimage_metrics.peak_signal_noise_ratio(gt_image, predicted_image)
        )

        #
        #####################################

        #####################################
        #
        # Calculate the LPIPS metrics
        print("Calculating LIPIPS metric")
        print("______________________________")
        metrics_dict["alex"].append(
                np.squeeze(
                    lpips_alex(gt_image_piq.float(), predicted_image_piq.float())
                    .detach()
                    .numpy()
                )
            )
        metrics_dict["vgg"].append(
            np.squeeze(
                lpips_vgg(gt_image_piq.float(), predicted_image_piq.float())
                .detach()
                .numpy()
            ))
        
        
        #####################################

        #####################################
        #
        # Calculate the Nanopyx metrics
        print("Calculating Nanopyx metrics")
        print("______________________________")
        error_map = ErrorMap()
        error_map.optimise(wf_image, gt_image)
        metrics_dict["gt_rse"].append(
            error_map.getRSE()
        )
        metrics_dict["gt_rsp"].append(
            error_map.getRSP()
        )

        if not all_equals:
            error_map = ErrorMap()
            error_map.optimise(wf_image, predicted_image)
            metrics_dict["pred_rse"].append(
                error_map.getRSE()
            )
            metrics_dict["pred_rsp"].append(
                error_map.getRSP()
            )
        else: 
            metrics_dict["pred_rse"].append(np.nan)
            metrics_dict["pred_rsp"].append(np.nan)

        if not all_equals:
            decorr_calculator_raw = DecorrAnalysis()
            decorr_calculator_raw.run_analysis(predicted_image)
            metrics_dict["decor"].append(
                decorr_calculator_raw.resolution
            )
        else: 
            metrics_dict["decor"].append(np.nan)

        #####################################

        #####################################
        #
        # Calculate the ILNIQE
        print("Calculating IL-NIQUE metric")
        print("______________________________")
        # Temporally commented to avoid long evaluation times (83 seconds for each image)
        if not all_equals:
            metrics_dict['ilniqe'].append(calculate_ilniqe(img_as_ubyte(predicted_image), 0,
                                            input_order='HW', resize=True, version='python'))
        else: 
            metrics_dict['ilniqe'].append(np.nan)

        
        #####################################

        
        if i in test_metric_indexes:
            # In case you want to calculate in specific images (a reduced number to avoid time issues)
            pass
        

    return metrics_dict

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /Users/esti/mambaforge/envs/microscopy/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /Users/esti/mambaforge/envs/microscopy/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth




In [None]:
gt_dir = "/Users/esti/Documents/PROYECTOS/TFM-IVANHIDALGO/toy_data/MT-SMLM-registered/gt"
input_dir = "/Users/esti/Documents/PROYECTOS/TFM-IVANHIDALGO/toy_data/MT-SMLM-registered/wf/"
predict_dir = "/Users/esti/Documents/PROYECTOS/TFM-IVANHIDALGO/toy_data/MT-SMLM-registered/rcan/epc200_btch4_lr0.0001_optim-adam_lrsched-ReduceOnPlateau_seed666_1/prediction/"
evaluate_dataset(gt_dir, predict_dir, input_dir)

  0%|                                                                                                                                                           | 0/12 [00:00<?, ?it/s]
  0%|                                                                                                                                                            | 0/1 [00:00<?, ?it/s][A

gt_image shape: (2048, 2048) - intensity range: 0.0 1.0 - data type float32
predicted_image shape: (2048, 2048) - intensity range: 0.0 1.0 - data type float32
wf_image shape: (256, 256) - intensity range: 0.0 1.0 - data type float32
Calculating standard pixel based metrics
__________________________________________
Calculating LIPIPS metric
______________________________
Calculating Nanopyx metrics
______________________________
Querying the Agent...
Agent: ShiftMagnify_CR using OpenCL_Apple M1 ran in 0.06551862499327399 seconds
Querying the Agent...
Agent: ShiftMagnify_CR using OpenCL_Apple M1 ran in 0.017144000012194738 seconds
Calculating IL-NIQUE metric
______________________________



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:54<00:00, 114.47s/it][A
  8%|████████████▏                                                                                                                                     | 1/12 [01:54<20:59, 114.51s/it]
  0%|                                                                                                                                                            | 0/1 [00:00<?, ?it/s][A

gt_image shape: (2048, 2048) - intensity range: 0.0 1.0 - data type float32
predicted_image shape: (2048, 2048) - intensity range: 0.0 1.0 - data type float32
wf_image shape: (256, 256) - intensity range: 0.0 1.0 - data type float32
Calculating standard pixel based metrics
__________________________________________
Calculating LIPIPS metric
______________________________
Calculating Nanopyx metrics
______________________________
Querying the Agent...
Agent: ShiftMagnify_CR using OpenCL_Apple M1 ran in 0.018396249986835755 seconds
Querying the Agent...
Agent: ShiftMagnify_CR using OpenCL_Apple M1 ran in 0.019851499993819743 seconds
Calculating IL-NIQUE metric
______________________________



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [02:10<00:00, 130.21s/it][A
 17%|████████████████████████▎                                                                                                                         | 2/12 [04:04<20:37, 123.78s/it]
  0%|                                                                                                                                                            | 0/1 [00:00<?, ?it/s][A

gt_image shape: (2048, 2048) - intensity range: 0.0 1.0 - data type float32
predicted_image shape: (2048, 2048) - intensity range: 0.0 1.0 - data type float32
wf_image shape: (256, 256) - intensity range: 0.0 1.0 - data type float32
Calculating standard pixel based metrics
__________________________________________
Calculating LIPIPS metric
______________________________
Calculating Nanopyx metrics
______________________________
Querying the Agent...
Agent: ShiftMagnify_CR using OpenCL_Apple M1 ran in 0.029432915995130315 seconds
Querying the Agent...
Agent: ShiftMagnify_CR using OpenCL_Apple M1 ran in 0.016967749994364567 seconds
Calculating IL-NIQUE metric
______________________________



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:43<00:00, 103.27s/it][A
 25%|████████████████████████████████████▌                                                                                                             | 3/12 [05:48<17:09, 114.44s/it]
  0%|                                                                                                                                                            | 0/1 [00:00<?, ?it/s][A

gt_image shape: (4000, 4000) - intensity range: 0.0 1.0 - data type float32
predicted_image shape: (4000, 4000) - intensity range: 0.0 1.0 - data type float32
wf_image shape: (500, 500) - intensity range: 0.0 1.0 - data type float32
Calculating standard pixel based metrics
__________________________________________
Calculating LIPIPS metric
______________________________
