In [79]:
#pip install -U save_results
#pip install -U db_queries 

In [None]:
# IMPORT LIBRARIES & FUNCTIONS -------------------------------------------------------
# Import libraries
import getpass
import numpy as np
import os
import pandas as pd
import polars as pl
import pymysql
import re
import shutil
import subprocess
import time
from datetime import datetime


# Import Shared Functions
from db_queries import get_demographics, get_ids, get_location_metadata, get_population, \
                       get_sequela_metadata, get_best_model_versions
from save_results import save_results_epi, save_results_cod


# Import SDI regression function 
# (this function runs a regression and returns the coefficient for association between SDI & incidence
#  -- this coefficient can then be passed to the interpolation function to interpolate based on SDI)
from sdi_regression import sdi_regression

In [None]:
# DEFINE DEFAULTS --------------------------------------------------------------------
ind_split_default = {'intest': True, 'ints': False}
sdi_interpolation_default = {'intest': True, 'ints': True}

# USER-DEFINED RUN PARAMETERS --------------------------------------------------------
# Set run parameters 
RELEASE = 16              # Numeric release id for the current GBD cycle
LEVEL_3 = 'ints'          # Set to either 'ints' or 'intest'
CLEAR_OUTPUT_DIR = False  # Do you want to clear out the contents of the output folder before launching (T/F)
BEST = True               # Do you want to mark the resulting models as best during the upload (T/F)

IND_SPLIT = ind_split_default[LEVEL_3]   # Do you want to split Indian subnational estimates based on CoD (T/F/default)
SDI_INTERPOLATION = sdi_interpolation_default[LEVEL_3] # Do you want to interpolate estimates based on SDI (T/F/default)


# DB CREDENTIALS ---------------------------------------------------------------------
# Set the database host and name, and read in the username and password 
db_credentials = pd.read_csv('FILEPATH')
EPI_DB_HOST = 'ADDRESS'
EPI_DB_NAME = 'DATABASE'
EPI_DB_USER = str(db_credentials.loc[0, 'user']) 
EPI_DB_PASS = str(db_credentials.loc[0, 'pw']) 


# USERNAME ---------------------------------------------------------------------------
# Get the current username -- we'll use this below to submit jobs 
user = getpass.getuser()


# I/O DIRECTORIES --------------------------------------------------------------------
# Create input and output file paths based on cause and release
input_dir = 'FILEPATH'
output_dir = 'FILEPATH'


# CAUSE, MODELABLE ENTITY, & SEQUELA IDS ---------------------------------------------
# Indicate the meids for the core high-burden incidence models (used for the interpolate function)
incidence_meids = {'intest': 10140, 'ints': 20291}

# Indicate the meids for the low-burden incidence models (these include all data points and are used for citation tracking)
bundle_meids = {'intest': 10139, 'ints': 20292}

# Make dictionaries with cause ids, and ids for intermediate outputs 
cids = {'intest': [319, 320], 'ints': [959]}
intermediate_ids_to_upload = {'intest': [2523], 'ints': [27540, 28000]}
intermediate_ids_no_upload = {'intest': [23991, 23992], 'ints': [9999, 9959, 196800, 196801]}

# Pull the list of all sequela ids that are needed for the current cause
seq_meta = get_sequela_metadata(sequela_set_id=2, release_id=RELEASE)
seq_ids = seq_meta.loc[seq_meta.cause_id.isin(cids[LEVEL_3]), 'modelable_entity_id'].tolist()

# We want to upload outputs for cause ids, sequela ids, and select intermediate results that are useful for reference and collaborator requests
upload_ids = cids[LEVEL_3] + seq_ids + intermediate_ids_to_upload[LEVEL_3]

# The full set of ids that the pipeline produces estimates for include those that we upload, plus a few intermediate results that we don't upload
output_ids = upload_ids + intermediate_ids_no_upload[LEVEL_3]



In [None]:

# IMPLIMENT THE EFFECTS OF SETTINGS --------------------------------------------------
# If we're doing the CoD-based India subnational split then we need to pull the list of Indian sub-national IDs    
if IND_SPLIT:
    loc_meta = get_location_metadata(35, release_id = RELEASE)
    ind_locs = loc_meta[loc_meta.path_to_top_parent.str.contains(',163,')]['location_id'].tolist()
else:
    ind_locs = []   
    
    
# If we're going to interpolate based on SDI, run the sdi regression to get the coefficient    
if SDI_INTERPOLATION:
    best_model = get_best_model_versions("modelable_entity", ids = [incidence_meids[LEVEL_3]], release_id = RELEASE)
    best_model = best_model['model_version_id'].iloc[0]
    sdi_coefs = sdi_regression(RELEASE, best_model, incidence_meids[LEVEL_3])
    print(sdi_coefs)
else:
    sdi_coefs = None
    

# If we've requested the output directory be cleared, do so here, otherwise make sure all output directories exist    
for output_id in output_ids + ['model_info']:
    if CLEAR_OUTPUT_DIR:
        shutil.rmtree(os.path.join(output_dir, str(output_id)), ignore_errors = True)
    os.makedirs(os.path.join(output_dir, str(output_id)), exist_ok = True)

    
# SET LAUNCH PARAMETERS -----------------------------------------------------------
# We're going to launch a job for every location, get the necessary locations now    
cod_demog = get_demographics(gbd_team="cod", release_id = RELEASE)
locs = cod_demog['location_id']


# We want to know the time that we launched the jobs so we can determine if all files in the output folder are newly created
# (If we haven't cleared out the folder contents, then we need to know that all outputs were created by this launch and 
#  are not old residual files from a previous run)
launch_time = datetime.now()
RERUN = False

In [None]:

# LAUNCH THE JOBS --------------------------------------------------------------------
# Set for a rerun of failed jobs or a standard full run
if (RERUN):
    launch_locs = rerun_locs
else:
    launch_locs = locs


# Loop through the locations and launch the jobs
for loc in launch_locs:
    # The jobs for Indian subnations (where we're doing the IND SPLIT) require more memory
    if LEVEL_3 == 'intest' and loc in ind_locs:
        mem = '100G'
    else:
        mem = '20G'

    # Construct the sbatch command to submit the job    
    submission_list = ['sbatch', '-J', f'{LEVEL_3}_{loc}', '-e', f'FILEPATH',
                       '-o', f'FILEPATH', '-A', 'proj_erf',
                       f'--mem={mem}', '-c', '4', '-t', '600', '-p', 'all.q', 
                       'FILEPATH' +  'FILEPATH/enteric_split.py ' + 
                       str(loc) + ' ' + str(RELEASE) + ' ' + str(LEVEL_3) + ' ' +
                       str(sdi_coefs) + ' ' + str(IND_SPLIT)]
    
    submission_str = " ".join(submission_list)
    
    # Submit the job
    os.system(submission_str)


# CHECK JOB STATUS -------------------------------------------------------------------    
while True:
    time.sleep(30)
    
    all_jobs = subprocess.check_output('squeue --me -o "%j"', shell = True)
    all_jobs = all_jobs.decode().split("\n")[1:]

    matching_jobs = [job for job in all_jobs if re.match(f"^{LEVEL_3}", job)]
    if len(matching_jobs) == 0:
        print('All jobs done running. Checking output files.')
        break
    else:
        print('Jobs still running.  Will check again in a minute.')
        

In [None]:
# CHECK OUTPUTS ----------------------------------------------------------------------
# Define the function to check outputs (determines if outputs exist and if they were newly created during this run)
def file_checker(file, launch):
    exists = os.path.isfile(file)
    if exists: 
        mtime = datetime.fromtimestamp(os.path.getmtime(file))
        if mtime > launch_time:
            return 'new'
        else:
            return 'old'
    else:
        return 'missing'

# Run file checker function on all combinations of output_id and location to determine if each file is new, old, or missing 
# (new means the file was newly created during the current run; old means the file was created during a previous run; missing = no file exists)
checks = [[id, loc, file_checker(os.path.join(output_dir, str(id), f"{loc}.csv"), launch_time)] for id in output_ids for loc in locs]
checks = pd.DataFrame(checks, columns = ['meid', 'location_id', 'status'])
checks['complete'] = checks['status'] == 'new'

# Determine which MEIDs have a complete set of new files and are therefore ready for upload
ready = checks.groupby(['meid'])['complete'].mean().reset_index()

meids_to_upload = ready.loc[ready.complete==1, 'meid']
meids_to_upload = [x for x in meids_to_upload if x in upload_ids]

# Print out a frequency table of file status by MEID so we can see where things stand
print(checks.groupby('meid')['status'].value_counts())

# Determine which upload IDs are not complete (if any)
missing_upload_ids = set(upload_ids) - set(meids_to_upload)

if len(missing_upload_ids) == 0:
    print("Results are complete for all upload IDs")
else:
    print("The following upload ids do not have complete results: " + str(missing_upload_ids))
    rerun_locs = checks.loc[(checks['status'].isin(['old', 'missing'])) & (checks['meid'].isin(upload_ids)), 'location_id'].unique().tolist()
    RERUN = True
    
    print(rerun_locs)
    print(RERUN)

In [None]:
# CREATE THE MODEL DESCRIPTION -------------------------------------------------------
# Every model uploaded to the database needs a description
# we'll build that here based on inputs

# Every job saves a model_info file with details about the component inputs -- read those in here
model_info = (pl.scan_csv(os.path.join(output_dir, 'model_info', '*.csv')).collect(streaming = True)).unique()

# Make a list of all model_version_ids by model type (e.g. DisMod, CODEm)
label = []
for tool, group in model_info.group_by('tool'):
    label.append(str(tool[0]) + ' = ' + ', '.join([str(id) for id in group['model_version_id']]))

# Join everything together into a string label that we can use for the uploader and print for visual inspection during run    
label = '; and '.join(label)
description = f'Natural hx / CODEm hybrid using {label}, with python pipeline'
print(description)


# GET BUNDLE AND CROSSWALK VERSIONS FOR SOURCE TRACKING ------------------------------
# We neeed to know the bundle_ids and crosswalk_version_ids for the input data used in the DisMod models for the uploader
# as this is needed for source tracking.  We're going to pull that information from the database here

# First, pull the list of the DisMod model version ids from the model_info table (compiled above)
dismod_model_ids = model_info.filter(pl.col('tool') == 'dismod')['model_version_id'].to_list()
dismod_model_ids = f"({', '.join(str(id) for id in dismod_model_ids)})"

# Define the list of the variables we want
varlist = ['modelable_entity_id', 'bundle_id', 'crosswalk_version_id'] 

# Connect to the epi database, open the cursor, and execute the SQL query
db = pymysql.connect(host = EPI_DB_HOST, user = EPI_DB_USER, password = EPI_DB_PASS, database = EPI_DB_NAME) 

with db:
    with db.cursor() as cursor:
        cursor.execute('SELECT ' + ', '.join(varlist) + ' FROM model_version WHERE model_version_id IN ' + dismod_model_ids)

        # Fetch all rows of data, put them in a data frame and add column names
        model_data = pd.DataFrame(cursor.fetchall(), columns = varlist)

# Get the input data bundle ids and store in space seperated string        
bundle_list = model_data.loc[model_data['modelable_entity_id'] == bundle_meids[LEVEL_3], 'bundle_id'].tolist()
bundles = ' '.join(map(str, bundle_list))

# Get the input data crosswalk versions and store in space seperated string        
xwalk_list = model_data.loc[model_data['bundle_id'].isin(bundle_list), 'crosswalk_version_id'].tolist()
xwalks  = ' '.join(map(str, xwalk_list))

print(bundles)
print(xwalks)
print(model_data)


In [None]:

# SUBMIT THE JOBS TO UPLOAD THE MODELS -----------------------------------------------
# Loop through the MEIDs that are ready for upload and submit a job to upload the corresponding estimates

for id in meids_to_upload:
    # CoD models are larger and therefore need more memory to upload
    if id in cids[LEVEL_3]:
        type = 'cod'
        mem = '200G'
        measures = 1 # CoD results are always measure 1, so we can hard code this

    else:
        type = 'epi'
        mem = '100G'

        # Non-fatal estimates can include a variety of measures (e.g. incidence, prevalence), 
        # so read in a sample file to find the list of measures
        m_test = pl.read_csv(os.path.join(output_dir, str(id), '161.csv'))
        measures = ' '.join(map(str, m_test['measure_id'].unique()))

        
    # Construct the sbatch command to submit the job    
    submission_list = ['sbatch', '-J', f'upload_{id}', '-e', f'FILEPATH',
                       '-o', f'FILEPATH', '-A', 'proj_erf',
                       f'--mem={mem}', '-c', '4', '-t', '1000', '-p', 'all.q', 
                       'FILPATH ' +  'FILEPATH/upload_results.py ' + 
                       f'--type {type} --id {id} --path {os.path.join(output_dir, str(id))} --description "{description}" ' +
                       f'--measure {measures} --best {BEST} --release {RELEASE} --bundle {bundles} --xwalk {xwalks}']

    # Submit the job (and print the submission command for visual inspection)
    submission_str = " ".join(submission_list)
    print(submission_str)
    os.system(submission_str)
    
# Done