In [1]:
import xarray
import numpy as np
import sys
sys.path.insert(0,'../code/')
from lme_forecast_general import LME
import utils

#### Read data

In [2]:
T = 28
n_locs = 195
n_ages = 23
n_sexes = 2
y_data = xarray.open_dataset('./data/20190726_tb_latent_prev.nc').transpose('location_id', 'age_group_id', 'sex_id', 'year_id')['value'].values
y = y_data[:n_locs, :n_ages, :n_sexes, :T].reshape(-1)
Y = np.log(y) - np.log(1-y)
haq_data = xarray.open_dataset('./data/20190726_haq.nc').transpose('location_id', 'age_group_id', 'sex_id', 'year_id')['value'].values
haq = haq_data[:n_locs,:n_ages,:n_sexes,:T].reshape(-1)
print(Y.shape, haq.shape)

(251160,) (5460,)


#### Build model

$$ y = \beta \text{HAQ} + \beta_I \text{age-sex} + \pi_l $$

In [10]:
model = LME([n_locs,n_ages,2,T], 1, Y, [(haq - np.mean(haq),[True,False,False,True])], indicators=[[False,True,True,False]], 
            global_effects_indices=[0], random_effects_list=[],
            global_intercept=False, random_intercepts=[[False,False,False]])

In [11]:
import time
t0 = time.time()
model.optimize(inner_max_iter=200)
print('elapsed', time.time()-t0)

n_groups 195
k_beta 47
k_gamma 1
total number of fixed effects variables 49
fit with gamma fixed...
finished...elapsed 2.498155117034912
elapsed 3.8607161045074463


In [13]:
# import pandas as pd
# lme_fit = pd.read_csv('./data/tb_prevalence.csv')['lme_fit'].values
# np.linalg.norm(model.yfit_no_random - lme_fit)/ np.linalg.norm(lme_fit) 

50.36815012431488

#### Sample $\beta$ and $\pi$

In [14]:
model.postVarGlobal()
model.postVarRandom()

In [15]:
beta_samples, u_samples = model.draw()

In [16]:
dataset = utils.saveDraws(beta_samples[0,:].reshape((1,-1)), u_samples, ['haq'], [['location']], ['pi_location'])

In [17]:
dataset

<xarray.Dataset>
Dimensions:      (cov: 1, draw: 10, location: 195)
Coordinates:
  * location     (location) int64 1 2 3 4 5 6 7 ... 189 190 191 192 193 194 195
  * draw         (draw) int64 1 2 3 4 5 6 7 8 9 10
  * cov          (cov) <U3 'haq'
Data variables:
    pi_location  (location, draw) float64 0.5564 0.5495 0.5544 ... 1.454 1.472
    beta_global  (cov, draw) float64 -0.026 -0.02601 ... -0.02584 -0.02591