# MRI NIH : Data loading

The input data is supposed to be sharing the same orientation, resolution and matrix size (i.e. sharing a common header for the whole dataset).

## Imports

In [55]:
import torch
import random
import json
import math
import numbers
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as F
from msct_image import Image as msct_Image
from PIL import Image as PIL_Image
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates

## Hyperparameters

In [102]:
parameters = json.load(open('/Users/frpau_local/Documents/nih/data/luisa_with_gt/parameters.json'))
print json.dumps(parameters, indent=4, sort_keys=True)

{
    "alpha_range": [
        4, 
        6
    ], 
    "crop_size": [
        200, 
        150
    ], 
    "elastic_rate": 0.5, 
    "flip_rate": 0.5, 
    "max_angle": 20, 
    "ratio_range": [
        0.75, 
        1.25
    ], 
    "scale_range": [
        0.5, 
        1
    ], 
    "sigma_range": [
        10, 
        30
    ]
}


## Transforms

In [136]:
class ElasticTransform(object):
    def __init__(self, alpha_range, sigma_range, p=0.5):
        self.alpha_range = alpha_range
        self.sigma_range = sigma_range
        self.p = p
    
    @staticmethod
    def get_params(alpha_range, sigma_range):
        alpha = np.random.uniform(alpha_range[0], alpha_range[1])
        sigma = np.random.uniform(sigma_range[0], sigma_range[1])
        return alpha, sigma

    @staticmethod
    def elastic_transform(image, alpha, sigma):
        shape = image.shape
        dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
        dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha

        x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
        indices = np.reshape(x+dx, (-1, 1)), np.reshape(y+dy, (-1, 1))
        return map_coordinates(image, indices, order=1).reshape(shape)

    def __call__(self, sample):
        if np.random.random() < self.p:
            param_alpha, param_sigma = self.get_params(self.alpha_range, self.sigma_range)
            
            input_data = np.array(sample['input'])
            input_data = self.elastic_transform(input_data, param_alpha, param_sigma)
            input_data = PIL_Image.fromarray(input_data, mode='F')
            
            gt_data = sample['gt']
            for i in range(len(gt_data)):
                gt = np.array(gt_data[i])
                gt = self.elastic_transform(gt, param_alpha, param_sigma)
                gt[gt >= 0.5] = 1.0
                gt[gt < 0.5] = 0.0
                gt_data[i] = PIL_Image.fromarray(gt, mode='F')

                
            sample['input'] = input_data
            sample['gt'] = gt_data
            
            for i in range(len(gt_data)) :
                gt_data[i].save("gt_"+str(i)+"_elastic.tiff")
        
        return sample

class ToPIL(object):
    def __call__(self, sample):
        sample['input'] = PIL_Image.fromarray(np.array(sample['input']), mode='F')
        sample['gt'] = [PIL_Image.fromarray(np.array(gt), mode='F') for gt in sample['gt']]
        for i in range(len(sample['gt'])) :
                sample['gt'][i].save("gt_"+str(i)+"_original.tiff")
        return sample
    
class ToTensor(object):
    def __call__(self, sample):
        sample['input'] = torch.Tensor(np.array(sample['input']))
        sample['gt'] = [torch.Tensor(np.array(gt)) for gt in sample['gt']]
        return sample

class RandomRotation(object):
    def __init__(self, degrees, resample=False, expand=False, center=None):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
            self.degrees = (-degrees, degrees)
        else:
            if len(degrees) != 2:
                raise ValueError("If degrees is a sequence, it must be of len 2.")
            self.degrees = degrees

        self.resample = resample
        self.expand = expand
        self.center = center

    @staticmethod
    def get_params(degrees):
        angle = np.random.uniform(degrees[0], degrees[1])
        return angle

    def __call__(self, sample):
        angle = self.get_params(self.degrees)
        rdict = {}
        
        input_data = sample['input']
        input_data = F.rotate(input_data, angle, self.resample, self.expand, self.center)
        rdict['input'] = input_data
        
        gt_data = sample['gt']
        gt_data = [F.rotate(gt, angle, self.resample, self.expand, self.center) for gt in gt_data]
        rdict['gt'] = gt_data
        
        for i in range(len(gt_data)) :
            gt_data[i].save("gt_"+str(i)+"_rotation.tiff")
            
        return rdict
    

class RandomResizedCrop(object):
    """Crop the given PIL Image to random size and aspect ratio.
    A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size.
    Args:
        size: expected output size of each edge
        scale: range of size of the origin size cropped
        ratio: range of aspect ratio of the origin aspect ratio cropped
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=PIL_Image.BILINEAR):
        self.size = (size[0], size[1])
        self.interpolation = interpolation
        self.scale = scale
        self.ratio = ratio

    @staticmethod
    def get_params(img, scale, ratio):
        """Get parameters for ``crop`` for a random sized crop.
        Args:
            img (PIL Image): Image to be cropped.
            scale (tuple): range of size of the origin size cropped
            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
                sized crop.
        """
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(*scale) * area
            aspect_ratio = random.uniform(*ratio)

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w

        # Fallback
        w = min(img.size[0], img.size[1])
        i = (img.size[1] - w) // 2
        j = (img.size[0] - w) // 2
        return i, j, w, w

    def __call__(self, sample):
        i, j, h, w = self.get_params(sample['input'], self.scale, self.ratio)
        rdict = {}
        
        input_data = F.resized_crop(sample['input'], i, j, h, w, self.size, self.interpolation)
        
        gt_data = [F.resized_crop(gt, i, j, h, w, self.size, self.interpolation) for gt in sample['gt']]
        for i in range(len(gt_data)):
            gt = np.array(gt_data[i])
            gt[gt >= 0.5] = 1.0
            gt[gt < 0.5] = 0.0
            gt_data[i] = PIL_Image.fromarray(gt, mode='F')

        rdict['input'] = input_data
        rdict['gt'] = gt_data

        for i in range(len(gt_data)) :
            gt_data[i].save("gt_"+str(i)+"_resize.tiff")
        
        return rdict
    
    
class RandomVerticalFlip(object):
    """Vertically flip the given PIL Image randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        if random.random() < self.p:
            sample['input'] = F.vflip(sample['input'])
            sample['gt'] = [F.vflip(gt) for gt in sample['gt']]
        return sample

        

In [137]:
toTensor = ToTensor()
toPIL = ToPIL()
randomVFlip = RandomVerticalFlip()
randomResizedCrop = RandomResizedCrop(parameters["crop_size"], scale=parameters["scale_range"], ratio=parameters["ratio_range"])
randomRotation = RandomRotation(parameters["max_angle"])
elasticTransform = ElasticTransform(parameters["alpha_range"], parameters["sigma_range"], parameters["elastic_rate"])

composed = transforms.Compose([toPIL,randomVFlip,randomRotation,randomResizedCrop, elasticTransform, toTensor])

## Dataset

In [10]:
class MRI2DSegDataset(Dataset):
    """This is a generic class for 2D (slice-wise) segmentation datasets.
    
    :param txt_path_file: the path to a txt file containing the list of paths to input data files and gt masks.
    :param slice_axis: axis to make the slicing (default axial).
    :param cache: if the data should be cached in memory or not.
    :param transform: transformations to apply.
    """
    def __init__(self, txt_path_file, slice_axis=2, cache=True, transform=None):
        self.filenames = []
        self.header = {}
        self.class_names = []
        self.read_filenames(txt_path_file)
        self.transform = transform
        self.cache = cache
        self.slice_axis = slice_axis
        self.handlers = []
        
        self._load_files()
    
    def __len__(self):
        return len(self.handlers)
    
    def __getitem__(self, index):
        sample = self.handlers[index]
        data_dict = {
            'input': sample[0],
            'gt': [sample[i] for i in range(1, len(sample))]
        }
        
        if self.transform:
            data_dict = self.transform(data_dict)## TODO : apply same tranform to input and gt
            
        return data_dict
        
    
    def _load_files(self):
        for input_filename, gt_dict in self.filenames:
            input_3D = msct_Image(input_filename)
            if self.slice_axis == 0:
                resolution = list(np.around(input_3D.dim[5:7], 2))
                matrix_size = input_3D.dim[1:3]
            elif self.slice_axis == 1:
                resolution = list(np.around([input_3D.dim[4], input_3D.dim[6]], 2))
                matrix_size = (input_3D.dim[0], input_3D.dim[2])
            else:
                if self.slice_axis != 2:
                    print "Invalid slice axis given, replaced by default value of 2."
                    self.slice_axis = 2
                resolution = list(np.around(input_3D.dim[4:6], 2))
                matrix_size = input_3D.dim[0:2]
                
            input_header = {"orientation":input_3D.orientation, "resolution":resolution, "matrix_size":matrix_size}
            
            gt_3D = []
            gt_class_names = sorted(gt_dict.keys())
            for gt_class in gt_class_names:
                gt_3D.append(msct_Image(gt_dict[gt_class]))
                  
            if not self.header:
                self.header = input_header
            #sanity check for consistent header
            elif self.header != input_header :
                print self.header
                print input_header
                raise RuntimeError('Inconsistent header in input files.')
                
            if not self.class_names:
                self.class_names = gt_class_names 
            #sanity check for consistent gt classes
            elif self.class_names != gt_class_names:
                raise RuntimeError('Inconsistent classes in gt files.')
                
            for i in range(input_3D.dim[2]):
                if self.slice_axis == 0:
                    input_slice = input_3D.data[i,::,::]
                    gt_slices = [gt.data[i,::,::] for gt in gt_3D]
                elif self.slice_axis == 1:
                    input_slice = input_3D.data[::,i,::]
                    gt_slices = [gt.data[::,i,::] for gt in gt_3D]
                else:
                    input_slice = input_3D.data[::,::,i]
                    gt_slices = [gt.data[::,::,i] for gt in gt_3D]
                seg_item = [input_slice]
                for gt_slice in gt_slices:
                    if gt_slice.shape != input_slice.shape:
                        print "input dimensions : {}".format(input_slice.shape)
                        print "gt dimensions : {}".format(gt_slice.shape)
                        raise RuntimeError('Input and ground truth with different dimensions.')
                    seg_item.append(gt_slice)
                self.handlers.append(np.array(seg_item))
                
    
    def read_filenames(self, txt_path_file):
        for line in open(txt_path_file, 'r'):
            if "input" in line:
                fnames=[None, {}]
                line = line.split()
                if len(line)%2:
                    raise RuntimeError('Error in filenames txt file parsing.')
                for i in range(len(line)/2):
                    try:
                        msct_Image(line[2*i+1])
                    except Exception:
                        raise RuntimeError("Invalid path in filenames txt file.")
                    if(line[2*i]=="input"):
                        fnames[0]=line[2*i+1]
                    else:
                        fnames[1][line[2*i]]=line[2*i+1]
                self.filenames.append((fnames[0], fnames[1]))
        
                

In [138]:
test = MRI2DSegDataset("/Users/frpau_local/Documents/nih/data/luisa_with_gt/filenames_csf_gm_nawm.txt", transform = composed)

In [140]:
test[5]

{'gt': [
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
         ...          ⋱          ...       
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
  [torch.FloatTensor of size 200x150], 
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
         ...          ⋱          ...       
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
  [torch.FloatTensor of size 200x150], 
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
         ...          ⋱          ...       
      0     0     0  ...      0     0     0
      0     0     0  ...      0     0     0
      0     0     0  ...      0