In [1]:
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from argparse import ArgumentParser
from pathlib import Path
import os

from src import transforms as T

from src.unet.unet import Unet
from src.unet.unet_module import UnetModule
from src.mri_module import MriModule
from src.subsample import create_mask_for_mask_type, RandomMaskFunc
from src.mri_data import CombinedSliceDataset, AnnotatedSliceDataset

In [2]:
def load_model_from_checkpoint(checkpoint_path, hparams_file=None):
    print(f'Loading model from checkpoint: {checkpoint_path}')
    model = UnetModule.load_from_checkpoint(checkpoint_path, hparams_file=hparams_file)
    model.eval()
    print(f'Model loaded from checkpoint: {checkpoint_path}')
    return model


def evaluate_model(model, dataloader):
    from tqdm import tqdm
    metric = 0
    print('Evaluating model...')
    for batch in tqdm(dataloader):
        with torch.no_grad():
            output = model.validation_step_comparison(batch, batch_idx=0)
            metric += output['val_loss'].item()
    metric /= len(dataloader)
    return metric


configs = {
    'data_dir': 'data/singlecoil_val',
    'checkpoints': [
        'logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt',
        'logs/unet/unet_l1/checkpoints/epoch=9-step=347420.ckpt',
    ],
    'batch_size': 1,
    'num_workers': 4,
    'challenge':"singlecoil",
    'mask_type':"random",  # "random" or "equispaced_fraction"
    'center_fractions':[0.08],  # number of center lines to use in the mask
    'accelerations':[4],  # acceleration rates to use for the mask
}


mask = create_mask_for_mask_type(
    configs['mask_type'], configs['center_fractions'], configs['accelerations']
)

val_transform = T.UnetDataTransform(configs['challenge'], mask_func=mask)

# validation_dataset = CombinedSliceDataset(
#     roots=[Path(configs['data_dir'])],
#     challenges=['singlecoil'],
#     transforms=[val_transform]
# )

validation_dataset = AnnotatedSliceDataset(
    root=Path(configs['data_dir']),
    transform=val_transform,
    challenge=configs['challenge'],
    use_dataset_cache=False,
    raw_sample_filter=None,
    subsplit='knee',
    multiple_annotation_policy='all',
)

validation_loader = DataLoader(
    validation_dataset,
    batch_size=configs['batch_size'],
    num_workers=configs['num_workers'],
    shuffle=False
)

In [3]:
results = {}
for checkpoint_path in configs['checkpoints']:
    model = load_model_from_checkpoint(checkpoint_path)
    val_metric = evaluate_model(model, validation_loader)
    results[checkpoint_path] = val_metric
    print(f"Model: {checkpoint_path}, Validation Metric: {val_metric}")

  return torch.load(f, map_location=map_location)


Loading model from checkpoint: logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt
Model loaded from checkpoint: logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt
Evaluating model...


  1%|          | 66/7135 [00:18<20:23,  5.78it/s]  

In [None]:
print("\nSummary of all models:")
for checkpoint_path, val_metric in results.items():
    print(f"Model: {checkpoint_path}, Validation Metric: {val_metric}")