## NETWORK TRAINING CODE ##

This module contains all code used for network training (aside from helper/data loader functions) and basic evaluations

Author: antoniabhain@gmail.com

In [None]:
import nibabel as nib
import json
import numpy as np
import cv2
import tensorflow.keras as keras
import tensorflow as tf

import streamline_loader

TRK_PATH = "data/599671/All_10M_corrected.trk"
POS_STREAMLINES_PATH = "data/599671/json/pos_streamlines.json"
NEG_STREAMLINES_PATH = "data/599671/json/neg_streamlines.json"

In [None]:
pos_streamlines, neg_streamlines, inc_streamlines = streamline_loader.load_data(
    TRK_PATH, 
    POS_STREAMLINES_PATH,  
    NEG_STREAMLINES_PATH, 
    normalize=True)

In [None]:
np.random.shuffle(pos_streamlines)
np.random.shuffle(neg_streamlines)
np.random.shuffle(inc_streamlines)

In [None]:
# resize data to fit network input dimensions
pos_resized = np.array([cv2.resize(x, (3,23)).reshape((69,1)) for x in pos_streamlines])
neg_resized = np.array([cv2.resize(x, (3,23)).reshape((69,1)) for x in neg_streamlines])
inc_resized = np.array([cv2.resize(x, (3,23)).reshape((69,1)) for x in inc_streamlines])

In [None]:
class BalancedDataGen(keras.utils.Sequence):
    
    def __init__(self, data, weights, batch_size, categorical=False):
        
        # copy to avoid messing up original data order through np.shuffle
        self._data = [d.copy() for d in data] 
        self._num_classes = len(data)
        
        if len(weights) != self._num_classes:
            raise ValueError("Please provide weight for each class")
        
        self._weights = weights
        
        if (batch_size % self._num_classes != 0):
            raise ValueError("Please make batch size dividable by number of classes")
        
        self._batch_size = batch_size
        self._subbatch_size = batch_size // self._num_classes
        
        self._batches = [int(len(x) // self._subbatch_size) for x in self._data]
        
        self._categorical = categorical
        
        self.on_epoch_end()
     
    def on_epoch_end(self):
        for d in self._data:
            np.random.shuffle(d)
            
    def _get_subbatch(self, class_idx, idx):
        # class_idx is the index for the class. so self._data[class_idx] will be used
        
        data_idx = np.array(idx % self._batches[class_idx])
        class_data = self._data[class_idx]
    
        x = class_data[data_idx * self._subbatch_size:
                       data_idx * self._subbatch_size + self._subbatch_size]
        y = np.full(self._subbatch_size, class_idx)
        w = np.full(self._subbatch_size, self._weights[class_idx])
        
        # shuffle the data-subset if epoch is not over but data is exhausted (for under-represented subsets)
        if data_idx == self._batches[class_idx] - 1:
            np.random.shuffle(self._data[class_idx])
        
        return x,y,w
    
    def __getitem__(self, idx):
        
        x = []
        y = []
        w = []
        
        # compile parts of batch from all streamline classes
        for i in range(self._num_classes):
            x_tmp, y_tmp, w_tmp = self._get_subbatch(i, idx)
            x.append(x_tmp)
            y.append(y_tmp)
            w.append(w_tmp)
            
        x = np.concatenate(x, axis=0)
        y = np.concatenate(y, axis=0)
        w = np.concatenate(w, axis=0)
        
        if self._categorical:
            y = tf.keras.utils.to_categorical(y, num_classes=self._num_classes, dtype='float32')

        return x, y, w
            
    def __len__(self):
        return max(self._batches)

In [None]:
# categorical crossentropy model
def get_categorical_model(num_classes):
    # returns multi-class model with given number of classes

    model = keras.Sequential(
        [
            keras.layers.InputLayer(input_shape=(69,1)),
            keras.layers.Conv1D(8, kernel_size=5, padding='same', activation="relu"),
            keras.layers.MaxPooling1D(pool_size=2),
            keras.layers.Conv1D(16, kernel_size=3, padding='same', activation="relu"),
            keras.layers.MaxPooling1D(pool_size=2),
            keras.layers.Flatten(),
            keras.layers.Dropout(0.5),
            keras.layers.Dense(num_classes, activation="softmax"),
        ]
    )
    
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    return model

def get_binary_model():
    # returns binary model
    
    model = keras.Sequential(
        [
            keras.layers.InputLayer(input_shape=(69,1)),
            keras.layers.Conv1D(8, kernel_size=5, padding='same', activation="relu"),
            keras.layers.MaxPooling1D(pool_size=2),
            keras.layers.Conv1D(16, kernel_size=3, padding='same', activation="relu"),
            keras.layers.MaxPooling1D(pool_size=2),
            keras.layers.Flatten(),
            keras.layers.Dropout(0.5),
            keras.layers.Dense(1, activation="sigmoid"),
        ]
    )
    
    model.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])
    return model

## PLAUSIBLE VS IMPLAUSIBLE ##

In [None]:
BATCH_SIZE = 50
k = 5

# exchange one of these for inc_resized when wanting to train with inconclusive streamlines
data = [neg_resized, pos_resized]
fold_len = [int(len(d)/k) for d in data]

# train 5 models for cross-validation
for fold in range(k):

    # generate masks that cut out the data folds
    test_mask = []
    
    for i in range(len(data)):
        tmp_mask = np.zeros(len(data[i]), dtype=np.bool)
        tmp_mask[int(fold * fold_len[i]):int((fold+1) * fold_len[i])] = True
        test_mask += [tmp_mask]
    
    train_mask = [np.invert(m) for m in test_mask]
    
    print("MODEL", fold)
    traingen_pn = BalancedDataGen([data[0][train_mask[0]], 
                              data[1][train_mask[1]]],
                              [1, 1], BATCH_SIZE)   
    testgen_pn = BalancedDataGen([data[0][test_mask[0]], 
                              data[1][test_mask[1]]],
                              [1, 1], BATCH_SIZE)                                        

    model = get_binary_model()
    model.fit(traingen_pn, epochs=5, verbose=1)
    model.evaluate(testgen_pn, verbose=1)
    
    correct_predictions_0 = 1 - np.sum(np.round(model.predict(data[0][test_mask[0]]))) / len(data[0])
    print(correct_predictions_0)
    print(len(data[0]))
    print()

    correct_predictions_1 = np.sum(np.round(model.predict(data[1][test_mask[1]]))) / len(data[1])
    print(len(data[1]))
    print()
    
    model.save("model_binary_"+str(fold))

## MULTI-CLASS CLASSIFIER ##

In [None]:
BATCH_SIZE = 60
k = 5

data = [neg_resized, pos_resized, inc_resized]
fold_len = [int(len(d)/k) for d in data]

for fold in range(k):

    # generate masks that cut out the data folds
    test_mask = []
    
    for i in range(len(data)):
        tmp_mask = np.zeros(len(data[i]), dtype=np.bool)
        tmp_mask[int(fold * fold_len[i]):int((fold+1) * fold_len[i])] = True
        test_mask += [tmp_mask]
    
    train_mask = [np.invert(m) for m in test_mask]
    
    traingen_cat3 = BalancedDataGen([data[0][train_mask[0]], 
                              data[1][train_mask[1]],
                              data[2][train_mask[2]]],
                               [1, 1, 1], BATCH_SIZE, categorical=True)   
    testgen_cat3 = BalancedDataGen([data[0][train_mask[0]], 
                              data[1][train_mask[1]],
                              data[2][train_mask[2]]],
                              [1, 1, 1], BATCH_SIZE, categorical=True)                                        

    print("MODEL", fold)
    model_cat3 = get_categorical_model(3)
 
    model_cat3.fit(traingen_cat3, epochs=2, verbose=1)
    model_cat3.evaluate(testgen_cat3, verbose=1)
      
    model_cat3.save("model_cat3_"+str(fold))