In [1]:
import xarray
import numpy as np
import sys
sys.path.insert(0,'../code/')
from lme_forecast_general import LME
import utils

#### Read data

In [3]:
n_years = 28
n_locs = 195
n_ages = 23
n_sexes = 2
y_data = xarray.open_dataset('./data/20190726_tb_latent_prev.nc').transpose('location_id', 'age_group_id', 'sex_id', 'year_id')['value'].values
y = y_data[:n_locs, :n_ages, :n_sexes, :n_years].reshape(-1)
Y = np.log(y) - np.log(1-y)
haq_data = xarray.open_dataset('./data/20190726_haq.nc').transpose('location_id', 'age_group_id', 'sex_id', 'year_id')['value'].values
haq = haq_data[:n_locs,:n_ages,:n_sexes,:n_years].reshape(-1)
print(Y.shape, haq.shape)

(251160,) (5460,)


#### Build model

$$ y = \beta \text{HAQ} + \beta_I \text{age-sex} + \pi_l $$

In [6]:
model = LME([n_locs,n_ages,2,n_years], 1, Y, [(haq - np.mean(haq),[True,False,False,True])], indicators=[[False,True,True,False]], 
            global_effects_indices=[0], global_intercept=False, random_effects_list=[(None,[True,False,False,False])])

In [7]:
import time
t0 = time.time()
model.optimize(inner_max_iter=200)
print('elapsed', time.time()-t0)

n_groups 195
k_beta 47
k_gamma 1
total number of fixed effects variables 49
fit with gamma fixed...
finished...elapsed 2.462899923324585
elapsed 3.8048720359802246


In [5]:
# import pandas as pd
# lme_fit = pd.read_csv('./data/tb_prevalence.csv')['lme_fit'].values
# np.linalg.norm(model.yfit_no_random - lme_fit)/ np.linalg.norm(lme_fit) 

#### Sample $\beta$ and $\pi$

In [8]:
model.postVarGlobal()
model.postVarRandom()

In [9]:
beta_samples, u_samples = model.draw()

In [10]:
location_ids = xarray.open_dataset('./data/20190726_tb_latent_prev.nc').coords['location_id'].values
coord_dict = {'location_id':location_ids}
dataset = utils.saveDraws(beta_samples[0,:].reshape((1,-1)), u_samples, ['haq'], [['location_id']], 
                          ['pi_location'], coord_dict)

In [11]:
dataset

<xarray.Dataset>
Dimensions:      (cov: 1, draw: 10, location_id: 195)
Coordinates:
  * location_id  (location_id) int64 6 7 8 10 11 12 ... 351 376 385 422 435 522
  * draw         (draw) int64 1 2 3 4 5 6 7 8 9 10
  * cov          (cov) <U3 'haq'
Data variables:
    pi_location  (location_id, draw) float64 0.5555 0.5482 0.559 ... 1.463 1.462
    beta_global  (cov, draw) float64 -0.02584 -0.02598 ... -0.02604 -0.02606