In [1]:
import matplotlib.pyplot as plt

In [2]:
from posteriordb import PosteriorDatabase

pdb_path = "/workspaces/jupyter-data/posteriordb/posterior_database/"
my_pdb = PosteriorDatabase(pdb_path)
model_name = "diamonds"
model = my_pdb.model(model_name)
data = my_pdb.data(model_name)

In [4]:
print(model.stan_code())

// generated with brms 2.10.0

functions {
}
data {
  int<lower=1> N;  // number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  // temporary intercept for centered predictors
  real Intercept;
  real<lower=0> sigma;  // residual SD
}
transformed parameters {
}
model {
  // priors including all constants
  target += normal_lpdf(b | 0, 1);
  target += student_t_lpdf(Intercept | 3, 8, 10);
  target += student_t_lpdf(sigma | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  // likelihood includ

In [3]:
import torch

N = torch.tensor(data.values()["N"])
k = torch.tensor(data.values()["K"])
y = torch.tensor(data.values()["Y"])
x = (
    torch.tensor(data.values()["X"]) - torch.tensor(data.values()["X"]).mean(dim=0)
)[:, 1:]

In [4]:
def p(alpha, beta, sigma):
    return (
        dist.Independent(dist.Normal(torch.zeros(k-1), torch.ones(k-1)), 1).log_prob(beta) +
        dist.StudentT(3, 8, 10).log_prob(alpha) +
        dist.StudentT(3, 0, 10).log_prob(sigma) - torch.log(torch.tensor(0.5)) +
        dist.Independent(dist.Normal((alpha + x @ beta.T).T, sigma.unsqueeze(1)), 1).log_prob(y)
    )

In [7]:
import torch
import torch.distributions as dist
import flowtorch.bijectors
import flowtorch.params

flow = flowtorch.bijectors.AffineAutoregressive(
    flowtorch.params.DenseAutoregressive(
        hidden_dims=(32,32,32)
    )
)
base_dist = dist.Normal(torch.zeros(k+2), torch.ones(k+2))
new_dist, params = flow(base_dist)

In [8]:
import torch.optim as optim
import torch.nn as nn
from tqdm.auto import tqdm

optimizer = optim.Adam(params.parameters(), lr=1e-2)

n = 1000
for _ in tqdm(range(int(1e3))):
    optimizer.zero_grad()
    samples = new_dist.rsample((n,))
    log_q = new_dist.log_prob(samples)

    alpha = samples[:,0]
    beta = samples[:,1:k]
    sigma = samples[:,k].exp()
    log_p = p(alpha, beta, sigma) #+ samples[:,k] * sigma
    loss = (log_q - log_p).mean()
    loss.backward(retain_graph=True)
    optimizer.step()

    marginal_lik = torch.logsumexp(log_p - log_q - torch.log(torch.tensor([n])), dim=0)
    tqdm.write(f"ELBO: {-loss}, log p(y): {marginal_lik}", end='')

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2744.085205078125, log p(y): 2855.183349609375555


In [6]:
import torch
import torch.distributions as dist
import flowtorch.bijectors
import flowtorch.params
import torch.optim as optim
import torch.nn as nn
from tqdm.auto import tqdm

data = []

for trial in range(100):
    for num_layers in range(1,5):
        flow = flowtorch.bijectors.AffineAutoregressive(
            flowtorch.params.DenseAutoregressive(
                hidden_dims=(32,)*num_layers
            )
        )
        base_dist = dist.Normal(torch.zeros(k+2), torch.ones(k+2))
        new_dist, params = flow(base_dist)
        optimizer = optim.Adam(params.parameters(), lr=1e-2)

        n = 1000
        for _ in tqdm(range(int(1e3))):
            optimizer.zero_grad()
            samples = new_dist.rsample((n,))
            if torch.isnan(samples).any(): continue

            log_q = new_dist.log_prob(samples)

            alpha = samples[:,0]
            beta = samples[:,1:k]
            sigma = samples[:,k].exp()
            log_p = p(alpha, beta, sigma) #+ samples[:,k] * sigma
            loss = (log_q - log_p).mean()
            loss.backward(retain_graph=True)
            optimizer.step()

            marginal_lik = torch.logsumexp(log_p - log_q - torch.log(torch.tensor([n])), dim=0)
            tqdm.write(f"ELBO: {-loss}, log p(y): {marginal_lik}", end='')

        if marginal_lik < 2500: continue
        data.append({
            'method': 'ADVI',
            'num_layers': num_layers,
            'elbo': -loss.item(),
            'log_p_y': marginal_lik.item(),
        })

        # TAF, train just DF
        dfs = nn.Parameter(data=torch.ones(1)*10)
        new_dist.base_dist = dist.StudentT(df=dfs*torch.ones(k+2), loc=torch.zeros(k+2), scale=torch.ones(k+2))
        optimizer = optim.Adam([dfs], lr=1e-2)
        for _ in tqdm(range(int(1e2))):
            optimizer.zero_grad()
            samples = new_dist.rsample((n,))
            if torch.isnan(samples).any(): continue
            log_q = new_dist.log_prob(samples)

            alpha = samples[:,0]
            beta = samples[:,1:k]
            sigma = samples[:,k].exp()
            log_p = p(alpha, beta, sigma) #+ samples[:,k] * sigma
            loss = (log_q - log_p).mean()
            loss.backward(retain_graph=True)
            optimizer.step()

            marginal_lik = torch.logsumexp(log_p - log_q - torch.log(torch.tensor([n])), dim=0)
            tqdm.write(f"ELBO: {-loss}, log p(y): {marginal_lik}", end='')

        data.append({
            'method': 'TAF',
            'num_layers': num_layers,
            'elbo': -loss.item(),
            'log_p_y': marginal_lik.item(),
        })

        # TAF, train just DF
        dfs = nn.Parameter(data=dfs.data*torch.ones(k+2))
        new_dist.base_dist = dist.StudentT(df=dfs, loc=torch.zeros(k+2), scale=torch.ones(k+2))
        optimizer = optim.Adam([dfs], lr=1e-2)
        for _ in tqdm(range(int(1e2))):
            optimizer.zero_grad()
            samples = new_dist.rsample((n,))
            if torch.isnan(samples).any(): continue
            log_q = new_dist.log_prob(samples)

            alpha = samples[:,0]
            beta = samples[:,1:k]
            sigma = samples[:,k].exp()
            log_p = p(alpha, beta, sigma) #+ samples[:,k] * sigma
            loss = (log_q - log_p).mean()
            loss.backward(retain_graph=True)
            optimizer.step()

            marginal_lik = torch.logsumexp(log_p - log_q - torch.log(torch.tensor([n])), dim=0)
            tqdm.write(f"ELBO: {-loss}, log p(y): {marginal_lik}", end='')

        data.append({
            'method': 'ATAF',
            'num_layers': num_layers,
            'elbo': -loss.item(),
            'log_p_y': marginal_lik.item(),
        })
        
pickle.dump(data, open("diamonds.pkl", "wb"))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 856.3377685546875, log p(y): 2562.155517578125575


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 1004.0053100585938, log p(y): 2740.8449707031255


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 535.5269775390625, log p(y): 2807.1596679687555


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2923.075927734375, log p(y): 2981.962158203125552


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2904.009033203125, log p(y): 2979.973388671875


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2905.238525390625, log p(y): 2981.435302734375


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: -4923.806640625, log p(y): 2667.7080078125755555


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: -5033.8056640625, log p(y): 2731.0578613281255


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: -4963.328125, log p(y): 2645.75927734375203125


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: -4381.68359375, log p(y): 2513.9897460937587555


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: -4744.85498046875, log p(y): 2732.173583984375


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: -4553.74853515625, log p(y): 2667.333984375755


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2639.10205078125, log p(y): 2900.3315429687575555


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2580.299072265625, log p(y): 2893.175048828125


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2557.64111328125, log p(y): 2891.2597656253125


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 1131.0946044921875, log p(y): 1798.27490234375555


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2303.601318359375, log p(y): 2611.1179199218755552


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2309.259033203125, log p(y): 2672.373291015625


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2306.002685546875, log p(y): 2668.589355468755


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2414.762939453125, log p(y): 2757.3920898437555844


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2513.810546875, log p(y): 2820.873779296875525


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2518.125244140625, log p(y): 2839.578125256255


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: -7053.357421875, log p(y): 46.109443664550788868


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2824.724609375, log p(y): 2957.505859375093755555


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2775.265380859375, log p(y): 2967.718505859375


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2780.264892578125, log p(y): 2972.066894531255


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2524.40478515625, log p(y): 2812.5173339843755555


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2626.662841796875, log p(y): 2872.464599609375


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2616.268798828125, log p(y): 2862.489013671875


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 1761.2716064453125, log p(y): 2606.811767578125526


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 1883.1842041015625, log p(y): 2655.429199218755


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 1843.176025390625, log p(y): 2635.2065429687525


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2590.58935546875, log p(y): 2816.9677734375875555


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2498.173095703125, log p(y): 2834.150878906255


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

ELBO: 2487.996337890625, log p(y): 2899.777099609375


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

ELBO: 2627.158935546875, log p(y): 2817.402832031255555


KeyboardInterrupt: 

In [27]:
import pickle
#pickle.dump(data, open("diamonds.pkl", "wb"))
data = pickle.load(open("diamonds.pkl", "rb"))

In [28]:
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")

df = pd.DataFrame(data).melt(id_vars=['method','num_layers'])

In [29]:
df

Unnamed: 0,method,num_layers,variable,value
0,ADVI,1,elbo,856.337769
1,TAF,1,elbo,1004.005310
2,ATAF,1,elbo,535.526978
3,ADVI,2,elbo,2923.075928
4,TAF,2,elbo,2904.009033
...,...,...,...,...
61,TAF,4,log_p_y,2655.429199
62,ATAF,4,log_p_y,2635.206543
63,ADVI,1,log_p_y,2816.967773
64,TAF,1,log_p_y,2834.150879


In [30]:
df.groupby(['method', 'num_layers', 'variable']).describe()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,value,value,value,value,value,value,value,value
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,count,mean,std,min,25%,50%,75%,max
method,num_layers,variable,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
ADVI,1,elbo,3.0,2028.676392,1015.564747,856.337769,1723.463562,2590.589355,2614.845703,2639.102051
ADVI,1,log_p_y,3.0,2759.818278,176.182589,2562.155518,2689.561646,2816.967773,2858.649658,2900.331543
ADVI,2,elbo,2.0,2873.900269,69.544884,2824.724609,2849.312439,2873.900269,2898.488098,2923.075928
ADVI,2,log_p_y,2.0,2969.734009,17.293215,2957.505859,2963.619934,2969.734009,2975.848083,2981.962158
ADVI,3,elbo,3.0,-31.933512,4237.924677,-4923.806641,-1310.102661,2303.601318,2414.003052,2524.404785
ADVI,3,log_p_y,3.0,2697.114421,103.870033,2611.11792,2639.412964,2667.708008,2740.112671,2812.517334
ADVI,4,elbo,3.0,-68.549683,3749.547417,-4381.683594,-1310.205994,1761.271606,2088.017273,2414.762939
ADVI,4,log_p_y,3.0,2626.064535,122.838011,2513.989746,2560.400757,2606.811768,2682.101929,2757.39209
ATAF,1,elbo,3.0,1860.388143,1147.891732,535.526978,1511.761658,2487.996338,2522.818726,2557.641113
ATAF,1,log_p_y,3.0,2866.065511,51.191406,2807.159668,2849.209717,2891.259766,2895.518433,2899.7771


In [32]:
import numpy as np

gs = my_pdb.posterior("diamonds-diamonds").reference_draws()
gs_dict = {}
num_chains = len(gs)
num_samples = len(gs[0][next(iter(gs[0]))])

for i,chain in enumerate(gs):
    for var in chain:
        if '[' not in var:
            if var not in gs_dict:
                gs_dict[var] = np.zeros((num_chains, num_samples))
            gs_dict[var][i,:] = np.array(chain[var])
        else:
            name = var.split('[')[0]
            idx = int(var.split('[')[1].split(']')[0]) - 1
            if name not in gs_dict:
                var_size = len(list(filter(lambda x: x.startswith(name), chain)))
                gs_dict[name] = np.zeros((num_chains,num_samples,var_size))
            gs_dict[name][i,:,idx] = np.array(chain[var])
            
data = my_pdb.data(model_name)
N = torch.tensor(data.values()["N"])
k = torch.tensor(data.values()["K"])
y = torch.tensor(data.values()["Y"])
x = (
    torch.tensor(data.values()["X"]) - torch.tensor(data.values()["X"]).mean(dim=0)
)[:, 1:]

In [33]:
gs_dict.keys()

dict_keys(['b', 'Intercept', 'sigma'])

In [38]:
list(map(lambda x: x.shape, [
    torch.tensor(gs_dict['Intercept']),
    torch.tensor(gs_dict['b']),
    torch.tensor(gs_dict['sigma'])
]))

[torch.Size([10, 1000]), torch.Size([10, 1000, 24]), torch.Size([10, 1000])]

In [40]:
import torch.distributions as dist
            
#log_p = p(alpha, beta, sigma) #+ samples[:,k] * sigma
ps = p(
    torch.tensor(gs_dict['Intercept']).flatten(end_dim=1).float(),
    torch.tensor(gs_dict['b']).flatten(end_dim=1).float(),
    torch.tensor(gs_dict['sigma']).flatten(end_dim=1).float(),
)
print(ps.mean(), ps.std() / np.sqrt(ps.shape[0]))

tensor(3274.5872) tensor(0.0366)


In [42]:
y.shape

torch.Size([5000])

In [36]:
p_gs.shape

NameError: name 'mu' is not defined

In [51]:
torch.tensor(gs_dict['tau']).flatten(end_dim=1).outer(torch.ones(3)).shape

torch.Size([10000, 3])

In [38]:
gs_dict['tau'].shape

(10, 1000)

In [39]:
gs_dict['theta'].shape

(10, 1000, 8)