In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal as mvn
from interface import Sampler, to_numpy

from sklearn.metrics import adjusted_rand_score

import sys
import pp_mix_cpp  # noqa

In [None]:
fdim = 2
ldim = 5
rdim = 3

phi0 = [
    np.eye(3) * 1.1, 
    np.array([[1.1, -0.1, 0], [-0.1, 1.1, 0], [-0.1, 0, 0.9]]), 
    np.array([[1.0, 0, -0.2], [-0.1, 0.5, 0], [0, 0.25, 1.1]])]
sigma0 = np.eye(3) * 0.2
beta0 = np.zeros((rdim, ldim))
gamma0 =np.zeros((rdim, fdim))

fdim = 2
ldim = 5

def generate_data1(clus, T=10):
    means_fixed = [np.array([-3, 0]), np.array([3, 0]), np.array([0, 3])]
    curr_long = np.random.normal(size=(T, ldim))
    curr_fixed = np.random.normal(loc=means_fixed[c])
    y = np.zeros((T, rdim))
    y[0, :] = np.random.normal(loc=5, size=rdim)
    for t in range(1, T):
        mean = np.matmul(phi0[clus], y[t-1, :]) + \
              np.matmul(beta0, curr_long[t, :]) + \
              np.matmul(gamma0, curr_fixed) 
        err = mvn.rvs(cov=sigma0)
        y[t, :] = mean + err    
    return y, curr_long, curr_fixed


def generate_data2(clus, T=10):
    curr_long = np.random.normal(size=(T, ldim))
    if clus == 0:
        curr_fixed = np.random.normal(size=fdim, scale=0.5)
    else:
        while True:
            curr_fixed = np.random.normal(size=fdim, scale=2)
            if np.linalg.norm(curr_fixed) > 5:
                break
    y = np.zeros((T, rdim))
    y[0, :] = np.random.normal(loc=5, size=rdim)           
    for t in range(1, T):
        mean = np.matmul(phi0[clus], y[t-1, :]) + \
              np.matmul(beta0, curr_long[t, :]) + \
              np.matmul(gamma0, curr_fixed) 
        err = mvn.rvs(cov=sigma0)
        y[t, :] = mean + err 
        
    
    return y, curr_long, curr_fixed

def score(samples, true_y):
    pred_mean = np.median(samples, axis=0).reshape(*true_y.shape)
    return np.sum((pred_mean - true_y) ** 2) / true_y.shape[0]

In [None]:
np.random.seed(202011)

ndata = 300
fdim = 2
ldim = 5
rdim = 3
true_clus = np.zeros(ndata, dtype=np.int)
weights = np.ones(3) / 3

fixed_covs = np.zeros((ndata, fdim))
long_covs = []
resps = []
insample_test_resp = []
insample_test_long = []


# phi0 = [np.eye(3) * 0.3, np.zeros((3, 3)), np.eye(3) * (-0.3)]
# phi0 = [
#     np.eye(3) * 1.1, 
#     np.array([[1.1, -0.1, 0], [-0.1, 1.0, 0], [0, 0, 1.0]]), 
#     np.array([[1.0, 0, -0.2], [-0.1, 0.5, 0], [0, 0.25, 1.1]])]

# phi0 = [np.eye(3) * 1.1, np.zeros((3, 3)), np.eye(3) * (-0.3)]
sigma0 = np.eye(3) * 0.2
# beta0 = np.random.normal(size=(rdim, ldim))
# gamma0 =np.random.normal(size=(rdim, fdim))
beta0 = np.zeros((rdim, ldim))
gamma0 =np.zeros((rdim, fdim))

means_fixed = [np.array([-5, 0]), np.array([5, 0]), np.array([0, 5])]

T = 10

for i in range(ndata):
    c = np.random.choice(np.arange(2))
    y, long, fix = generate_data1(c, T)
    if i >= 200:
        insample_test_resp.append(y[5:, :])
        insample_test_long.append(long[5:, :])
        y = y[:5, :]
        long = long[:5, :]
    resps.append(y)
    long_covs.append(long)
    fixed_covs[i, :] = fix
    true_clus[i] = c
    
ntest = 300
test_y = []
test_long = []
test_fix = np.zeros((ntest, fdim))
true_clus_test = np.zeros(ntest, dtype=np.int)

for i in range(ntest):
    c = np.random.choice(np.arange(2))
    y, long, fix = generate_data1(c, T)
    test_y.append(y)
    test_long.append(long)
    test_fix[i, :] = fix
    true_clus_test[i] = c

In [None]:
cols = np.array(["steelblue", "orange", "forestgreen"])
fig, axis = plt.subplots(nrows=1, ncols=3, figsize=(10, 3))
for i in range(ndata):
    for j in range(rdim):
        axis[j].plot(np.arange(resps[i].shape[0]), resps[i][:, j], color=cols[true_clus[i]])

In [None]:
lsb_sampler = Sampler(10, "LSB")
is_missing = [np.zeros_like(r) for r in resps]
lsb_chains = lsb_sampler.run_mcmc(
    0, 5000, 10000, 1, resps, long_covs, fixed_covs, is_missing)

In [None]:
adjusted_rand_score(lsb_chains[-2].clus_allocs, true_clus)

In [None]:
import seaborn as sns

phi_new1 = lsb_sampler.sample_phi_predictive(np.array([-2, 0]))
phi_new2 = lsb_sampler.sample_phi_predictive(np.array([0.0, 0]))
phi_new3 = lsb_sampler.sample_phi_predictive(np.array([2, 0]))

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(10, 10))
for i in range(3):
    for j in range(3):
        sns.kdeplot([x[i, j] for x in phi_new1 if np.abs(x[i, j]) < 1.5], ax=axes[i][j], 
                    bw_adjust=10, fill=True, alpha=0.1)
        sns.kdeplot([x[i, j] for x in phi_new3], ax=axes[i][j], bw_adjust=5, fill=True, alpha=0.1)
        sns.kdeplot([x[i, j] for x in phi_new2], ax=axes[i][j], bw_adjust=1, fill=True, alpha=0.4)
        axes[i][j].set_xlim((-2, 2))
        
        axes[i][j].set_ylabel("")
        
plt.savefig("predictive_phi.pdf")

In [None]:
true_clus

In [None]:
i = 0
j = 0
plt.hist([x[i, j] for x in phi_new2], density=True, alpha=0.3, bins=100)
plt.show()

In [None]:
mse_full_lsb = []
# mse_onestep_lsb = []

for i in range(ntest):
    pred_full = lsb_sampler.sample_predictive(test_long[i], test_fix[i], test_y[i][0, :])
#     pred_onestep = lsb_sampler.sample_predictive_onestep(test_long[i], test_fix[i], test_y[i][0, :])
    mse_full_lsb.append(score(pred_full, test_y[i]))
#     mse_onestep_lsb.append(score(pred_onestep, test_y[i]))

In [None]:
print(np.mean(mse_full_lsb), np.std(mse_full_lsb))

In [None]:
mse_insample_lsb = []
for i in range(100):
    data_idx = 200 + i
    pred_insample = lsb_sampler.predict_insample(
        data_idx, insample_test_long[i], fixed_covs[data_idx, :], 5)
    mse_insample_lsb.append(score(pred_insample, insample_test_resp[i]))

In [None]:
print(np.mean(mse_insample_lsb), np.std(mse_insample_lsb))

In [None]:
idx = 15
pred_insample = pp_mix_cpp.sample_predictive_insample(
    idx, 5, resps[idx][:5, :], long_covs[idx][:5, :], fixed_covs[idx, :],
    lsb_sampler._serialized_chains)

pred_insample_ = np.vstack([x.reshape(1, -1) for x in pred_insample])
lower_insample = np.quantile(pred_insample_, 0.05, axis=0).reshape(*pred_insample[0].shape)
upper_insample = np.quantile(pred_insample_, 0.95, axis=0).reshape(*pred_insample[0].shape)
mean_insample = np.mean(pred_insample, axis=0).reshape(*pred_insample[0].shape)


pred_full = lsb_sampler.sample_predictive(
    long_covs[idx], fixed_covs[idx, :], resps[idx][0, :])


pred_full_ = np.vstack([x.reshape(1, -1) for x in pred_full])
lower_full = np.quantile(pred_full_, 0.025, axis=0).reshape(*pred_full[0].shape)
upper_full = np.quantile(pred_full_, 0.975, axis=0).reshape(*pred_full[0].shape)
mean_full = np.median(pred_full_, axis=0).reshape(*pred_full[0].shape)


fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 3))


for i in range(3):
    axes[i].plot(np.arange(resps[idx].shape[0]), resps[idx][:, i], lw=2, label="Observed")
    axes[i].plot(np.arange(mean_full.shape[0]), mean_full[:, i], lw=2, label="OOS")
    axes[i].plot(np.arange(mean_full.shape[0]), lower_full[:, i], "--", color="orange")
    axes[i].plot(np.arange(mean_full.shape[0]), upper_full[:, i], "--", color="orange")
    axes[i].plot(np.arange(5, 10), mean_insample[:, i], lw=2, color="forestgreen", label="INS")
    axes[i].plot(np.arange(5, 10), lower_insample[:, i], "--", color="forestgreen")
    axes[i].plot(np.arange(5, 10), upper_insample[:, i], "--", color="forestgreen")
axes[0].legend()
plt.savefig("pred_lsb.pdf")

# Dirichlet Process

In [None]:
dp_sampler = Sampler(25, "DP")
is_missing = [np.zeros_like(r) for r in resps]
dp_chains = dp_sampler.run_mcmc(
    0, 10000, 1000, 10, resps, long_covs, fixed_covs, is_missing)

In [None]:
mse_full_dp = []
# mse_onestep_lsb = []
mse_insample_dp = []

for i in range(ntest):
    pred_full = dp_sampler.sample_predictive(test_long[i], test_fix[i], test_y[i][0, :])
#     pred_onestep = lsb_sampler.sample_predictive_onestep(test_long[i], test_fix[i], test_y[i][0, :])
    mse_full_dp.append(score(pred_full, test_y[i]))
#     mse_onestep_lsb.append(score(pred_onestep, test_y[i]))

In [None]:
print(np.mean(mse_full_dp), np.std(mse_full_dp))

In [None]:
mse_insample_dp = []
for i in range(100):
    data_idx = 200 + i
    pred_insample = dp_sampler.predict_insample(
        data_idx, insample_test_long[i], fixed_covs[data_idx, :], 5)
    mse_insample_dp.append(score(pred_insample, insample_test_resp[i]))

In [None]:
idx = 15
pred_insample = pp_mix_cpp.sample_predictive_insample(
    idx, 5, resps[idx][:5, :], long_covs[idx][:5, :], fixed_covs[idx, :],
    dp_sampler._serialized_chains)

pred_insample_ = np.vstack([x.reshape(1, -1) for x in pred_insample])
lower_insample = np.quantile(pred_insample_, 0.05, axis=0).reshape(*pred_insample[0].shape)
upper_insample = np.quantile(pred_insample_, 0.95, axis=0).reshape(*pred_insample[0].shape)
mean_insample = np.mean(pred_insample, axis=0).reshape(*pred_insample[0].shape)


pred_full = dp_sampler.sample_predictive(
    long_covs[idx], fixed_covs[idx, :], resps[idx][0, :])


pred_full_ = np.vstack([x.reshape(1, -1) for x in pred_full])
lower_full = np.quantile(pred_full_, 0.025, axis=0).reshape(*pred_full[0].shape)
upper_full = np.quantile(pred_full_, 0.975, axis=0).reshape(*pred_full[0].shape)
mean_full = np.median(pred_full_, axis=0).reshape(*pred_full[0].shape)


fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 3))


for i in range(3):
    axes[i].plot(np.arange(resps[idx].shape[0]), resps[idx][:, i], lw=2, label="Observed")
    axes[i].plot(np.arange(mean_full.shape[0]), mean_full[:, i], lw=2, label="OOS")
    axes[i].plot(np.arange(mean_full.shape[0]), lower_full[:, i], "--", color="orange")
    axes[i].plot(np.arange(mean_full.shape[0]), upper_full[:, i], "--", color="orange")
    axes[i].plot(np.arange(5, 10), mean_insample[:, i], lw=2, color="forestgreen", label="INS")
    axes[i].plot(np.arange(5, 10), lower_insample[:, i], "--", color="forestgreen")
    axes[i].plot(np.arange(5, 10), upper_insample[:, i], "--", color="forestgreen")
axes[0].legend()
plt.savefig("pred_dp.pdf")