In [5]:
import sys
sys.path.append(sys.path[0].replace('/notebooks', ''))
print(sys.path)

['/home/gcardoso/projets/br_snis/notebooks', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '', '/home/gcardoso/projets/venv/lib/python3.8/site-packages', '/home/gcardoso/projets/br_snis', '/home/gcardoso/projets/br_snis']


In [39]:
from typing import Tuple
from torch import ones, diag, inf, arange, Tensor, zeros, eye, no_grad, DeviceObjType, tensor, cat, rand, from_numpy, FloatTensor
from torch.distributions import Categorical, MixtureSameFamily, MultivariateNormal, Distribution
from torch.linalg import norm
from pyro.distributions import MultivariateStudentT
from typing import Union, Callable
from functools import partial
from typing import Callable
from torch import Tensor, inf, rand, cos, tensor, zeros, from_numpy, randint, cat
from torch.distributions import Categorical, MixtureSameFamily, MultivariateNormal, Distribution, Normal, StudentT
from br_snis import br_snis, snis
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from math import ceil

rc_context = {'axes.edgecolor':'black',
              'ytick.color':'black',
              'xtick.color':'black',
              'axes.labelcolor':'black',
              'axes.titlecolor': 'black',
              'legend.fontsize': 22,
              'font.size': 22,
              'figure.subplot.right': 0.98,
              'figure.subplot.top': .98,
              'figure.subplot.bottom': .10,
              'figure.subplot.left': 0.12,
              'lines.markersize': 15,
              'lines.linewidth': 3,
              'xtick.labelsize': 'x-large',
              }

#### Code for mixture of gaussians

In [2]:

def full_f(x, rectangle_1_coordinates, rectangle_2_coordinates):
    center_first, heights_first = rectangle_1_coordinates
    center_second, heights_second = rectangle_2_coordinates
    is_in_first_rectangle = norm((x - center_first[None, :]) / heights_first[None, :], ord=inf, dim=-1) < 1
    is_in_second_rectangle = norm((x - center_second[None, :]) / heights_second[None, :], ord=inf, dim=-1) < 1

    return is_in_first_rectangle.float() - is_in_second_rectangle.float()


def build_f(d: int = 7,
            device: Union[DeviceObjType, str] = 'cpu') -> Callable[[Tensor], Tensor]:
    heights_first = ones(d, device=device)*2
    heights_first[1] = .5
    heights_first[2:] *= .5
    center_first = zeros(d, device=device)
    center_first[0] = -4

    heights_second = ones(d, device=device)
    heights_second[0]=.25
    heights_second[1]=.5
    heights_second[2:] *= .1
    center_second = zeros(d, device=device)
    center_second[1] = 1.5
    center_second[0] = 1

    f = partial(full_f,
                rectangle_1_coordinates=(center_first, heights_first),
                rectangle_2_coordinates=(center_second, heights_second))
    return f


def build_params(d, device):
    mu_1 = zeros(d, device=device)
    mu_2 = zeros(d, device=device)
    mu_1[0:2] = 1
    mu_2[0]= -2
    sigma_1 = ones(d, device=device) / d
    sigma_2 = ones(d, device=device) / d
    return mu_1, mu_2, diag(sigma_1), diag(sigma_2)


def get_pi_dist(p, mus, sigmas):
    mix = Categorical(tensor([p, 1-p], device=mus.device))
    return MixtureSameFamily(mix,
                             MultivariateNormal(mus, sigmas))


def get_lambda_dist(nu, d, device):
    return MultivariateStudentT(loc=zeros(d, device=device), scale_tril=eye(d, device=device), df=nu)


def get_toy_problem_distributions(p: int = 1/3,
                                  nu: int = 3,
                                  dim: int = 7,
                                  device: Union[str, DeviceObjType] ='cpu') -> Tuple[Distribution, Distribution]:
    mu_1, mu_2, sigma_1, sigma_2 = build_params(dim, device=device)
    pi = get_pi_dist(p,
                     cat([mu_1.unsqueeze(0), mu_2.unsqueeze(0)]),
                     cat([sigma_1.unsqueeze(0), sigma_2.unsqueeze(0)]))
    lda = get_lambda_dist(nu, d=dim, device=device)

    return pi, lda


#### Code for Unbiased-PIMH

In [14]:

def snis_pimh(N:int,
              proposal:Distribution,
              target:Distribution):
    x = proposal.sample((N,))
    logw = target.log_prob(x) - proposal.log_prob(x)
    maxlogw = logw.max(dim=0).values
    w = (logw - maxlogw).exp()
    logavew = maxlogw + w.mean(dim=0).log()
    nw = w / w.sum()
    return {
        "nw": nw,
        "logavew": logavew,
        "x": x
    }


def debiasedis2(N, proposal, target, test_function):
    state1 = snis_pimh(N, proposal, target)
    device = state1["x"].device
    state2 = snis_pimh(N, proposal, target)
    if (state1["logavew"] < state2["logavew"]):
        swap = state2
        state2 = state1
        state1 = swap
    estimate1 = (state1["nw"] * test_function(state1["x"])).sum()
    estimate2 = (state2["nw"] * test_function(state2["x"])).sum()
    tau = None
    u = rand(1).to(device)
    if (u.log() < (state2["logavew"] - state1["logavew"])):
        tau = 1
    time = 1
    while tau is None:
        estimate1 += (state1["nw"] * test_function(state1["x"])).sum() - (state2["nw"] * test_function(state2["x"])).sum()
        time += 1
        stateproposal = snis_pimh(N, proposal, target)
        u = rand(1).to(device)
        logacc1 = (stateproposal["logavew"] - state1["logavew"])
        logacc2 = (stateproposal["logavew"] - state2["logavew"])
        if (u.log() < logacc1):
            state1 = stateproposal
        if (u.log() < logacc2):
            state2 = stateproposal
        if ((u.log() < logacc1) and (u.log() < logacc2)):
            tau = time
    return{
        "estimate": 0.5*estimate1+0.5*estimate2,
        "tau": tau
    }

In [44]:
def build_unbiased_pimh_estimates_for_budget(budget,
                                                  minimal_n_replications,
                                                  target,
                                                  proposal, 
                                                  h,
                                                  ref_value,
                                             hard=True,
                                                 plot=True):
    
    Nseq = [budget // 2**l  for l in range(2, 6) if budget // 2**l > 2]
    results_pimh = []
    # for each N
    estimations_pimh = {}
    for iN in range(len(Nseq)):
        N = Nseq[iN]

        estimates_dis1 = []
        total_running_cost = 0
        costs = []
        n_zeros = 0
        for irep in tqdm(range(minimal_n_replications),
                         desc=f'PIMH Unbiased-PIMH HARD {N}' if hard else f'PIMH Unbiased-PIMH SOFT {N}'):
            # run algorithm
            round_dis_estimate = []
            running_cost = 0
            for i_est in range(budget // N):
                dis_result = debiasedis2(N,
                                         proposal=proposal,
                                         target=target,
                                         test_function=h)
                running_cost += 2 * N + (dis_result["tau"] - 1) * N
                if hard and (running_cost > budget):
                    running_cost -= 2 * N + (dis_result["tau"] - 1) * N
                    break
                round_dis_estimate.append(dis_result["estimate"].item())
                if running_cost >= budget:
                    break
                    
            costs.append(running_cost)
            total_running_cost += running_cost
            if len(round_dis_estimate) == 0:
                n_zeros += 1
                #estimates_dis1.append(0)
            else:
                estimates_dis1.append(np.mean(round_dis_estimate))
        estimations_pimh[N] = np.array(estimates_dis1)
        if plot:
            with plt.rc_context(rc_context):
                fig, ax = plt.subplots(1, 1, figsize=(6, 6))
                ax.hist(costs, alpha=.7)
                ax.axvline(budget, color='red')
                ax.set_yscale('log')
                fig.show()
        results_pimh.append({"bias": np.mean(estimations_pimh[N]) - ref_value,
                             "std deviation": np.std(estimations_pimh[N]),
                             "N": N,
                             "replications": len(estimations_pimh[N]),
                             "algorithm": "Unbiased-PIMH Hard" if hard else "Unbiased-PIMH Soft",
                             "average M": total_running_cost / minimal_n_replications,
                             "fails": n_zeros})

    results_pimh = pd.DataFrame.from_records(results_pimh)
    return results_pimh

In [34]:
def build_snis_estimates_for_budget(budget,
                                    minimal_n_replications,
                                    all_f_values,
                                    all_log_weights,
                                    ref_value,
                                    device):
    estimations = snis(f_values=all_f_values[:budget*minimal_n_replications].reshape(minimal_n_replications, budget, 1).to(device),
                      log_weights=all_log_weights[:budget*minimal_n_replications].reshape(minimal_n_replications, budget).to(device),
                      ).cpu()
    result_snis = [{"N": budget,
                    "algorithm": "SNIS",
                    "bias": estimations.mean().item() - ref_value,
                    "std deviation": estimations.std().item(),
                    "replications": estimations.shape[0],
                    "average M": budget}]
    return pd.DataFrame.from_records(result_snis)

In [37]:
def build_br_snis_estimates_for_budget(budget,
                                    minimal_n_replications,
                                    all_f_values,
                                    all_log_weights,
                                    ref_value,
                                    n_chains,   
                                    device):
    br_snis_estimations = {}

    for minibatch_size in [64, 128, 256, 512]:
        k = budget // minibatch_size
        if k <= 1:
            break
        n_rep = minimal_n_replications
        if n_rep < n_chains:
            n_chains_br_snis = n_rep
        else:
            n_chains_br_snis = n_chains
            
        total_size_per_batch = (k * minibatch_size) * n_chains_br_snis
        br_snis_estimations[(minibatch_size, k)] = []
        for i in tqdm(range(n_rep // n_chains_br_snis + 1 * (n_rep % n_chains_br_snis > 0)),
                      desc=f'BR SNIS N={minibatch_size + 1}, k={k}, n_chains={n_chains_br_snis}'):
            if i >= n_rep // n_chains_br_snis:
                n_chains_br_snis = n_rep % n_chains_br_snis
                total_size_per_batch = (k * minibatch_size) * n_chains_br_snis
                
            start = randint(high=all_f_values.shape[0] - total_size_per_batch, size=(1,)).item()
            perm = np.arange(start, start + total_size_per_batch)
            
            estimations = br_snis(k_max=k,
                                  n_particles=minibatch_size,
                                  f_values=all_f_values[perm].reshape(n_chains_br_snis, budget, 1).to(device),
                                  log_weights=all_log_weights[perm].reshape(n_chains_br_snis, budget).to(device),
                                  n_bootstrap=k).cpu()
            br_snis_estimations[(minibatch_size, k)].append(estimations[:, :, :, 0].mean(dim=-1))
        br_snis_estimations[(minibatch_size, k)] = cat(br_snis_estimations[(minibatch_size, k)], dim=1)
    results_br_snis = []
    estimations_br_snis = {}
    for (minibatch_size, k), estimations in br_snis_estimations.items():
        estimate = estimations[-ceil(.1*k):].mean(dim=0)
        estimations_br_snis[(minibatch_size, k)] = estimate.numpy()
        results_br_snis.append({"N": minibatch_size + 1,
                             "k": k,
                             "algorithm": "BR-SNIS",
                             "bias": estimate.mean().item() - ref_value,
                             "std deviation": estimate.std().item(),
                             "replications": estimations.shape[-1],
                             "average M": budget})
    return pd.DataFrame.from_records(results_br_snis)

## Generating dataset from Mixture of Gaussians example

In [8]:
device = 'cuda:6' # cpu 
target, proposal = get_toy_problem_distributions(device=device)
f = build_f(device=device)
n_total_particles = 10_000_000_000
n_batch = 500_000
with no_grad():
    all_log_weights = []
    all_f_values = []
    for i in tqdm(range(n_total_particles // n_batch), desc="Generating samples"):
        particles = proposal.sample((n_batch,))
        log_weights = target.log_prob(particles) - proposal.log_prob(particles)
        f_values = f(particles)

        all_log_weights.append(log_weights.squeeze(0).cpu())
        all_f_values.append(f_values.squeeze(0).cpu())

del particles, log_weights, f_values
if device != 'cpu':
    from torch.cuda import empty_cache
    empty_cache()
    
all_log_weights = cat(all_log_weights).cpu()
all_f_values = cat(all_f_values).cpu()

Generating samples: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [02:58<00:00, 112.21it/s]


### Calculating reference value by using SNIS with all the particles

In [9]:
ref_value = snis(log_weights=all_log_weights.reshape(1, -1),
                 f_values=all_f_values.reshape(1, -1, 1)).item()

## Running all algorithms

In [45]:
by_budget_results = {}
minimal_n_replications = 1024
for budget in [2**9, 2**13, 2**16]:
    results_snis = build_snis_estimates_for_budget(budget,
                                               minimal_n_replications,
                                               all_f_values,
                                               all_log_weights,
                                               ref_value,
                                               device)
    
    results_br_snis = build_br_snis_estimates_for_budget(budget,
                                                        minimal_n_replications,
                                                        all_f_values,
                                                        all_log_weights,
                                                        ref_value,
                                                        minimal_n_replications,   
                                                        device)

    results_hard_unbiased_pimh = build_unbiased_pimh_estimates_for_budget(budget,
                                                                              minimal_n_replications,
                                                                              target=target,
                                                                              proposal=proposal,
                                                                              h=f,
                                                                              ref_value=ref_value,
                                                                          hard=True,
                                                                              plot=False)
    results_soft_unbiased_pimh = build_unbiased_pimh_estimates_for_budget(budget,
                                                                          minimal_n_replications,
                                                                          target=target,
                                                                          proposal=proposal,
                                                                          h=f,
                                                                          ref_value=ref_value,
                                                                      hard=False,
                                                                        plot=False)
    by_budget_results[budget] = pd.concat([results_snis, 
                                           results_br_snis,
                                           results_hard_unbiased_pimh,
                                          results_soft_unbiased_pimh],
                                          axis=0)
    print(by_budget_results[budget].to_string())


BR SNIS N=65, k=8, n_chains=1024: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 81.08it/s]
BR SNIS N=129, k=4, n_chains=1024: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 170.56it/s]
BR SNIS N=257, k=2, n_chains=1024: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 252.47it/s]
PIMH Unbiased-PIMH HARD 128: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [00:08<00:00, 122.02it/s]
PIMH Unbiased-PIMH HARD 64: 100%|███████████████████████████████████

     N           algorithm      bias  std deviation  replications   average M    k  fails
0  512                SNIS -0.147887       0.231681          1024  512.000000  NaN    NaN
0   65             BR-SNIS -0.144086       0.257288          1024  512.000000  8.0    NaN
1  129             BR-SNIS -0.147218       0.255165          1024  512.000000  4.0    NaN
2  257             BR-SNIS -0.140107       0.259455          1024  512.000000  2.0    NaN
0  128  Unbiased-PIMH Hard -0.200403       0.207071           834  312.750000  NaN  190.0
1   64  Unbiased-PIMH Hard -0.195070       0.306194           953  374.437500  NaN   71.0
2   32  Unbiased-PIMH Hard -0.189756       0.454365           984  402.250000  NaN   40.0
3   16  Unbiased-PIMH Hard -0.187309       0.537857          1006  429.000000  NaN   18.0
0  128  Unbiased-PIMH Soft  0.148017       1.789264          1024  870.250000  NaN    0.0
1   64  Unbiased-PIMH Soft  0.123835       3.172055          1024  753.750000  NaN    0.0
2   32  Un

BR SNIS N=65, k=128, n_chains=1024: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.84it/s]
BR SNIS N=129, k=64, n_chains=1024: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.89it/s]
BR SNIS N=257, k=32, n_chains=1024: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10.35it/s]
BR SNIS N=513, k=16, n_chains=1024: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.07it/s]
PIMH Unbiased-PIMH HARD 2048: 100%|█████████████████████████████████

      N           algorithm      bias  std deviation  replications  average M      k  fails
0  8192                SNIS -0.032600       0.148053          1024    8192.00    NaN    NaN
0    65             BR-SNIS -0.004890       0.176529          1024    8192.00  128.0    NaN
1   129             BR-SNIS -0.003195       0.182782          1024    8192.00   64.0    NaN
2   257             BR-SNIS -0.015152       0.177535          1024    8192.00   32.0    NaN
3   513             BR-SNIS  0.002019       0.192616          1024    8192.00   16.0    NaN
0  2048  Unbiased-PIMH Hard -0.054032       0.209503           967    6280.00    NaN   57.0
1  1024  Unbiased-PIMH Hard -0.039495       0.308501          1000    6822.00    NaN   24.0
2   512  Unbiased-PIMH Hard -0.048938       0.326717          1018    7317.00    NaN    6.0
3   256  Unbiased-PIMH Hard -0.028220       0.396899          1021    7485.75    NaN    3.0
0  2048  Unbiased-PIMH Soft  0.048879       0.471244          1024   10256.00   

BR SNIS N=65, k=1024, n_chains=1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:19<00:00, 19.38s/it]
BR SNIS N=129, k=512, n_chains=1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:09<00:00,  9.22s/it]
BR SNIS N=257, k=256, n_chains=1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.51s/it]
BR SNIS N=513, k=128, n_chains=1024: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.32s/it]
PIMH Unbiased-PIMH HARD 16384: 100%|████████████████████████████████

       N           algorithm      bias  std deviation  replications  average M       k  fails
0  65536                SNIS -0.004101       0.065690          1024    65536.0     NaN    NaN
0     65             BR-SNIS -0.000468       0.065383          1024    65536.0  1024.0    NaN
1    129             BR-SNIS -0.000112       0.065356          1024    65536.0   512.0    NaN
2    257             BR-SNIS -0.000666       0.066600          1024    65536.0   256.0    NaN
3    513             BR-SNIS -0.000179       0.066557          1024    65536.0   128.0    NaN
0  16384  Unbiased-PIMH Hard -0.002150       0.086162          1019    57824.0     NaN    5.0
1   8192  Unbiased-PIMH Hard -0.002680       0.093069          1024    59112.0     NaN    0.0
2   4096  Unbiased-PIMH Hard  0.004311       0.114922          1024    61776.0     NaN    0.0
3   2048  Unbiased-PIMH Hard -0.000146       0.119510          1024    63198.0     NaN    0.0
0  16384  Unbiased-PIMH Soft  0.002075       0.090535       


