In [1]:
from datetime import datetime

import os
from pathlib import Path
import subprocess
import warnings
import re
import pickle
import pandas as pd

import jax
import jax.numpy as jnp

from src.config.core import Config
from src.config.sampler import Sampler
from src.config.data import DatasetType
import src.dataset as ds
from src.models.tabular import FCN
import src.training.utils as train_utils
import src.inference.utils as inf_utils
import src.visualization as viz
from src.config.data import Task
from src.inference.evaluation import evaluate_bde

from matplotlib import pyplot as plt
import numpy as np

DIR = os.getcwd()
DIR

2025-06-17 21:20:52,968 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:20:52,970 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:20:53,371 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:20:53,372 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:20:53,373 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


'/home/bszh/MILE'

### Config template

In [2]:
CONFIG_DICT = {
    'saving_dir': 'results/',
    'experiment_name': 'bike',
    'data': {
        'path': 'data/bikesharing.data',
        'source': 'local',
        'data_type': 'tabular',
        'task': 'regr',
        'target_column': None,
        'target_len': 1,
        'features': None,
        'datapoint_limit': None,
        'normalize': True,
        'train_split': 0.7,
        'valid_split': 0.1,
        'test_split': 0.2,
    },
    'model': {
        'model': 'FCN',
        'hidden_structure': [16, 16, 16, 2],
        'activation': 'relu',
        'use_bias': True,
    },
    'training': {
        'warmstart': { # meaningless placeholder
            'include': False,
            'optimizer_config': {'name': "sgd", 'parameters': {}}
        },
        'sampler': {
            'name': 'sgld',
            'warmup_steps': 0,
            'n_chains': 4,
            'n_samples': 24000,  # total steps
            'batch_size': 512,
            'step_size_init': 2.0e-6,  # step_size_explore
            'n_thinning': 1,
            'keep_warmup': False,
            'optimizer_name': 'sgd',
            'prior_config': {
                'name': 'StandardNormal'
            },
            'scheduler_config': {
                'name': 'Cyclical',
                'n_samples_per_cycle': 200,
                'parameters': {
                    'n_cycles': 4,
                }
            }
        }
    },
    'rng': 1446,
    'logging': False,
}

In [3]:
def get_config(
        exp_name: str = 'bike',
        n_chains: int = 4,
        n_cycles: int = 4,
        n_steps_per_cycle: int = 2000,
        n_samples_per_cycle: int = 200,
        n_thinning: int = 1,
        optimizer_name: str = 'adam',
        scheduler_name: str = 'Cyclical',
        step_size_init: float = 2.0e-6,
        step_size_sampling: float | None = None,
        seed: int = 0
    ):
    n_samples = n_cycles * n_steps_per_cycle

    new_config_dict = CONFIG_DICT.copy()
    new_config_dict['experiment_name'] = exp_name
    new_config_dict['training']['sampler'] = {
        'name': 'sgld',
        'warmup_steps': 0,
        'keep_warmup': False,
        'n_chains': n_chains,
        'n_samples': n_samples,  # total steps
        'batch_size': 512,
        'step_size_init': step_size_init,  # step_size_explore
        'n_thinning': n_thinning,
        'optimizer_name': optimizer_name,
        'prior_config': {
            'name': 'StandardNormal'
        },
        'scheduler_config': {
            'name': scheduler_name,
            'n_samples_per_cycle': n_samples_per_cycle,
            'parameters': {
                'n_cycles': n_cycles,
                'step_size_sampling': step_size_sampling
            }
        }
    }
    new_config_dict['rng'] = seed

    # datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    new_config_dict['saving_dir'] = f'results/'
    config_path = Path('experiments/csgld') / f'{exp_name}.yaml'
    if not config_path.parent.exists():
        config_path.parent.mkdir(parents=True)
    Config.from_dict(new_config_dict).to_yaml(config_path)

    return config_path

### Constant Schedule

In [4]:
# parallel
chains_cycles = [2, 4, 8, 12]
config_paths = []
for i, n_chains in enumerate(chains_cycles):
    for seed in range(3):
        exp_name = f'bike/parallel_constant_{n_chains}_seed{seed}'
        config_path = get_config(
            exp_name=exp_name,
            n_chains=n_chains,
            n_cycles=1,
            n_steps_per_cycle=2500,
            n_samples_per_cycle=500,
            n_thinning=10,
            optimizer_name='adam',
            scheduler_name='Constant',
            step_size_init=0.01,
            step_size_sampling=1.0e-8,
            seed=seed
        )
        # print(f'Config saved to {config_path}')
        config_paths.append(config_path)

for config_path in config_paths:
    print("=" * 50)
    print(f'Running training for config: {config_path}')
    subprocess.run(['python', 'train.py', '-c', str(config_path), '-d', '12'])

Running training for config: experiments/csgld/bike/parallel_constant_2_seed0.yaml
2025-06-17 21:20:54,104 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:20:54,639 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:20:54,640 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:20:55,029 - __main__ - INFO - > Running experiment: bike/parallel_constant_2_seed0
2025-06-17 21:20:55,039 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:20:55,039 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:20:55,040 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:20:55,087 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:24<00:00, 101.74it/s]


Running training for config: experiments/csgld/bike/parallel_constant_2_seed1.yaml
2025-06-17 21:21:37,971 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:21:38,572 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:21:38,572 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:21:39,026 - __main__ - INFO - > Running experiment: bike/parallel_constant_2_seed1
2025-06-17 21:21:39,036 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:21:39,036 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:21:39,036 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:21:39,081 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:23<00:00, 104.27it/s]


Running training for config: experiments/csgld/bike/parallel_constant_2_seed2.yaml
2025-06-17 21:22:21,085 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:22:21,609 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:22:21,609 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:22:21,982 - __main__ - INFO - > Running experiment: bike/parallel_constant_2_seed2
2025-06-17 21:22:21,992 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:22:21,992 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:22:21,992 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:22:22,037 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:24<00:00, 103.84it/s]


Running training for config: experiments/csgld/bike/parallel_constant_4_seed0.yaml
2025-06-17 21:23:04,344 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:23:04,864 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:23:04,864 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:23:05,233 - __main__ - INFO - > Running experiment: bike/parallel_constant_4_seed0
2025-06-17 21:23:05,243 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:23:05,243 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:23:05,243 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:23:05,292 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:30<00:00, 83.09it/s]


Running training for config: experiments/csgld/bike/parallel_constant_4_seed1.yaml
2025-06-17 21:23:56,203 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:23:56,736 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:23:56,737 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:23:57,120 - __main__ - INFO - > Running experiment: bike/parallel_constant_4_seed1
2025-06-17 21:23:57,130 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:23:57,130 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:23:57,131 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:23:57,177 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:29<00:00, 83.70it/s]


Running training for config: experiments/csgld/bike/parallel_constant_4_seed2.yaml
2025-06-17 21:24:47,630 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:24:48,181 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:24:48,182 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:24:48,558 - __main__ - INFO - > Running experiment: bike/parallel_constant_4_seed2
2025-06-17 21:24:48,568 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:24:48,568 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:24:48,568 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:24:48,613 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:29<00:00, 83.56it/s]


Running training for config: experiments/csgld/bike/parallel_constant_8_seed0.yaml
2025-06-17 21:25:38,966 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:25:39,504 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:25:39,505 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:25:39,871 - __main__ - INFO - > Running experiment: bike/parallel_constant_8_seed0
2025-06-17 21:25:39,880 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:25:39,881 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:25:39,881 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:25:39,925 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:40<00:00, 62.24it/s]


Running training for config: experiments/csgld/bike/parallel_constant_8_seed1.yaml
2025-06-17 21:26:44,724 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:26:45,247 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:26:45,248 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:26:45,633 - __main__ - INFO - > Running experiment: bike/parallel_constant_8_seed1
2025-06-17 21:26:45,642 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:26:45,642 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:26:45,643 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:26:45,688 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:40<00:00, 62.05it/s]


Running training for config: experiments/csgld/bike/parallel_constant_8_seed2.yaml
2025-06-17 21:27:50,426 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:27:50,958 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:27:50,959 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:27:51,346 - __main__ - INFO - > Running experiment: bike/parallel_constant_8_seed2
2025-06-17 21:27:51,357 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:27:51,357 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:27:51,358 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:27:51,409 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:40<00:00, 62.22it/s]


Running training for config: experiments/csgld/bike/parallel_constant_12_seed0.yaml
2025-06-17 21:28:56,272 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:28:56,792 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:28:56,793 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:28:57,159 - __main__ - INFO - > Running experiment: bike/parallel_constant_12_seed0
2025-06-17 21:28:57,168 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:28:57,168 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:28:57,169 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:28:57,215 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:50<00:00, 49.28it/s]


Running training for config: experiments/csgld/bike/parallel_constant_12_seed1.yaml
2025-06-17 21:30:40,598 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:30:41,131 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:30:41,132 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:30:41,512 - __main__ - INFO - > Running experiment: bike/parallel_constant_12_seed1
2025-06-17 21:30:41,522 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:30:41,522 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:30:41,523 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:30:41,567 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:26<00:00, 93.14it/s]


Running training for config: experiments/csgld/bike/parallel_constant_12_seed2.yaml
2025-06-17 21:31:36,971 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:31:37,492 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:31:37,492 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:31:37,863 - __main__ - INFO - > Running experiment: bike/parallel_constant_12_seed2
2025-06-17 21:31:37,873 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:31:37,873 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:31:37,874 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:31:37,916 - src.training.trainer - INFO - > Setting up directories...


Sampling: 100%|██████████| 2500/2500 [00:50<00:00, 49.93it/s]


In [5]:
# sequential
chains_cycles = [2, 4, 8, 12]
config_paths = []
for i, n_cycles in enumerate(chains_cycles):
    for seed in range(3):
        exp_name = f'bike/sequential_constant_{n_cycles}_seed{seed}'
        config_path = get_config(
            exp_name=exp_name,
            n_chains=1,
            n_cycles=n_cycles,
            n_steps_per_cycle=2500,
            n_samples_per_cycle=500,
            n_thinning=10,
            optimizer_name='adam',
            scheduler_name='Constant',
            step_size_init=0.01,
            step_size_sampling=1.0e-8,
            seed=seed
        )
        # print(f'Config saved to {config_path}')
        config_paths.append(config_path)

for config_path in config_paths:
    print("=" * 50)
    print(f'Running training for config: {config_path}')
    subprocess.run(['python', 'train.py', '-c', str(config_path), '-d', '12'])

Running training for config: experiments/csgld/bike/sequential_constant_2_seed0.yaml
2025-06-17 21:58:09,852 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:58:10,464 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:58:10,464 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:58:11,110 - __main__ - INFO - > Running experiment: bike/sequential_constant_2_seed0
2025-06-17 21:58:11,120 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:58:11,120 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:58:11,121 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:58:11,172 - src.training.trainer - INFO - > Setting up directories...


Sampling:   0%|          | 0/5000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/bszh/MILE/train.py", line 133, in <module>
    train_bde(cfg, n_devices)
  File "/home/bszh/MILE/train.py", line 19, in train_bde
    trainer.train_bde()
  File "/home/bszh/MILE/src/training/trainer.py", line 177, in train_bde
    self.start_sampling()
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/bszh/MILE/src/training/trainer.py", line 599, in start_sampling
    inference_loop_batch(
  File "/home/bszh/MILE/src/training/sampling_batch.py", line 266, in inference_loop_batch
    _log_metrics(step_count, state, batch)
  File "/home/bszh/MILE/src/training/sampling_batch.py", line 217, in _log_metrics
    curr_logpost = logpost(state.params, X, y)
                   ^^^^^^^
NameError: cannot access free variable 'logpost' where it is not associated with a value in enclosing scope


Running training for config: experiments/csgld/bike/sequential_constant_2_seed1.yaml
2025-06-17 21:58:15,172 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:58:15,701 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:58:15,702 - datasets - INFO - JAX version 0.4.28 available.
2025-06-17 21:58:16,116 - __main__ - INFO - > Running experiment: bike/sequential_constant_2_seed1
2025-06-17 21:58:16,125 - jax._src.xla_bridge - INFO - Unable to initialize backend 'cuda': 
2025-06-17 21:58:16,126 - jax._src.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2025-06-17 21:58:16,126 - jax._src.xla_bridge - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-06-17 21:58:16,169 - src.training.trainer - INFO - > Setting up directories...


Sampling:   0%|          | 0/5000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/bszh/MILE/train.py", line 133, in <module>
    train_bde(cfg, n_devices)
  File "/home/bszh/MILE/train.py", line 19, in train_bde
    trainer.train_bde()
  File "/home/bszh/MILE/src/training/trainer.py", line 177, in train_bde
    self.start_sampling()
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/bszh/MILE/src/training/trainer.py", line 599, in start_sampling
    inference_loop_batch(
  File "/home/bszh/MILE/src/training/sampling_batch.py", line 266, in inference_loop_batch
    _log_metrics(step_count, state, batch)
  File "/home/bszh/MILE/src/training/sampling_batch.py", line 217, in _log_metrics
    curr_logpost = logpost(state.params, X, y)
                   ^^^^^^^
NameError: cannot access free variable 'logpost' where it is not associated with a value in enclosing scope


Running training for config: experiments/csgld/bike/sequential_constant_2_seed2.yaml
2025-06-17 21:58:20,163 - __main__ - INFO - Loaded 1 Experiment(s)
2025-06-17 21:58:20,695 - datasets - INFO - PyTorch version 2.2.2+cpu available.
2025-06-17 21:58:20,696 - datasets - INFO - JAX version 0.4.28 available.


Traceback (most recent call last):
  File "/home/bszh/MILE/train.py", line 133, in <module>
    train_bde(cfg, n_devices)
  File "/home/bszh/MILE/train.py", line 13, in train_bde
    from src.training.trainer import BDETrainer  # noqa
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bszh/MILE/src/training/trainer.py", line 29, in <module>
    from src.inference.reporting import generate_html_report
  File "/home/bszh/MILE/src/inference/reporting.py", line 7, in <module>
    from nbconvert import HTMLExporter
  File "/home/bszh/MILE/venv/lib/python3.12/site-packages/nbconvert/__init__.py", line 6, in <module>
    from . import filters, postprocessors, preprocessors, writers
  File "/home/bszh/MILE/venv/lib/python3.12/site-packages/nbconvert/filters/__init__.py", line 8, in <module>
    from .markdown import (
  File "/home/bszh/MILE/venv/lib/python3.12/site-packages/nbconvert/filters/markdown.py", line 12, in <module>
    from .markdown_mistune import markdown2html_mistune


KeyboardInterrupt: 