In [None]:
import tensorflow as tf
import os
import os.path as op
import numpy as np
import nibabel as nib
from ae_keras import ReflectionPadding1D, IncrFeatStridedConvFCUpsampReflectPadAE
import torch

from dipy.io.stateful_tractogram import Space
from dipy.io.streamline import load_tractogram
from dipy.tracking.streamline import Streamlines  # same as nibabel.streamlines.ArraySequence

import aux_functions as af

: 

In [2]:
# Read some TRK data:
fibercup_path = "/home/teitxe/data/FiberCup/"
data_path = "/home/teitxe/data/FiberCup/fibercup_advanced_filtering_no_ushapes/"
f_trk_data = op.join(data_path, "ae_input_std_endpoints/train/fibercup_Simulated_prob_tracking_minL10_resampled256_plausibles_std_endpoints_train.trk")
f_img_data = op.join(fibercup_path, "Simulated_FiberCup.nii.gz")
streamlines = af.read_data(f_trk_data, f_img_data)
print(f"N of streamlines: {len(streamlines)}")
print(f"Example of a streamline: {streamlines[0][0]}")
print(f"N of points in the first streamline: {len(streamlines[0])}")

N of streamlines: 3112
Example of a streamline: [-62.621086  -17.029413    2.6032662]
N of points in the first streamline: 256


### Instantiate the Model

In [4]:
# Instantiate the Model
latent_space_dims = 32
model = IncrFeatStridedConvFCUpsampReflectPadAE(latent_space_dims)
input_shape = (1, 256, 3)  # Example input shape
input_streamline = np.array([streamlines[0]])
output = model.call(input_streamline)

print(f"Difference between input and output for a streamline point = {(input_streamline - output)[0][0]}")

Difference between input and output for a streamline point = [-62.62566   -16.969572    2.5094502]


### Create the dataset to fetch from it during the training loop

In [9]:
# Make a tensorflow dataset out of the streamlines
dataset = tf.data.Dataset.from_tensor_slices(streamlines)

### Define the Loss and the Optimizer

In [17]:
# Loss function: Mean squared error
loss_mse = tf.keras.losses.MeanSquaredError()

def loss(model, x, y):
    y_ = model(x)
    return loss_mse(y_true=y, y_pred=y_)

def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

# Optimizer
optimizer = tf.keras.optimizers.Adadelta(learning_rate=0.01)


In [28]:
train_mse_results = []

n_epochs = 5
batch_size = 1
dataset_train_batch = dataset.batch(batch_size)

for epoch in range(n_epochs):
    epoch_mse = tf.keras.metrics.MeanSquaredError()
    
    for x in dataset_train_batch:
        # Optimize the model
        loss_value, grads = grad(model, x, x)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        # Track progress
        epoch_mse.update_state(x, model(x))
    
    # End epoch
    train_mse_results.append(epoch_mse.result())
    
    print(f"Epoch {epoch}: Loss: {epoch_mse.result()}")

KeyboardInterrupt: 

In [25]:
# model.compile(optimizer='adam', loss='mse')
# model.fit(dataset, epochs=10)

Epoch 1/10


ValueError: Exception encountered when calling Sequential.call().

[1mInvalid input shape for input Tensor("data:0", shape=(256, 3), dtype=float32). Expected shape (1, 256, 3), but input has incompatible shape (256, 3)[0m

Arguments received by Sequential.call():
  • inputs=tf.Tensor(shape=(256, 3), dtype=float32)
  • training=None
  • mask=None

### Test a training loop iteration manually

In [18]:
loss_value, gradients = grad(model, input_streamline, input_streamline)
print(f"Step: {optimizer.iterations.numpy()}, Initial Loss: {loss_value.numpy()}")
print(f"Step: {optimizer.iterations.numpy()},         Loss: {loss(model, input_streamline, input_streamline).numpy()}")

Step: 0, Initial Loss: 6098.486328125
Step: 0,         Loss: 6098.486328125


### Define the training loop

In [18]:
# Adding a leading underscore to avoid function parameters shadowing these
# variables
_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_train_tractogram_fname = "strml_train.trk"
_valid_tractogram_fname = "strml_valid.trk"
_img_fname = "t1.nii.gz"
_trained_weights_fname = "already_available_model_weights.pt"
_training_weights_fname = "training_model_weights.pt"
# The following values were found to give best results
_lr = 6.68e-4
_weight_decay = 0.13
_epochs = 100
# resample_data()   # resample your tractogram to 256 points if needed
test_ae_model(
    _train_tractogram_fname, _img_fname, _device
)  # only does a forward pass, does not train the model
test_ae_model_loader(_train_tractogram_fname, _img_fname, _device)  # computes loss
_ = load_model_weights(_trained_weights_fname, _device, _lr, _weight_decay)  # load model weights
train_ae_model(
    _train_tractogram_fname, _valid_tractogram_fname, _img_fname, _device, _lr, _weight_decay, _epochs, _training_weights_fname
)  # computes loss and trains the model

(1, 256, 3)

### Try to export the model weights from PyTorch ---> ONNX ---> TensorFlow

In [2]:
import tractolearn.models.track_ae_cnn1d_incr_feat_strided_conv_fc_upsamp_reflect_pad_pytorch as AE_model

In [3]:
weights_path = "/home/teitxe/data/tractolearn_data/"
state_dict = torch.load(os.path.join(weights_path, "best_model_contrastive_tractoinferno_hcp.pt"), map_location=torch.device('cpu'))
net = AE_model.IncrFeatStridedConvFCUpsampReflectPadAE(32)
dummy_input = torch.randn(1, 3, 256)
net(dummy_input)[0][0][0:2]

tensor([0.0901, 0.0901], grad_fn=<SliceBackward0>)

#### Load the weights into the model and export them to ONNX

In [4]:
net.load_state_dict(state_dict["state_dict"])
onnx_file = os.path.join(weights_path, "best_model_contrastive_tractoinferno_hcp.onnx")
torch.onnx.export(net, dummy_input, onnx_file, input_names=["input"], output_names=["output"])

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


#### Load the ONNX file

In [18]:
import subprocess as sp
container = 'docker run --rm -v /home/teitxe:/workdir -w /workdir docker.io/pinto0309/onnx2tf:1.22.3 /bin/bash -c "'
tf_model_path = '/workdir/data/tractolearn_data/tf_model'
onnx_model_path = '/workdir/data/tractolearn_data/best_model_contrastive_tractoinferno_hcp.onnx'
command = f'mkdir -p {tf_model_path} && onnx2tf -i {onnx_model_path} -o {tf_model_path}"'
print(container + command)
sp.run(container + command, shell=True, check=True)

docker run --rm -v /home/teitxe:/workdir -w /workdir docker.io/pinto0309/onnx2tf:1.22.3 /bin/bash -c "mkdir -p /workdir/data/tractolearn_data/tf_model && onnx2tf -i /workdir/data/tractolearn_data/best_model_contrastive_tractoinferno_hcp.onnx -o /workdir/data/tractolearn_data/tf_model"

Simplifying...
Finish! Here is the difference:
┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃            ┃ Original Model ┃ Simplified Model ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ Constant   │ 32             │ 32               │
│ Conv       │ 12             │ 12               │
│ Gemm       │ 2              │ 2                │
│ Pad        │ 12             │ 12               │
│ Relu       │ 10             │ 10               │
│ Reshape    │ 2              │ 2                │
│ Resize     │ 5              │ 5                │
│ Model Size │ 18.0MiB        │ 18.0MiB          │
└────────────┴────────────────┴──────────────────┘

Simplifying...
Finish! Here is the difference:
┏━━━━━

W0000 00:00:1716981887.040676       1 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1716981887.040735       1 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
W0000 00:00:1716981887.460485       1 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1716981887.460539       1 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.


[32mINFO:[0m [35mtf_op_type[0m: Pad
[32mINFO:[0m [34m input.1.x[0m: [34mname[0m: tf.compat.v1.gather_4/GatherV2:0 [34mshape[0m: (1, 256, 32) [34mdtype[0m: <dtype: 'float32'> 
[32mINFO:[0m [34m input.2.paddings[0m: [34mshape[0m: (3, 2) [34mdtype[0m: <dtype: 'int32'> 
[32mINFO:[0m [34m input.3.constant_value[0m: [34mval[0m: 0 
[32mINFO:[0m [34m input.4.mode[0m: [34mval[0m: reflect 
[32mINFO:[0m [34m input.5.tensor_rank[0m: [34mval[0m: 3 
[32mINFO:[0m [34m output.1.output[0m: [34mname[0m: tf.compat.v1.pad_11//decod_conv6/pad/Pad:0 [34mshape[0m: (1, 258, 32) [34mdtype[0m: <dtype: 'float32'> 

[32mINFO:[0m [32m44 / 44[0m
[32mINFO:[0m [35monnx_op_type[0m: Conv[35m onnx_op_name[0m: /decod_conv6/decod_conv6.1/Conv
[32mINFO:[0m [36m input_name.1[0m: /decod_conv6/pad/Pad_output_0 [36mshape[0m: [1, 32, 258] [36mdtype[0m: float32
[32mINFO:[0m [36m input_name.2[0m: decod_conv6.1.weight [36mshape[0m: [3, 32, 3] [36mdtype[0m:

CompletedProcess(args='docker run --rm -v /home/teitxe:/workdir -w /workdir docker.io/pinto0309/onnx2tf:1.22.3 /bin/bash -c "mkdir -p /workdir/data/tractolearn_data/tf_model && onnx2tf -i /workdir/data/tractolearn_data/best_model_contrastive_tractoinferno_hcp.onnx -o /workdir/data/tractolearn_data/tf_model"', returncode=0)

#### Try to load the model into TF

In [23]:
local_tf_model_path = '/home/teitxe/data/tractolearn_data/tf_model'
imported = tf.saved_model.load(local_tf_model_path)
f = imported.signatures

In [31]:
f

_SignatureMap({})