In [1]:
class RunDataTransform:
    """
    Data Transformer for running U-Net models on a test dataset.
    """

    def __init__(self, resolution, which_challenge, mask_func=None):
        """
        Args:
            resolution (int): Resolution of the image.
            which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset.
            mask_func (common.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
        """
        if which_challenge not in ('singlecoil', 'multicoil'):
            raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"')
        self.resolution = resolution
        self.which_challenge = which_challenge
        self.mask_func = mask_func

    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        if self.mask_func is not None:
            seed = tuple(map(ord, fname))
            masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image to given resolution if larger
        smallest_width = min(self.resolution, image.shape[-2])
        smallest_height = min(self.resolution, image.shape[-3])
        image = transforms.complex_center_crop(image, (smallest_height, smallest_width))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image)
        image = image.clamp(-6, 6)
        return image, mean, std, fname, slice

In [None]:
from models.unet.unet_model import UnetModel
from collections import defaultdict
from data.mri_data import SliceData
from torch.utils.data import DataLoader



def create_inference_data_loaders(data_path, batch_size=1, resolution=320):
    data = SliceData(
        root=data_path / 'singlecoil_test_v2',
        transform=RunDataTransform(resolution, 'singlecoil', None),
        sample_rate=1.,
        challenge='singlecoil'
    )
    data_loader = DataLoader(
        dataset=data,
        batch_size=batch_size,
        num_workers=4,
        pin_memory=True,
    )
    return data_loader


def load_model(checkpoint_file, device, num_channels=32, num_pools=4, dropout_prob=0.0):
    checkpoint = torch.load(checkpoint_file)
    model = UnetModel(1, 1, num_channels, num_pools, dropout_prob).to(device)
    model.load_state_dict(checkpoint['model'])
    return model


def run_unet(model, device, data_loader):
    model.eval()
    reconstructions = defaultdict(list)
    with torch.no_grad():
        for (input, mean, std, fnames, slices) in data_loader:
            input = input.unsqueeze(1).to(device)
            recons = model(input).to('cpu').squeeze(1)
            for i in range(recons.shape[0]):
                recons[i] = recons[i] * std[i] + mean[i]
                reconstructions[fnames[i]].append((slices[i].numpy(), recons[i].numpy()))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }
    return reconstructions

In [9]:
import torch
from common.utils import save_reconstructions
from pathlib import Path



def infer():
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')        
        
    data_loader = create_inference_data_loaders(Path('../data'))
    model = load_model('unet/best_model.pt', device)
    reconstructions = run_unet(model, device, data_loader)
    save_reconstructions(reconstructions, Path('unet/inference'))

    
infer()

NameError: name 'UnetModel' is not defined