In [1]:
# Python Path non-sense
%load_ext autoreload
%autoreload 2

import sys
import os

sys.path = [x for x in sys.path if 'bayes_gsl' not in x]
new_path = '/Users/maxw/projects/gsl-bnn/' ## change this to the path of the src directory!
if new_path not in sys.path:
    sys.path.append(new_path)

# Now try importing your module using the absolute path as a check
from src.models import dpg_bnn


In [2]:
import numpy as np
import networkx as nx
import time
import os
import matplotlib.pyplot as plt
import pickle
from scipy.linalg import block_diag

import jax
import jax.numpy as jnp
import jax.random
from jax.random import PRNGKey
from jax import random as jax_random

import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro.infer import init_to_value
plt.style.use("bmh")
#from IPython.display import set_matplotlib_formats
import matplotlib_inline
if "NUMPYRO_SPHINXBUILD" in os.environ:
    #set_matplotlib_formats("svg")
    matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, autoguide

from src.utils import degrees_from_upper_tri
from src.iterates_and_unroll import unroll_dpg
from src.utils import adj2vec, vec2adj, edge_density
from src.models import dpg_bnn

from src.config import w_init_scale, lam_init_scale, altered_prior

from src.metrics import compute_metrics

### Load Synthetic Data

In [3]:
from src import SYNTHETIC_DATA_ROOT

graph_distribution = SYNTHETIC_DATA_ROOT + f"RG_N={20}_r={0.5}_dim={2}.pt"
num_signals = float('inf') # analytic Euclidean distance matrix
print(f"Loading Data: {graph_distribution}")
data_dict = pickle.load(open(graph_distribution, "rb"))

# data is pairs of adjacency matrices and euclidean distance matrices
adjacencies = data_dict['adjacencies'].astype(np.float32)
dataset_key = 'expected' if num_signals == float('inf') else str(num_signals)
euclidean_distance_matrices = data_dict[dataset_key]

# convert to vectors
adjacencies = adj2vec(adjacencies)
euclidean_distance_matrices = adj2vec(euclidean_distance_matrices)
num_edges = adjacencies.shape[-1]

# concatenate for easier processing
data = np.concatenate([euclidean_distance_matrices, adjacencies], axis=1)

# predetermine train/val/test split
num_train, num_val, num_test = 50, 50, 100
train, val, test = data[:num_train], data[num_train:num_train + num_val], data[num_train + num_val:]
data =  {"train": (train[:, :num_edges], train[:, num_edges:]), 
         "val": (val[:, :num_edges], val[:, num_edges:]),
         "test": (test[:, :num_edges], test[:, num_edges:])}

# unpack data
x_total, y_total = data['train']
n, num_edges = int(0.5*(np.sqrt(8 * x_total.shape[-1] + 1) + 1)), x_total.shape[-1]

num_train_samples_ = 50
x_total, y_total = jnp.array(x_total), jnp.array(y_total)
x, y = x_total[:num_train_samples_], y_total[:num_train_samples_]

Loading Data: /Users/maxw/projects/gsl-bnn/data/synthetic/RG_N=20_r=0.5_dim=2.pt


### Configure DPG

In [4]:
depth = 30

num_train_samples_ = 50

w_init, lam_init = w_init_scale * jnp.ones((num_train_samples_, num_edges)), lam_init_scale * jnp.ones((num_train_samples_, n))

S = jnp.array(degrees_from_upper_tri(n))
model = dpg_bnn.model
model_args = {'x': x, 'y': y,
              'depth': depth,
              'w_init': w_init, 'lam_init': lam_init,
              'S': S,
              'dummy': False,
              'prior_settings': altered_prior} # priors for model parameters

### Run Inference

In [5]:
import multiprocessing

In [6]:
# for parallelization for multiple chains. Must be done before jax import:
# Blackjax tutorial: https://blackjax-devs.github.io/blackjax/examples/howto_sample_multiple_chains.html
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)

print("Number of CPU cores:", jax.local_device_count()) # often will not allow multi-threading in Jupyter notebooks

Number of CPU cores: 1


In [7]:
print(f'\n\n********** Running Inference: depth={depth}, num_train_samples={num_train_samples_} **********')

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = jax_random.PRNGKey(0)
rng_key, rng_key_ = jax_random.split(rng_key)
kernel = NUTS(model, forward_mode_differentiation=True)
num_chains, num_warmup_samples, num_samples = 2, 100, 100 #jax.local_device_count()
mcmc = MCMC(kernel,
            num_warmup=num_warmup_samples, num_samples=num_samples,
            progress_bar=True,
            num_chains=num_chains, chain_method='parallel')
start_time = time.time()
mcmc.run(rng_key_, **model_args)
end_time = time.time()
print(f"Time taken for inference using {num_chains} with {num_warmup_samples} warmup samples and {num_samples} samples: {end_time - start_time}")
mcmc.print_summary()
print(f"^^********** Finished  **********^^\n\n")
samples = mcmc.get_samples()



********** Running Inference: depth=30, num_train_samples=50 **********


  mcmc = MCMC(kernel,
sample: 100%|██████████| 200/200 [00:14<00:00, 13.90it/s, 15 steps of size 4.99e-03. acc. prob=0.98] 
sample: 100%|██████████| 200/200 [00:11<00:00, 16.85it/s, 31 steps of size 3.24e-03. acc. prob=0.97]


Time taken for inference using 2 with 100 warmup samples and 100 samples: 27.661314249038696

                mean       std    median      5.0%     95.0%     n_eff     r_hat
         b     11.93      0.68     11.93     10.86     13.13     55.33      0.99
     delta    128.03      5.25    127.77    117.39    135.51     53.56      1.00
     theta      0.14      0.00      0.14      0.14      0.15     74.00      1.00

Number of divergences: 0
^^********** Finished  **********^^




In [8]:
num_test_samples = 100
x_total_test, y_total_test = data['test']
n, num_edges = int(0.5*(np.sqrt(8 * x_total_test.shape[-1] + 1) + 1)), x_total_test.shape[-1]
x_total_test, y_total_test = jnp.array(x_total_test), jnp.array(y_total_test)
x_test, y_test = x_total_test[:num_test_samples], y_total_test[:num_test_samples]
w_test, lam_test = w_init_scale * jnp.ones((num_test_samples, num_edges)), lam_init_scale * jnp.ones((num_test_samples, n))


# print the shapes of all the inputs to the model
print(f'samples[theta].shape: {samples["theta"].shape}')
print(f'samples[delta].shape: {samples["delta"].shape}')
print(f'samples[b].shape: {samples["b"].shape}')
print(f'x_test.shape: {x_test.shape}')
print(f'w_test.shape: {w_test.shape}')
print(f'lam_test.shape: {lam_test.shape}')
print(f'S.shape: {S.shape}')

samples[theta].shape: (200,)
samples[delta].shape: (200,)
samples[b].shape: (200,)
x_test.shape: (100, 190)
w_test.shape: (100, 190)
lam_test.shape: (100, 20)
S.shape: (20, 190)


In [9]:
dpg_bnn_forward_pass = dpg_bnn.forward_pass_vmap()

In [10]:
edge_logits = dpg_bnn_forward_pass(
    samples['theta'],
    samples['delta'],
    samples['b'],
    x_test,
    w_test,
    lam_test,
    depth,
    S)

In [11]:
from src import NUM_BINS

metrics_dict = compute_metrics(edge_logits, y_test, NUM_BINS)
calibration_dict = metrics_dict['calibration_dict']

print(f'Test Error: {1 - metrics_dict["accuracies"].mean():.5f} \pm {metrics_dict["accuracies"].std():.5f}')
print(f'Test NLL: {-1 * metrics_dict["log_likelihoods"].mean():.3f} \pm {metrics_dict["log_likelihoods"].std():.3f}')
print(f'Test BS: {metrics_dict["brier_scores"].mean():.5f} \pm {metrics_dict["brier_scores"].std():.5f}')
print(f'Test ECE:{calibration_dict["ece"]:.5f}')

Test Error: 0.01363 \pm 0.01274
Test NLL: 10.161 \pm 9.809
Test BS: 0.01129 \pm 0.00940
Test ECE:0.00246
