In [1]:
import torch
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from argparse import ArgumentParser
from pathlib import Path
import os

from src import transforms as T

from src.unet.unet import Unet
from src.unet.unet_module import UnetModule
from src.mri_module import MriModule
from src.subsample import create_mask_for_mask_type, RandomMaskFunc
from src.mri_data import CombinedSliceDataset, AnnotatedSliceDataset

In [4]:
def load_model_from_checkpoint(checkpoint_path, hparams_file=None):
    print(f'Loading model from checkpoint: {checkpoint_path}')
    model = UnetModule.load_from_checkpoint(checkpoint_path, hparams_file=hparams_file, map_location='mps')
    model.eval()
    print(f'Model loaded from checkpoint: {checkpoint_path}')
    return model


def evaluate_model(model, dataloader):
    from tqdm import tqdm
    metric = 0
    metrics = dict()
    roi_len = 1
    metrics['val_loss'] = 0
    metrics['image_l1_loss'] = 0
    metrics['image_ssim_loss'] = 0
    metrics['roi_l1_loss'] = 0
    metrics['roi_ssim_loss'] = 0
    
    print('Evaluating model...')
    for batch in tqdm(dataloader):
        with torch.no_grad():
            output = model.validation_step_comparison(batch, batch_idx=0)
            metrics['val_loss'] += output['val_loss'].item()
            metrics['image_l1_loss'] += output['image_l1_loss'].item()
            metrics['image_ssim_loss'] += output['image_ssim_loss'].item()
            roi_l1_loss = output['roi_l1_loss'].item()
            roi_ssim_loss = output['roi_ssim_loss'].item()
            if not (roi_l1_loss == 0 and roi_ssim_loss == 0):
                roi_len += 1              
                metrics['roi_l1_loss'] += roi_l1_loss
                metrics['roi_ssim_loss'] += roi_ssim_loss
            metric += output['val_loss']
    metric /= len(dataloader)
    metrics['val_loss'] /= len(dataloader)
    metrics['image_l1_loss'] /= len(dataloader)
    metrics['image_ssim_loss'] /= len(dataloader)
    metrics['roi_l1_loss'] /= roi_len
    metrics['roi_ssim_loss'] /= roi_len
    metrics['roi_len'] = roi_len
    return metric, metrics


configs = {
    'data_dir': 'data/singlecoil_val',
    'checkpoints': [
        'logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt',
        'logs/unet/unet_l1/checkpoints/epoch=9-step=347420.ckpt',
    ],
    'batch_size': 1,
    'num_workers': 4,
    'challenge':"singlecoil",
    'mask_type':"random",  # "random" or "equispaced_fraction"
    'center_fractions':[0.08],  # number of center lines to use in the mask
    'accelerations':[4],  # acceleration rates to use for the mask
}


mask = create_mask_for_mask_type(
    configs['mask_type'], configs['center_fractions'], configs['accelerations']
)

val_transform = T.UnetDataTransform(configs['challenge'], mask_func=mask)

# validation_dataset = CombinedSliceDataset(
#     roots=[Path(configs['data_dir'])],
#     challenges=['singlecoil'],
#     transforms=[val_transform]
# )

validation_dataset = AnnotatedSliceDataset(
    root=Path(configs['data_dir']),
    transform=val_transform,
    challenge=configs['challenge'],
    use_dataset_cache=False,
    raw_sample_filter=None,
    subsplit='knee',
    multiple_annotation_policy='all',
    
)

gen = torch.Generator().manual_seed(42)
# Use the line below to get a smaller validation set
#samples = 150
#validation_dataset = random_split(validation_dataset, [samples, len(validation_dataset) - samples], generator=gen)[0]

validation_loader = DataLoader(
    validation_dataset,
    batch_size=configs['batch_size'],
    num_workers=configs['num_workers'],
    shuffle=False,
)

In [5]:
results = {}

for checkpoint_path in configs['checkpoints']:
    model = load_model_from_checkpoint(checkpoint_path)
    val_metric, metrics = evaluate_model(model, validation_loader)
    results[checkpoint_path] = val_metric
    print(f"Model: {checkpoint_path}, Validation Metric: {val_metric}, other metrics: {metrics}")

Loading model from checkpoint: logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt


  return torch.load(f, map_location=map_location)


Model loaded from checkpoint: logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt
Evaluating model...


100%|██████████| 7135/7135 [28:38<00:00,  4.15it/s]


Model: logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt, Validation Metric: 0.0030482891015708447, other metrics: {'val_loss': 0.00304829346690883, 'image_l1_loss': 0.2920762070437036, 'image_ssim_loss': 0.6643564193668111, 'roi_l1_loss': 0.0, 'roi_ssim_loss': 0.0, 'roi_len': 1}
Loading model from checkpoint: logs/unet/unet_l1/checkpoints/epoch=9-step=347420.ckpt
Model loaded from checkpoint: logs/unet/unet_l1/checkpoints/epoch=9-step=347420.ckpt
Evaluating model...


  2%|▏         | 175/7135 [00:55<27:33,  4.21it/s] 

0.0036901235580444336

In [12]:
print("\nSummary of all models:")
for checkpoint_path, val_metric in results.items():
    print(f"Model: {checkpoint_path}, Validation Metric: {val_metric}")


Summary of all models:
Model: logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt, Validation Metric: 0.003733583688735962
Model: logs/unet/unet_l1/checkpoints/epoch=9-step=347420.ckpt, Validation Metric: 0.0036901235580444336


In [13]:
results

{'logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt': 0.003733583688735962,
 'logs/unet/unet_l1/checkpoints/epoch=9-step=347420.ckpt': 0.0036901235580444336}