In [1]:
import time
import tqdm
from typing import Callable, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import Distribution as torchDist

from distributions import SamplableDistribution, GaussianMixture
from samplers.mala_ex2mcmc import mala as mala_old
from samplers.mala_modified import mala as mala_new

from tools.benchmark import BenchmarkUtils, Benchmark

2024-01-24 00:43:03.181474: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-24 00:43:03.181511: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-24 00:43:03.182489: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
seed = 123

torch.manual_seed(seed)

mass_points_count = 25
true_means = torch.rand((mass_points_count, mass_points_count)) * 2 - 1
true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)

sample_count = 1000
gm = GaussianMixture(true_means, true_covs,
                     torch.full((mass_points_count,), 1/mass_points_count))

starting_points = true_means
target_dist = gm

true_samples = gm.sample(sample_count)

In [3]:
mcmc_samples = BenchmarkUtils.sample_mcmc(mala_old, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=100,
                                          step_size=0.5, keep_graph=False)

# BenchmarkUtils.create_plot(mcmc_samples, true_samples, "true dist")
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

{'ess': 0.0758526,
 'tv_mean': Array(0.06440658, dtype=float32),
 'tv_conf_sigma': Array(0.00083773, dtype=float32),
 'wasserstein': 30.07535656635286}

In [5]:
mcmc_samples = BenchmarkUtils.sample_mcmc(mala_new, starting_points, target_dist,
                                          sample_count=100,
                                          burn_in=100,
                                          keep_graph=False,
                                          sigma_init=0.01)

# BenchmarkUtils.create_plot(mcmc_samples, true_samples, "true dist")
BenchmarkUtils.compute_metrics(mcmc_samples[0].detach(), true_samples)

noise torch.Size([25, 25])
nan proposal torch.Size([25, 25])
logp_y torch.Size([25])
grad_y torch.Size([25, 25])
log_qyx torch.Size([25])
log_qxy torch.Size([25])
accept torch.Size([25])
point torch.Size([25, 25])
logpx torch.Size([25])
grad_x torch.Size([25, 25])
accept torch.Size([25, 1])
sigma torch.Size([25, 1])

noise torch.Size([25, 25])
nan proposal torch.Size([25, 25])
logp_y torch.Size([25])
grad_y torch.Size([25, 25])
log_qyx torch.Size([25])
log_qxy torch.Size([25])
accept torch.Size([25])
point torch.Size([25, 25])
logpx torch.Size([25])
grad_x torch.Size([25, 25])
accept torch.Size([25, 1])
sigma torch.Size([25, 1])

noise torch.Size([25, 25])
nan proposal torch.Size([25, 25])
logp_y torch.Size([25])
grad_y torch.Size([25, 25])
log_qyx torch.Size([25])
log_qxy torch.Size([25])
accept torch.Size([25])
point torch.Size([25, 25])
logpx torch.Size([25])
grad_x torch.Size([25, 25])
accept torch.Size([25, 1])
sigma torch.Size([25, 1])

noise torch.Size([25, 25])
nan proposal tor

In [None]:
mcmc_samples[1]["sigma"][-1]

tensor([[2.5459e-01],
        [2.9277e-01],
        [2.1264e-01],
        [3.7643e-01],
        [6.2230e-02],
        [2.3434e-01],
        [4.5781e-08],
        [4.7687e-08],
        [2.1299e-01],
        [2.7496e-01],
        [2.7109e-01],
        [6.3111e-08],
        [3.2799e-01],
        [2.5853e-01],
        [3.3318e-01],
        [1.2736e-01],
        [3.2388e-01],
        [2.3401e-01],
        [3.2539e-01],
        [6.3485e-08],
        [2.6707e-01],
        [5.9591e-08],
        [2.1664e-01],
        [3.5879e-01],
        [1.5427e-01]], grad_fn=<AsStridedBackward0>)

In [None]:

prec = Tensor([
    [[1, 0], [2, 1]],
    [[2, 1], [2, 4]]
])

In [None]:
z = Tensor([[1, 3], [1, 3]])
z.requires_grad_()
v = Tensor([[1, 2], [4, 5]])
v.requires_grad_()

target_dist2 = torch.distributions.MultivariateNormal(
    torch.zeros(2),
    torch.eye(2),
)


In [None]:
h(z, z, torch.ones(2), [prec, prec.permute(0, 2, 1)], target_dist2, False)

tensor([ -7.2500, -29.1250], grad_fn=<MulBackward0>)

In [None]:
A = Tensor([
    [1, 1],
    [0, 1]
])

B = Tensor([
    [2, -2],
    [3, 3]
])

(A[:, None, :] @ B[..., None]).squeeze()

tensor([0., 3.])

In [None]:
from functools import partial


def h(z: Tensor, v: Tensor, sigma: Tensor, prec_factors: list[Tensor], 
      target_dist: Union[SamplableDistribution, torchDist], keep_graph: bool) -> Tensor:
    """
    z, v (sample_count, n_dim)
    sigma (sample_count)
    prec_factors List[(sample_count, n_dim, n_dim)]
    """

    logp_v = target_dist.log_prob(v)
    if keep_graph:
        grad_v = torch.autograd.grad(
            logp_v.sum(),
            v,
            create_graph=keep_graph,
            retain_graph=keep_graph,
        )[0]
    else:
        grad_v = torch.autograd.grad(logp_v.sum(), v)[0].detach()
    
    grad_v_img = prec_factors[-1] @ grad_v[..., None]
    for factor in reversed(prec_factors[:-1]):
        grad_v_img = factor @ grad_v_img

    grad_v_img = grad_v_img.squeeze()

    return 0.5 * (grad_v[:, None, :] @ 
                  (z - v - 0.25 * grad_v_img * sigma[..., None] ** 2)[..., None]
                 ).squeeze()


def fisher_mala(
    starting_points: torch.Tensor,
    target_dist: Union[SamplableDistribution, torchDist],
    sample_count: int,
    burn_in: int,
    project: Callable = lambda x: x,
    *,
    sigma_init: float = 1.,
    damping: float = 10.,
    rho: float = 0.015,
    alpha: float = 0.574,
    verbose: bool = False,
    meta: Optional[Dict] = None,
    keep_graph: bool = False,
) -> Tuple[torch.Tensor, Dict]:
    """
    starting_points (sample_count, n_dim)
    sigma (sample_count)
    """


    if sample_count + burn_in <= 0:
        raise ValueError("Number of steps might be positive")

    chains = []
    point = starting_points.clone()
    point.requires_grad_()
    point.grad = None
    device = point.device

    proposal_dist = torch.distributions.MultivariateNormal(
        torch.zeros(point.shape[-1], device=device),
        torch.eye(point.shape[-1], device=device),
    )

    meta = meta or dict()
    meta["mh_accept"] = meta.get("mh_accept", [])
    meta["step_size"] = meta.get("step_size", [])
    meta["logp"] = logp_x = target_dist.log_prob(point)
    meta["sigma"] = meta.get("sigma", [])

    if "grad" not in meta:
        if keep_graph:
            grad_x = torch.autograd.grad(
                meta["logp"].sum(),
                point,
                create_graph=keep_graph,
                retain_graph=keep_graph,
            )[0]
        else:
            grad_x = torch.autograd.grad(logp_x.sum(), point)[0].detach()
        meta["grad"] = grad_x
    else:
        grad_x = meta["grad"]

    sigma = torch.full(point.shape[:-1], sigma_init)[..., None]
    # print("sigma", sigma.shape)

    pbar = trange if verbose else range
    for step_id in pbar(burn_in):
        noise = proposal_dist.sample(point.shape[:-1])
        # print("noise", noise.shape)

        proposal_point = point + 0.5 * sigma ** 2 * grad_x + noise * sigma 
        # print("nan proposal", torch.isnan(proposal_point).sum())

        if not keep_graph:
            proposal_point = proposal_point.detach().requires_grad_()

        logp_y = target_dist.log_prob(proposal_point)
        # print("logp_y", logp_y)
        # print("nan logp_y", torch.isnan(logp_y).sum())

        grad_y = torch.autograd.grad(
            logp_y.sum(),
            proposal_point,
            create_graph=keep_graph,
            retain_graph=keep_graph,
        )[
            0
        ]  # .detach()
        # print("grad_y", grad_y)
        # print("nan grad_y", torch.isnan(grad_y).sum())

        log_qyx = proposal_dist.log_prob(noise)

        # print("nan num", torch.isnan(point - proposal_point - sigma ** 2 * grad_y).sum())
        log_qxy = proposal_dist.log_prob(
            (point - proposal_point - sigma ** 2 * grad_y) / sigma
        )

        accept_prob = torch.clamp((logp_y + log_qxy - logp_x - log_qyx).exp(), max=1)
        mask = torch.rand_like(accept_prob) < accept_prob
        mask = mask.detach()

        if keep_graph:
            mask_f = mask.float()
            point = point * (1 - mask_f) + proposal_point * mask_f
            logp_x = logp_x * (1 - mask_f) + logp_y * mask_f
            grad_x = grad_x * (1 - mask_f) + grad_y * mask_f
        else:
            with torch.no_grad():
                mask_f = mask.float()

                # point[mask] = proposal_point[mask]
                # logp_x[mask] = logp_y[mask]
                # grad_x[mask] = grad_y[mask]
                point = point * (1 - mask_f) + proposal_point * mask_f
                logp_x = logp_x * (1 - mask_f) + logp_y * mask_f
                grad_x = grad_x * (1 - mask_f) + grad_y * mask_f

        last_accept = mask.float().mean().item()
        meta["mh_accept"].append(last_accept)

        sigma *= (1 + rho * (accept_prob[..., None] - alpha)) ** 0.5
        # print("sigma", sigma)
        meta["sigma"].append(sigma)

        if not keep_graph:
            point = point.detach().requires_grad_()


    R = torch.eye(point.shape[-1]).repeat(*point.shape[:-1], 1, 1)
    sigma_R = sigma[..., None]
    sigma_ = sigma_R.clone()

    h_ = partial(h, prec_factors=[R, R.permute(0, 2, 1)], keep_graph=keep_graph,
                 target_dist=target_dist)

    for step_id in pbar(sample_count):
        # print("step", step_id)
        noise = proposal_dist.sample(point.shape[:-1])

        grad_x_img = grad_x[..., None]
        grad_x_img = R @ (R.permute(0, 2, 1) @ grad_x_img)
        # print("nan grad_transf", torch.isnan(grad_x_img).sum())

        # print("grad_transf", grad_transf.shape)

        proposal_point = point + (
            0.5 * grad_x_img * sigma_R ** 2 + R @ noise[..., None] * sigma_R
        ).squeeze()
        # print("nan proposal_point", torch.isnan(proposal_point).sum())

        # print("proposal point", proposal_point.shape)

        if not keep_graph:
            proposal_point = proposal_point.detach().requires_grad_()

        logp_y = target_dist.log_prob(proposal_point)
        # print("logpy", logp_y.shape)
        # print("nan logp_y", torch.isnan(logp_y).sum())        
        # print("logp_y", logp_y)
        
        grad_y = torch.autograd.grad(
            logp_y.sum(),
            proposal_point,
            create_graph=keep_graph,
            retain_graph=keep_graph,
        )[
            0
        ]  # .detach()

        grad_y_img = grad_y[..., None]
        grad_y_img = R @ (R.permute(0, 2, 1) @ grad_y_img)

        # log_qyx = proposal_dist.log_prob(noise)
        # log_qxy = proposal_dist.log_prob(
        #     (
        #         (R * sigma_R).inverse() @ 
        #         (point - proposal_point - (0.5 * grad_y_img * sigma_R ** 2).squeeze())[..., None]
        #     ).squeeze()
        # )

        # accept_prob = torch.clamp((logp_y + log_qxy - logp_x - log_qyx).exp(), max=1)
        accept_prob = torch.clamp(
            torch.exp(
                logp_y + h_(point, proposal_point, sigma_R.squeeze()) - logp_x \
                - h_(proposal_point, point, sigma_R.squeeze())
            ),
            max=1
        )
        # print("accept_prob", accept_prob)
        # print("nan accept_prob", torch.isnan(accept_prob).sum())


        # print("accept", accept_prob.shape)
        # print("grad_y - grad_x", (grad_y - grad_x).shape)

        signal_adaptation = torch.sqrt(accept_prob)[..., None] * (grad_y - grad_x)
        # print("nan signal_adaptation", torch.isnan(signal_adaptation).sum())
        # print("sqrt accept", torch.sqrt(accept_prob))
        
        # print("sig adapt", signal_adaptation)

        phi_n = R.permute(0, 2, 1) @ signal_adaptation[..., None]
        # print("nan phi_n", torch.isnan(phi_n).sum())
        # print("phi_n", phi_n)

        gramm_diag = phi_n.permute(0, 2, 1) @ phi_n
        # print("gramm_diag", gramm_diag)

        if step_id == 0:
            r_1 = 1. / (1 + torch.sqrt(damping / (damping + gramm_diag)))
            shift = phi_n @ phi_n.permute(0, 2, 1)
            # print("shift", shift)
            R = 1. / damping ** 0.5 * (R - shift * r_1 / (damping + gramm_diag))
        else:
            r_n = 1. / (1 + torch.sqrt(1 / (1 + gramm_diag)))
            # print("nan rn", torch.isnan(r_n).sum())
            # print("r_n", r_n)

            # print((R @ phi_n).shape)
            # print((phi_n.permute(0, 2, 1)).shape)
            shift = (R @ phi_n) @ phi_n.permute(0, 2, 1)
            # print("shift", shift)
            R = R - shift * r_n / (1 + gramm_diag)

        
        # print("R", R)
        
        # print("sigma update", (1 + rho * (accept_prob - alpha)))
        sigma_[..., 0, 0] *= (1 + rho * (accept_prob - alpha)) ** 0.5
        # sigma_R = torch.full_like(sigma_R, 1)

        # print("sigma_R before norm", sigma_R)

        trace_prec = (R[..., None, :] @ R[..., None]).sum(dim=1)
        # print("trace", trace_prec)
        normalizer = (1. / point.shape[-1]) * trace_prec
        # print("normalizer", normalizer)
        sigma_R = sigma_ / normalizer ** 0.5
        # print("sigma_R", sigma_R)

        A_n = R * sigma_R
        # print("R * sigma", A_n)

        A_n = A_n @ A_n.permute(0, 2, 1)
        trace_A = [A.trace() for A in A_n]

        # print("trace_A", trace_A)

        # sigma_R = torch.full_like(sigma_R, 1)
        # R = torch.eye(point.shape[-1]).repeat(*point.shape[:-1], 1, 1)
        # print("normalizer", normalizer)
        # print("sigma_R", sigma_R)  

        # print()

        mask = torch.rand_like(accept_prob) < accept_prob
        mask = mask.detach()[..., None]

        if keep_graph:
            mask_f = mask.float()
            point = point * (1 - mask_f) + proposal_point * mask_f
            logp_x = logp_x * (1 - mask_f) + logp_y * mask_f
            grad_x = grad_x * (1 - mask_f) + grad_y * mask_f
        else:
            with torch.no_grad():
                mask_f = mask.float()

                # point[mask] = proposal_point[mask]
                # logp_x[mask] = logp_y[mask]
                # grad_x[mask] = grad_y[mask]
                point = point * (1 - mask_f) + proposal_point * mask_f
                logp_x = logp_x * (1 - mask_f) + logp_y * mask_f
                grad_x = grad_x * (1 - mask_f) + grad_y * mask_f

        last_accept = mask.float().mean().item()
        meta["mh_accept"].append(last_accept)

        # meta["sigma"].append(sigma)

        if not keep_graph:
            point = point.detach().requires_grad_()

        chains.append(point.cpu().clone())
        
    chains = torch.stack(chains, 0)

    meta["logp"] = logp_x
    meta["grad"] = grad_x
    meta["mask"] = mask.cpu()

    return chains, meta

In [None]:
torch.manual_seed(seed)
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=5000,
                                          keep_graph=False,
                                          sigma_init=1.,
                                          damping=10)[0].detach()

# BenchmarkUtils.create_plot(mcmc_samples, true_samples, "true dist")
BenchmarkUtils.compute_metrics(mcmc_samples, true_samples)

KeyboardInterrupt: 

In [None]:
# Dummy example

torch.manual_seed(seed)

mass_points_count = 25
true_means = torch.rand((mass_points_count, mass_points_count)) * 2 - 1
true_covs = torch.eye(mass_points_count).repeat(mass_points_count, 1, 1)

sample_count = 1000
gm = GaussianMixture(true_means, true_covs,
                     torch.full((mass_points_count,), 1/mass_points_count))

starting_points = true_means
target_dist = gm

true_samples = gm.sample(sample_count)

# mcmc_samples = BenchmarkUtils.sample_mcmc(mala, starting_points, target_dist,
#                                           sample_count=sample_count,
#                                           burn_in=100,
#                                           step_size=0.5, keep_graph=False)[0].detach()

# # mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
# #                                           sample_count=sample_count,
# #                                           burn_in=100,
# #                                           keep_graph=True)[0].detach()

# # BenchmarkUtils.create_plot(mcmc_samples, true_samples, "true dist")
# BenchmarkUtils.compute_metrics(mcmc_samples, true_samples)

In [None]:
mcmc_samples = BenchmarkUtils.sample_mcmc(ada_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=5000,
                                          keep_graph=True)[0].detach()

# BenchmarkUtils.create_plot(mcmc_samples, true_samples, "true dist")
BenchmarkUtils.compute_metrics(mcmc_samples, true_samples)

sigma tensor([[0.9970],
        [1.0032],
        [1.0020],
        [1.0032],
        [1.0032],
        [0.9991],
        [0.9990],
        [1.0032],
        [1.0030],
        [1.0032],
        [0.9974],
        [0.9989],
        [0.9995],
        [1.0032],
        [1.0023],
        [1.0032],
        [0.9957],
        [0.9990],
        [0.9969],
        [1.0032],
        [1.0027],
        [1.0032],
        [1.0030],
        [1.0032],
        [1.0026]], grad_fn=<AsStridedBackward0>)
sigma tensor([[0.9928],
        [1.0064],
        [1.0011],
        [1.0064],
        [1.0064],
        [0.9993],
        [0.9947],
        [1.0064],
        [1.0062],
        [1.0063],
        [0.9931],
        [0.9946],
        [1.0027],
        [1.0064],
        [0.9981],
        [1.0064],
        [0.9914],
        [0.9947],
        [1.0001],
        [0.9998],
        [1.0059],
        [1.0012],
        [1.0062],
        [1.0064],
        [0.9982]], grad_fn=<AsStridedBackward0>)
sigma tensor([[0.9885],
  

KeyboardInterrupt: 

In [None]:
mcmc_samples = BenchmarkUtils.sample_mcmc(fisher_mala, starting_points, target_dist,
                                          sample_count=sample_count,
                                          burn_in=100,
                                          keep_graph=True)[0].detach()

# BenchmarkUtils.create_plot(mcmc_samples, true_samples, "true dist")
BenchmarkUtils.compute_metrics(mcmc_samples, true_samples)

{'ess': 0.038977556,
 'tv_mean': Array(0.28513286, dtype=float32),
 'tv_conf_sigma': Array(0.0056896, dtype=float32),
 'wasserstein': 39.994763852539066}

In [None]:
# algs = [mala, ada_mala]
algs = [ada_mala, fisher_mala]
res_total = {
    alg.__name__: {}
    for alg in algs
}

alg_params = {
    alg.__name__: {}
    for alg in algs
}


alg_params["ada_mala"] = {
    "sigma_init": 1.,
}

alg_params["fisher_mala"] = {
    "sigma_init": 1.,
    "damping": 10
}
distances = [0.01, 0.1, 1, 2, 8]

for distance in tqdm.tqdm(distances):
    for alg in algs:
        benchmark = Benchmark(
            target_dist=gm,
            target_dist_title="true samples",
            dimension=mass_points_count,
            sampling_algorithm=alg, 
            sample_count=sample_count,
            chain_count=50,
            target_dist_mass_points=true_means,
            distance_to_mass_points=distance
        )

        cur_res = benchmark.run(burn_in=10000,
                                keep_graph=False,
                                **alg_params[alg.__name__])
        cur_res["distance"] = distance
        
        for key in cur_res:
            if key in res_total[alg.__name__]:
                res_total[alg.__name__][key].append(cur_res[key])
            else:
                res_total[alg.__name__][key] = [cur_res[key]]

  0%|          | 0/5 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (25) must match the size of tensor b (50) at non-singleton dimension 1

In [None]:
pd.DataFrame(res_total["fisher_mala"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,distance
0,0.017917,0.36468434,0.0077017285,48.963574,24.564249,0.01
1,0.010343,0.34175932,0.006006761,49.940879,24.90175,0.1
2,0.035573,0.9860077,0.56881034,57.310989,26.047411,1.0
3,0.009297,0.28547817,0.005291307,44.639782,24.365457,2.0
4,0.038017,0.3665052,0.008227379,50.279976,24.43841,8.0


In [None]:
pd.DataFrame(res_total["ada_mala"])

Unnamed: 0,ess,tv_mean,tv_conf_sigma,wasserstein,time_elapsed,distance
0,0.028792,0.31372875,0.008866465,46.858629,23.512178,0.01
1,0.014141,0.25455028,0.007784676,42.17215,22.8452,0.1
2,0.012125,0.25432864,0.0061967955,45.037771,23.412413,1.0
3,0.023894,0.33077544,0.008398838,47.881754,23.088268,2.0
4,0.029026,3.0803134,2.7288668,50.324107,23.602863,8.0
