In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from jax.scipy.linalg import solve_triangular
from pykronecker import KroneckerProduct, KroneckerDiag
from jax.scipy.special import expit

from kreg.utils import cartesian_prod
from kreg.kernel.factory import vectorize_kfunc
from kreg.kernel import KroneckerKernel
from kreg.likelihood import BinomialLikelihood
from kreg.model import KernelRegModel

plt.style.use("ggplot")

In [None]:
age_grid = jnp.arange(0., 101, 1)
location_ids = jnp.arange(0, 1000)
location_signature = np.sort(jnp.vstack(
    [location_ids % 10, location_ids % 100, location_ids]
).T,axis=0)

full_X = cartesian_prod(location_ids, age_grid)


exp_a = 0.66
exp_b = 0.0


def k_region_single(x, y):
    return jnp.power(exp_a, jnp.sum(x != y) + exp_b)


def get_gaussianRBF(gamma):
    def f(x, y):
        return jnp.exp(-jnp.sum((x - y) ** 2) / (2 * gamma**2))

    return f


gamma = 10
k_age_rbf = get_gaussianRBF(gamma)


def k_age_single(x, y):
    return k_age_rbf(x, y) + 1.0# + 1e-10 * (x == y).all()


def k_combined_single(x, y):
    return k_region_single(x[0], y[0]) * k_age_single(x[1], y[1])


k_combined = vectorize_kfunc(k_combined_single)
k_region = vectorize_kfunc(k_region_single)
k_age = vectorize_kfunc(k_age_single)

In [None]:
K_regions = k_region(location_signature, location_signature)
K_age = k_age(age_grid, age_grid)

K = KroneckerProduct([K_regions, K_age])
P = KroneckerProduct([jnp.linalg.inv(K_regions), jnp.linalg.inv(K_age)])

reg_eigval, reg_eigvec = jnp.linalg.eigh(K_regions)
age_eigval, age_eigvec = jnp.linalg.eigh(K_age)
left = KroneckerProduct([reg_eigvec, age_eigvec])

kronvals = jnp.outer(reg_eigval, age_eigval)


def P_beta(lam, beta):
    return left @ KroneckerDiag(kronvals / (lam + beta * kronvals)) @ left.T


cholreg = jnp.linalg.cholesky(K_regions)
cholreg_inv = solve_triangular(cholreg, jnp.identity(len(K_regions)))

cholage = jnp.linalg.cholesky(K_age + 1e-8 * jnp.identity(len(K_age)))
cholage_inv = solve_triangular(cholage, jnp.identity(len(K_age)), lower=True)

chol_K = KroneckerProduct([cholreg, cholage])


# cholkron = jnp.kron(cholreg,cholage)
# cholkron_inv = jnp.kron(cholreg_inv,cholage_inv)

np.random.seed(12)
S = jnp.array(np.random.randn(len(location_ids) * len(age_grid)))
truth_sample_raw = chol_K @ S
offset = 0.2 * jnp.sin(full_X[:, 1] / 2) - 0.05 * full_X[:, 1]

offset_mat = offset.reshape(len(location_ids), len(age_grid))
truth_sample = truth_sample_raw + offset
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
truth_mat = truth_sample.reshape(len(location_ids), len(age_grid))
plt.title("With offset added")
for i in range(jnp.minimum(len(location_ids), 10)):
    plt.plot(truth_mat[i])
plt.subplot(1, 2, 2)
plt.title("Actual Gaussian Process Truth")
for i in range(jnp.minimum(len(location_ids), 10)):
    plt.plot(truth_mat[i] - offset_mat[i])
plt.show()

num_per_location = 80
sample_size_lower = 50
sample_size_upper = 100
observed_indices = jnp.vstack(
    [
        i * len(age_grid)
        + np.random.choice(len(age_grid), num_per_location, replace=False)
        for i in range(len(location_ids))
    ]
).flatten()

sample_sizes_observed = np.random.choice(
    range(sample_size_lower, sample_size_upper), len(observed_indices)
)
probs_observed = expit(truth_sample[observed_indices])

In [None]:
import pandas as pd
from kreg.kernel import KernelComponent, KroneckerKernel

In [None]:
offset = jnp.array(offset)

sample_sizes = jnp.zeros(len(location_signature) * len(age_grid))
sample_sizes = (
    sample_sizes.at[observed_indices].set(sample_sizes_observed).astype(int)
)
prob_vals = jnp.zeros(len(location_signature) * len(age_grid))
prob_vals = prob_vals.at[observed_indices].set(probs_observed)
obs_counts = np.random.binomial(sample_sizes, prob_vals)
obs_rate = jnp.zeros(len(location_signature) * len(age_grid))
obs_rate = obs_rate.at[observed_indices].set(
    obs_counts[observed_indices] / sample_sizes[observed_indices]
)

data = pd.DataFrame(
    dict(
        obs=obs_rate,
        counts=obs_counts,
        weights=sample_sizes,
        offset=offset,
    )
)

label_region = pd.DataFrame(
    data=location_signature,
    columns=["super_region_id", "region_id", "location_id"],
)
label_age = pd.DataFrame(dict(age_mid=age_grid))
label = label_region.merge(label_age, how="cross")
data = pd.concat([label, data], axis=1)
data = data.sort_values(["super_region_id", "region_id", "location_id","age_mid"])

kernel_components = [
    KernelComponent(["super_region_id", "region_id", "location_id"], k_region),
    KernelComponent("age_mid", k_age),
]

In [None]:
kernel = KroneckerKernel(kernel_components)
likelihood = BinomialLikelihood("obs", "weights", "offset")
lam = 1.0
model = KernelRegModel(kernel, likelihood, lam)
y_opt, conv = model.fit(
    data,
    gtol=5e-4,
    max_iter=50,
    cg_maxiter=100,
    cg_maxiter_increment=2,
    nystroem_rank=50,
)

In [None]:
plt.plot(conv["loss_vals"])
plt.yscale("log")

#### MSE/Var(truth)

In [None]:
jnp.mean((y_opt - truth_sample_raw) ** 2) / jnp.var(truth_sample_raw)

In [None]:
new_data = data.iloc[:4]
model.predict(new_data,y_opt)

In [None]:
model.kernel.attach(data)
kc = model.kernel.kernel_components

ind = 0
jnp.dot(
    jnp.kron(kc[0].kfunc(data.iloc[ind:ind+1][kc[0].name].values,kc[0].grid),kc[1].kfunc(data.iloc[ind:ind+1][kc[1].name].values,kc[1].grid)),
    model.kernel.op_p@y_opt)

In [None]:
y_opt[0]

In [None]:
model.kernel.attach(data)
kc = model.kernel.kernel_components

In [None]:
import jax

components = model.kernel.kernel_components
new_data = data.iloc[:4]
prediction_inputs = [jnp.array(new_data[kc.name].values) for kc in components]


from functools import reduce
def predict_single(*single_input):
    return jnp.dot(reduce(jnp.kron,[kc.kfunc(jnp.array([x]),kc.grid) for kc,x in zip(components,*single_input)]),model.kernel.op_p@y_opt)

predict_vec = jax.vmap(jax.jit(predict_single))
predict_vec(prediction_inputs)

In [None]:
def predict(new_data,y_opt):
    if model.kernel.matrices_computed is False:
        model.kernel.build_matrices()
    components = model.kernel.kernel_components
    prediction_inputs = [jnp.array(new_data[kc.name].values) for kc in components]
    from functools import reduce
    def _predict_single(*single_input):
        return jnp.dot(reduce(jnp.kron,[kc.kfunc(jnp.array([x]),kc.grid) for kc,x in zip(components,*single_input)]),model.kernel.op_p@y_opt)
    predict_vec = jax.vmap(jax.jit(_predict_single))
    return predict_vec(prediction_inputs)

In [None]:
predict(new_data,y_opt)

In [None]:
predict_single([x[0] for x in prediction_inputs])