# Poisson RGG Using Pyro

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
sys.path.append('../utils')

In [None]:
from porpoise.plot import Style
Style.from_default().apply()

In [None]:
from graphs import (get_independent_components_rgg, make_inter_vertex_distances, 
                    deg_corrected_poissonian_random_geometric_graph)

In [None]:
from plot import plot_multigraph
from networkx.linalg.graphmatrix import adjacency_matrix
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import math
from time import time
import pandas as pd

In [None]:
SEED = 42
n = 10
dim = 2

In [None]:
lambda_r_truth = 2
exponential_kernel = lambda dist, ki, kj: ki * kj * math.exp(-lambda_r_truth * dist)

In [None]:
kolness_truth = np.ones(n)
kolness_truth[0] = 10

In [None]:
r = np.random.RandomState(SEED)
G_poisson = deg_corrected_poissonian_random_geometric_graph(
    n, 1000, kolness_truth, exponential_kernel, r, 
)

In [None]:
adj = adjacency_matrix(G_poisson)
adj = adj.toarray()

In [None]:
ax = plot_multigraph(G_poisson)

In [None]:
distances = make_inter_vertex_distances(G_poisson)

## Inference

In [None]:
import pyro.distributions as dist
import pyro
from torch.distributions import constraints
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import torch

In [None]:
distances = torch.tensor(distances)
distances

In [None]:
adj = torch.tensor(adj)
adj

$$P(G|D, \theta, k) = \prod_{i<j} \frac{(k_i k_j \mathcal{F}(D_{ij}, \theta))^{A_{ij}}}{A_{ij}!} \exp(- k_i k_j \mathcal{F}(D_{ij}, \theta))$$

In [None]:
g = dist.Gamma(1,0.1)

In [None]:
x = np.linspace(0.01,20,100)
plt.plot(x, 10**g.log_prob(x))

In [None]:
def dc_poisson_rgg_model(distances, adj):
    n_nodes = adj.shape[0]
    k = pyro.sample("kolness", dist.Exponential(0.1).expand([n_nodes]).to_event(1))  # tell pyro that these are all dependent variables
    lambda_r = pyro.sample("lambda_r", dist.Exponential(0.1))
    conn_kernel = torch.exp(-lambda_r*distances)
    
    assert k.shape == (n_nodes,)
    assert conn_kernel.shape == (n_nodes, n_nodes)
    
    r, c = torch.triu_indices(n_nodes, n_nodes, offset=1)    
    
    with pyro.plate("data_loop", len(r)) as i:
        rate_i = k[r[i]]*k[c[i]]*conn_kernel[r[i],c[i]]
        pyro.sample(f"obs", dist.Poisson(rate_i), obs=adj[r[i],c[i]]) 


## Exact inference

In [None]:
import arviz as az
from pyro.infer import MCMC, NUTS

In [None]:
nuts_kernel = NUTS(dc_poisson_rgg_model)

In [None]:
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=100, num_chains=4)

In [None]:
mcmc.run(distances, adj)

In [None]:
inferred = az.from_pyro(mcmc)

In [None]:
summary = az.summary(inferred) # , var_names = ['lambda_r', 'kolness']
summary

In [None]:
az.plot_trace(inferred, var_names = ['lambda_r', 'kolness']);

In [None]:
az.plot_posterior(inferred, var_names = ['lambda_r', 'kolness'], ref_val = [lambda_r_truth] + list(kolness_truth));
plt.tight_layout()

In [None]:
az.to_netcdf(inferred, 'pyro-poisson-rgg.netcdf')

In [None]:
inferred = az.from_netcdf('pyro-poisson-rgg.netcdf')

## MAP estimation

In [None]:
from pyro.infer.autoguide import AutoDelta

In [None]:
# def dc_poisson_rgg_guide_map(distances, adj):    
#     n_nodes = adj.shape[0]
#     kolness_map = pyro.param("kolness_map", torch.tensor(1).expand([n_nodes]), constraint=constraints.positive).to_event(1)
#     lambda_r_map = pyro.param("lambda_r_map", torch.tensor(1), constraint=constraints.positive)
#     kolness = pyro.sample("kolness", dist.Delta(kolness_map))
#     lambda_r = pyro.sample("lambda_r", dist.Delta(lambda_r_map))

dc_poisson_rgg_guide_map = AutoDelta(dc_poisson_rgg_model)
    

In [None]:
def train(model, guide, data, lr=0.01, n_steps=1000):
    pyro.clear_param_store()
    adam = pyro.optim.Adam({"lr": lr})
    svi = SVI(model, guide, adam, loss=Trace_ELBO())
    
    losses = []
    for step in range(n_steps):
        loss = svi.step(**data)
        losses.append(loss)
        if step % 100 == 0:
            print('[iter {}]  loss: {:.4f}'.format(step, loss))
    
    return losses

In [None]:
data = {'distances':distances, 'adj': adj}

In [None]:
losses = train(dc_poisson_rgg_model, dc_poisson_rgg_guide_map, data)

In [None]:
fig, ax = plt.subplots()
ax.plot(losses)
ax.set_title("ELBO")
ax.set_xlabel("step")
ax.set_ylabel("loss");

In [None]:
for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

In [None]:
map_estimate = pd.DataFrame(zip(pyro.param('AutoDelta.kolness').tolist(), kolness_truth), columns=['MAP', 'truth'])
map_estimate

In [None]:
ax = map_estimate.plot(x='truth', y='MAP', kind='scatter');

---

## Full SVI -- experimental

In [None]:
from torch.distributions.gamma import Gamma

In [None]:
def dc_poisson_rgg_guide(distances, adj):
    n_nodes = adj.shape[0]
    concentration_k = pyro.param('concentration_kolness', torch.tensor(1).expand([n_nodes]), constraint=constraints.positive, event_dim=1)
    rate_k = pyro.param('rate_kolness', torch.tensor(0.1).expand([n_nodes]), constraint=constraints.positive, event_dim=1)
    k = pyro.sample("kolness", dist.Gamma(concentration_k, rate_k).to_event(1))  # tell pyro that these are all dependent variables
    
    concentration_lr = pyro.param('concentration_lambda_r', torch.tensor(1), constraint=constraints.positive, event_dim=1)
    rate_lr = pyro.param('rate_lambda_r', torch.tensor(0.1), constraint=constraints.positive, event_dim=1)
    lambda_r = pyro.sample("lambda_r", dist.Gamma(concentration_lr, rate_lr))    

In [None]:
adam = pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)})
svi = SVI(dc_poisson_rgg_model, dc_poisson_rgg_guide, adam, loss=Trace_ELBO())

In [None]:
losses = train(dc_poisson_rgg_model, dc_poisson_rgg_guide, data)

In [None]:
fig, ax = plt.subplots()
ax.plot(losses)
ax.set_title("ELBO")
ax.set_xlabel("step")
ax.set_ylabel("loss");

In [None]:
inferred_svi = {k:v for k, v in pyro.get_param_store().items()}

In [None]:
fig, axs = plt.subplots(nrows=4, ncols=3, figsize=(5*3, 5*4))
axs = axs.ravel()

k_sp_upper = kolness_truth*5

for i in range(len(kolness_truth)):
    k_sp = np.linspace(0.01, k_sp_upper[i])
    axs[i].hist(inferred.posterior.data_vars['kolness'].values[:,:,i].ravel(), density=True, bins='auto', label='MCMC')
    g = Gamma(inferred_svi['concentration_kolness'][i], inferred_svi['rate_kolness'][i])    
    axs[i].plot(k_sp, np.exp(g.log_prob(k_sp).detach().numpy()), '-r', label='SVI (gamma)')
    axs[i].legend()
    axs[i].set_title(f'KOLness [{i}]')
    plt.tight_layout()

    
lr_sp = np.linspace(0.1, 8)    
axs[-2].hist(inferred.posterior.data_vars['lambda_r'].values.ravel(), density=True, bins='auto', label='MCMC')
g = Gamma(inferred_svi['concentration_lambda_r'], inferred_svi['rate_lambda_r'])     
axs[-2].plot(lr_sp, g.log_prob(lr_sp).exp().detach().numpy(), '-r', label='SVI (gamma)')
axs[-2].legend()
axs[-2].set_title(f'lambda_r')

axs[-1].axis('off')

plt.tight_layout()
plt.savefig(f'svi_{n}.png')