In [1]:
import os, sys
sys.path.append(os.path.abspath(os.path.join('../..'))) # access sibling directories
sys.path.append("C:\\Users\\lasse\\Documents\\GitHub\\BayesFlow")

from src.python.settings import summary_meta_diffusion, probability_meta_diffusion
from src.python.helpers import load_simulated_rt_data, mask_inputs
from src.python.networks import HierarchicalInvariantNetwork, ModelProbabilityNetwork
from src.python.losses import softmax_loss

import numpy as np
from tensorflow.keras.experimental import CosineDecay
from tensorflow.keras.optimizers import Adam
from functools import partial

from bayesflow.trainers import ModelComparisonTrainer
from bayesflow.amortizers import MultiModelAmortizer 

# Fine-tuning with 900 trials per person

## Load in data

In [4]:
levy_sims_folder = "c:\\Users\\lasse\\documents\\hierarchical_model_comparison_project\\data\\03_levy_flight_application\\truncnormal_alpha_prior"

indices_900_filename = "train_indices_900_trials.npy"
datasets_900_filename = "train_datasets_900_trials.npy"

indices_900, datasets_900 = load_simulated_rt_data(levy_sims_folder, indices_900_filename, datasets_900_filename)

## Training

When only conducting fine-tuning: manually move/delete fine-tuning checkpoints in checkpoints folder so that training resumes from pretrained network.

In [5]:
summary_net = HierarchicalInvariantNetwork(summary_meta_diffusion)
probability_net = ModelProbabilityNetwork(probability_meta_diffusion)
amortizer = MultiModelAmortizer(probability_net, summary_net)

In [16]:
# Training steps
epochs = 30 
n_datasets = datasets_900.shape[0]
batch_size = 32
iterations_per_epoch=n_datasets/batch_size

# CAREFUL: cosine decay will take previous training steps into account
# -> take previous number of training steps from pretraining for correct cosine decay
# epochs * (data sets / batch size)
# check starting lr with trainer.optimizer._decayed_lr(tf.float32)
pretraining_steps = 20 *  (40000/32)

# Cosine decaying learning rate
initial_lr = 0.0005
decay_steps = epochs*iterations_per_epoch + pretraining_steps
alpha = 0
lr_schedule = CosineDecay(
    initial_lr, decay_steps, alpha=alpha
    )

# Checkpoint path for loading pretrained network and saving the final network
checkpoint_path = "c:\\Users\\lasse\\documents\\hierarchical_model_comparison_project\\checkpoints\\03_levy_flight_application\\truncnormal_alpha_prior\\pre-trained_net"

trainer = ModelComparisonTrainer(
    network=amortizer, 
    loss=partial(softmax_loss),
    optimizer=partial(Adam, lr_schedule),
    checkpoint_path=checkpoint_path,
    skip_checks=True
    )

TRAINER INITIALIZATION: No generative model provided. Only offline learning mode is available!
Networks loaded from c:\Users\lasse\documents\hierarchical_model_comparison_project\checkpoints\03_levy_flight_application\truncnormal_alpha_prior\pre-trained_net\ckpt-20


In [7]:
# Mask some training data so that training leads to a robust net that can handle missing data

for epoch in range(epochs):
    datasets_900_masked = mask_inputs(datasets_900, missings_mean=28.5, missings_sd=13.5, missing_rts_equal_mean=True)
    losses = trainer.train_offline(epochs=1, batch_size=32, 
                                   model_indices=indices_900, sim_data=datasets_900_masked)
    print(f"epoch {epoch+1} finished.")

tf.Tensor(6.287231e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(5.8921072e-05, shape=(), dtype=float32)
epoch 1 finished.
tf.Tensor(5.8921072e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(5.50814e-05, shape=(), dtype=float32)
epoch 2 finished.
tf.Tensor(5.50814e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(5.135557e-05, shape=(), dtype=float32)
epoch 3 finished.
tf.Tensor(5.135557e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(4.7745736e-05, shape=(), dtype=float32)
epoch 4 finished.
tf.Tensor(4.7745736e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(4.4254022e-05, shape=(), dtype=float32)
epoch 5 finished.
tf.Tensor(4.4254022e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(4.0882456e-05, shape=(), dtype=float32)
epoch 6 finished.
tf.Tensor(4.0882456e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(3.7633003e-05, shape=(), dtype=float32)
epoch 7 finished.
tf.Tensor(3.7633003e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(3.4507575e-05, shape=(), dtype=float32)
epoch 8 finished.
tf.Tensor(3.4507575e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(3.1507985e-05, shape=(), dtype=float32)
epoch 9 finished.
tf.Tensor(3.1507985e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(2.8635981e-05, shape=(), dtype=float32)
epoch 10 finished.
tf.Tensor(2.8635981e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(2.5893272e-05, shape=(), dtype=float32)
epoch 11 finished.
tf.Tensor(2.5893272e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(2.3281396e-05, shape=(), dtype=float32)
epoch 12 finished.
tf.Tensor(2.3281396e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(2.0801963e-05, shape=(), dtype=float32)
epoch 13 finished.
tf.Tensor(2.0801963e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(1.8456325e-05, shape=(), dtype=float32)
epoch 14 finished.
tf.Tensor(1.8456325e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(1.6245947e-05, shape=(), dtype=float32)
epoch 15 finished.
tf.Tensor(1.6245947e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(1.4172033e-05, shape=(), dtype=float32)
epoch 16 finished.
tf.Tensor(1.4172033e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(1.2235881e-05, shape=(), dtype=float32)
epoch 17 finished.
tf.Tensor(1.2235881e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(1.0438532e-05, shape=(), dtype=float32)
epoch 18 finished.
tf.Tensor(1.0438532e-05, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(8.781106e-06, shape=(), dtype=float32)
epoch 19 finished.
tf.Tensor(8.781106e-06, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(7.26454e-06, shape=(), dtype=float32)
epoch 20 finished.
tf.Tensor(7.26454e-06, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(5.889729e-06, shape=(), dtype=float32)
epoch 21 finished.
tf.Tensor(5.889729e-06, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(4.6574623e-06, shape=(), dtype=float32)
epoch 22 finished.
tf.Tensor(4.6574623e-06, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(3.5684855e-06, shape=(), dtype=float32)
epoch 23 finished.
tf.Tensor(3.5684855e-06, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(2.6233943e-06, shape=(), dtype=float32)
epoch 24 finished.
tf.Tensor(2.6233943e-06, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(1.8227846e-06, shape=(), dtype=float32)
epoch 25 finished.
tf.Tensor(1.8227846e-06, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(1.1670888e-06, shape=(), dtype=float32)
epoch 26 finished.
tf.Tensor(1.1670888e-06, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(6.567091e-07, shape=(), dtype=float32)
epoch 27 finished.
tf.Tensor(6.567091e-07, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(2.9194356e-07, shape=(), dtype=float32)
epoch 28 finished.
tf.Tensor(2.9194356e-07, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(7.300079e-08, shape=(), dtype=float32)
epoch 29 finished.
tf.Tensor(7.300079e-08, shape=(), dtype=float32)
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

tf.Tensor(0.0, shape=(), dtype=float32)
epoch 30 finished.


SOFTMAX

30 epochs
- training time = 74 minutes
- running loss after 30 epochs = 0.153 (0.28 after first epoch)