## NETWORK TRAINING CODE ##

This module contains all code used for network training (aside from helper/data loader functions) and basic evaluations

Author: antoniabhain@gmail.com

In [None]:
import numpy as np
import cv2

from randomised_filtering.classifier.streamline_loader import load_data
from randomised_filtering.classifier.generator import BalancedDataGen
from randomised_filtering.classifier.model import get_binary_model, get_categorical_model

TRK_PATH = "data/599671/All_10M_corrected.trk"
POS_STREAMLINES_PATH = "data/599671/json/pos_streamlines.json"
NEG_STREAMLINES_PATH = "data/599671/json/neg_streamlines.json"

In [None]:
pos_streamlines, neg_streamlines, inc_streamlines = load_data(
    TRK_PATH, 
    POS_STREAMLINES_PATH,  
    NEG_STREAMLINES_PATH, 
    normalize=True)

In [None]:
np.random.shuffle(pos_streamlines)
np.random.shuffle(neg_streamlines)
np.random.shuffle(inc_streamlines)

In [None]:
# resize data to fit network input dimensions
pos_resized = np.array([cv2.resize(x, (3,23)).reshape((69,1)) for x in pos_streamlines])
neg_resized = np.array([cv2.resize(x, (3,23)).reshape((69,1)) for x in neg_streamlines])
inc_resized = np.array([cv2.resize(x, (3,23)).reshape((69,1)) for x in inc_streamlines])

## PLAUSIBLE VS IMPLAUSIBLE ##

In [None]:
BATCH_SIZE = 50
k = 5

# exchange one of these for inc_resized when wanting to train with inconclusive streamlines
data = [neg_resized, pos_resized]
fold_len = [int(len(d) / k) for d in data]

# train 5 models for cross-validation
for fold in range(k):

    # generate masks that cut out the data folds
    test_mask = []
    
    for i in range(len(data)):
        tmp_mask = np.zeros(len(data[i]), dtype=np.bool)
        tmp_mask[int(fold * fold_len[i]) : int((fold + 1) * fold_len[i])] = True
        test_mask += [tmp_mask]
    
    train_mask = [np.invert(m) for m in test_mask]
    
    print("MODEL", fold)
    traingen_pn = BalancedDataGen(
        data=[_data[_mask] for _data, _mask in zip(data, train_mask)],
        weights=[1, 1],
        batch_size=BATCH_SIZE,
    )
    testgen_pn = BalancedDataGen(
        data=[_data[_mask] for _data, _mask in zip(data, test_mask)],
        weights=[1, 1],
        batch_size=BATCH_SIZE,
    )

    model = get_binary_model()
    model.fit(traingen_pn, epochs=5, verbose=1)
    model.evaluate(testgen_pn, verbose=1)
    
    correct_predictions_0 = 1 - np.sum(
        np.round(model.predict(data[0][test_mask[0]]))
    ) / len(data[0])
    print(correct_predictions_0)
    print(len(data[0]))
    print()

    correct_predictions_1 = np.sum(
        np.round(model.predict(data[1][test_mask[1]]))
    ) / len(data[1])
    print(len(data[1]))
    print()
    
    model.save("model_binary_" + str(fold))

## MULTI-CLASS CLASSIFIER ##

In [None]:
BATCH_SIZE = 60
k = 5

data = [neg_resized, pos_resized, inc_resized]
fold_len = [int(len(d) / k) for d in data]

for fold in range(k):

    # generate masks that cut out the data folds
    test_mask = []
    
    for i in range(len(data)):
        tmp_mask = np.zeros(len(data[i]), dtype=np.bool)
        tmp_mask[int(fold * fold_len[i]) : int((fold + 1) * fold_len[i])] = True
        test_mask += [tmp_mask]
    
    train_mask = [np.invert(m) for m in test_mask]
    
    traingen_cat3 = BalancedDataGen(
        data=[_data[_mask] for _data, _mask in zip(data, train_mask)],
        weights=[1, 1, 1],
        batch_size=BATCH_SIZE,
        categorical=True,
    )
    testgen_cat3 = BalancedDataGen(
        data=[_data[_mask] for _data, _mask in zip(data, test_mask)],
        weights=[1, 1, 1],
        batch_size=BATCH_SIZE,
        categorical=True,
    )

    print("MODEL", fold)
    model_cat3 = get_categorical_model(3)
 
    model_cat3.fit(traingen_cat3, epochs=2, verbose=1)
    model_cat3.evaluate(testgen_cat3, verbose=1)
      
    model_cat3.save("model_cat3_" + str(fold))