# BIDS Loader


We use the data in `SpineGeneric`

In [36]:
from bids_neuropoly import bids
from medicaltorch import datasets as mt_datasets
from medicaltorch import transforms as mt_transforms


class BIDSSegPair2D(mt_datasets.SegmentationPair2D):
    def __init__(self, input_filename, gt_filename, metadata):
        super().__init__(input_filename, gt_filename)
        self.metadata = metadata
        self.metadata["input_filename"] = input_filename
        self.metadata["gt_filename"] = gt_filename

    def get_pair_slice(self, slice_index, slice_axis=2):
        dreturn = super().get_pair_slice(slice_index, slice_axis)
        self.metadata["slice_index"] = slice_index
        dreturn["input_metadata"]["bids_metadata"] = self.metadata
        return dreturn


class MRI2DBidsSegDataset(mt_datasets.MRI2DSegmentationDataset):
    def _load_filenames(self):
        for input_filename, gt_filename, bids_metadata in self.filename_pairs:
            segpair = BIDSSegPair2D(input_filename, gt_filename,
                                    bids_metadata)
            self.handlers.append(segpair)


class BidsDataset(MRI2DBidsSegDataset):
    def __init__(self, root_dir, slice_axis=2, cache=True,
                 transform=None, slice_filter_fn=None,
                 canonical=False, labeled=True):
        self.bids_ds = bids.BIDS(root_dir)
        self.filename_pairs = []
        self.metadata = {"FlipAngle": [], "RepetitionTime": [], "EchoTime": [], "Manufacturer": []}

        for subject in self.bids_ds.get_subjects():

            if not subject.has_derivative("labels"):
                print("Subject without derivative, skipping.")
                continue
            derivatives = subject.get_derivatives("labels")
            cord_label_filename = None

            for deriv in derivatives:
                if deriv.endswith("seg-manual.nii.gz"):
                    cord_label_filename = deriv

            if cord_label_filename is None:
                continue

            if not subject.has_metadata():
                print("Subject without metadata.")
                continue

            metadata = subject.metadata()
            if "FlipAngle" not in metadata:
                print("{} without FlipAngle, skipping.".format(subject))
                continue
            else:
                self.metadata["FlipAngle"].append(float(metadata["FlipAngle"]))

            if "EchoTime" not in metadata:
                print("{} without EchoTime, skipping.".format(subject))
                continue
            else:
                self.metadata["EchoTime"].append(float(metadata["EchoTime"]))

            if "RepetitionTime" not in metadata:
                print("{} without RepetitionTime, skipping.".format(subject))
                continue
            else:
                self.metadata["RepetitionTime"].append(float(metadata["RepetitionTime"]))

            if "Manufacturer" not in metadata:
                print("{} without Manufacturer, skipping.".format(subject))
                continue
            else:
                self.metadata["Manufacturer"].append(metadata["Manufacturer"])

            self.filename_pairs.append((subject.record.absolute_path,
                                        cord_label_filename, metadata))

        super().__init__(self.filename_pairs, slice_axis, cache,
                         transform, slice_filter_fn, canonical)


# Model

Here we define the architecture of the network in a PyTorch Module

In [37]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

class DownConv(nn.Module):
    def __init__(self, in_ch, out_ch, bn_momentum=0.1):
        super(DownConv, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1)
        self.conv_bn = nn.BatchNorm2d(out_ch, momentum=bn_momentum)
    
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.conv_bn(x)
        return x

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class Classifier(nn.Module):
    def __init__(self, drop_rate=0.4, bn_momentum=0.1):
        super(Classifier, self).__init__()
        self.conv1 = DownConv(1, 32, bn_momentum)
        self.mp1 = nn.MaxPool2d(2)
        
        self.conv2 = DownConv(32, 32, bn_momentum)
        self.mp2 = nn.MaxPool2d(2)
        
        self.conv3 = DownConv(32, 64, bn_momentum)
        self.mp3 = nn.MaxPool2d(2)       
        
        self.flat = Flatten()
        self.dense1 = nn.Linear(16384, 256)
        self.drop = nn.Dropout2d(drop_rate)
        self.dense2 = nn.Linear(256, 6)
        self.soft = nn.Softmax(dim=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.mp1(x1)
        
        x3 = self.conv2(x2)
        x4 = self.mp2(x3)
        
        x5 = self.conv3(x4)
        x6 = self.mp3(x5)
        
        x7 = self.flat(x6)
        x8 = self.dense1(x7)
        x9 = self.drop(x8)
        x10 = self.dense2(x9)
        x11 = self.soft(x10)

        return(x11)

#torch tensors are of the format (batch_size, n_channels, shape_of_image)
a = torch.rand(18,1,128,128)
test = Classifier().forward(a)

# Training the model

In [75]:
from tensorboardX import SummaryWriter
import time
import shutil
import sys
import pickle
import nibabel as nib
import numpy as np
import json
import os
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
from skimage.transform import resize

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader, ConcatDataset

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"


def get_modality(batch):
    labels = []
    for acq in batch['input_metadata']:
        path = acq.__getitem__('bids_metadata')['input_filename']
        name = os.path.basename(path)
        if "acq-MToff_MTS" in name :
            labels.append(0)
            continue
        if "acq-MTon_MTS" in name :
            labels.append(1)
            continue
        if "acq-T1w_MTS" in name :
            labels.append(2)
            continue
        if "T1w" in name :
            labels.append(3)
            continue
        if "T2star" in name :
            labels.append(4)
            continue
        if "T2w" in name :
            labels.append(5) 
            continue
    return labels

def OneHotEncode(labels):
    ohe_labels = []
    for label in labels :
        ohe = [0 for i in range(6)]
        ohe[label] = 1 
        ohe_labels.append(ohe)
    return torch.cuda.FloatTensor(ohe_labels)


def cmd_train(context):
    """Main command do train the network.
    :param context: this is a dictionary with all data from the
                    configuration file:
                        - 'command': run the specified command (e.g. train, test)
                        - 'gpu': ID of the used GPU
                        - 'bids_path_train': list of relative paths of the BIDS folders of each training center
                        - 'bids_path_validation': list of relative paths of the BIDS folders of each validation center
                        - 'bids_path_test': list of relative paths of the BIDS folders of each test center
                        - 'batch_size'
                        - 'dropout_rate'
                        - 'batch_norm_momentum'
                        - 'num_epochs'
                        - 'initial_lr': initial learning rate
                        - 'log_directory': folder name where log files are saved
                        - 'debugging': allows extended verbosity and intermediate outputs
    """
    # Set the GPU
    gpu_number = int(0)
    torch.cuda.set_device(gpu_number)
    device = torch.device("cuda:0")

    # These are the training transformations
    train_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((128, 128)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # These are the validation/testing transformations
    val_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((128, 128)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # This code will iterate over the folders and load the data, filtering
    # the slices without labels and then concatenating all the datasets together
    train_datasets = []
    for bids_ds in tqdm_notebook(context["bids_path_train"], desc="Loading training set"):
        ds_train = BidsDataset(bids_ds,
                               transform=train_transform)
        train_datasets.append(ds_train)

    ds_train = ConcatDataset(train_datasets)
    print(f"Loaded {len(ds_train)} axial slices for the training set.")
    train_loader = DataLoader(ds_train, batch_size=context["batch_size"],
                              shuffle=True, pin_memory=True,
                              collate_fn=mt_datasets.mt_collate,
                              num_workers=0)
    
    # Validation dataset ------------------------------------------------------
    validation_datasets = []
    for bids_ds in tqdm_notebook(context["bids_path_validation"], desc="Loading validation set"):
        ds_val = BidsDataset(bids_ds,
                             transform=val_transform)
        validation_datasets.append(ds_val)

    ds_val = ConcatDataset(validation_datasets)
    print(f"Loaded {len(ds_val)} axial slices for the validation set.")
    val_loader = DataLoader(ds_val, batch_size=context["batch_size"],
                            shuffle=True, pin_memory=True,
                            collate_fn=mt_datasets.mt_collate,
                            num_workers=0)
    
    
    # Model definition ---------------------------------------------------------
    model = Classifier(drop_rate=context["dropout_rate"],
                       bn_momentum=context["batch_norm_momentum"])
    model.to(device)

    num_epochs = context["num_epochs"]
    initial_lr = context["initial_lr"]

    # Using Adam with cosine annealing learning rate
    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=context["log_directory"])
    
    # Binary Cross Entropy Loss
    criterion = nn.BCELoss()
    
    
    # Training loop -----------------------------------------------------------
    best_validation_loss = float("inf")
    for epoch in tqdm_notebook(range(1, num_epochs+1), desc="Training"):
        start_time = time.time()

        scheduler.step()

        lr = scheduler.get_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        model.train()
        train_loss_total = 0.0
        num_steps = 0
    
        for i, batch in enumerate(train_loader):
            tbatch = time.time()
            t = time.time()
            input_samples = batch["input"]
            input_labels = get_modality(batch)
            #print(f"took {time.time() - t} to load batch")
            
            t = time.time()      
            var_input = input_samples.to(device)
            var_labels = OneHotEncode(input_labels).to(device)
            #print(f"took {time.time() - t} to load tensors on cuda")
            
            t = time.time()
            preds = model(var_input)
            #print(f"took {time.time() - t} to compute preds")
            
            
            #Here I tried BCELoss and BCEWithLogitsLoss 
            t = time.time()
            loss = criterion(preds, var_labels)
            #print(f"took {time.time() - t} to compute loss")
            train_loss_total += loss.item()

            t = time.time()
            optimizer.zero_grad()
            #print(f"took {time.time() - t} to compute gd")
            
            t = time.time()
            loss.backward()
            #print(f"took {time.time() - t} to update weights")

            optimizer.step()
            num_steps += 1
            
            #print(f"BATCH FINISHED IN {time.time() - tbatch}")
        train_loss_total_avg = train_loss_total / num_steps

        #tqdm.write(f"Epoch {epoch} training loss: {train_loss_total_avg:.4f}.")
        print(f"Epoch {epoch} training loss: {train_loss_total_avg:.4f}.")
        
        '''
        # Validation loop -----------------------------------------------------
        model.eval()
        val_loss_total = 0.0
        num_steps = 0

        for i, batch in enumerate(val_loader):
            input_samples = batch["input"]
            input_labels = get_modality(batch)
            
            with torch.no_grad():
                var_input = input_samples.to(device)
                var_labels = OneHotEncode(input_labels).to(device)

                preds = model(var_input)

                loss = criterion(preds, var_labels)
                val_loss_total += loss.item()

            num_steps += 1

        val_loss_total_avg = val_loss_total / num_steps

        #tqdm.write(f"Epoch {epoch} validation loss: {val_loss_total_avg:.4f}.")
        print(f"Epoch {epoch} validation loss: {val_loss_total_avg:.4f}.")'''
        
        end_time = time.time()
        total_time = end_time - start_time
        #tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))
        print("Epoch {} took {:.2f} seconds.".format(epoch, total_time))
        
        '''
        if val_loss_total_avg < best_validation_loss:
            best_validation_loss = val_loss_total_avg
            torch.save(model, "./"+context["log_directory"]+"/best_model.pt")'''

        
    # save final model
    torch.save(model, "./"+context["log_directory"]+"/final_model.pt")
    
    return


def run_main(command):
    with open('config_small.json') as fhandle:
        context = json.load(fhandle)

    #command = context["command"]

    if command == 'train':
        cmd_train(context)
        shutil.copyfile(sys.argv[1], "./"+context["log_directory"]+"/config_file.json")
    elif command == 'test':
        cmd_test(context)


In [None]:
run_main('train')

HBox(children=(IntProgress(value=0, description='Loading training set', max=3, style=ProgressStyle(description…

Loaded 1940 axial slices for the training set.


HBox(children=(IntProgress(value=0, description='Loading validation set', max=1, style=ProgressStyle(descripti…

Loaded 996 axial slices for the validation set.


HBox(children=(IntProgress(value=0, description='Training', style=ProgressStyle(description_width='initial')),…

Epoch 1 training loss: 6.5695.
Epoch 1 took 62.70 seconds.
Epoch 2 training loss: 6.7001.
Epoch 2 took 23.62 seconds.
Epoch 3 training loss: 6.3778.
Epoch 3 took 23.11 seconds.


# Testing

In [None]:
run_main('test')

# Random testing

In [None]:
with open('config_small.json') as fhandle:
    context = json.load(fhandle)

command = context["command"]

if command == 'train':
    cmd_train(context)
    shutil.copyfile(sys.argv[1], "./"+context["log_directory"]+"/config_file.json")
elif command == 'test':
    cmd_test(context)

In [None]:
preds = torch.rand(18,6)
preds_norm = nn.Softmax()
labels = [4, 5, 1, 3, 2, 5, 2, 5, 0, 3, 1, 0, 4, 0, 3, 0, 5, 3]
var_labels = []
for l in labels :
    a = [0 for i in range(6)]
    a[l] = 1 
    var_labels.append(a)
var_labels = torch.FloatTensor(var_labels)



CS_loss = nn.BCELoss()
loss = CS_loss(var_labels, preds)

NLL = nn.NLLLoss()
loss = NLL(var_labels, preds)
loss

In [17]:
model = Classifier()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
next(model.parameters()).is_cuda

True

In [5]:
torch.cuda.is_available()

True

In [22]:
for i in range(100):
    a = torch.rand(18,1,128,128)
    a = a.cuda()
    t = time.time()
            
    preds = model(a)

    print(f"took {time.time() - t} to compute")
            

took 0.002450704574584961 to compute
took 0.0018796920776367188 to compute
took 0.014809608459472656 to compute
took 0.012682676315307617 to compute
took 0.013240814208984375 to compute
took 0.013932466506958008 to compute
took 0.014234304428100586 to compute
took 0.0015873908996582031 to compute
took 0.0015411376953125 to compute
took 0.0017554759979248047 to compute
took 0.0015392303466796875 to compute
took 0.001560211181640625 to compute
took 0.0015194416046142578 to compute
took 0.005422115325927734 to compute
took 0.0015501976013183594 to compute
took 0.0016944408416748047 to compute
took 0.001562356948852539 to compute
took 0.0015995502471923828 to compute
took 0.0015420913696289062 to compute
took 0.0016045570373535156 to compute
took 0.0017175674438476562 to compute
took 0.0016162395477294922 to compute
took 0.0015463829040527344 to compute
took 0.00157928466796875 to compute
took 0.0015652179718017578 to compute
took 0.001575469970703125 to compute
took 0.0015316009521484375 

In [69]:
import time
import shutil
import sys
import pickle
import nibabel as nib
import numpy as np
import json
import os
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
from skimage.transform import resize

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader, ConcatDataset

import torch.backends.cudnn as cudnn
cudnn.benchmark = True


def get_modality(batch):
    labels = []
    for acq in batch['input_metadata']:
        path = acq.__getitem__('bids_metadata')['input_filename']
        name = os.path.basename(path)
        if "acq-MToff_MTS" in name :
            labels.append(0)
            continue
        if "acq-MTon_MTS" in name :
            labels.append(1)
            continue
        if "acq-T1w_MTS" in name :
            labels.append(2)
            continue
        if "T1w" in name :
            labels.append(3)
            continue
        if "T2star" in name :
            labels.append(4)
            continue
        if "T2w" in name :
            labels.append(5) 
            continue
    return labels

def OneHotEncode(labels):
    ohe_labels = []
    for label in labels :
        ohe = [0 for i in range(6)]
        ohe[label] = 1 
        ohe_labels.append(ohe)
    return torch.cuda.FloatTensor(ohe_labels)


def cmd_train(context):
    """Main command do train the network.
    :param context: this is a dictionary with all data from the
                    configuration file:
                        - 'command': run the specified command (e.g. train, test)
                        - 'gpu': ID of the used GPU
                        - 'bids_path_train': list of relative paths of the BIDS folders of each training center
                        - 'bids_path_validation': list of relative paths of the BIDS folders of each validation center
                        - 'bids_path_test': list of relative paths of the BIDS folders of each test center
                        - 'batch_size'
                        - 'dropout_rate'
                        - 'batch_norm_momentum'
                        - 'num_epochs'
                        - 'initial_lr': initial learning rate
                        - 'log_directory': folder name where log files are saved
                        - 'debugging': allows extended verbosity and intermediate outputs
    """
    # Set the GPU
    gpu_number = int(0)
    torch.cuda.set_device(gpu_number)
    device = torch.device("cuda:0")

    # These are the training transformations
    train_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((128, 128)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # These are the validation/testing transformations
    val_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((128, 128)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # This code will iterate over the folders and load the data, filtering
    # the slices without labels and then concatenating all the datasets together
    train_datasets = []
    for bids_ds in tqdm_notebook(context["bids_path_train"], desc="Loading training set"):
        ds_train = BidsDataset(bids_ds,
                               transform=train_transform)
        train_datasets.append(ds_train)

    ds_train = ConcatDataset(train_datasets)
    print(f"Loaded {len(ds_train)} axial slices for the training set.")
    train_loader = DataLoader(ds_train, batch_size=context["batch_size"],
                              shuffle=True, pin_memory=True,
                              collate_fn=mt_datasets.mt_collate,
                              num_workers=16)

    # Validation dataset ------------------------------------------------------
    validation_datasets = []
    for bids_ds in tqdm_notebook(context["bids_path_validation"], desc="Loading validation set"):
        ds_val = BidsDataset(bids_ds,
                             transform=val_transform)
        validation_datasets.append(ds_val)

    ds_val = ConcatDataset(validation_datasets)
    print(f"Loaded {len(ds_val)} axial slices for the validation set.")
    val_loader = DataLoader(ds_val, batch_size=context["batch_size"],
                            shuffle=True, pin_memory=True,
                            collate_fn=mt_datasets.mt_collate,
                            num_workers=1)

    model = Classifier(drop_rate=context["dropout_rate"],
                       bn_momentum=context["batch_norm_momentum"])
    model.cuda()

    num_epochs = context["num_epochs"]
    initial_lr = context["initial_lr"]

    # Using Adam with cosine annealing learning rate
    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    # Write the metrics, images, etc to TensorBoard format
    writer = SummaryWriter(log_dir=context["log_directory"])

    # Training loop -----------------------------------------------------------
    best_validation_loss = float("inf")
    for epoch in tqdm_notebook(range(1, num_epochs+1), desc="Training"):
        start_time = time.time()

        scheduler.step()

        lr = scheduler.get_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        model.train()
        train_loss_total = 0.0
        num_steps = 0

        print(train_loader)
        print(train_loader.dataset)
        for i, batch in enumerate(train_loader):
            print(time.time())
    return

def run_main(command):
    with open('config_small.json') as fhandle:
        context = json.load(fhandle)

    #command = context["command"]

    if command == 'train':
        cmd_train(context)
        shutil.copyfile(sys.argv[1], "./"+context["log_directory"]+"/config_file.json")
    elif command == 'test':
        cmd_test(context)
        
run_main("train")

HBox(children=(IntProgress(value=0, description='Loading training set', max=3, style=ProgressStyle(description…

Loaded 1940 axial slices for the training set.


HBox(children=(IntProgress(value=0, description='Loading validation set', max=1, style=ProgressStyle(descripti…

Loaded 996 axial slices for the validation set.


HBox(children=(IntProgress(value=0, description='Training', style=ProgressStyle(description_width='initial')),…

<torch.utils.data.dataloader.DataLoader object at 0x7fc378413d30>
<torch.utils.data.dataset.ConcatDataset object at 0x7fc3ce68d630>
1558638177.7926917
1558638180.6633515
1558638180.6725476
1558638180.672866
1558638182.5215883
1558638182.5219839
1558638182.522162
1558638182.5223496
1558638182.5225306
1558638182.5227137
1558638182.5228922
1558638182.5230596
1558638182.524441
1558638182.5248587
1558638184.9174876
1558638184.918017
1558638211.8422
1558638211.880668
1558638211.8810496
1558638221.7064786
1558638221.7067747
1558638221.7069128
1558638221.70704
1558638221.707172
1558638221.7073019
1558638221.707436
1558638221.7075562
1558638221.7076678
1558638221.7077844
1558638221.7079196
1558638221.7082303
1558638224.0905116


KeyboardInterrupt: 