# Calculating PSNR for denoised results

In [None]:
import os

import tifffile
from skimage.metrics import peak_signal_noise_ratio as PSNR
import numpy as np
import matplotlib.pyplot as plt

### Load denoised images

In [None]:
model_name = "convallaria"
mmse_file = os.path.join("results", model_name, "MMSEs.tif")
mmse_data = tifffile.imread(mmse_file)
samples_file = os.path.join("results", model_name, "samples.tif")
samples_data = tifffile.imread(samples_file)

### Load ground truth data
In this example, the ground truth is obtained by averaging the noisy dataset.

In [None]:
data_path = "data/flower.tif"
low_snr = tifffile.imread(data_path).astype(float)[:, np.newaxis, :512, :512]
gt_data = low_snr.mean(axis=0, keepdims=True)
gt_data = gt_data.repeat(low_snr.shape[0], axis=0)

In [None]:
psnrs = []
data_range = np.max(gt_data) - np.min(gt_data)
for i in range(len(mmse_data)):
    gt = gt_data[i]
    test = mmse_data[i]

    psnrs.append(PSNR(gt, test, data_range=data_range.item()))

print(f"PSNR: {np.mean(psnrs)}")


### Visualize results

In [None]:
idx = 0
img_patch = (0, slice(200, 300), slice(200, 300))

fig, ax = plt.subplots(2, 2)

ax[0, 0].imshow(low_snr[idx][img_patch], cmap="inferno")
ax[0, 0].set_title("Low SNR")
ax[0, 0].axis("off")   

ax[0, 1].imshow(gt_data[idx][img_patch], cmap="inferno")
ax[0, 1].set_title("Ground Truth")
ax[0, 1].axis("off")

ax[1, 0].imshow(mmse_data[idx][img_patch], cmap="inferno")
ax[1, 0].set_title("MMSE")
ax[1, 0].axis("off")


ax[1, 1].imshow(samples_data[0][idx][img_patch], cmap="inferno")
ax[1, 1].set_title("Sample")
ax[1, 1].axis("off")

plt.tight_layout()