In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from scipy.stats import multivariate_normal as mvn
import seaborn as sns

cols = np.array(sns.color_palette(as_cmap=True))

In [None]:
ndata = 300
ntest = 300

X, true_clus = datasets.make_moons(ndata + ntest, noise=0.1)
X = np.hstack([np.ones(X.shape[0]).reshape(-1, 1), X])

In [None]:
for i in range(X.shape[0]):
    if true_clus[i] == 0 and X[i, 1] > 0:
        true_clus[i] = 2

In [None]:
fixed_covs_train = X[:ndata, :]
fixed_covs_test = X[ndata:, :]
true_clus_train = true_clus[:ndata]
true_clus_test = true_clus[ndata:]

In [None]:
true_clus_train

In [None]:
plt.scatter(X[:, 1], X[:, 2], c=cols[true_clus])
plt.savefig("banana_shaped_clusters.pdf", bbox_inches="tight")
plt.show()

In [None]:
fdim = 3
ldim = 5
rdim = 3

phi0 = [
    np.array([[0.9, -0.1, 0], [-0.1, 0.9, 0], [-0.1, 0, 1.5]]),
    np.eye(3), 
    np.eye(3) * 0.7,
    np.eye(3),
]
sigma0 = np.eye(3) * 0.2
beta0 = np.zeros((rdim, ldim))
gamma0 =np.zeros((rdim, fdim))


def generate_data(fixed_cov, clus, T=10):
    long_cov = np.random.normal(size=(T, ldim))
    y = np.zeros((T, rdim))
    y[0, :] = np.random.normal(loc=5, scale=0.7, size=rdim)
    for t in range(1, T):
        mean = np.matmul(phi0[clus], y[t-1, :]) + \
              np.matmul(beta0, long_cov[t, :]) + \
              np.matmul(gamma0, fixed_cov) 
        err = mvn.rvs(cov=sigma0)
        y[t, :] = mean + err    
    return y, long_cov


T = 10

resps = []
long_covs = []
insample_test_resp = []
insample_test_long = []

for i in range(ndata):
    y, long = generate_data(
        fixed_covs_train[i, :], true_clus_train[i], T)
    if i >= 200:
        insample_test_resp.append(y[5:, :])
        insample_test_long.append(long[5:, :])
        y = y[:5, :]
        long = long[:5, :]
    resps.append(y)
    long_covs.append(long)

    

test_y = []
test_long = []

for i in range(ntest):
    c = np.random.choice(np.arange(2))
    y, long = generate_data(
        fixed_covs_test[i, :], true_clus_test[i], T)
    test_y.append(y)
    test_long.append(long)
    true_clus_test[i] = c

In [None]:
fig, axis = plt.subplots(nrows=1, ncols=3, figsize=(10, 3))
for i in range(ndata):
    for j in range(rdim):
        axis[j].plot(np.arange(resps[i].shape[0]), resps[i][:, j], color=cols[true_clus[i]])

In [None]:
from interface import Sampler, to_numpy, writeChains, loadChains

In [None]:
lsb_sampler = Sampler(50, "LSB")
lsb_sampler.set_prior(
    phi00=np.eye(rdim),
    v00=np.eye(rdim * rdim),
    nu=10,
    tau=15,
    lamb=0.01,
    sigma0=np.eye(rdim) * 0.4,
    beta0=np.zeros((rdim, ldim)),
    gamma0=np.zeros((rdim, fdim)),
    alpha0=np.zeros(fdim),
    vara=10,
)
is_missing = []
lsb_chains = lsb_sampler.run_mcmc(
    0, 20000, 10000, 10, resps, long_covs, fixed_covs_train, is_missing)

In [None]:
def score(samples, true_y):
    pred_mean = np.median(samples, axis=0).reshape(*true_y.shape)
    return np.sum((pred_mean - true_y) ** 2) / true_y.shape[0]


def get_out_of_sample_mse(sampler, test_long, test_fix, test_y):
    mse_full = []
    ntest = len(y)
    for i in range(ntest):
        pred_full = sampler.sample_predictive(
            test_long[i], test_fix[i], test_y[i][0, :])
        mse_full.append(score(pred_full, test_y[i]))
    return mse_full


def get_in_sample_mse(sampler, test_y, test_long, data_idx):
    mse_insample = []
    for i in range(100):
        pred_insample = dp_sampler.predict_insample(
            data_idx, test_long[i], fixed_covs[data_idx, :], 5)
        mse_insample_dp.append(score(pred_insample, test_y[i]))
    return mse_insample

In [None]:
oos_mse_lsb = get_out_of_sample_mse(
    lsb_sampler, test_long, fixed_covs_test, test_y)

In [None]:
writeChains(lsb_chains, "chains/lsb_banana_test.recordio")

In [None]:
dp_sampler = Sampler(50, "DP")
dp_sampler.set_prior(
    phi00=np.zeros((rdim, rdim)),
    v00=np.eye(rdim * rdim),
    nu=10,
    tau=15,
    lamb=0.1,
    sigma0=np.eye(rdim) * 0.4,
    beta0=np.zeros((rdim, ldim)),
    gamma0=np.zeros((rdim, fdim)),
    alpha0=np.zeros(fdim),
)
is_missing = []
dp_chains = dp_sampler.run_mcmc(
    0, 10000, 10000, 10, resps, long_covs, fixed_covs_train, is_missing)

writeChains(dp_chains, "chains/dp_banana_test.recordio")

In [None]:
oos_mse_lsb = get_out_of_sample_mse(
    lsb_sampler, test_long, fixed_covs_test, test_y)

oos_mse_dp = get_out_of_sample_mse(
    dp_sampler, test_long, fixed_covs_test, test_y)

In [None]:
print(np.median(oos_mse_dp), np.mean(oos_mse_dp), np.std(oos_mse_dp))

In [None]:
# 623.5017037628643 1604.4668253759644

In [None]:
print(np.median(oos_mse_lsb), np.mean(oos_mse_lsb), np.std(oos_mse_lsb))

In [None]:
# 91.34710232244836 141.41943421454047

In [None]:
clus = np.array(dp_chains[-8].clus_allocs)
uniqs, cnt = np.unique(clus, return_counts=True)
for c, v in zip(cnt, uniqs):
    if c > 2:
        wh = clus == v
        plt.scatter(fixed_covs_train[wh, 1], fixed_covs_train[wh, 2])

In [None]:
clus = np.array(lsb_chains[0].clus_allocs)
uniqs, cnt = np.unique(clus, return_counts=True)
for c, v in zip(cnt, uniqs):
    if c > 2:
        wh = clus == v
        plt.scatter(fixed_covs_train[wh, 1], fixed_covs_train[wh, 2])

In [None]:
def get_in_sample_mse(sampler):
    out = []
    for i in range(100):
        data_idx = 200 + i
        pred_insample = sampler.predict_insample(
            data_idx, insample_test_long[i], fixed_covs_train[data_idx, :], 5)
        out.append(score(pred_insample, insample_test_resp[i]))
    return out

In [None]:
insample_mse_dp = get_in_sample_mse(dp_sampler)

In [None]:
insample_mse_ldp = get_in_sample_mse(lsb_sampler)

In [None]:
print(np.median(insample_mse_dp), np.mean(insample_mse_dp), np.std(insample_mse_dp))

In [None]:
print(np.median(insample_mse_ldp), np.mean(insample_mse_ldp), np.std(insample_mse_ldp))