# MRI NIH : Data loading

The input data is supposed to be sharing the same orientation, resolution and matrix size (i.e. sharing a common header for the whole dataset).

## Imports

In [23]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from msct_image import Image as msct_Image
from PIL import Image as PIL_Image

In [62]:
class MRI2DSegmentationDataset(Dataset):
    """This is a generic class for 2D (slice-wise) segmentation datasets.
    
    :param txt_path_file: the path to a txt file containing the list of paths to input data files and gt masks.
    :param slice_axis: axis to make the slicing (default axial).
    :param cache: if the data should be cached in memory or not.
    :param transform: transformations to apply.
    """
    def __init__(self, txt_path_file, slice_axis=2, cache=True, transform=None):
        self.filenames = []
        self.header = {}
        self.class_names = []
        self.read_filenames(txt_path_file)
        self.transform = transform
        self.cache = cache
        self.slice_axis = slice_axis
        self.handlers = []
        
        self._load_files()
    
    def __len__(self):
        return len(self.handlers)
    
    def __getitem__(self, index):
        sample = self.handlers[index]
        sample = [PIL_Image.fromarray(sample[i], mode='F') for i in range(sample.shape[0])]
        if self.transform:
            sample = [self.transform(sample_item) for sample_item in sample]
        data_dict = {
            'input': sample[0],
            'gt': [sample[i] for i in range(1, len(sample))]
        }
        return data_dict
        
    
    def _load_files(self):
        for input_filename, gt_dict in self.filenames:
            input_3D = msct_Image(input_filename)
            input_header = {"orientation":input_3D.orientation, "resolution":list(np.around(input_3D.dim[4:6], 2)), "matrix_size":input_3D.dim[0:2]}
            
            gt_3D = []
            gt_class_names = sorted(gt_dict.keys())
            for gt_class in gt_class_names:
                gt_3D.append(msct_Image(gt_dict[gt_class]))
                  
            if not self.header:
                self.header = input_header
            #sanity check for consistent header
            elif self.header != input_header :
                raise RuntimeError('Inconsistent header in input files.')
                
            if not self.class_names:
                self.class_names = gt_class_names 
            #sanity check for consistent gt classes
            elif self.class_names != gt_class_names:
                raise RuntimeError('Inconsistent classes in gt files.')
                
            for i in range(input_3D.dim[2]):                
                input_slice = input_3D.data[::,::,i]
                gt_slices = [gt.data[::,::,i] for gt in gt_3D]
                seg_item = [input_slice]
                for gt_slice in gt_slices:
                    if gt_slice.shape != input_slice.shape:
                        raise RuntimeError('Input and ground truth with different dimensions.')
                    seg_item.append(gt_slice)
                self.handlers.append(np.array(seg_item))
                
    
    def read_filenames(self, txt_path_file):
        for line in open(txt_path_file, 'r'):
            if "input" in line:
                fnames=[None, {}]
                line = line.split()
                if len(line)%2:
                    raise RuntimeError('Error in filenames txt file parsing.')
                for i in range(len(line)/2):
                    try:
                        msct_Image(line[2*i+1])
                    except Exception:
                        print "Invalid path in filenames txt file."
                    if(line[2*i]=="input"):
                        fnames[0]=line[2*i+1]
                    else:
                        fnames[1][line[2*i]]=line[2*i+1]
                self.filenames.append((fnames[0], fnames[1]))
        
                