## NETWORK TRAINING CODE ##

This module contains all code used for network training (aside from helper/data loader functions) and basic evaluations

Author: antoniabhain@gmail.com

In [None]:
import numpy as np
import cv2

from randomised_filtering.classifier.streamline_loader import load_data
from randomised_filtering.classifier.model import get_binary_model, get_categorical_model
from randomised_filtering.classifier.training import training_cv

TRK_PATH = "data/599671/All_10M_corrected.trk"
POS_STREAMLINES_PATH = "data/599671/json/pos_streamlines.json"
NEG_STREAMLINES_PATH = "data/599671/json/neg_streamlines.json"

In [None]:
pos_streamlines, neg_streamlines, inc_streamlines = load_data(
    TRK_PATH, 
    POS_STREAMLINES_PATH,  
    NEG_STREAMLINES_PATH, 
    normalize=True
)

In [None]:
np.random.shuffle(pos_streamlines)
np.random.shuffle(neg_streamlines)
np.random.shuffle(inc_streamlines)

In [None]:
# resize data to fit network input dimensions

POINTS_PER_STREAMLINE = 23  # resample streamlines to this number of points

RESAMPLING_SHAPE = (3, POINTS_PER_STREAMLINE)
INPUT_SHAPE = (np.prod(RESAMPLING_SHAPE), 1)  # shape of network input

pos_resized = np.array([
    cv2.resize(x, RESAMPLING_SHAPE).reshape(INPUT_SHAPE) for x in pos_streamlines
])
neg_resized = np.array([
    cv2.resize(x, RESAMPLING_SHAPE).reshape(INPUT_SHAPE) for x in neg_streamlines
])
inc_resized = np.array([
    cv2.resize(x, RESAMPLING_SHAPE).reshape(INPUT_SHAPE) for x in inc_streamlines
])

## PLAUSIBLE VS IMPLAUSIBLE ##

In [None]:
BATCH_SIZE = 50
k = 5

# exchange one of these for inc_resized when wanting to train with inconclusive streamlines
data = [neg_resized, pos_resized]
model = get_binary_model(input_shape=INPUT_SHAPE)

# train 5 models in 5-fold cross-validation
training_cv(
    data=data,
    model=model,
    nb_folds=k,
    batch_size=BATCH_SIZE,
    epochs=5,
    base_path_to_model="model_binary",
)

## MULTI-CLASS CLASSIFIER ##

In [None]:
BATCH_SIZE = 60
k = 5

data = [neg_resized, pos_resized, inc_resized]
model = get_categorical_model(num_classes=len(data), input_shape=INPUT_SHAPE)

# train 5 models in 5-fold cross-validation
training_cv(
    data=data,
    model=model,
    nb_folds=k,
    batch_size=BATCH_SIZE,
    epochs=2,
    base_path_to_model="model_cat" + str(len(data)),
)