### Imports

In [None]:
from IPython.display import HTML, display

display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import glob
import json
import os
import pickle
import random
from copy import deepcopy
from typing import Literal, Optional, Union

import matplotlib
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset

from careamics.config.likelihood_model import NMLikelihoodConfig
from careamics.config.nm_model import MultiChannelNMConfig
from careamics.lightning import VAEModule
# from careamics.lvae_training.eval_utils import (
#     Calibration,
#     get_calibrated_factor_for_stdev,
    # get_dset_predictions,
#     get_eval_output_dir,
#     plot_calibration,
#     plot_error,
#     show_for_one,
#     stitch_predictions,
# )
# from careamics.models.lvae.noise_models import noise_model_factory
# from careamics.utils.metrics import (
#     # avg_psnr,
#     # avg_range_inv_psnr,
#     # avg_ssim,
#     scale_invariant_psnr,
#     # multiscale_ssim
# )

torch.multiprocessing.set_sharing_strategy('file_system')

In [3]:
def fix_seeds():
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    torch.backends.cudnn.deterministic = True

fix_seeds()

In [4]:
DATA_DIR = "/group/jug/federico/microsim/BIOSR_spectral_data/2410/v1/imgs/digital/"
OUT_ROOT = "/group/jug/federico/lambdasplit_training/"
DEBUG = False

In [5]:
ckpt_dir = os.path.join(OUT_ROOT, "2410/lambdasplit_no_LC/0")
assert os.path.exists(ckpt_dir)

### Set Evaluation Parameters

In [None]:
# Eval Parameters
mmse_count: int = 10
"""The number of predictions to average for MMSE evaluation."""
image_size_for_grid_centers: int = 32
"""The size of the portion of image we retain from inner padding/tiling."""
eval_patch_size: Optional[int] = 64
"""The actual patch size. If not specified data.image_size."""
psnr_type: Literal['simple', 'range_invariant'] = 'range_invariant'
"""The type of PSNR to compute."""
which_ckpt: Literal['best', 'last'] = 'best'
"""Which checkpoint to use for evaluation."""
enable_calibration: bool = False
"""Whether to enable calibration."""

In [None]:
# Optional other params
batch_size: int = 32
"""The batch size for training."""
num_workers: int = 4
"""The number of workers to use for data loading."""

### 1. Load configs

In [8]:
from careamics.config import (
    VAEAlgorithmConfig,
    TrainingConfig,
    DataConfig,
)
from careamics.utils.io_utils import load_config

In [9]:
if os.path.isdir(ckpt_dir):
    algo_config = VAEAlgorithmConfig(**load_config(ckpt_dir, "algorithm"))
    training_config = TrainingConfig(**load_config(ckpt_dir, "training"))
    data_config = DataConfig(**load_config(ckpt_dir, "data"))

### 2. Load dataset

In [10]:
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from careamics.dataset import InMemoryDataset

In [11]:
fnames = [
    Path(DATA_DIR) / fname 
    for fname in glob.glob(os.path.join(DATA_DIR, "*.tif"))
]
fnames = sorted(fnames, key=lambda x: int(x.stem.split("_")[-1]))

In [None]:
train_dset = InMemoryDataset(
    data_config=data_config,
    inputs=fnames,
)
val_dset = train_dset.split_dataset(percentage=0.15)

In [13]:
def unsupervised_collate_fn(batch: list[torch.Tensor, None]) -> torch.Tensor:
    inputs = [item[0] for item in batch]
    inputs = torch.stack([torch.from_numpy(input_array) for input_array in inputs], dim=0)
    return inputs, None

train_dloader = DataLoader(
    train_dset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=unsupervised_collate_fn
)
val_dloader = DataLoader(
    val_dset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers, 
    collate_fn=unsupervised_collate_fn
)

In [14]:
data_mean = data_config.image_means
data_std = data_config.image_stds

##### GT dataset (i.e., split by fluorophore)

NOTE: we cannot straightforwardly use the optical image data since images are larger (1004 vs. 251). We'd need to implement a way to dowscale them within the dataset.

In [15]:
GT_DATA_DIR = "/group/jug/federico/microsim/BIOSR_spectral_data/2410/v1/imgs/digital_pf/"

gt_fnames = [
    Path(GT_DATA_DIR) / fname 
    for fname in glob.glob(os.path.join(GT_DATA_DIR, "*.tif"))
]
gt_fnames = sorted(gt_fnames, key=lambda x: int(x.stem.split("_")[-1]))

gt_data_config = deepcopy(data_config)
gt_data_config.set_means_and_stds(image_means=None, image_stds=None)

In [None]:
gt_train_dset = InMemoryDataset(
    data_config=gt_data_config,
    inputs=gt_fnames,
)
gt_val_dset = gt_train_dset.split_dataset(percentage=0.15)

In [None]:
import matplotlib.pyplot as plt
from careamics.utils.visualization import view3D

_, ax = plt.subplots(1, 3, figsize=(10, 5))
ax[0].imshow(gt_train_dset[0][0][0])
ax[1].imshow(gt_train_dset[0][0][1])
ax[2].imshow(gt_train_dset[0][0][2])

In [None]:
view3D(train_dset[0][0], axis=0, jupyter=True)

### 3. Create model

Note: noise model and the associated likelihood are not saved in the config, hence we need to reinitialize them.

In [19]:
from torch import nn

from careamics.utils.io_utils import load_model_checkpoint

In [None]:
light_model = VAEModule(algorithm_config=algo_config)

In [None]:
checkpoint = load_model_checkpoint(ckpt_dir, which_ckpt)
light_model.load_state_dict(checkpoint['state_dict'], strict=False)
light_model.eval()
light_model.cuda()

print('Loading weights from epoch', checkpoint['epoch'])

In [None]:
def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in light_model.parameters() if p.requires_grad)

print(f'Model has {count_parameters(light_model)/1000_000:.3f}M parameters')

### 4. Evaluation

#### 4.1. Get predictions for patches

In [24]:
from careamics.utils.eval_utils import get_dset_predictions

In [None]:
pred_patches, pred_stds = get_dset_predictions(light_model, val_dloader, mmse_count)

In [None]:
pred_patches.shape, pred_stds.shape

In [None]:
view3D(pred_patches[0], axis=0, jupyter=True)

In [None]:
pred_patches

#### 4.2. Get full image predictions by stitching the predicted tiles

In [None]:
if pred_tiled.shape[-1] != val_dset.get_img_sz():
    pad = (val_dset.get_img_sz() - pred_tiled.shape[-1] )//2
    pred_tiled = np.pad(pred_tiled, ((0,0),(0,0),(pad,pad),(pad,pad)))

# Stitch tiled predictions
pred = stitch_predictions(
    pred_tiled,
    val_dset,
    smoothening_pixelcount=0
)

# # Stitch predicted tiled logvar
# if len(np.unique(logvar_tiled)) == 1:
#     logvar = None
# else:
#     logvar = stitch_predictions(logvar_tiled, val_dset, smoothening_pixelcount=0) # TODO: there's a bug here

# Stitch the std of the predictions (i.e., std computed on the mmse_count predictions)
pred_std = stitch_predictions(pred_std_tiled, val_dset, smoothening_pixelcount=0)

#### 4.3. Predictions Post-processing

Ignore (and remove) the pixels which are present in the last few rows and columns (since not multiples of patch_size)
1. They don't come in the batches. So, in prediction, they are simply zeros. So they are being are ignored right now. 
2. For the border pixels which are on the top and the left, overlapping yields worse performance. This is becuase, there is nothing to overlap on one side. So, they are essentially zero padded. This makes the performance worse. 

In [None]:
def get_ignored_pixels():
    """Get the number of ignored pixels in the predictions.
    
    Given the current predictions `pred`, analyze the first image std
    to find the number of pixels that are ignored in prediction.
    """
    ignored_pixels = 1
    while(pred[0, -ignored_pixels:, -ignored_pixels:,].std() == 0):
        ignored_pixels+=1
    ignored_pixels-=1
    print(f'In {pred.shape}, last {ignored_pixels} many rows and columns are all zero.')
    return ignored_pixels

actual_ignored_pixels = get_ignored_pixels()
print(f'Actual ignored pixels: {actual_ignored_pixels}')

In [None]:
if data_config["data_type"] in [
    DataType.OptiMEM100_014,
    DataType.SemiSupBloodVesselsEMBL,
    DataType.Pavia2VanillaSplitting,
    DataType.ExpansionMicroscopyMitoTub,
    DataType.ShroffMitoEr,
    DataType.HTIba1Ki67
]:
    ignored_last_pixels = 32
elif data_config["data_type"] == DataType.BioSR_MRC:
    ignored_last_pixels = 44
elif data_config["data_type"] == DataType.NicolaData:
    ignored_last_pixels = 8
else:
    ignored_last_pixels = 0

ignore_first_pixels = 0
# assert actual_ignored_pixels <= ignored_last_pixels, f'Set ignored_last_pixels={actual_ignored_pixels}' # TODO: check this once stitching is fixed
print(ignored_last_pixels)

In [None]:
tar = val_dset._data
"""Data used to do evaluation againts. Shape is (N, H, W, C).

NOTE: this is the original data (`dset._data`), hence not normalized!
"""

if DEBUG:
    if 'target_idx_list' in data_config and data_config.target_idx_list is not None:
        tar = tar[..., data_config.target_idx_list]

def ignore_pixels(
    arr: Union[np.ndarray, torch.Tensor],
    patch_size: int
) -> Union[np.ndarray, torch.Tensor]:
    """Remove pixels that are ignored in the predictions."""
    if arr.shape[2] % patch_size:
        if ignore_first_pixels:
            arr = arr[:,ignore_first_pixels:,ignore_first_pixels:]
        if ignored_last_pixels:
            arr = arr[:,:-ignored_last_pixels,:-ignored_last_pixels]
    return arr

pred = ignore_pixels(pred, val_dset.get_img_sz())
tar = ignore_pixels(tar, val_dset.get_img_sz())
if pred_std is not None:
    pred_std = ignore_pixels(pred_std, val_dset.get_img_sz())

print(pred.shape)

#### Visually compare Targets and Predictions

In [None]:
# One random target vs predicted image (patch of shape [sz x sz])
ncols = tar.shape[-1]
_,ax = plt.subplots(figsize=(ncols*5, 2*5), nrows=2, ncols=ncols)
img_idx = 0
sz = 800
hs = np.random.randint(tar.shape[1] - sz)
ws = np.random.randint(tar.shape[2] - sz)
for i in range(ncols):
    ax[i,0].set_title(f'Target Channel {i+1}')
    ax[i,0].imshow(tar[0, hs:hs+sz, ws:ws+sz, i])
    ax[i,1].set_title(f'Predicted Channel {i+1}')
    ax[i,1].imshow(pred[0, hs:hs+sz, ws:ws+sz, i])

# plt.subplots_adjust(wspace=0.1, hspace=0.1)
# clean_ax(ax)

In [None]:
nrows = pred.shape[-1]
img_sz = 3
_,ax = plt.subplots(figsize=(4*img_sz,nrows*img_sz), ncols=4, nrows=nrows)
idx = np.random.randint(len(pred))
print(idx)
for ch_id in range(nrows):
    ax[ch_id,0].set_title(f'Target Channel {ch_id+1}')
    ax[ch_id,0].imshow(tar_normalized[idx,..., ch_id], cmap='magma')
    ax[ch_id,1].set_title(f'Predicted Channel {ch_id+1}')
    ax[ch_id,1].imshow(pred[idx,:,:,ch_id], cmap='magma')
    plot_error(
        tar_normalized[idx,...,ch_id],
        pred[idx,:,:,ch_id],
        cmap = matplotlib.cm.coolwarm,
        ax = ax[ch_id,2],
        max_val = None
    )

    cropsz = 256
    h_s = np.random.randint(0, tar_normalized.shape[1] - cropsz)
    h_e = h_s + cropsz
    w_s = np.random.randint(0, tar_normalized.shape[2] - cropsz)
    w_e = w_s + cropsz

    plot_error(
        tar_normalized[idx,h_s:h_e,w_s:w_e, ch_id],
        pred[idx,h_s:h_e,w_s:w_e,ch_id],
        cmap = matplotlib.cm.coolwarm,
        ax = ax[ch_id,3],
        max_val = None
    )

    # Add rectangle to the region
    rect = patches.Rectangle((w_s, h_s), w_e-w_s, h_e-h_s, linewidth=1, edgecolor='r', facecolor='none')
    ax[ch_id,2].add_patch(rect)


#### Compute metrics between predicted data and high-SNR (ground truth) data

Prepare data:

In [None]:
pred_unnorm = []
for i in range(pred.shape[-1]):
    if sep_std.shape[-1] == 1:
        temp_pred_unnorm = pred[...,i] * sep_std[...,0] + sep_mean[...,0]
    else:
        temp_pred_unnorm = pred[...,i] * sep_std[...,i] + sep_mean[...,i]
    pred_unnorm.append(temp_pred_unnorm)

In [None]:
# Get & process high-SNR data from previously loaded dataset
highres_data = highsnr_val_dset._data
if highres_data is not None:
    highres_data = ignore_pixels(highres_data, highsnr_val_dset.get_img_sz()).copy()
    if data_t_list is not None:
        highres_data = highres_data[data_t_list].copy()

    if "target_idx_list" in data_config and data_config["target_idx_list"] is not None:
        highres_data = highres_data[..., data_config["target_idx_list"]]

Compute metrics:

In [None]:
def avg_range_inv_psnr(
    pred: np.ndarray,
    target: np.ndarray,
) -> float:
    """Compute the average range-invariant PSNR."""
    psnr_arr = []
    for i in range(pred.shape[0]):
        psnr_arr.append(scale_invariant_psnr(pred[i], target[i]))
    return np.mean(psnr_arr)

In [None]:
if highres_data is not None:
    print(f'{DataSplitType.name(eval_datasplit_type)}_P{eval_patch_size}_G{image_size_for_grid_centers}_M{mmse_count}_Sk{ignored_last_pixels}')
    psnr_list = [avg_range_inv_psnr(highres_data[...,k], pred_unnorm[k]) for k in range(len(pred_unnorm))]
    highres_norm = (highres_data - sep_mean) / sep_std
    # care_ssim_list = multiscale_ssim(highres_norm, pred)
    print(f"PSNR on Highres: {' '.join([str(x) for x in psnr_list])}, avg: {np.mean(psnr_list)}")
    # print(f"CARE-SSIM on Highres: {' '.join([str(np.round(x,3)) for x in care_ssim_list])}, avg: {np.mean(care_ssim_list)}")

In [None]:
rmse_arr = []
psnr_arr = []
rinv_psnr_arr = []
ssim_arr = []
for ch_id in range(pred.shape[-1]):
    rmse =np.sqrt(((pred[...,ch_id] - tar_normalized[...,ch_id])**2).reshape(len(pred),-1).mean(axis=1))
    rmse_arr.append(rmse)
    psnr = avg_psnr(tar_normalized[...,ch_id].copy(), pred[...,ch_id].copy())
    rinv_psnr = avg_range_inv_psnr(tar_normalized[...,ch_id].copy(), pred[...,ch_id].copy())
    ssim_mean, ssim_std = avg_ssim(tar[...,ch_id], pred_unnorm[ch_id])
    psnr_arr.append(psnr)
    rinv_psnr_arr.append(rinv_psnr)
    ssim_arr.append((ssim_mean,ssim_std))

In [None]:
print(f'{DataSplitType.name(eval_datasplit_type)}_P{eval_patch_size}_G{image_size_for_grid_centers}_M{mmse_count}_Sk{ignored_last_pixels}')
print('Rec Loss: ', np.round(rec_loss.mean(),3) )
print('RMSE: ', ' <--> '.join([str(np.mean(x).round(3)) for x in rmse_arr]))
print('PSNR: ', ' <--> '.join([str(x) for x in psnr_arr]))
print('RangeInvPSNR: ',' <--> '.join([str(x) for x in rinv_psnr_arr]))
print('SSIM: ',' <--> '.join([f'{round(x,3)}±{round(y,4)}' for (x,y) in ssim_arr]))
print()