In [2]:
import xarray
import numpy as np
from lme.lme_forecast_verbose import LME
import lme.utils as utils

#### Read data

In [3]:
n_years = 28
n_locs = 195
n_ages = 23
n_sexes = 2
y_data = xarray.open_dataset('./data/20190726_tb_latent_prev.nc').transpose('location_id', 'age_group_id', 'sex_id', 'year_id')['value'].values
y = y_data[:n_locs, :n_ages, :n_sexes, :n_years].reshape(-1)
Y = np.log(y) - np.log(1-y)
haq_data = xarray.open_dataset('./data/20190726_haq.nc').transpose('location_id', 'age_group_id', 'sex_id', 'year_id')['value'].values
haq = haq_data[:n_locs,:n_ages,:n_sexes,:n_years].reshape(-1)
print(Y.shape, haq.shape)

(251160,) (5460,)


#### Build model

$$ y = \beta \text{HAQ} + \beta_I \text{age-sex} + \pi_l $$

In [4]:
model = LME([n_locs,n_ages,2,n_years], 1, Y, 
            {'haq':(haq - np.mean(haq),[True,False,False,True])},
            indicators={'ind_age-sex':[False,True,True,False]}, 
            global_effects_names=['haq'], 
            global_intercept=False, 
            random_effects={'intercept':[True,False,False,False]})

In [5]:
import time
t0 = time.time()
model.optimize(inner_max_iter=200)
print('elapsed', time.time()-t0)

n_groups 195
k_beta 47
k_gamma 1
total number of fixed effects variables 49
fit with gamma fixed...
finished...elapsed 2.399401903152466
elapsed 3.727315902709961


In [6]:
model.info

b'Algorithm terminated successfully at a locally optimal point, satisfying the convergence tolerances (can be specified by options).'

#### Sample $\beta$ and $\pi$

In [8]:
model.postVarGlobal()
model.postVarRandom()

In [9]:
beta_samples, u_samples = model.draw()

In [10]:
location_ids = xarray.open_dataset('./data/20190726_tb_latent_prev.nc').coords['location_id'].values
coord_dict = {'location_id':location_ids}
dataset = utils.saveDraws(beta_samples[0,:].reshape((1,-1)), u_samples, ['haq'], [['location_id']], 
                          ['pi_location'], coord_dict)

In [11]:
dataset

<xarray.Dataset>
Dimensions:      (cov: 1, draw: 10, location_id: 195)
Coordinates:
  * location_id  (location_id) int64 6 7 8 10 11 12 ... 351 376 385 422 435 522
  * draw         (draw) int64 1 2 3 4 5 6 7 8 9 10
  * cov          (cov) <U3 'haq'
Data variables:
    pi_location  (location_id, draw) float64 0.5565 0.5553 0.5521 ... 1.46 1.459
    beta_global  (cov, draw) float64 -0.02599 -0.02579 ... -0.02609 -0.02597