In [1]:
import os, sys

sys.path.append(os.path.abspath(os.path.join("../..")))  # access sibling directories
sys.path.append(os.path.abspath(os.path.join("../../../BayesFlow_dev/BayesFlow/")))

from src.python.helpers import MaskingConfigurator
from src.python.settings import summary_meta_diffusion, probability_meta_diffusion
from src.python.training import setup_network

import bayesflow as bf
import numpy as np
from functools import partial
import pickle
import shutil

  from tqdm.autonotebook import tqdm


# Fine-tuning with 900 trials per person

## Load in data

In [2]:
# Training data
train_path = os.path.abspath(os.path.join("../../data/03_levy_flight_application/simulated_data/finetune.pkl"))
with open(train_path, "rb") as file:
    train_data = pickle.load(file)

In [3]:
# Validation data
val_path = os.path.abspath(os.path.join("../../data/03_levy_flight_application/simulated_data/validate.pkl"))
with open(val_path, "rb") as file:
    val_data = pickle.load(file)

## Training

In [5]:
N_EPOCHS = 30
LEARNING_RATE = 0.00005
BATCH_SIZE = 32

In [6]:
summary_net, probability_net, amortizer = setup_network(
    summary_net_settings=summary_meta_diffusion,
    inference_net_settings=probability_meta_diffusion,
    loss_fun=partial(bf.losses.log_loss, label_smoothing=None)
)

In [8]:
# Copy pretrained checkpoints to continue finetuning in new folder
shutil.copytree("checkpoints/pretrain", "checkpoints/finetune")

masking_configurator = MaskingConfigurator()
checkpoint_path = "checkpoints/finetune"
trainer = bf.trainers.Trainer(
    amortizer=amortizer,
    configurator=masking_configurator,
    checkpoint_path=checkpoint_path,
    default_lr=LEARNING_RATE
)

INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!
INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.


In [9]:
losses = trainer.train_offline(
    simulations_dict=train_data,
    epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    validation_sims=val_data,
    **{"sim_dataset_args": {"batch_on_cpu": True}}
)

Training epoch 1:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 1, Loss: 0.303


Training epoch 2:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 2, Loss: 0.267


Training epoch 3:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 3, Loss: 0.297


Training epoch 4:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 4, Loss: 0.368


Training epoch 5:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 5, Loss: 0.216


Training epoch 6:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 6, Loss: 0.204


Training epoch 7:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 7, Loss: 0.258


Training epoch 8:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 8, Loss: 0.284


Training epoch 9:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 9, Loss: 0.191


Training epoch 10:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 10, Loss: 0.227


Training epoch 11:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 11, Loss: 0.215


Training epoch 12:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 12, Loss: 0.231


Training epoch 13:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 13, Loss: 0.219


Training epoch 14:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 14, Loss: 0.259


Training epoch 15:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 15, Loss: 0.234


Training epoch 16:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 16, Loss: 0.241


Training epoch 17:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 17, Loss: 0.231


Training epoch 18:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 18, Loss: 0.198


Training epoch 19:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 19, Loss: 0.288


Training epoch 20:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 20, Loss: 0.235


Training epoch 21:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 21, Loss: 0.198


Training epoch 22:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 22, Loss: 0.205


Training epoch 23:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 23, Loss: 0.184


Training epoch 24:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 24, Loss: 0.234


Training epoch 25:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 25, Loss: 0.222


Training epoch 26:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 26, Loss: 0.226


Training epoch 27:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 27, Loss: 0.249


Training epoch 28:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 28, Loss: 0.228


Training epoch 29:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 29, Loss: 0.266


Training epoch 30:   0%|          | 0/250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 30, Loss: 0.261
