In [1]:
import os
os.chdir('..')

import numpy as np
import json
from systems import LotkaVolterra
from lightning_module import PreTrainLightning
from models import TSMVAE

from inference_utils import load_pretrained_model

In [2]:
model = load_pretrained_model(
    model_class=TSMVAE,
    lightning_class=PreTrainLightning,
    checkpoint_substr="TSMVAE",
    in_chans=2,
    folder_name='/home/jp4474/latent-abc-smc/lotka_d64_ed32_6_4_4_4_ae_mask_0.15_noise_0.0',)

Successfully loaded model


In [3]:
model.to('cuda')

PreTrainLightning(
  (model): TSMVAE(
    (embedder): Linear(in_features=2, out_features=64, bias=True)
    (blocks): ModuleList(
      (0-5): 6 x Block(
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=64, out_features=192, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_features=256, bias=True)
          (act): GELU(approximate='none')
          (drop1): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (fc2): Linear(in_features=256, out_features=64, bias=True)
          (drop2): Dropout

In [None]:
lotka_abc = LotkaVolterra(model=model,)

INFO:viaABC:Initializing ViaABC class
INFO:viaABC:Model updated
INFO:viaABC:Initialization complete
INFO:viaABC:LatentABCSMC class initialized with the following parameters:
INFO:viaABC:num_parameters: 2
INFO:viaABC:Mu: [0 0]
INFO:viaABC:Sigma: [10 10]
INFO:viaABC:t0: 0
INFO:viaABC:tmax: 15
INFO:viaABC:time_space: [ 1.1  2.4  3.9  5.6  7.5  9.6 11.9 14.4]
INFO:viaABC:pooling_method: no_cls
INFO:viaABC:metric: pairwise_cosine


In [6]:
lotka_abc.run(num_particles=1000, k=10)

INFO:viaABC:Starting ABC PMC run with Q Threshold: 0.99
INFO:viaABC:Initialization (generation 0) started
INFO:viaABC:Initialization completed in 36.11 seconds
INFO:viaABC:Mean: [3.2584414  1.97469638]
INFO:viaABC:Median: [1.73838092 1.66731045]
INFO:viaABC:Variance: [6.46374809 1.62934223]
INFO:viaABC:Generation 1 started
100%|██████████| 1000/1000 [02:46<00:00,  5.99it/s]
INFO:viaABC:ABC-SMC: Epsilon : 0.11074
INFO:viaABC:ABC-SMC: Quantile : 0.87116
INFO:viaABC:ABC-SMC: Simulations : 4286
INFO:viaABC:Mean: [3.27587416 1.89151777]
INFO:viaABC:Median: [1.79121897 1.59452189]
INFO:viaABC:Variance: [6.11477868 1.23286953]
INFO:viaABC:Generation 1 completed in 202.09 seconds
INFO:viaABC:Generation 2 started
100%|██████████| 1000/1000 [02:58<00:00,  5.61it/s]
INFO:viaABC:ABC-SMC: Epsilon : 0.10846
INFO:viaABC:ABC-SMC: Quantile : 0.76539
INFO:viaABC:ABC-SMC: Simulations : 4551
INFO:viaABC:Mean: [2.91435163 1.82857775]
INFO:viaABC:Median: [1.57745099 1.6091929 ]
INFO:viaABC:Variance: [5.7404

KeyboardInterrupt: 