In [1]:
import time
import logging

import numpy as np
import pandas as pd

from src.mcmc import RandomWalk, Barker, SMBarker, MALA, SMMALA, MMALA
from src.distribution import Rosenbrock, SmoothGeneralNormal
from src.metrics import ks_distance, ad_distance, ess

In [15]:
def run_experiment(mcmc, target, X_true, path, n_trial=10):
    stats = pd.DataFrame(
        columns=[
            "accept_rate",  # actually mean acceptance probability
            "ESS_min",
            "ESS_median",
            "ESS_max",
            "time",
            "time/ESS_min",
            "log_KS_max",
            "log_AD_max",
        ]
    )

    i = 0
    while i < n_trial:
        try:
            start_time = time.process_time_ns()

            res = mcmc.sample(
                target, n_main_iter=1000000, adapter_method="batch", seed=i
            )

            end_time = time.process_time_ns()

            np.save(path + str(i), res["trace_main"])

            accept_rate = sum(res["accept_prob_main"]) / len(res["accept_prob_main"])

            res_ess = ess(X_mcmc=res["trace_main"], target=target)

            cpu_time = end_time - start_time

            ks = ks_distance(X_mcmc=res["trace_main"], target=target, X_true=X_true)

            ad = ad_distance(X_mcmc=res["trace_main"], target=target, X_true=X_true)

            stats.loc[i] = [
                accept_rate,
                min(res_ess),
                np.median(res_ess),
                cpu_time,
                cpu_time / min(res_ess),
                max(res_ess),
                np.log(max(ks)),
                np.log(max(ad)),
            ]

            i += 1
        except Exception as e:
            logging.warning(f"Attempt {i} failed: {e}")

    stats.to_pickle(path + "summary" + ".pkl")

    return stats

In [16]:
target = Rosenbrock(n_var=3, n1=2, n2=2)

X_true = np.load("data/rosenbrock_3d/X_true.npy")

filepath = "data/rosenbrock_3d/"

samplers = [RandomWalk(), Barker(), MALA(), SMBarker(), SMMALA()]
name = ["rwm", "barker", "mala", "smbarker", "smmala"]

for i in range(len(samplers)):
    df = run_experiment(samplers[i], target, X_true, path=filepath + name[i] + "/")
    print(df)

  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)


   accept_rate    ESS_min  ESS_median       ESS_max          time  \
0     0.209306   5.891009    6.365634  9.354872e+09  1.587991e+09   
1     0.112613   9.665748   10.309302  9.111134e+09  9.426207e+08   
2     0.188683   3.333311    3.345312  9.142828e+09  2.742866e+09   
3     0.147398   7.576435    7.587720  9.063187e+09  1.196234e+09   
4     0.169875  11.629413   11.638541  9.234743e+09  7.940850e+08   
5     0.398450  15.873026   26.683346  9.623278e+09  6.062661e+08   
6     0.086206  10.444130   10.560062  9.075453e+09  8.689525e+08   
7     0.074642  16.936259   23.720885  9.047149e+09  5.341882e+08   
8     0.267737   8.520590    8.547070  9.180096e+09  1.077401e+09   
9     0.174708   4.813462   10.054049  9.113526e+09  1.893341e+09   

   time/ESS_min  log_KS_max  log_AD_max  
0      6.369132   -0.842515   13.058288  
1     10.322892   -1.669819   11.312935  
2      6.136353   -1.883668   11.078097  
3      7.765622   -1.345685   12.191766  
4     11.800416   -2.015997   

  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)


   accept_rate    ESS_min  ESS_median       ESS_max          time  \
0     0.733787   1.952591    2.199987  5.132604e+10  2.628612e+10   
1     0.279544   9.138968    9.150064  5.039313e+10  5.514094e+09   
2     0.388252   5.951852   28.998060  5.059234e+10  8.500268e+09   
3     0.225758   1.711568    1.994682  5.034112e+10  2.941228e+10   
4     0.476492  24.989234   25.220937  5.136691e+10  2.055562e+09   
5     0.281007   2.116172    2.503385  5.099105e+10  2.409589e+10   
6     0.203253   2.307758    2.309702  5.099701e+10  2.209808e+10   
7     0.295040   5.852296    5.861237  5.059475e+10  8.645282e+09   
8     0.322735  23.203894   40.081286  5.064304e+10  2.182524e+09   
9     0.333249   3.883420    7.807319  5.049913e+10  1.300378e+10   

   time/ESS_min  log_KS_max  log_AD_max  
0      2.203549   -1.678200   11.714202  
1      9.278637   -1.691743   11.522975  
2     29.252523   -1.954684   10.443775  
3      1.994930   -1.073347   12.424966  
4     28.819133   -2.169589   

  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)
  ad.append(anderson_ksamp([X_mcmc[i, :], X_true[i, :]]).statistic)


KeyboardInterrupt: 