In [1]:
import tensorflow as tf
import os
import os.path as op
import numpy as np
import nibabel as nib
from ae_keras import ReflectionPadding1D, IncrFeatStridedConvFCUpsampReflectPadAE
import torch

from dipy.io.stateful_tractogram import Space
from dipy.io.streamline import load_tractogram
from dipy.tracking.streamline import Streamlines  # same as nibabel.streamlines.ArraySequence

import aux_functions as af

2024-05-26 13:14:57.380862: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-26 13:14:57.381143: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-26 13:14:57.383139: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-26 13:14:57.408348: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Read some TRK data:
fibercup_path = "/home/teitxe/data/FiberCup/"
data_path = "/home/teitxe/data/FiberCup/fibercup_advanced_filtering_no_ushapes/"
f_trk_data = op.join(data_path, "ae_input_std_endpoints/train/fibercup_Simulated_prob_tracking_minL10_resampled256_plausibles_std_endpoints_train.trk")
f_img_data = op.join(fibercup_path, "Simulated_FiberCup.nii.gz")
streamlines = af.read_data(f_trk_data, f_img_data)
print(f"N of streamlines: {len(streamlines)}")
print(f"Example of a streamline: {streamlines[0][0]}")
print(f"N of points in the first streamline: {len(streamlines[0])}")

N of streamlines: 3112
Example of a streamline: [-62.621086  -17.029413    2.6032662]
N of points in the first streamline: 256


### Instantiate the Model

In [12]:
# Instantiate the Model
latent_space_dims = 128
model = IncrFeatStridedConvFCUpsampReflectPadAE(latent_space_dims)
input_shape = (1, 256, 3)  # Example input shape
input_streamline = np.array([streamlines[0]])
output = model.call(input_streamline)

print(f"Difference between input and output for a streamline point = {(input_streamline - output)[0][0]}")

Difference between input and output for a streamline point = [-62.090878  -17.037579    3.1329408]


### Create the dataset to fetch from it during the training loop

In [9]:
# Make a tensorflow dataset out of the streamlines
dataset = tf.data.Dataset.from_tensor_slices(streamlines)

### Define the Loss and the Optimizer

In [17]:
# Loss function: Mean squared error
loss_mse = tf.keras.losses.MeanSquaredError()

def loss(model, x, y):
    y_ = model(x)
    return loss_mse(y_true=y, y_pred=y_)

def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

# Optimizer
optimizer = tf.keras.optimizers.Adadelta(learning_rate=0.01)


In [28]:
train_mse_results = []

n_epochs = 5
batch_size = 1
dataset_train_batch = dataset.batch(batch_size)

for epoch in range(n_epochs):
    epoch_mse = tf.keras.metrics.MeanSquaredError()
    
    for x in dataset_train_batch:
        # Optimize the model
        loss_value, grads = grad(model, x, x)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        # Track progress
        epoch_mse.update_state(x, model(x))
    
    # End epoch
    train_mse_results.append(epoch_mse.result())
    
    print(f"Epoch {epoch}: Loss: {epoch_mse.result()}")

KeyboardInterrupt: 

In [25]:
# model.compile(optimizer='adam', loss='mse')
# model.fit(dataset, epochs=10)

Epoch 1/10


ValueError: Exception encountered when calling Sequential.call().

[1mInvalid input shape for input Tensor("data:0", shape=(256, 3), dtype=float32). Expected shape (1, 256, 3), but input has incompatible shape (256, 3)[0m

Arguments received by Sequential.call():
  • inputs=tf.Tensor(shape=(256, 3), dtype=float32)
  • training=None
  • mask=None

### Test a training loop iteration manually

In [18]:
loss_value, gradients = grad(model, input_streamline, input_streamline)
print(f"Step: {optimizer.iterations.numpy()}, Initial Loss: {loss_value.numpy()}")
print(f"Step: {optimizer.iterations.numpy()},         Loss: {loss(model, input_streamline, input_streamline).numpy()}")

Step: 0, Initial Loss: 6098.486328125
Step: 0,         Loss: 6098.486328125


### Define the training loop

In [18]:
# Adding a leading underscore to avoid function parameters shadowing these
# variables
_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_train_tractogram_fname = "strml_train.trk"
_valid_tractogram_fname = "strml_valid.trk"
_img_fname = "t1.nii.gz"
_trained_weights_fname = "already_available_model_weights.pt"
_training_weights_fname = "training_model_weights.pt"
# The following values were found to give best results
_lr = 6.68e-4
_weight_decay = 0.13
_epochs = 100
# resample_data()   # resample your tractogram to 256 points if needed
test_ae_model(
    _train_tractogram_fname, _img_fname, _device
)  # only does a forward pass, does not train the model
test_ae_model_loader(_train_tractogram_fname, _img_fname, _device)  # computes loss
_ = load_model_weights(_trained_weights_fname, _device, _lr, _weight_decay)  # load model weights
train_ae_model(
    _train_tractogram_fname, _valid_tractogram_fname, _img_fname, _device, _lr, _weight_decay, _epochs, _training_weights_fname
)  # computes loss and trains the model

(1, 256, 3)