In [None]:
import tqdm
from typing import Callable, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import Distribution as torchDist

from distributions import SamplableDistribution, GaussianMixture
from samplers.mala_ex2mcmc import mala as mala_old
from samplers.mala_modified import mala as mala_new
from samplers.fisher_mala import fisher_mala

from tools.benchmark import BenchmarkUtils, Benchmark

In [None]:
seed = 123

torch.manual_seed(seed)

mass_points_count = 2
gaussian_count = 1
true_means = torch.rand((gaussian_count, mass_points_count)) * 2 - 1
true_covs = torch.eye(mass_points_count).repeat(gaussian_count, 1, 1)

sample_count = 1000
gm = GaussianMixture(true_means, true_covs,
                     torch.full((gaussian_count,), 1/mass_points_count, dtype=torch.float64))

starting_points = true_means
target_dist = gm

true_samples = gm.sample(sample_count)

In [None]:
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in_prec=500,
                                          burn_in_sigma=500,
                                          sigma_init=1e-2, damping=10, keep_graph=False)

BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
mcmc_samples = BenchmarkUtils.sample_mcmc(mala_new, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=5000,
                                          keep_graph=False,
                                          sigma_init=1e-2)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "trues dist")
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
# torch.autograd.set_detect_anomaly(True)

mcmc_samples = BenchmarkUtils.sample_mcmc(mala_old, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=5000,
                                          step_size=1e-2, keep_graph=True)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
#BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in_prec=4500,
                                          burn_in_sigma=500,
                                          sigma_init=1e-2, damping=10, keep_graph=False)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
#BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [None]:
A = torch.rand(5, 5) 
A = A @ A.permute(1, 0)
A

In [None]:
# algs = [mala, ada_mala]
algs = {
    "mala_new": mala_new,
    "mala_old": mala_old,
    "fisher_mala": fisher_mala
}
res_total = {
    alg : {}
    for alg in algs
}


mass_points_counts = [2, 5, 10, 25, 50] #, 100, 200]
gaussian_count = 5
radius = 0.5
cube = 1
chain_count = 25
sample_count = 1000
burn_in = 1000

alg_params = {
    "mala_old": {
        "burn_in": burn_in,
        "step_size": 1e-2
    },
    "mala_new": {
        "burn_in": burn_in,
        "sigma_init": 1e-2,
    },
    "fisher_mala": {
        "burn_in_prec": burn_in - 500,
        "burn_in_sigma": 500,
        "sigma_init": 1e-2
    },
}


for mass_points_count in tqdm.tqdm(mass_points_counts):
    true_means = torch.rand((gaussian_count, mass_points_count)) * cube * 2 - 5
    true_cov = torch.rand(mass_points_count, mass_points_count)
    true_cov = true_cov @ true_cov.permute(1, 0) + torch.eye(mass_points_count)
    true_covs = true_cov.repeat(gaussian_count, 1, 1)
    
    gm = GaussianMixture(true_means, true_covs,
                         torch.full((gaussian_count,), 1/mass_points_count))

    for alg in algs:
        benchmark = Benchmark(
            target_dist=gm,
            target_dist_title="true samples",
            dimension=mass_points_count,
            sampling_algorithm=algs[alg], 
            sample_count=sample_count,
            chain_count=chain_count,
            target_dist_mass_points=true_means,
            distance_to_mass_points=radius
        )


        cur_res = benchmark.run(keep_graph=False, **alg_params[alg])
        cur_res["dimension"] = mass_points_count

        for key in cur_res:
            if key in res_total[alg]:
                res_total[alg][key].append(cur_res[key])
            else:
                res_total[alg][key] = [cur_res[key]]

In [None]:
pd.DataFrame(res_total["mala_old"])

In [None]:
pd.DataFrame(res_total["mala_new"])

In [None]:
pd.DataFrame(res_total["fisher_mala"])

In [None]:
pd.DataFrame(res_total["mala_old"])

In [None]:
pd.DataFrame(res_total["mala_new"])

In [None]:
pd.DataFrame(res_total["fisher_mala"])