In [None]:
import os
import os.path as op
import time

import tensorflow as tf
import keras

import numpy as np
import nibabel as nib

import tractoencoder_gsoc.utils as utils

In [None]:
os.environ["KERAS_BACKEND"] = "tensorflow"

In [None]:
# Read some TRK data:
fibercup_path = "/home/teitxe/data/FiberCup/"
data_path = "/home/teitxe/data/FiberCup/fibercup_advanced_filtering_no_ushapes/"
f_trk_data = op.join(data_path, "ae_input_std_endpoints/train/fibercup_Simulated_prob_tracking_minL10_resampled256_plausibles_std_endpoints_train.trk")
f_img_data = op.join(fibercup_path, "Simulated_FiberCup.nii.gz")
streamlines = utils.read_data(f_trk_data, f_img_data)
print(f"N of streamlines: {len(streamlines)}")
print(f"Example of a streamline point: {streamlines[0][0]}")
print(f"N of points in the first streamline: {len(streamlines[0])}")

### Create the dataset to fetch from it during the training loop

In [None]:
# Make a tensorflow dataset out of the streamlines
dataset = tf.data.Dataset.from_tensor_slices(streamlines)

### Define the Loss and the Optimizer

In [None]:
# Loss function: Mean squared error
loss_mse = tf.keras.losses.MeanSquaredError()

def loss(model, x, y):
    y_ = model(x)
    return loss_mse(y_true=y, y_pred=y_)

def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

# Optimizer
optimizer = tf.keras.optimizers.Adadelta(learning_rate=0.01)


In [None]:
train_mse_results = []

n_epochs = 5
batch_size = 1
dataset_train_batch = dataset.batch(batch_size)

for epoch in range(n_epochs):
    epoch_mse = tf.keras.metrics.MeanSquaredError()
    
    for x in dataset_train_batch:
        # Optimize the model
        loss_value, grads = grad(model, x, x)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        # Track progress
        epoch_mse.update_state(x, model(x))
    
    # End epoch
    train_mse_results.append(epoch_mse.result())
    
    print(f"Epoch {epoch}: Loss: {epoch_mse.result()}")

In [None]:
# model.compile(optimizer='adam', loss='mse')
# model.fit(dataset, epochs=10)

### Test a training loop iteration manually

In [None]:
loss_value, gradients = grad(model, input_streamline, input_streamline)
print(f"Step: {optimizer.iterations.numpy()}, Initial Loss: {loss_value.numpy()}")
print(f"Step: {optimizer.iterations.numpy()},         Loss: {loss(model, input_streamline, input_streamline).numpy()}")

### Define the training loop

In [None]:
# Adding a leading underscore to avoid function parameters shadowing these
# variables
_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_train_tractogram_fname = "strml_train.trk"
_valid_tractogram_fname = "strml_valid.trk"
_img_fname = "t1.nii.gz"
_trained_weights_fname = "already_available_model_weights.pt"
_training_weights_fname = "training_model_weights.pt"
# The following values were found to give best results
_lr = 6.68e-4
_weight_decay = 0.13
_epochs = 100
# resample_data()   # resample your tractogram to 256 points if needed
test_ae_model(
    _train_tractogram_fname, _img_fname, _device
)  # only does a forward pass, does not train the model
test_ae_model_loader(_train_tractogram_fname, _img_fname, _device)  # computes loss
_ = load_model_weights(_trained_weights_fname, _device, _lr, _weight_decay)  # load model weights
train_ae_model(
    _train_tractogram_fname, _valid_tractogram_fname, _img_fname, _device, _lr, _weight_decay, _epochs, _training_weights_fname
)  # computes loss and trains the model