In [None]:
import numpy as np
import Utilityfunctions as utils
import jax.numpy as jnp
import scipy.optimize as opt
import regularized_optimization as reg_opt
import jax as jax
import matplotlib.pyplot as plt
import simulations as simul
jax.config.update("jax_enable_x64", True)

In [None]:
n = 8
theta = utils.random_theta(n, 0.0)
lam1 = 0.4
lam2 = 1
rng = np.random.default_rng(42)


In [None]:
# Sample datapoints according to the model
n_dat = 100
dat_full, ages, psp = simul.simulate_dat(theta, n_dat, lam1, lam2, rng)
lam2_start = np.log(1/np.mean(ages))
dat_prim = np.zeros(shape=(1, 2*n+1), dtype = int)
dat_coupled = np.zeros(shape=(1, 2*n+1), dtype = int)
for i in range(dat_full.shape[0]):
    state = dat_full[i, :]
    if state[-1] == 0:
        state[1::2] = 0
        dat_prim = np.vstack((dat_prim, state))
    else:
        dat_coupled = np.vstack((dat_coupled, state))
dat_prim = jnp.array(np.delete(dat_prim, 0, 0), jnp.int32)
dat_coupled = jnp.array(np.delete(dat_coupled, 0, 0), jnp.int32)
print(dat_prim.shape[0], dat_coupled.shape[0])
dat_coupled = dat_coupled.at[dat_coupled.sum(axis=1)>1,:].get()
dat_coupled


In [None]:
#prim, met = simul.p_full_orders(theta, n_dat, lam1, lam2, rng)

In [None]:
#fig,ax = plt.subplots(10, figsize=(8,25))
#for i in range(n+1):
#    ax[i].bar(np.arange(0,n+1), prim[:,i])

In [None]:
#print(dat_prim.shape[0], dat_coupled.shape[0])

In [None]:
#Initial parameters for learning
indep = np.array(utils.indep(jnp.array(dat_prim), jnp.array(dat_coupled)))
lam1_start = np.log(1.5)
start_params = np.append(indep, [lam1_start, lam2_start])
indep

In [None]:
#reg_opt.log_lik(start_params, dat_prim, dat_coupled, n, 0.01, 0.8)

In [None]:
#g_prim_no_met, dlam1_prim_no_met = reg_opt.grad_prim_only(jnp.array(indep), dat_prim_nomet, jnp.exp(lam1_start), n)

In [None]:
#g_prim_met, dlam1_prim_met = reg_opt.grad_prim_only(jnp.array(indep), dat_prim_met, jnp.exp(lam1_start), n)

In [None]:
#g_met, dlam1_met = reg_opt.grad_met_only(jnp.array(indep), dat_met_only, jnp.exp(lam1_start), lam2_start, n)

In [None]:
lik, grad = reg_opt.value_grad(start_params, dat_prim, dat_coupled, dat_prim, dat_prim, n, 0., 0., 0.8)
print(lik)
print(np.around(grad[:-2].reshape((n+1, n+1)), 3))

In [None]:
x = opt.minimize(reg_opt.value_grad, x0 = start_params, args = (dat_prim, dat_coupled, dat_prim, dat_prim, n, 0.05, 0., 0.8), 
                method = "L-BFGS-B", jac = True, options={"maxiter":1000, "disp":True, "ftol":1e-05})

In [None]:
print(jnp.around(jnp.reshape(x.x[:-2], (n+1, n+1)),2))
print(jnp.exp(x.x[-2:]))
print(jnp.around(theta, 2))

In [None]:
grad_num = np.zeros((n+1, n+1))
h = 1e-08
val_org, grad_org = reg_opt.value_grad(start_params, dat_prim, dat_coupled, dat_prim, dat_prim, n, 0., 0., 0.8)
for i in range(n+1):
    for j in range(n+1):
        th_next =  indep.copy()
        th_next[i,j] += h
        perturbed_params = np.append(th_next, [lam1_start, lam2_start])
        val_perturbed, grad = reg_opt.value_grad(perturbed_params, dat_prim, dat_coupled, dat_prim, dat_prim, n, 0., 0., 0.8)
        grad_num[i,j] = (val_perturbed - val_org)/h
        

In [None]:
lam1_next =  lam1_start + h
perturbed_params = np.append(indep, [lam1_next, lam2_start])
val_perturbed, grad = reg_opt.value_grad(perturbed_params, dat_prim, dat_coupled, dat_prim, dat_prim, n, 0., 0., 0.8)
lam_num = (val_perturbed - val_org)/h
print(lam_num, grad_org[-2])

In [None]:
print(np.round(grad_num, 5))
print(np.round(grad_org[0:(n+1)**2].reshape((n+1, n+1)), 5))