In [1]:
import tqdm
from typing import Callable, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import Distribution as torchDist

from distributions import SamplableDistribution, GaussianMixture
from samplers.mala_ex2mcmc import mala as mala_old
from samplers.mala_modified import mala as mala_new
from samplers.fisher_mala import fisher_mala

from tools.benchmark import BenchmarkUtils, Benchmark

In [2]:
seed = 123

torch.manual_seed(seed)

mass_points_count = 200
true_means = torch.rand((mass_points_count, mass_points_count)) * 10 - 5
true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)

sample_count = 5000
gm = GaussianMixture(true_means, true_covs,
                     torch.full((mass_points_count,), 1/mass_points_count, dtype=torch.float64))

starting_points = true_means
target_dist = gm

true_samples = gm.sample(sample_count)

In [17]:
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in_prec=4500,
                                          burn_in_sigma=500,
                                          sigma_init=1e-2, damping=10, keep_graph=False)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
#BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

KeyboardInterrupt: 

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.0731672,
 'tv_mean': Array(0.6034001, dtype=float32),
 'tv_conf_sigma': Array(0.00240188, dtype=float32),
 'wasserstein': 1707.7763328802475}

In [None]:
mcmc_samples = BenchmarkUtils.sample_mcmc(mala_new, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=5000,
                                          keep_graph=False,
                                          sigma_init=1e-2)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "trues dist")
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
# torch.autograd.set_detect_anomaly(True)

mcmc_samples = BenchmarkUtils.sample_mcmc(mala_old, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=5000,
                                          step_size=1e-2, keep_graph=True)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
#BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [18]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.010139392,
 'tv_mean': Array(0.64891124, dtype=float32),
 'tv_conf_sigma': Array(0.00244415, dtype=float32),
 'wasserstein': 1750.0380741119382}

In [16]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.010078529,
 'tv_mean': Array(0.64356774, dtype=float32),
 'tv_conf_sigma': Array(0.00246596, dtype=float32),
 'wasserstein': 1746.3172492352303}

In [3]:
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in_prec=4500,
                                          burn_in_sigma=500,
                                          sigma_init=1e-2, damping=10, keep_graph=False)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
#BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [4]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.07315685,
 'tv_mean': Array(0.6034165, dtype=float32),
 'tv_conf_sigma': Array(0.00240202, dtype=float32),
 'wasserstein': 1707.7999110211185}

In [6]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.011537474,
 'tv_mean': Array(0.728531, dtype=float32),
 'tv_conf_sigma': Array(0.00190675, dtype=float32),
 'wasserstein': 1729.3248153472905}

In [4]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.012827393,
 'tv_mean': Array(0.7115489, dtype=float32),
 'tv_conf_sigma': Array(0.00197668, dtype=float32),
 'wasserstein': 1728.1137651416018}

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [13]:
A = torch.rand(5, 5) 
A = A @ A.permute(1, 0)
A

tensor([[0.9488, 0.5682, 0.5890, 1.3748, 0.4605],
        [0.5682, 0.9247, 0.5632, 0.8831, 0.9615],
        [0.5890, 0.5632, 0.7172, 0.8804, 0.6908],
        [1.3748, 0.8831, 0.8804, 2.3430, 0.8233],
        [0.4605, 0.9615, 0.6908, 0.8233, 1.1882]])

In [15]:
# algs = [mala, ada_mala]
algs = {
    "mala_new": mala_new,
    "mala_old": mala_old,
    "fisher_mala": fisher_mala
}
res_total = {
    alg : {}
    for alg in algs
}


mass_points_counts = [2, 5, 10, 25, 50, 100, 200]
radius = 2
cube = 5
chain_count = 25
sample_count = 1000
burn_in = 1000
true_means = torch.rand((mass_points_count, mass_points_count)) * 2 * cube - cube
true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)

alg_params = {
    "mala_old": {
        "burn_in": burn_in,
        "step_size": 1e-2
    },
    "mala_new": {
        "burn_in": burn_in,
        "sigma_init": 1e-2,
    },
    "fisher_mala": {
        "burn_in_prec": burn_in - 500,
        "burn_in_sigma": 500,
        "sigma_init": 1e-2
    },
}


for mass_points_count in tqdm.tqdm(mass_points_counts):
    true_means = torch.rand((mass_points_count, mass_points_count)) * 10 - 5
    true_cov = torch.rand(mass_points_count, mass_points_count)
    true_cov = true_cov @ true_cov.permute(1, 0) + torch.eye(mass_points_count)
    true_covs = true_cov.repeat(mass_points_count, 1, 1)
    
    gm = GaussianMixture(true_means, true_covs,
                         torch.full((mass_points_count,), 1/mass_points_count))

    for alg in algs:
        benchmark = Benchmark(
            target_dist=gm,
            target_dist_title="true samples",
            dimension=mass_points_count,
            sampling_algorithm=algs[alg], 
            sample_count=sample_count,
            chain_count=chain_count,
            target_dist_mass_points=true_means,
            distance_to_mass_points=radius
        )


        cur_res = benchmark.run(keep_graph=False, **alg_params[alg])
        cur_res["dimension"] = mass_points_count

        for key in cur_res:
            if key in res_total[alg]:
                res_total[alg][key].append(cur_res[key])
            else:
                res_total[alg][key] = [cur_res[key]]

100%|██████████| 7/7 [26:50<00:00, 230.13s/it]


In [16]:
pd.DataFrame(res_total["mala_old"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.00842,0.51736635,0.0049867486,26.8739,23.428321,2
1,0.007666,0.49616697,0.0057462757,41.474016,24.648222,5
2,0.00779,0.5997244,0.005878168,171.16634,25.432154,10
3,0.007293,0.625027,0.004817148,593.445169,28.609877,25
4,0.006968,0.6502758,0.0048980564,1848.86158,32.686932,50
5,0.006671,0.6721403,0.004979343,5611.973498,59.65487,100
6,0.006445,0.72814596,0.004753406,18220.174666,209.313812,200


In [17]:
pd.DataFrame(res_total["mala_new"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.071355,0.41978392,0.005451159,24.641059,31.691438,2
1,0.018405,0.23169306,0.0041940007,18.060171,31.698124,5
2,0.018199,0.4061715,0.0071294564,137.517761,33.633917,10
3,0.015748,0.44305643,0.0049689864,501.289273,34.898961,25
4,0.014077,0.4193722,0.0046921787,1410.431,38.5623,50
5,0.013206,0.42431372,0.005594132,4904.157759,67.29628,100
6,0.010556,0.45828193,0.0056107547,16767.216269,217.775079,200


In [18]:
pd.DataFrame(res_total["fisher_mala"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.068552,0.41842148,0.005500972,24.236641,24.048251,2
1,0.030954,0.23055966,0.004722572,17.771032,24.537522,5
2,0.037105,0.39340174,0.0073568076,136.320696,26.199791,10
3,0.02322,0.39469257,0.0054488396,438.193267,29.287315,25
4,0.016366,0.36536884,0.004676817,1301.557125,39.114649,50
5,0.017107,0.3694105,0.0049873767,4418.660156,104.167506,100
6,0.016434,0.40463078,0.0053079813,16237.847361,504.228172,200


In [8]:
pd.DataFrame(res_total["mala_old"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.00982,0.4996605,0.0026837846,39.589665,97.001324,2
1,0.01006,0.54095554,0.002975099,48.048493,100.929728,5
2,0.010044,0.59718144,0.002739974,121.150241,108.764723,10
3,0.01,0.63365895,0.0024835747,402.311367,112.878523,25
4,0.010145,0.6351473,0.0024942579,821.776378,105.594287,50
5,0.010077,0.6428129,0.0025388985,1714.623244,137.11627,100
6,0.010103,0.64545614,0.0025014454,3516.648847,317.796911,200


In [9]:
pd.DataFrame(res_total["mala_new"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.099876,0.44438413,0.0027448875,38.579628,107.18328,2
1,0.056515,0.43524906,0.0033502397,40.389596,100.489273,5
2,0.073209,0.5398282,0.0028249314,113.297063,111.031553,10
3,0.070972,0.5962063,0.0025471672,401.516984,110.272567,25
4,0.063558,0.594541,0.0024541793,800.400773,110.25628,50
5,0.055779,0.5984195,0.0023482742,1682.441764,142.309467,100
6,0.048539,0.6071409,0.0024535258,3449.50143,374.732302,200


In [10]:
pd.DataFrame(res_total["fisher_mala"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,dimension
0,0.09925,0.44205633,0.0028334283,38.33455,91.613585,2
1,0.055048,0.43834776,0.003271828,41.469913,107.163142,5
2,0.074045,0.5470617,0.0028533805,116.932464,104.887605,10
3,0.069717,0.5949279,0.0024969126,390.232393,108.573844,25
4,0.059061,0.59082687,0.0025192092,795.256714,108.403614,50
5,0.047919,0.5943416,0.0025029427,1693.208986,189.702629,100
6,0.03581,0.5921154,0.002461629,3478.399266,646.193185,200
