In [None]:
from functools import partial

import gsd
import gsd.experimental as gsde
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import requests
import tensorflow_probability.substrates.jax as tfp
from gsd.experimental.bootstrap import pp_plot_data
from gsd.experimental.fit import GridEstimator
from gsd.fit import GSDParams, allowed_region, make_logits
from jax import Array
from jax.flatten_util import ravel_pytree
from jax.typing import ArrayLike


tfd = tfp.distributions
tfb = tfp.bijectors

Lets use one experiment form sureal library

In [None]:
url = "https://raw.githubusercontent.com/Netflix/sureal/master/test/resource/NFLX_dataset_public_raw.py"  
dataset = {}
try:
    response = requests.get(url)
    if response.status_code == 200:
        content = response.text
        exec(content, {}, dataset)
    else:
        print(f"Failed to retrieve the file. Status code: {response.status_code}")
except requests.RequestException as e:
    print(f"Error fetching the file: {e}")

In [None]:
o = np.asarray([v["os"] for v in dataset["dis_videos"]])
print(o.shape)
counts = jax.vmap(gsd.sufficient_statistic)(o)

In [None]:
hdtv = False
#hdtv = True

In [None]:
import pandas as pd

if hdtv:
    hdtv = pd.read_csv(
        "/Users/krzysiek/Documents/lts_analysis_soft/pyits/log/hdtv_data.csv"
    )
    exp1 = hdtv[hdtv.Experiment == 1]
    exp1
    n_sub = 24
    n_pvs = 168
    o = np.zeros((n_pvs, n_sub))
    o = pd.pivot(exp1, columns="Tester_id", values="Score", index="PVS_id").to_numpy()
 
    counts = jax.vmap(gsd.sufficient_statistic)(o)

In [None]:
@jax.jit
def gsdfit(x: Array):
    params, opt_state = gsde.fit_mle(data=x, max_iterations=200)
    return params

Fit model for a single PVS

In [None]:
gsdfit(counts[0])

And compare the fit to the one estimated without a gradient:


In [None]:
theta0 = GSDParams(psi=2.0, rho=0.9)
x0, unravel_fn = ravel_pytree(theta0)


def nll(x: ArrayLike, data: Array) -> Array:
    logits = make_logits(unravel_fn(x))
    #tv = allowed_region(logits, data.sum())
    ret = -jnp.dot(logits, data)

    return ret


@jax.jit
def tfpfit(data: Array):
    initial_simplex = np.asarray(
        [
            [4.9, 0.1],
            [1.1, 0.9],
            [4.9, 0.9],
        ]
    )
    results = tfp.optimizer.nelder_mead_minimize(
        partial(nll, data=data), initial_simplex=jnp.asarray(initial_simplex)
    )
    return unravel_fn(results.position)

In [None]:
[gsdfit(counts[0]), tfpfit(counts[0])]

Let's estimate parameter for all the PVSs. 
For this we are going to use `jax.lax.map`.
 _Note, that `vmap` is nor best here as each estimatio contain control flow instructions._
  
 

In [None]:
fits = jax.lax.map(gsdfit, counts)

In [None]:
num = GSDParams(512, 128)
grid = GridEstimator.make(num)

n = 40
n = 3
print(counts[n])
print(jax.tree_util.tree_map(lambda x: x[n], fits))

print(tfpfit(counts[n]))
print(grid(counts[n]))
print(gsde.fit_mle_grid(counts[n], num=num, constrain_by_pmax=False))

# PP-plot

In [None]:
key = jax.random.key(42)
keys = jax.random.split(key, counts.shape[0])


@jax.jit
def estimator(x):
    return grid(x)

n_b=999

pvals = np.stack(
    [
        pp_plot_data(c, estimator=estimator, key=key, n_bootstrap_samples=n_b)
        for c, key in zip(counts, keys)
    ]
)


In [None]:
from scipy.stats import norm
import matplotlib.pyplot as plt

def pp_plot(pvalues: np.ndarray, thresh_pvalue=0.2):

    n_pvs = len(pvalues)
    ref_p_values = np.linspace(start=0.001, stop=thresh_pvalue, num=100)
    significance_line = ref_p_values + norm.ppf(0.95) * np.sqrt(
        ref_p_values * (1 - ref_p_values) / n_pvs
    )

    def count_pvs_fraction(p_value, p_value_per_pvs):
        return jnp.sum(p_value_per_pvs <= p_value) / len(p_value_per_pvs)

    pvs_fraction_gsd = np.asarray(
        jax.vmap(count_pvs_fraction, in_axes=(0, None))(pvalues, pvalues)
    )

    plt.scatter(pvalues, pvs_fraction_gsd, label="GSD")

    plt.xlabel("theoretical uniform cdf")
    plt.ylabel("ecdf of $p$-values")
    plt.plot(ref_p_values, significance_line, "-k")
    plt.xlim([0, thresh_pvalue])
    plt.ylim([0, thresh_pvalue + 0.1])
    plt.minorticks_on()
    plt.show()


pp_plot(pvals)