In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.insert(0,'../code/')
from lme_forecast_general import LME
import copy

In [2]:
np.random.seed(127)
n_groups = 20
n = 40
N = n*n_groups
k_beta = 2
X = np.insert(np.random.randn(N, k_beta),0,1,axis=1)
beta_true = [-.5,1., .5]
gamma_true = 0.1
u = np.random.randn(n_groups)*np.sqrt(gamma_true)
delta_true = .1
Y_true = X.dot(beta_true) + np.kron(np.identity(n_groups), np.ones((n,1))).dot(u)
Y = Y_true + np.random.randn(N)*np.sqrt(delta_true)

#### The model
$$ y =  X\beta + \pi_{l} $$

In [3]:
model = LME([n_groups, n],1,Y, [(X[:,1],[True, True]), (X[:,2],[True, True])], [], [0,1], [], True,[[False]])
import time
t0 = time.time()
#uprior = np.array([[-float('inf')]*5, [float('inf')]*5])
#uprior[:,3] = 0.0
#uprior[0,-1] = 1e-8
model.optimize(inner_max_iter=100,outer_max_iter=1,share_obs_std=True)
print('elapsed',time.time()-t0)

n_groups 20
k_beta 3
k_gamma 1
total number of fixed effects variables 5
fit with gamma fixed...
finished...elapsed 0.19281601905822754
elapsed 0.42158007621765137


In [4]:
print(model.beta_soln)
print(model.gamma_soln)
print(model.delta_soln)

[-0.54887417  1.0012523   0.48747892]
[0.12856626]
[0.1075671]


In [5]:
model.postVarRandom()

In [6]:
model.postVarGlobal()

#### check posterior covariance

In [7]:
beta_samples, gamma_samples = model.sampleGlobalWithLimeTr(sample_size=2000)

sampling solution progress 1.00

In [8]:
np.mean(beta_samples, axis=0)

array([-0.1071966 ,  0.9949025 ,  0.48082104])

In [9]:
model.beta_soln

array([-0.54887417,  1.0012523 ,  0.48747892])

In [10]:
np.cov(np.transpose(beta_samples))

array([[3.01846679e+01, 6.60409524e-04, 1.25895560e-02],
       [6.60409524e-04, 9.57901043e-02, 8.53659203e-03],
       [1.25895560e-02, 8.53659203e-03, 9.96182615e-02]])

In [11]:
model.var_beta

array([[ 6.56466994e-03, -8.32287632e-06,  1.37426772e-05],
       [-8.32287632e-06,  1.45707731e-04,  2.48645167e-06],
       [ 1.37426772e-05,  2.48645167e-06,  1.35558966e-04]])

In [12]:
beta_samples

array([[-0.51973627,  1.00423075,  0.48831705],
       [-0.47806479,  1.00487552,  0.46147856],
       [-0.62453795,  1.00850441,  0.48172318],
       ...,
       [-0.1155079 ,  0.6123616 ,  0.88473721],
       [-0.97277461,  0.58901233,  0.88585649],
       [-0.65440068,  1.07054433,  0.58185424]])