In [4]:
import os
import os.path as op
import time

import tensorflow as tf
import keras

import numpy as np
import nibabel as nib

import tractoencoder_gsoc.utils as utils

In [5]:
os.environ["KERAS_BACKEND"] = "tensorflow"

In [3]:
# Read some TRK data:
fibercup_path = "/home/teitxe/data/FiberCup/"
data_path = "/home/teitxe/data/FiberCup/fibercup_advanced_filtering_no_ushapes/"
f_trk_data = op.join(data_path, "ae_input_std_endpoints/train/fibercup_Simulated_prob_tracking_minL10_resampled256_plausibles_std_endpoints_train.trk")
f_img_data = op.join(fibercup_path, "Simulated_FiberCup.nii.gz")
streamlines = utils.read_data(f_trk_data, f_img_data)
print(f"N of streamlines: {len(streamlines)}")
print(f"Example of a streamline point: {streamlines[0][0]}")
print(f"N of points in the first streamline: {len(streamlines[0])}")

N of streamlines: 3112
Example of a streamline point: [-62.621086  -17.029413    2.6032662]
N of points in the first streamline: 256


### Create the dataset to fetch from it during the training loop

In [9]:
# Make a tensorflow dataset out of the streamlines
dataset = tf.data.Dataset.from_tensor_slices(streamlines)

### Define the Loss and the Optimizer

In [17]:
# Loss function: Mean squared error
loss_mse = tf.keras.losses.MeanSquaredError()

def loss(model, x, y):
    y_ = model(x)
    return loss_mse(y_true=y, y_pred=y_)

def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

# Optimizer
optimizer = tf.keras.optimizers.Adadelta(learning_rate=0.01)


In [28]:
train_mse_results = []

n_epochs = 5
batch_size = 1
dataset_train_batch = dataset.batch(batch_size)

for epoch in range(n_epochs):
    epoch_mse = tf.keras.metrics.MeanSquaredError()
    
    for x in dataset_train_batch:
        # Optimize the model
        loss_value, grads = grad(model, x, x)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        # Track progress
        epoch_mse.update_state(x, model(x))
    
    # End epoch
    train_mse_results.append(epoch_mse.result())
    
    print(f"Epoch {epoch}: Loss: {epoch_mse.result()}")

KeyboardInterrupt: 

In [25]:
# model.compile(optimizer='adam', loss='mse')
# model.fit(dataset, epochs=10)

Epoch 1/10


ValueError: Exception encountered when calling Sequential.call().

[1mInvalid input shape for input Tensor("data:0", shape=(256, 3), dtype=float32). Expected shape (1, 256, 3), but input has incompatible shape (256, 3)[0m

Arguments received by Sequential.call():
  • inputs=tf.Tensor(shape=(256, 3), dtype=float32)
  • training=None
  • mask=None

### Test a training loop iteration manually

In [18]:
loss_value, gradients = grad(model, input_streamline, input_streamline)
print(f"Step: {optimizer.iterations.numpy()}, Initial Loss: {loss_value.numpy()}")
print(f"Step: {optimizer.iterations.numpy()},         Loss: {loss(model, input_streamline, input_streamline).numpy()}")

Step: 0, Initial Loss: 6098.486328125
Step: 0,         Loss: 6098.486328125


### Define the training loop

In [18]:
# Adding a leading underscore to avoid function parameters shadowing these
# variables
_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_train_tractogram_fname = "strml_train.trk"
_valid_tractogram_fname = "strml_valid.trk"
_img_fname = "t1.nii.gz"
_trained_weights_fname = "already_available_model_weights.pt"
_training_weights_fname = "training_model_weights.pt"
# The following values were found to give best results
_lr = 6.68e-4
_weight_decay = 0.13
_epochs = 100
# resample_data()   # resample your tractogram to 256 points if needed
test_ae_model(
    _train_tractogram_fname, _img_fname, _device
)  # only does a forward pass, does not train the model
test_ae_model_loader(_train_tractogram_fname, _img_fname, _device)  # computes loss
_ = load_model_weights(_trained_weights_fname, _device, _lr, _weight_decay)  # load model weights
train_ae_model(
    _train_tractogram_fname, _valid_tractogram_fname, _img_fname, _device, _lr, _weight_decay, _epochs, _training_weights_fname
)  # computes loss and trains the model

(1, 256, 3)

### Try to export the model weights from PyTorch ---> ONNX ---> TensorFlow (Didn't work well...)

In [2]:
import tractolearn.models.track_ae_cnn1d_incr_feat_strided_conv_fc_upsamp_reflect_pad_pytorch as AE_model

In [3]:
weights_path = "/home/teitxe/data/tractolearn_data/"
state_dict = torch.load(os.path.join(weights_path, "best_model_contrastive_tractoinferno_hcp.pt"), map_location=torch.device('cpu'))
net = AE_model.IncrFeatStridedConvFCUpsampReflectPadAE(32)
dummy_input = torch.randn(1, 3, 256)
net(dummy_input)[0][0][0:2]

tensor([-0.0569, -0.0569], grad_fn=<SliceBackward0>)

#### Load the weights into the model and export them to ONNX

In [8]:
import torch
net.load_state_dict(state_dict["state_dict"])
onnx_file = os.path.join(weights_path, "best_model_contrastive_tractoinferno_hcp.onnx")
torch.onnx.export(net, dummy_input, onnx_file, input_names=["input"], output_names=["output"], 
                  export_params=True)

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


#### Load the ONNX file

In [25]:
import subprocess as sp
if not os.path.exists('/home/teitxe/data/tractolearn_data/tf_model'):
    container = 'docker run --rm -v /home/teitxe:/workdir -w /workdir docker.io/pinto0309/onnx2tf:1.22.3 /bin/bash -c "'
    tf_model_path = '/workdir/data/tractolearn_data/tf_model'
    onnx_model_path = '/workdir/data/tractolearn_data/best_model_contrastive_tractoinferno_hcp.onnx'
    command = f'mkdir -p {tf_model_path} && onnx2tf -i {onnx_model_path} -o {tf_model_path}"'
    print(container + command)
    sp.run(container + command, shell=True, check=True)

#### Try to load the model into TF

In [5]:
local_tf_model_path = '/home/teitxe/data/tractolearn_data/tf_model'
model = tf.saved_model.load(local_tf_model_path)


print((model(input_streamline))[0][0])

tf.Tensor([-65.18617   -18.12808     2.1061761], shape=(3,), dtype=float32)


In [6]:
output_torch[0][0]

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.02813642,  0.13867192,  0.03355828], dtype=float32)>

In [66]:
outputs = []
for streamline in streamlines:
    outputs.append(model(streamline.reshape(1, 256, 3))[0])

In [50]:
# Copy the header of MNI152
mni152 = nib.load("/home/teitxe/data/tractolearn_data/mni_masked.nii.gz")
tractogram = nib.streamlines.Tractogram(outputs, affine_to_rasmm=mni152.affine)
nib.streamlines.save(tractogram, "test.trk", header=mni152.header)

ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 2 dimension(s)

In [65]:
tractogram.streamlines

ArraySequence([array([[[-6.51861725e+01, -1.81280804e+01,  2.10617614e+00],
        [-6.51380234e+01, -1.81901741e+01,  2.10030055e+00],
        [-6.50476379e+01, -1.83469982e+01,  2.09295368e+00],
        [-6.49414368e+01, -1.85465355e+01,  2.08224583e+00],
        [-6.48481216e+01, -1.87630825e+01,  2.06499243e+00],
        [-6.47774200e+01, -1.89976387e+01,  2.04580307e+00],
        [-6.47690887e+01, -1.92527981e+01,  2.03363013e+00],
        [-6.48385086e+01, -1.95248356e+01,  2.03694749e+00],
        [-6.50052795e+01, -1.98126793e+01,  2.05857325e+00],
        [-6.52664032e+01, -2.01171722e+01,  2.09566951e+00],
        [-6.55915604e+01, -2.04336700e+01,  2.13942409e+00],
        [-6.59629745e+01, -2.07625751e+01,  2.18123603e+00],
        [-6.63543320e+01, -2.10907021e+01,  2.21733999e+00],
        [-6.67639618e+01, -2.14156990e+01,  2.24915624e+00],
        [-6.71870880e+01, -2.17433033e+01,  2.27869654e+00],
        [-6.76241913e+01, -2.20781727e+01,  2.30541134e+00],
        [

In [64]:
streamlines

ArraySequence([array([[ -62.621086  ,  -17.029413  ,    2.6032662 ],
       [ -62.899376  ,  -17.509941  ,    2.6986322 ],
       [ -63.31436   ,  -17.877901  ,    2.7416296 ],
       [ -63.63633   ,  -18.323187  ,    2.6293426 ],
       [ -64.022484  ,  -18.71436   ,    2.5888877 ],
       [ -64.37239   ,  -19.153662  ,    2.5698748 ],
       [ -64.79024   ,  -19.519556  ,    2.4586916 ],
       [ -65.30597   ,  -19.729292  ,    2.5384636 ],
       [ -65.681915  ,  -20.156317  ,    2.5701723 ],
       [ -66.04933   ,  -20.573917  ,    2.6530647 ],
       [ -66.44974   ,  -20.910442  ,    2.8476071 ],
       [ -66.9426    ,  -21.188663  ,    2.9057546 ],
       [ -67.392685  ,  -21.534174  ,    2.9384232 ],
       [ -67.90479   ,  -21.760218  ,    2.9267936 ],
       [ -68.43143   ,  -21.909626  ,    2.7741146 ],
       [ -68.82687   ,  -22.238876  ,    2.5391746 ],
       [ -69.10892   ,  -22.715961  ,    2.4079037 ],
       [ -69.45478   ,  -23.144743  ,    2.3422744 ],
       [ -69.