In [1]:
import jax.numpy as jnp

import pickle
import numpy as np

from pysketch.utils import log_gen_fac

# Full Sketch DP vs PY

In [2]:
from pysketch.species import py_pred_full, dp_pred, py_pred_single

In [9]:
J = 10

sketches = []
js = []

sketch = np.ones(J) * 50 / J
sketch = np.array(sketch, dtype=np.int32)
sketches.append(sketch)
js.append(5)


probs = 1 / (4 + np.arange(1, J+1)) ** 2
probs = probs / np.sum(probs)
sketch = np.array(55 * probs, dtype=np.int32)
sketches.append(sketch)
js.append(3)


probs = 1 / (4 + np.arange(1, J+1))
probs = probs / np.sum(probs)
sketch = np.array(55 * probs, dtype=np.int32)
sketches.append(sketch)
js.append(3)

In [None]:
py_pmfs_all = []
gamma = 1
sigmas = [0.1, 0.3, 0.5]

for sigma in sigmas:
    print("sigma: ", sigma)
    log_gen_fac_table = jnp.array(log_gen_fac(sigma, 1000))
    py_pmfs = []
    for sketch, j in zip(sketches, js):
        print(sketch)
        sketch = jnp.array(sketch)
        l_max = sketch[j] + 1
        curr = py_pred_full(l_max, sketch, gamma, sigma, j, J, log_gen_fac_table)
        py_pmfs.append((np.arange(l_max), curr))
        
    py_pmfs_all.append(py_pmfs)
    
with open("py_pmfs_all_sigmas.pickle", "wb") as fp:
    out = {"gamma": gamma, "sketches": sketches, "js": js, "sigmas": sigmas,
          "py_pmfs": py_pmfs_all}
    pickle.dump(out, fp)

In [6]:
py_pmfs_single = []
gamma = 1
sigmas = [0.1, 0.3, 0.5]
J = 10

for sigma in sigmas:
    log_gen_fac_table = jnp.array(log_gen_fac(sigma, 1000))
    py_pmfs_single.append(py_pred_single(5, 5, 50, gamma, sigma, J, log_gen_fac_table))
    
with open("py_pmfs_single_sigmas.pickle", "wb") as fp:
    out = {"gamma": gamma, "sigmas": sigmas,
          "py_pmfs": py_pmfs_single}
    pickle.dump(out, fp)

In [10]:
dp_pmf = dp_pred(5, sketches[-1], gamma, js[-1], J)

with open("dp_pmf.pickle", "wb") as fp:
    out = {"gamma": gamma, "dp_pmf": dp_pmf}
    pickle.dump(out, fp)

# Single Bucket Posteriors

In [None]:
from pysketc.traits import ngg_pred_single

In [None]:
import pickle

m = 100
c = 10
J = 10

l_max = c


theta = 2
sigmas = [0, 0.25, 0.5, 0.75]
py_lpmfs = []
for sig in sigmas:
    if sig == 0:
        py_lpmfs.append(dp_lpmf(l_max, c, m, theta, J))
    else:
        log_gen_fac_table = log_gen_fac(sigma, 120)
        py_lpmfs.append(py_lpmf_single(l_max, c, m, theta, sigma, J, log_gen_fac_table))

with open("py_lpmfs_single_new.pickle", "wb") as fp:
    out = {"m": m, "c": c, "J": J, "theta": theta, "sigmas": sigmas,
           "lpmfs": py_lpmfs}
    pickle.dump(out, fp)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=len(sigmas), figsize=(15, 5))

for i in range(len(sigmas)):
    x = np.arange(l_max + 1)
    axes[i].bar(x, py_lpmfs[i])
    axes[i].set_title("Gamma: {0}, sigma: {1}".format(theta, sigmas[i]), fontsize=15)
    axes[i].set_ylim((0, 0.6))
plt.savefig("py_lpmfs_single_new.pdf")

In [None]:
sigma = 0.25
log_gen_fac_table = log_gen_fac(sigma, 120)

betas = [0.3, 0.7, 1.1, 1.5]
ngg_lpmfs = []
for beta in betas:
    ngg_lpmfs.append(ngg_pred_single(l_max, c, m, sigma, beta, J))
    
    
with open("ngg_lpmfs_single_new.pickle", "wb") as fp:
    out = {"m": m, "c": c, "J": J, "sigma": sigma, "betas": betas,
           "lpmfs": ngg_lpmfs}
    pickle.dump(out, fp)

In [None]:
sigma2 = 0.75
log_gen_fac_table = log_gen_fac(sigma, 120)

betas = [0.3, 0.7, 1.1, 1.5]
ngg_lpmfs2 = []
for beta in betas:
    ngg_lpmfs2.append(ngg_pred_single(l_max, c, m, sigma2, beta, J))
    
    
with open("ngg_lpmfs_single2_new.pickle", "wb") as fp:
    out = {"m": m, "c": c, "J": J, "sigma": sigma, "betas": betas,
           "lpmfs": ngg_lpmfs2}
    pickle.dump(out, fp)

# Trait Sketch: Poisson likelihood

In [None]:
from pysketch.trait import poisson_gamma_pred, poisson_gg_pred

In [None]:
m = 1000
c = 50
bs = [5, 10, 15]
as_ = [1, 2, 3, 4]

theta = 0.3
J = 50

l = np.arange(c+1)

fig, axes = plt.subplots(nrows=len(as_), ncols=len(bs), figsize=(15, 15))
axes = axes.flat

tau = 1
r = 1

sigmas = [0.25, 0.75]

for i, (a, b) in enumerate(product(as_, bs)):
    poi_pmf = np.exp(poisson_gamma_pred(l, c, b, a, theta, J))
    poi_pmf /= np.sum(poi_pmf)
    axes[i].bar(l, poi_pmf, alpha = 0.15)
    axes[i].plot(l, poi_pmf, lw=3, label= "Gamma")
    axes[i].set_title("a: {0}, b: {1}".format(a, b), fontsize=15)

for sigma in sigmas:
    log_gen_fac_table = log_gen_fac(sigma, 1000)

    for i, (a, b) in enumerate(product(as_, bs)):
        poi_pmf = softmax(poisson_gg_pred(l, c, b, a, m, theta, sigma, tau, r, J))
        axes[i].bar(l, poi_pmf, alpha = 0.15)
        axes[i].plot(l, poi_pmf, lw=3, label= "GG, sigma: {0}".format(sigma))

        
    plt.tight_layout()
plt.savefig("poi_post_new.pdf", bbox_inches="tight")

# plt.plot(poi_lpmf)
# plt.plot(softmax(dp_pred(l, c, theta, J)))