In [1]:
import tqdm
from typing import Callable, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import Distribution as torchDist

from distributions import SamplableDistribution, GaussianMixture
from samplers.mala_ex2mcmc import mala as mala_old
from samplers.mala_modified import mala as mala_new
from samplers.fisher_mala import fisher_mala

from tools.benchmark import BenchmarkUtils, Benchmark

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [16]:
seed = 123

torch.manual_seed(seed)

mass_points_count = 200
true_means = torch.rand((mass_points_count, mass_points_count)) * 10 - 5
true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)

sample_count = 5000
gm = GaussianMixture(true_means, true_covs,
                     torch.full((mass_points_count,), 1/mass_points_count, dtype=torch.float64))

starting_points = true_means
target_dist = gm

true_samples = gm.sample(sample_count)

In [17]:
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in_prec=4500,
                                          burn_in_sigma=500,
                                          sigma_init=1e-2, damping=10, keep_graph=False)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
#BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

KeyboardInterrupt: 

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.0731672,
 'tv_mean': Array(0.6034001, dtype=float32),
 'tv_conf_sigma': Array(0.00240188, dtype=float32),
 'wasserstein': 1707.7763328802475}

In [None]:
mcmc_samples = BenchmarkUtils.sample_mcmc(mala_new, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=5000,
                                          keep_graph=False,
                                          sigma_init=1e-2)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "trues dist")
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
# torch.autograd.set_detect_anomaly(True)

mcmc_samples = BenchmarkUtils.sample_mcmc(mala_old, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=5000,
                                          step_size=1e-2, keep_graph=True)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
#BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [18]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.010139392,
 'tv_mean': Array(0.64891124, dtype=float32),
 'tv_conf_sigma': Array(0.00244415, dtype=float32),
 'wasserstein': 1750.0380741119382}

In [16]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.010078529,
 'tv_mean': Array(0.64356774, dtype=float32),
 'tv_conf_sigma': Array(0.00246596, dtype=float32),
 'wasserstein': 1746.3172492352303}

In [3]:
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in_prec=4500,
                                          burn_in_sigma=500,
                                          sigma_init=1e-2, damping=10, keep_graph=False)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
#BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [4]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.07315685,
 'tv_mean': Array(0.6034165, dtype=float32),
 'tv_conf_sigma': Array(0.00240202, dtype=float32),
 'wasserstein': 1707.7999110211185}

In [6]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.011537474,
 'tv_mean': Array(0.728531, dtype=float32),
 'tv_conf_sigma': Array(0.00190675, dtype=float32),
 'wasserstein': 1729.3248153472905}

In [4]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.012827393,
 'tv_mean': Array(0.7115489, dtype=float32),
 'tv_conf_sigma': Array(0.00197668, dtype=float32),
 'wasserstein': 1728.1137651416018}

In [20]:
mcmc_samples[0]

tensor([[[-2.0497,  0.1764, -2.5829,  ...,  3.3046,  4.0692, -1.0022],
         [ 3.8289, -3.9299,  0.4494,  ..., -3.1968,  4.2702,  1.1870],
         [-1.3687,  0.3629,  1.6260,  ...,  0.0060,  3.4639,  1.7519],
         ...,
         [-3.1083,  3.6177,  2.0024,  ...,  2.4079,  0.9273, -2.4880],
         [-1.2598,  1.5646,  2.9307,  ...,  0.1115,  4.3825, -0.7696],
         [-4.4671,  0.8064,  1.8228,  ...,  2.4949, -2.3659,  1.5476]],

        [[-2.0431,  0.1697, -2.5927,  ...,  3.2980,  4.0705, -1.0041],
         [ 3.8175, -3.9269,  0.4373,  ..., -3.2005,  4.2864,  1.1818],
         [-1.3660,  0.3661,  1.6313,  ...,  0.0056,  3.4647,  1.7506],
         ...,
         [-3.1087,  3.6184,  1.9941,  ...,  2.3957,  0.9219, -2.5032],
         [-1.2584,  1.5583,  2.9274,  ...,  0.1129,  4.3873, -0.7634],
         [-4.4559,  0.8026,  1.8000,  ...,  2.5002, -2.3333,  1.5464]],

        [[-2.0437,  0.1651, -2.5960,  ...,  3.2923,  4.0604, -0.9996],
         [ 3.8085, -3.9371,  0.4386,  ..., -3

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [18]:
# algs = [mala, ada_mala]
algs = {
    "mala_new": mala_new,
    "mala_old": mala_old,
    "fisher_mala": fisher_mala
}
res_total = {
    alg : {}
    for alg in algs
}


mass_points_counts = [2, 5, 10, 25, 50, 100, 200]
radius = 2
cube = 5
chain_count = 5
sample_count = 5000
burn_in = 5000
true_means = torch.rand((mass_points_count, mass_points_count)) * 2 * cube - cube
true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)

alg_params = {
    "mala_old": {
        "burn_in": burn_in,
        "step_size": 1e-2
    },
    "mala_new": {
        "burn_in": burn_in,
        "sigma_init": 1e-2,
    },
    "fisher_mala": {
        "burn_in_prec": burn_in - 500,
        "burn_in_sigma": 500,
        "sigma_init": 1e-2
    },
}


for mass_points_count in tqdm.tqdm(mass_points_counts):
    true_means = torch.rand((mass_points_count, mass_points_count)) * 10 - 5
    true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)
    gm = GaussianMixture(true_means, true_covs,
                         torch.full((mass_points_count,), 1/mass_points_count))

    for alg in algs:
        benchmark = Benchmark(
            target_dist=gm,
            target_dist_title="true samples",
            dimension=mass_points_count,
            sampling_algorithm=algs[alg], 
            sample_count=sample_count,
            chain_count=chain_count,
            target_dist_mass_points=true_means,
            distance_to_mass_points=radius
        )


        cur_res = benchmark.run(keep_graph=False, **alg_params[alg])
        cur_res["dimension"] = mass_points_count

        for key in cur_res:
            if key in res_total[alg]:
                res_total[alg][key].append(cur_res[key])
            else:
                res_total[alg][key] = [cur_res[key]]

 14%|█▍        | 1/7 [04:23<26:18, 263.13s/it]


KeyboardInterrupt: 

In [None]:
pd.DataFrame(res_total["mala_old"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.018432,0.53270483,0.013523146,20.566136,1.403349,2
1,0.017617,0.58164823,0.014839096,68.570832,1.412987,5
2,0.018589,0.6533207,0.012992664,165.212734,1.5251,10
3,0.018728,0.6890273,0.010654363,240.010562,1.509693,15


In [None]:
pd.DataFrame(res_total["mala_new"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.018756,0.8588349,0.004755634,25.745464,1.589163,2
1,0.015801,0.87675947,0.004756532,70.006221,1.617835,5
2,0.016033,0.8988648,0.003826085,150.16171,1.795222,10
3,0.015761,0.9101176,0.0039818287,245.844473,1.838375,15


In [None]:
pd.DataFrame(res_total["fisher_mala"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.016448,0.8733221,0.005014251,28.01263,1.910149,2
1,0.013958,0.8805882,0.0055737854,71.746634,1.996947,5
2,0.016154,0.906693,0.003734962,180.062975,2.064831,10
3,0.017291,0.9091048,0.0032813984,232.97214,2.109851,15


In [28]:
logp = Tensor([-90, -80])
weights = torch.full_like(logp, 1/len(logp))

Tensor(torch.log(torch.exp(logp) @ weights))

tensor(-80.6931)

In [24]:
torch.full_like(logp, 1/len(logp))

tensor([0.5000, 0.5000])