In [1]:
import os, sys
sys.path.append(os.path.abspath(os.path.join('../..'))) # access sibling directories
sys.path.append("C:\\Users\\lasse\\Documents\\GitHub\\BayesFlow")

from src.python.settings import summary_meta_diffusion, evidence_meta_diffusion
from src.python.helpers import load_simulated_rt_data, mask_inputs
from src.python.networks import HierarchicalInvariantNetwork, EvidentialNetwork

import numpy as np
from tensorflow.keras.experimental import CosineDecayRestarts
from tensorflow.keras.optimizers import Adam
from functools import partial

from bayesflow.trainers import ModelComparisonTrainer
from bayesflow.amortizers import MultiModelAmortizer 
from bayesflow.losses import log_loss

# Fine-tuning with 900 trials per person

## Load in data

In [2]:
levy_sims_folder = "c:\\Users\\lasse\\documents\\hierarchical model comparison project\\data\\03_levy_flight_application\\uniform_alpha_prior"

indices_900_filename = "train_indices_900_trials.npy"
datasets_900_filename = "train_datasets_900_trials.npy"

indices_900, datasets_900 = load_simulated_rt_data(levy_sims_folder, indices_900_filename, datasets_900_filename)

## Training

When only conducting fine-tuning: manually move/delete fine-tuning checkpoints in checkpoints folder so that training resumes from pretrained network.

In [3]:
summary_net = HierarchicalInvariantNetwork(summary_meta_diffusion)
evidence_net = EvidentialNetwork(evidence_meta_diffusion)
amortizer = MultiModelAmortizer(evidence_net, summary_net)

# Cosine Decay with Restarts
initial_lr = 0.00005 # Shrink LR by factor 10 for fine-tuning
first_decay_steps = 250
t_mul = 2
m_mul = 0.8
alpha = 0.2
lr_schedule_restart = CosineDecayRestarts(
    initial_lr, first_decay_steps, t_mul=t_mul, m_mul=m_mul, alpha=alpha)

# Checkpoint path for loading pretrained network and saving the final network
checkpoint_path = "c:\\Users\\lasse\\documents\\hierarchical model comparison project\\checkpoints\\03_levy_flight_application\\uniform_alpha_prior\\fine-tuned_net"

trainer = ModelComparisonTrainer(
    network=amortizer, 
    loss=partial(log_loss, kl_weight=0.25),
    optimizer=partial(Adam, lr_schedule_restart),
    checkpoint_path=checkpoint_path,
    skip_checks=True
    )

TRAINER INITIALIZATION: No generative model provided. Only offline learning mode is available!
Networks loaded from c:\Users\lasse\documents\hierarchical model comparison project\checkpoints\03_levy_flight_application\uniform_alpha_prior\fine-tuned_net\ckpt-64


In [4]:
# Mask some training data so that training leads to a robust net that can handle missing data
epochs = 32

for epoch in range(epochs):
    datasets_900_masked = mask_inputs(datasets_900, missings_mean=28.5, missings_sd=13.5, missing_rts_equal_mean=True)
    losses = trainer.train_offline(epochs=1, batch_size=32, 
                               model_indices=indices_900, sim_data=datasets_900_masked)
    print(f"epoch {epoch} finished.")

Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 0 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 1 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 2 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 3 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 4 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 5 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 6 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 7 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 8 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 9 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 10 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 11 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 12 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 13 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 14 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 15 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 16 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 17 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 18 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 19 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 20 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 21 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 22 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 23 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 24 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 25 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 26 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 27 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 28 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 29 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 30 finished.
Converting 8000 simulations to a TensorFlow data set...


Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

epoch 31 finished.


results for run with truncnormal alpha prior
- training time = 79 minutes
- running loss after 32 epochs = 0.371

results for run with uniform alpha prior
- training time = 81 minutes
- running loss after 32 epochs = 0.184