# Torchio wrapper
    - https://github.com/fepegar/torchio/issues/221

In [22]:
import torchio

In [9]:
from torchio.transforms import (
    RandomFlip,
    RandomElasticDeformation,
    RandomMotion,
    RandomGhosting,
    RandomSpike,
    RandomBiasField,
    RandomBlur,
    RandomNoise
)

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Link: https://arxiv.org/abs/2003.04696



In [11]:
class EisenTransformWrapper:
    
    def __init__(self, transform, fields, label):
        super(EisenTransformWrapper, self).__init__()
        from torchio import Subject, ScalarImage, LabelMap
        self.fields = fields
        self.transform = transform
        self.label = label

    def __call__(self, data):
        
        subject = Subject(
        ct=ScalarImage(tensor=data['ct']),  # this class is new
        t1=ScalarImage(tensor=data['t1']),
        t2=ScalarImage(tensor=data['t2']),
        label=LabelMap(tensor=data[self.label]),
        )
        
        transformed = self.transform(subject)
        
        data['ct'] = transformed['ct'].numpy()
        data['t1'] = transformed['t1'].numpy()
        data['t2'] = transformed['t2'].numpy()
        data[self.label] = transformed[self.label].numpy()
        
        return data

In [None]:
from torchio import Subject, ScalarImage, RandomAffine
subject = Subject(
    t1=ScalarImage(tensor=image_t1),  # this class is new
    t1c=ScalarImage(tensor=image_t1c),
    t2=ScalarImage(tensor=image_t2),
    flair=ScalarImage(tensor=image_flair),
)
transform = RandomAffine()
transformed = transform(subject)
image_t1_transformed = transformed.t1.numpy()  # or transformed['t1'].numpy()

In [59]:
sample = input_sample['ct']
label = input_sample['label_task1']
type(label), type(sample)

(numpy.ndarray, numpy.ndarray)

In [60]:
sample = sample[None, :, :, :]
label = label
dict_only = {'ct':sample, 'label_task1':label}

sample.shape, label.shape

((1, 128, 128, 128), (5, 128, 128, 128))

In [61]:
rFlip = RandomFlip(axes=(0, 1, 2), keys=list(dict_only.keys()))

In [62]:
out = rFlip(dict_only)
out.keys()

dict_keys(['ct', 'label_task1'])

In [63]:
transformed_x = out['ct']
transformed_y = out['label_task1']

transformed_x.shape, transformed_y.shape

(torch.Size([1, 128, 128, 128]), torch.Size([5, 128, 128, 128]))