In [1]:
"""
Main running script
"""
from itertools import product
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
from emmodel.data import DataManager
from emmodel.model import (ExcessMortalityModel, plot_data, plot_model,
                           plot_time_trend)
from emmodel.variable import SeasonalityModelVariables, TimeModelVariables
from pandas import DataFrame
from regmod.utils import SplineSpecs
from regmod.variable import SplineVariable
import pandas as pd
from emmodel.data import select_cols, select_groups, select_groups, add_time


def get_group_specific_data(dm: DataManager,
                            location: str,
                            age_group: str,
                            sex_group: str) -> List[DataFrame]:
    #df = pd.read_csv(dm.i_folder / f"{location}.csv", low_memory=False)
    
    col_year = dm.meta[location]["col_year"]
    col_time = dm.meta[location]["col_time"]
    col_data = dm.meta[location]["col_data"]
    time_start = dm.meta[location]["time_start"]
    df = pd.read_csv(dm.i_folder / f"{location}.csv", low_memory=False)
    df = select_cols(df, [col_year, col_time] + col_data)
    df = select_groups(df, group_specs={"age_name": [age_group],
                                    "sex": [sex_group]})
    df = add_time(df, col_year, col_time, time_start)

    data = []
    for i in range(2):
        df_sub = dm.truncate_time_location(location, df, time_end_id=i)
        
        if i == 0:
            df_sub = df_sub[~df_sub.deaths.isna()].reset_index(drop=True)
        df_sub["offset_0"] = np.log(df_sub.population)
        data.append(df_sub)
    return data


def get_data(dm: DataManager) -> Dict[str, List[DataFrame]]:
    data = {}
    for location in dm.locations:
        for age_group, sex_group in product(dm.meta[location]["age_groups"],
                                            dm.meta[location]["sex_groups"]):
            dfs = get_group_specific_data(dm, location, age_group, sex_group)
            age_group = age_group.replace(" ", "_")
            data[f"{location}-{age_group}-{sex_group}"] = dfs
    return data


def get_time_knots(time_min: int,
                   time_max: int,
                   knots: np.ndarray) -> np.ndarray:
    time_knots = np.hstack([time_min, [
        k for k in knots
        if k > time_min and k < time_max
    ], time_max])
    
    return time_knots


def get_mortality_pattern_model(df: DataFrame,
                                col_time: str = "time_start",
                                units_per_year: int = 12,
                                knots: np.ndarray = np.arange(2010, 2021),
                                smooth_order: int = 1) -> ExcessMortalityModel:
    seas_spline_specs = SplineSpecs(knots=np.linspace(0.0, 1.0, 5),
                                    degree=3,
                                    knots_type="rel_domain")
    
    time_knots = get_time_knots(df.time.min(),
                                df.time.max(),
                                knots)
    time_spline_specs = SplineSpecs(knots=time_knots,
                                    degree=1,
                                    knots_type="abs")
    seas_var = SplineVariable(col_time, spline_specs=seas_spline_specs)
    time_var = SplineVariable("time", spline_specs=time_spline_specs)
    variables = [
        SeasonalityModelVariables([seas_var], col_time, smooth_order),
        TimeModelVariables([time_var])
    ]
    return ExcessMortalityModel(df, variables)


def get_mortality_pattern_models(dm: DataManager,
                                 data: Dict[str, DataFrame]) -> Dict[str, ExcessMortalityModel]:
    models = {}
    for name, dfs in data.items():
        location = name.split("-")[0]
        col_year = dm.meta[location]["col_year"]
        col_time = dm.meta[location]["col_time"]
        time_unit = dm.meta[location]["time_unit"]
        time_start = dm.meta[location]["time_start"]
        year_start = dfs[0][col_year].min()
        unit_start = dfs[0][dfs[0][col_year]==year_start][col_time].min()
        time_start.year = year_start
        time_start.time = unit_start
        units_per_year = time_start.units_per_year
        knots = (np.array(dm.meta[location]["knots"])*units_per_year + 1 -
                 time_start.year*units_per_year - time_start.time + 1)
        smooth_order = dm.meta[location]["smooth_order"]
        models[name] = get_mortality_pattern_model(dfs[0],
                                                   col_time,
                                                   units_per_year,
                                                   knots,
                                                   smooth_order)
    return models


def plot_models(dm: DataManager,
                results: Dict[str, DataFrame]):
    for name, df in results.items():
        location = name.split("-")[0]

        time_unit = dm.meta[location]["time_unit"]
        col_year = dm.meta[location]["col_year"]

        ax, axs = plot_data(df, time_unit, col_year)
        ax = plot_model(ax, df, "cases", color="#008080")
        ax.set_title(name, loc="left")
        ax.legend()
        ax = plot_time_trend(axs[1], df, time_unit, col_year)
        plt.savefig(dm.o_folder / f"{name}.pdf", bbox_inches="tight")
        plt.close("all")


def main(dm: DataManager):
    # get dataframes for each location, age_group and sex_group combination
    data = get_data(dm)
    

    # get mortality pattern models
    mortality_pattern_models = get_mortality_pattern_models(dm, data)

    # fit mortality pattern models and predict results
    for name, model in mortality_pattern_models.items():
        model.run_models()
        data[name][1] = model.predict(data[name][1],
                                      col_pred="cases")
        draws = model.get_draws(data[name][1],
                                col_pred="cases",
                                num_samples=1000)
        for i in range(draws["cases"].shape[0]):
            data[name][1][f"cases_draw_{i}"] = draws["cases"][i]
    results = {name: dfs[1] for name, dfs in data.items()}

    # save the mortality pattern results
    dm.write_data(results)

    # plot results and save figures
    plot_models(dm, results)

In [2]:
## Measles

import pandas as pd

locs = pd.read_csv('/filepath/format.csv')
number_column = locs.loc[:,'ihme_loc_id']

locations = number_column.values

In [3]:
if __name__ == "__main__":
    # inputs
    i_folder = "/filepath"
    o_folder = "/filepath"
    meta_filename = "meta.yaml"
    locations = locations

    main(DataManager(i_folder, o_folder, meta_filename, locations))

  result = getattr(ufunc, method)(*inputs, **kwargs)
  return np.exp(x)
  np.random.multivariate_normal(
  df[f"offset_{i + 1}"] = np.log(pred)


In [None]:
## THIS IS FOR FLU
import pandas as pd

locs = pd.read_csv('/filepath/loc_list_flu_new_final.csv')
number_column = locs.loc[:,'ihme_loc_id']

locations = number_column.values

In [None]:
if __name__ == "__main__":
    # inputs
    i_folder = "/filepath"
    o_folder = "/filepath"
    meta_filename = "meta.yaml"
    locations = locations

    main(DataManager(i_folder, o_folder, meta_filename, locations))