In [1]:
import tqdm
from typing import Callable, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import Distribution as torchDist

from distributions import SamplableDistribution, GaussianMixture
from samplers.mala_ex2mcmc import mala as mala_old
from samplers.mala_modified import mala as mala_new
from samplers.fisher_mala import fisher_mala

from tools.benchmark import BenchmarkUtils, Benchmark

2024-01-26 01:25:18.788243: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-26 01:25:18.788288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-26 01:25:18.789286: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
A = Tensor([
    [1, 2, 4],
    [3, 4, 5]
])

C = Tensor([3, -4])

A * C[..., None]

tensor([[  3.,   6.,  12.],
        [-12., -16., -20.]])

In [3]:
torch.isnan(A).sum(dim=0)

tensor([0, 0, 0])

In [3]:
seed = 123

torch.manual_seed(seed)

mass_points_count = 10
true_means = torch.rand((mass_points_count, mass_points_count)) * 10 - 5
true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)

sample_count = 1000
gm = GaussianMixture(true_means, true_covs,
                     torch.full((mass_points_count,), 1/mass_points_count))

starting_points = true_means
target_dist = gm

true_samples = gm.sample(sample_count)

In [34]:
mcmc_samples = BenchmarkUtils.sample_mcmc(mala_new, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=1000,
                                          keep_graph=False,
                                          sigma_init=1.)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "trues dist")
# BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

In [36]:
mcmc_samples[0].shape

torch.Size([1000, 10, 10])

In [6]:
mcmc_samples = BenchmarkUtils.sample_mcmc(mala_old, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=1000,
                                          step_size=0.5, keep_graph=False)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.16419676,
 'tv_mean': Array(0.54042053, dtype=float32),
 'tv_conf_sigma': Array(0.0091321, dtype=float32),
 'wasserstein': 130.96836972618107}

In [4]:
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in_prec=1,
                                          burn_in_sigma=500,
                                          sigma_init=1., damping=10, keep_graph=False)

# BenchmarkUtils.create_plot(mcmc_samples[0].detach(), true_samples, "true dist")
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.044364877,
 'tv_mean': Array(0.6811118, dtype=float32),
 'tv_conf_sigma': Array(0.00663659, dtype=float32),
 'wasserstein': 140.8879412521362}

In [10]:
# algs = [mala, ada_mala]
algs = {
    "mala_new": mala_new,
    "mala_old": mala_old,
    "fisher_mala": fisher_mala
}
res_total = {
    alg : {}
    for alg in algs
}

mass_points_counts = [2, 5, 10, 15]
radius = 2
cube = 5
chain_count = 30
sample_count = 4000
burn_in = 4000
true_means = torch.rand((mass_points_count, mass_points_count)) * 2 * cube - cube
true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)

for mass_points_count in tqdm.tqdm(mass_points_counts):
    true_means = torch.rand((mass_points_count, mass_points_count)) * 10 - 5
    true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)
    gm = GaussianMixture(true_means, true_covs,
                         torch.full((mass_points_count,), 1/mass_points_count))

    for alg in algs:
        benchmark = Benchmark(
            target_dist=gm,
            target_dist_title="true samples",
            dimension=mass_points_count,
            sampling_algorithm=algs[alg], 
            sample_count=sample_count,
            chain_count=chain_count,
            target_dist_mass_points=true_means,
            distance_to_mass_points=radius
        )

        if alg == "fisher_mala":
            cur_res = benchmark.run(burn_in_sigma=500,
                                    burn_in_prec=burn_in-500,
                                    keep_graph=False)
        else:
            cur_res = benchmark.run(burn_in=burn_in,
                                    keep_graph=False)
        for key in cur_res:
            if key in res_total[alg]:
                res_total[alg][key].append(cur_res[key])
            else:
                res_total[alg][key] = [cur_res[key]]

  0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
pd.DataFrame(res_total["mala_old"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed
0,0.637523,0.44291103,0.004538939,33.462368,120.688992


In [23]:
pd.DataFrame(res_total["mala_new"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,distance
0,0.168025,0.52380097,0.0058963136,123.588007,19.599287,0.01
1,0.171416,0.5246308,0.0059378766,120.137128,20.258705,0.1
2,0.165194,0.52777714,0.0059478884,124.745348,19.79745,1.0
3,0.180255,0.5360203,0.0058807116,128.227959,20.752135,2.0
4,0.173619,0.5427313,0.0062374915,138.726523,21.658707,8.0


In [1]:
from ex2mcmc.metrics.chain import ESS, acl_spectrum, autocovariance

2024-01-25 20:54:40.815639: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-25 20:54:40.815711: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-25 20:54:40.842863: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [37]:
import torch
import jax.numpy as jnp

n = 100
dist = torch.distributions.MultivariateNormal(
    torch.zeros(n),
    torch.eye(n)
)

X = dist.sample((n, ))[None, ...]

In [28]:
X.repeat(10, 1, 1).shape

torch.Size([10, 100, 100])

In [59]:
Y = jnp.array([
    [1, 2],
    [-3, -4],
    [5, -6],
    [7, 8],
    [-4, 1]
])

In [60]:
Y[:-2] * Y[2:,]

Array([[  5, -12],
       [-21, -32],
       [-20,  -6]], dtype=int32)

In [58]:
jnp.mean(Y[:-2] * Y[2:,], axis=0)

Array([ -8., -22.], dtype=float32)

In [61]:
autocovariance(Y, 0)

Array([20. , 24.2], dtype=float32)

In [45]:
acl_spectrum(Y - Y.mean(0)[None, ...])

array([[ 0.9999998,  0.9999998],
       [ 0.       ,  0.       ],
       [-1.4999998, -1.4999998],
       [       nan,        nan],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0.       , -0.       ],
       [-0

In [10]:
ESS(acl_spectrum(X - X.mean(0)))

TypeError: sum() received an invalid combination of arguments - got (out=NoneType, axis=int, ), but expected one of:
 * (*, torch.dtype dtype)
      didn't match because some of the keywords were incorrect: out, axis
 * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)
 * (tuple of names dim, bool keepdim, *, torch.dtype dtype)
