In [5]:
from datetime import datetime

import os
from pathlib import Path
import subprocess
import warnings
import re
import pickle
import pandas as pd
import functools

import jax
import jax.numpy as jnp

from src.config.core import Config
from src.config.sampler import Sampler
from src.config.data import DatasetType
import src.dataset as ds
from src.models.tabular import FCN
import src.training.utils as train_utils
import src.inference.utils as inf_utils
import src.visualization as viz
from src.config.data import Task
from src.inference.evaluation import evaluate_bde

from matplotlib import pyplot as plt
import numpy as np

In [6]:
DIR = 'results/bike0'

np.sort(os.listdir(DIR))

array(['parallel_constant_10_seed0', 'parallel_constant_10_seed1',
       'parallel_constant_10_seed2', 'parallel_constant_10_seed3',
       'parallel_constant_10_seed4', 'parallel_constant_12_seed0',
       'parallel_constant_12_seed1', 'parallel_constant_12_seed2',
       'parallel_constant_12_seed3', 'parallel_constant_12_seed4',
       'parallel_constant_2_seed0', 'parallel_constant_2_seed1',
       'parallel_constant_2_seed2', 'parallel_constant_2_seed3',
       'parallel_constant_2_seed4', 'parallel_constant_4_seed0',
       'parallel_constant_4_seed1', 'parallel_constant_4_seed2',
       'parallel_constant_4_seed3', 'parallel_constant_4_seed4',
       'parallel_constant_6_seed0', 'parallel_constant_6_seed1',
       'parallel_constant_6_seed2', 'parallel_constant_6_seed3',
       'parallel_constant_6_seed4', 'parallel_constant_8_seed0',
       'parallel_constant_8_seed1', 'parallel_constant_8_seed2',
       'parallel_constant_8_seed3', 'parallel_constant_8_seed4',
       'paralle

In [7]:
# def evaluate_bde_from_file(path: Path):
def evaluate_bde_from_file(path: Path):
    """
    Evaluate the Bayesian Deep Ensemble (BDE) from a given path.
    """
    metrics = {}
    tree_path = path / 'tree'
    sample_path = path / 'samples'

    with open(sample_path / 'info.pkl', 'rb') as f:
        info = pickle.load(f)
    metrics.update({'total_time': info['total_time']})

    config = Config.from_yaml(path / 'config.yaml')
    samples = train_utils.load_samples_from_dir(sample_path, tree_path=tree_path)
    n_samples = inf_utils.count_samples(samples)
    n_chains = inf_utils.count_chains(samples)
    loader = ds.TabularLoader(
        config.data,
        rng=config.jax_rng,
        target_len=config.data.target_len
    )
    random_input = next(loader.iter('train', 1))['feature']
    module = config.get_flax_model()

    features = loader.test_x # (B x F)
    labels = loader.test_y # (B x T)
    print("Test Set: Feature and Label have shapes: ", features.shape, labels.shape)

    mem_cap = 8 * 1024 ** 3
    mem_usage = (
        inf_utils.get_mem_size(samples) / 10
        + inf_utils.get_mem_size(features)
        + inf_utils.get_mem_size(labels)
    )
    overhead_unit = mem_usage * 10 # estimatation

    batch_size = max(1, (mem_cap - overhead_unit) // mem_usage)
    if batch_size > labels.shape[0]:
        batch_size = None
    print(f'> Batch size for evaluation: {batch_size}\n')

    logits, metrics = evaluate_bde(
        params=samples,
        module=module,
        features=features,
        labels=labels,
        task=config.data.task,
        batch_size=batch_size,
        verbose=True,
        metrics_dict=metrics,
        nominal_coverages=[0.5, 0.75, 0.9, 0.95]
    )
    return logits, metrics, info

### Constant Schedule

In [8]:
from tqdm import tqdm

pattern_sequential_constant = r"sequential_constant_(\d+)_seed(\d+)"
for folder in tqdm(np.sort(os.listdir(DIR))):
    match = re.search(pattern_sequential_constant, folder)
    if not match:
        continue
    
    # type = match.group(1)
    # schedule = match.group(2)
    n_cycles = int(match.group(1))
    seed = int(match.group(2))
    path = Path(DIR) / folder
    print("=" * 50)
    print(path)

    if (path / 'eval_metrics.pkl').exists() and (path / 'eval_logits.pkl').exists():
        print("Evaluation already done, skipping...")
        continue

    # if type == 'sequential' and schedule == 'cyclical' and n_cycles < 10:
    #     print("Skipping evaluation for sequential cyclical with less than 10 cycles.")
    #     continue
    
    logits, metrics, info = evaluate_bde_from_file(path)
    with open(path / 'eval_metrics.pkl', 'wb') as f:
        pickle.dump(metrics, f)
    with open(path / 'eval_logits.pkl', 'wb') as f:
        pickle.dump(logits, f)

100%|██████████| 106/106 [00:00<00:00, 11811.49it/s]

results/bike0/sequential_constant_10_seed0
Evaluation already done, skipping...
results/bike0/sequential_constant_10_seed1
Evaluation already done, skipping...
results/bike0/sequential_constant_10_seed2
Evaluation already done, skipping...
results/bike0/sequential_constant_10_seed3
Evaluation already done, skipping...
results/bike0/sequential_constant_10_seed4
Evaluation already done, skipping...
results/bike0/sequential_constant_12_seed0
Evaluation already done, skipping...
results/bike0/sequential_constant_12_seed1
Evaluation already done, skipping...
results/bike0/sequential_constant_12_seed2
Evaluation already done, skipping...
results/bike0/sequential_constant_12_seed3
Evaluation already done, skipping...
results/bike0/sequential_constant_12_seed4
Evaluation already done, skipping...
results/bike0/sequential_constant_2_seed0
Evaluation already done, skipping...
results/bike0/sequential_constant_2_seed1
Evaluation already done, skipping...
results/bike0/sequential_constant_2_seed2



