In [None]:
import cv2
import numpy as np
import torch
import numpy as np
from skimage.metrics import structural_similarity as ssim

try:
    import lpips
    from skimage.metrics import structural_similarity as cal_ssim
except:
    lpips = None
    cal_ssim = None


def rescale(x):
    return (x - x.max()) / (x.max() - x.min()) * 2 - 1

def _threshold(x, y, t):
    t = np.greater_equal(x, t).astype(np.float32)
    p = np.greater_equal(y, t).astype(np.float32)
    is_nan = np.logical_or(np.isnan(x), np.isnan(y))
    t = np.where(is_nan, np.zeros_like(t, dtype=np.float32), t)
    p = np.where(is_nan, np.zeros_like(p, dtype=np.float32), p)
    return t, p

def MSE_with_error(pred, true, spatial_norm=False):
    if not spatial_norm:
        mse = np.mean((pred - true) ** 2, axis=(0, 1)).sum()
        standard_error = np.std((pred - true) ** 2, axis=(0, 1)).sum() / np.sqrt(pred.shape[0])
    else:
        norm = pred.shape[-1] * pred.shape[-2] * pred.shape[-3]
        mse = np.mean((pred - true) ** 2 / norm, axis=(0, 1)).sum()
        standard_error = np.std((pred - true) ** 2 / norm, axis=(0, 1)).sum() / np.sqrt(pred.shape[0])

    return mse, standard_error




def PSNR_with_error(pred, true, min_max_norm=True):
    """Peak Signal-to-Noise Ratio with error bar."""
    mse = np.mean((pred.astype(np.float32) - true.astype(np.float32)) ** 2, axis=(0,1))

    if np.any(mse == 0):
        return float('inf'), 0  
    else:
        if min_max_norm:  # [0, 1] normalized by min and max
            psnr_values = 20. * np.log10(1. / np.sqrt(mse))  
        else:
            psnr_values = 20. * np.log10(255. / np.sqrt(mse))  # [-1, 1] normalized by mean and std

        psnr_mean = np.mean(psnr_values)

        standard_error = np.std(psnr_values) / np.sqrt(pred.shape[0])

    return psnr_mean, standard_error


def calculate_ssim_with_error(pred, true):
    total_ssim = 0
    ssim_values = []
    

    for b in range(pred.shape[0]):
        for c in range(pred.shape[1]):
            pred_frame = pred[b, c]
            true_frame = true[b, c]
            ssim_value = ssim(pred_frame, true_frame)
            total_ssim += ssim_value
            ssim_values.append(ssim_value)

    ssim_values = np.array(ssim_values)

    average_ssim = total_ssim / (pred.shape[0] * pred.shape[1])
    

    standard_error = np.std(ssim_values, ddof=1).sum() / np.sqrt(pred.shape[0])

    return average_ssim, standard_error