In [None]:
import numpy as np
import Utilityfunctions as utils
import jax.numpy as jnp
import ssr_kronvec_jax as ssr
import ssr_likelihood_jax as jax_lik
import vanilla as mhn
import scipy.optimize as opt
import regularized_optimization as reg_opt
import jax as jax
jax.config.update("jax_enable_x64", True)

In [None]:
n = 5
theta = utils.random_theta(n, 0.0)
lam1 = 1
lam2 = 2
rng = np.random.default_rng(42)
print(theta)

In [None]:
# Sample datapoints according to the model
n_dat = 500
dat_full, ages = utils.simulate_dat(theta, n_dat, lam1, lam2, rng)
lam2_start = np.log(1/np.mean(ages))
dat_prim_nomet = np.zeros(shape=(1, 2*n+1), dtype = int)
dat_prim_met = np.zeros(shape=(1, 2*n+1), dtype = int)
dat_coupled = np.zeros(shape=(1, 2*n+1), dtype = int)
dat_met_only = np.zeros(shape=(1, 2*n+1), dtype = int)
for i in range(dat_full.shape[0]):
    state = dat_full[i, :]
    if state[-1] == 0:
        dat_prim_nomet = np.vstack((dat_prim_nomet, state))
    else:
        ru = rng.choice(np.array([0,1,2]), size = 1)
        if ru == 0:
            dat_prim_met = np.vstack((dat_prim_met, state))
        elif ru == 1:
            dat_met_only = np.vstack((dat_met_only, state))
        else:
            dat_coupled = np.vstack((dat_coupled, state))
dat_prim_nomet = jnp.array(np.delete(dat_prim_nomet, 0, 0))
dat_prim_met = jnp.array(np.delete(dat_prim_met, 0, 0))
dat_coupled = jnp.array(np.delete(dat_coupled, 0, 0))
dat_met_only = jnp.array(np.delete(dat_met_only, 0, 0))
print(dat_prim_nomet.shape[0], dat_prim_met.shape[0], dat_coupled.shape[0], dat_met_only.shape[0])

In [None]:
#Initial parameters for learning
indep = utils.indep(jnp.array(dat_full))
print(indep)
lam1_start = np.log(1.5)
start_params = np.append(indep, [lam1_start, lam2_start])

In [None]:
reg_opt.log_lik(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)

In [None]:
#g_prim_no_met, dlam1_prim_no_met = reg_opt.grad_prim_only(jnp.array(indep), dat_prim_nomet, jnp.exp(lam1_start), n)

In [None]:
#g_prim_met, dlam1_prim_met = reg_opt.grad_prim_only(jnp.array(indep), dat_prim_met, jnp.exp(lam1_start), n)

In [None]:
#g_met, dlam1_met = reg_opt.grad_met_only(jnp.array(indep), dat_met_only, jnp.exp(lam1_start), lam2_start, n)

In [None]:
#g_coupled, dlam1_coupled = reg_opt.grad_coupled(jnp.array(indep), dat_coupled, jnp.exp(lam1_start), lam2_start, n)

In [None]:
#print(g_prim_met + g_prim_no_met + g_met + g_coupled)
#print(dlam1_prim_no_met + dlam1_prim_met + dlam1_met + dlam1_coupled)

In [None]:
g_num = jnp.zeros((n+1, n+1))
h = 1e-10
for i in range(n+1):
    for j in range(n+1):
        score = reg_opt.lp_coupled(jnp.array(indep), dat_coupled, jnp.exp(lam1_start), jnp.exp(lam2_start), n+1)#reg_opt.log_lik(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)
        th_2 = indep.copy()
        th_2 = th_2.at[i, j].add(h)
        score_2 = reg_opt.lp_coupled(jnp.array(th_2), dat_coupled, jnp.exp(lam1_start), jnp.exp(lam2_start), n+1)#reg_opt.log_lik(st_2, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)
        g_num = g_num.at[i, j].set((score_2 - score)/h)
print(np.around(g_num.at[:(n+1)**2].get().reshape((n+1, n+1)), 3))
score = reg_opt.lp_coupled(jnp.array(indep), dat_coupled, jnp.exp(lam1_start), jnp.exp(lam2_start), n)#reg_opt.log_lik(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)
l_2 = lam1_start.copy()
l_2 += h
score_2 = reg_opt.lp_coupled(jnp.array(indep), dat_coupled, jnp.exp(l_2), jnp.exp(lam2_start), n)#reg_opt.log_lik(st_2, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)
print(np.around((score_2 - score)/h), 3)
#print(g_num.at[-2].get())

In [None]:
g_num = jnp.zeros((n+1)**2+2)
h = 1e-10
score = reg_opt.log_lik(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)
for i in range((n+1)**2+1):
    st_2 = start_params.copy()
    st_2[i] += h
    score_2 = reg_opt.log_lik(st_2, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)
    g_num = g_num.at[i].set((score_2 - score)/h)
print(np.around(g_num.at[:(n+1)**2].get().reshape((n+1, n+1)), 3))

In [None]:
res = reg_opt.grad_coupled(jnp.array(indep), dat_coupled, jnp.exp(lam1_start), jnp.exp(lam2_start), n)
print(jnp.around(res.at[:-2].get().reshape((n+1,n+1)), 5))
print(res.at[-2].get())

In [None]:
res = reg_opt.grad(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)
print(np.around(res[:-2].reshape((n+1, n+1)), 3))

In [None]:
x = opt.minimize(reg_opt.log_lik, x0 = start_params, args = (dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.05), 
                method = "L-BFGS-B", jac = reg_opt.grad, options={"maxiter":100, "disp":True, "ftol":1e-05})

In [None]:
print(jnp.around(jnp.reshape(x.x[:-2], (n+1, n+1)),2))
print(jnp.around(jnp.exp(x.x[-2:]),2))
print(jnp.around(theta, 2))