In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys, os, torch
import torch.nn as nn
import torch.optim as optim
import itertools
from collections import defaultdict 

from tqdm import tqdm
from ista_unet import *
from ista_unet.models import ista_unet
from ista_unet.evaluate import *
from ista_unet.load_fastmri_dataset import get_dataloaders_fastmri
from ista_unet.utils import crop_center_2d
from ista_unet import model_save_dir, dataset_dir
from pathlib import Path

from dival.util.plot import plot_images
import dival
import torch.multiprocessing
from fastmri import save_reconstructions


import h5py
from fastmri.data import transforms
from runstats import Statistics
from skimage.metrics import structural_similarity, peak_signal_noise_ratio


torch.multiprocessing.set_sharing_strategy('file_system')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
guid = '53df389a-8054-402b-83f7-e230e8a22670'
    
model, config_dict = load_ista_unet_model(guid = guid, 
                             dataset = 'fastmri', 
                             return_config_dict = True)

model.to(device);

In [3]:
loaders_bs1 = get_dataloaders_fastmri(batch_size= 1, include_test = True)


In [4]:
def save_prediction_fastmri(model, loader, phase):
    outputs = defaultdict(list)
    with torch.no_grad():
        with tqdm(loader[phase]) as pbar:
            for obs, gt,  mean, std, fname, slice_num, max_value  in pbar:
                fname = fname[0]
                obs = obs.unsqueeze(1)
                gt = gt.unsqueeze(1)

                mean = mean.unsqueeze(1).unsqueeze(2)
                std = std.unsqueeze(1).unsqueeze(2)

                reco = model(obs.to(device)).cpu().clamp(-6, 6)

                # undo the instance-normalized the output and target
                trans_obs = (obs * std + mean).detach().numpy().squeeze()
                trans_reco = (reco * std + mean).detach().numpy().squeeze()
                trans_target = (gt * std + mean).detach().numpy().squeeze()

                # collect slices into the volume it belongs to
                outputs[fname].append((slice_num.numpy(), trans_reco ))

    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])
    
    save_to_path = Path(config_dict['saved_path'] ) / (phase + '_reconstructions') 
    save_reconstructions(outputs, save_to_path )    
    
    print('saved to ', str(save_to_path) ) 
    return str(save_to_path)

def nmse(gt, pred):
    """ Compute Normalized Mean Squared Error (NMSE) """
    return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2


def psnr(gt, pred):
    """ Compute Peak Signal to Noise Ratio metric (PSNR) """
    return peak_signal_noise_ratio(gt, pred, data_range=gt.max())


def ssim(gt, pred):
    """ Compute Structural Similarity Index Metric (SSIM). """
    return structural_similarity(
        gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max()
    )


METRIC_FUNCS = dict(
    NMSE=nmse,
    PSNR=psnr,
    SSIM=ssim)

class Metrics:
    """
    Maintains running statistics for a given collection of metrics.
    """

    def __init__(self, metric_funcs):
        """
        Args:
            metric_funcs (dict): A dict where the keys are metric names and the
                values are Python functions for evaluating that metric.
        """
        self.metrics = {metric: Statistics() for metric in metric_funcs}

    def push(self, target, recons):
        for metric, func in METRIC_FUNCS.items():
            self.metrics[metric].push(func(target, recons))

    def means(self):
        return {metric: stat.mean() for metric, stat in self.metrics.items()}

    def stddevs(self):
        return {metric: stat.stddev() for metric, stat in self.metrics.items()}

    def __repr__(self):
        means = self.means()
        stddevs = self.stddevs()
        metric_names = sorted(list(means))
        return " ".join(
            f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}"
            for name in metric_names )
    
def evaluate_saved_fastmri(target_path, predictions_path, acquisition, challenge = 'singlecoil', acceleration = None):
    
    target_path = Path( target_path )
    predictions_path = Path( predictions_path )
    
    recons_key = 'reconstruction_rss' if challenge == 'multicoil' else 'reconstruction_esc'

    METRIC_FUNCS = dict(
    NMSE=nmse,
    PSNR=psnr, 
    SSIM=ssim)

    metrics = Metrics(METRIC_FUNCS)
    
    for tgt_file in target_path.iterdir():
        with h5py.File(tgt_file, 'r') as target, h5py.File(
          predictions_path / tgt_file.name, 'r') as recons:
            if acquisition and acquisition != target.attrs['acquisition']:
                continue

            if acceleration and target.attrs['acceleration'] != acceleration:
                continue

            target = target[recons_key][()]
            recons = recons['reconstruction'][()]
            target = transforms.center_crop(target, (target.shape[-1], target.shape[-1]))
            recons = transforms.center_crop(recons, (target.shape[-1], target.shape[-1]))
            metrics.push(target, recons)
    return metrics

In [5]:
saved_to_path = save_prediction_fastmri(model, loader = loaders_bs1, phase = 'validation')

100%|██████████| 7135/7135 [09:41<00:00, 12.27it/s]


saved to  /home/liu0003/Desktop/projects/ista_unet/saved_model/ista/fastmri/53df389a-8054-402b-83f7-e230e8a22670/validation_reconstructions


In [6]:
saved_to_path = save_prediction_fastmri(model, loader = loaders_bs1, phase = 'test')

100%|██████████| 3903/3903 [05:21<00:00, 12.13it/s]


saved to  /home/liu0003/Desktop/projects/ista_unet/saved_model/ista/fastmri/53df389a-8054-402b-83f7-e230e8a22670/test_reconstructions


In [7]:
evaluate_saved_fastmri(target_path = os.path.join(dataset_dir, 'fastmri/knee/singlecoil_val'), 
         predictions_path = os.path.join( config_dict['saved_path'], 'validation_reconstructions' ), 
         acquisition = 'CORPD_FBK')

NMSE = 0.01572 +/- 0.01537 PSNR = 33.91 +/- 5.267 SSIM = 0.8115 +/- 0.1685

In [10]:
evaluate_saved_fastmri(target_path = os.path.join(dataset_dir, 'fastmri/knee/singlecoil_val'), 
         predictions_path = os.path.join( config_dict['saved_path'], 'validation_reconstructions' ), 
         acquisition = 'CORPDFS_FBK')

NMSE = 0.05269 +/- 0.04528 PSNR = 29.94 +/- 5.395 SSIM = 0.6321 +/- 0.2127