In [None]:
import sys
sys.path.append("../")
sys.path.append("../../../assets")

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import bayesflow as beef
import pandas as pd

from experiments import NonStationaryDDMExperiment
from models import MixtureRandomWalkDDM, LevyFlightDDM, RegimeSwitchingDDM

In [None]:
# gpu setting and checking
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
print(tf.config.list_physical_devices('GPU'))

# Constants

In [None]:
TRAIN_MIXTURE_RANDOM_WALK_DDM = False
TRAIN_LEVY_FLIGHT_DDM = True
TRAIN_REGIME_SWITCHING_DDM = True

In [None]:
mixture_random_walk_model = MixtureRandomWalkDDM()
mixture_random_walk_experiment = NonStationaryDDMExperiment(
    mixture_random_walk_model,
    checkpoint_path="../checkpoints/mixture_random_walk_ddm"
    )

levy_flight_model = LevyFlightDDM()
levy_flight_experiment = NonStationaryDDMExperiment(
    levy_flight_model,
    checkpoint_path="../checkpoints/levy_flight_ddm"
    )

regime_switching_model = RegimeSwitchingDDM()
regime_switching_experiment = NonStationaryDDMExperiment(
    regime_switching_model,
    checkpoint_path="../checkpoints/regime_switching_ddm"
    )

# Training

In [None]:
%%time
if TRAIN_MIXTURE_RANDOM_WALK_DDM:
    mixture_random_walk_history = mixture_random_walk_experiment.run(
        epochs=75, 
        iterations_per_epoch=1000, 
        batch_size=16
    )
else:
    mixture_random_walk_history = mixture_random_walk_experiment.trainer.loss_history.get_plottable()

In [None]:
f = beef.diagnostics.plot_losses(mixture_random_walk_history)

In [None]:
%%time
if TRAIN_LEVY_FLIGHT_DDM:
    levy_flight_history = levy_flight_experiment.run(
        epochs=75, 
        iterations_per_epoch=1000, 
        batch_size=16
    )
else:
    levy_flight_history = levy_flight_experiment.trainer.loss_history.get_plottable()

In [None]:
f = beef.diagnostics.plot_losses(levy_flight_history)

In [None]:
%%time
if TRAIN_REGIME_SWITCHING_DDM:
    regime_switching_history = regime_switching_experiment.run(
        epochs=75, 
        iterations_per_epoch=1000, 
        batch_size=16
    )
else:
    regime_switching_history = regime_switching_experiment.trainer.loss_history.get_plottable()

In [None]:
f = beef.diagnostics.plot_losses(regime_switching_history)