In [2]:
import os, sys

sys.path.append(os.path.abspath(os.path.join("../..")))  # access sibling directories
sys.path.append(os.path.abspath(os.path.join("../../../BayesFlow_dev/BayesFlow/")))

from src.python.settings import summary_meta_diffusion, probability_meta_diffusion
from src.python.training import setup_network
import bayesflow as bf
import numpy as np
from functools import partial
import pickle

# Pre-training with 100 trials per person

## Load in data

In [3]:
# Training data
train_path = os.path.abspath(os.path.join("../../data/03_levy_flight_application/simulated_data/pretrain.pkl"))
with open(train_path, "rb") as file:
    train_data = pickle.load(file)

In [4]:
# Validation data
val_path = os.path.abspath(os.path.join("../../data/03_levy_flight_application/simulated_data/validate.pkl"))
with open(val_path, "rb") as file:
    val_data = pickle.load(file)

## Training

In [5]:
N_EPOCHS = 20
LEARNING_RATE = 0.0005
BATCH_SIZE = 32

In [6]:
summary_net, probability_net, amortizer = setup_network(
    summary_net_settings=summary_meta_diffusion,
    inference_net_settings=probability_meta_diffusion,
    loss_fun=partial(bf.losses.log_loss, label_smoothing=None)
)

In [7]:
checkpoint_path = "checkpoints/pretrain"
trainer = bf.trainers.Trainer(
    amortizer=amortizer,
    checkpoint_path=checkpoint_path,
    default_lr=LEARNING_RATE
)

INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!
INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.


In [8]:
losses = trainer.train_offline(
    simulations_dict=train_data,
    epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    validation_sims=val_data,
    **{"sim_dataset_args": {"batch_on_cpu": True}}
)

Training epoch 1:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 1, Loss: 1.175


Training epoch 2:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 2, Loss: 0.917


Training epoch 3:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 3, Loss: 0.650


Training epoch 4:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 4, Loss: 1.073


Training epoch 5:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 5, Loss: 0.545


Training epoch 6:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 6, Loss: 0.593


Training epoch 7:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 7, Loss: 0.741


Training epoch 8:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 8, Loss: 0.462


Training epoch 9:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 9, Loss: 0.412


Training epoch 10:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 10, Loss: 0.340


Training epoch 11:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 11, Loss: 0.372


Training epoch 12:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 12, Loss: 0.490


Training epoch 13:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 13, Loss: 0.341


Training epoch 14:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 14, Loss: 0.230


Training epoch 15:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 15, Loss: 0.428


Training epoch 16:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 16, Loss: 0.264


Training epoch 17:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 17, Loss: 0.324


Training epoch 18:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 18, Loss: 0.342


Training epoch 19:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 19, Loss: 0.313


Training epoch 20:   0%|          | 0/1250 [00:00<?, ?it/s]

INFO:root:Validation, Epoch: 20, Loss: 0.310
