In [None]:
import sys
import argparse
import numpy as np
from torch.utils.data.dataloader import DataLoader

sys.path.insert(1, "/labs/gevaertlab/users/yyhhli/code/vae/")


def reconstruction_eval(version, log_name, ds_name):

    from datasets import PATCH_DATASETS
    from datasets.utils import sitk2tensor
    from evaluations.evaluator import MetricEvaluator
    ds_dict = {"stf": "StanfordRadiogenomicsPatchDataset", 
                "lidc": "LIDCPatchAugDataset",
                "lndb": "LNDbPatch32Dataset"}

    patch_ds = PATCH_DATASETS[ds_dict[ds_name]](root_dir=None,
                                                       transform=sitk2tensor,
                                                       split='val')
    print(f"length of {ds_dict[ds_name]} dataset",
          len(patch_ds))
    lndb_dl = DataLoader(dataset=patch_ds,
                         batch_size=36,
                         shuffle=False,
                         drop_last=False,
                         num_workers=4,
                         pin_memory=True)
    me = MetricEvaluator(metrics=['SSIM', 'MSE', 'PSNR'],
                         log_name=log_name,
                         version=version,
                         base_model_name='VAE3D')
    metrics_dict = me.calc_metrics(dataloader=lndb_dl)
    result_dict = {}
    for k, v in metrics_dict.items():
        print(f"{k}: mean value = {np.mean(v)}")
        result_dict[k] = np.mean(v)
    return metrics_dict