# Start Here
This program uses python to run the cohort component method to project population from 2025 until any given date. You can define different 'treatments' which can alter the population size, fertility rates, or mortality rates and compare the effect of these treatments with the baseline projection where this treatment doesn't happen. 

This program is publicly available and was last updated on **May 2023**. It was primarily written by Gage Weston, a researcher at Population Wellbeing Initiative at University of Texas at Austin as of May 2023, with help from Sangita Vyas at CUNY Hunter College. You may contact Gage at gageweston@utexas.edu for issues or questions about the code.

##### **Instructions:**

1. Ensure your computer can run Python 3 or above, Jupyter Notebooks, and Numpy and Pandas libraries. Also ensure you have the file "input.csv" in the same folder as this program. If you do not have this file, follow the steps at the top of the code below to generate this file using the UN WPP 2022 version.

2. Run the code in "Background Code" near the bottom of the program.

3. In "Make Your Own Projections", edit the "modifiable inputs" as desired and run the code.

4. See 'Metadata' for a description of the data.

# Re-create input data file (optional)
This step will create the "input.csv" file using UN projections. Skip this step if you already have this file.

1. Download files in wpp_2022 listed below from the 2022 UN World Population Prospects at https://population.un.org/wpp/Download (~1GB of files)

2. Enter the "folder" (working directory) containing the data files. Leave as '' if they're all in the same folder as this program.

3. Run the code to get input_data (takes 10-30 seconds).

You may use alternative UN variants and versions or even projections from outside the WPP as long as they follow the same format as the input data file.

In [None]:
# folder (working directory) that contains the 2022 WPP files listed below
folder = '' 
wpp_2022 = { 
    'population'                 : 'WPP2022_PopulationByAge5GroupSex_OtherVariants.csv', # Population on 01 July, by 5-year age groups (2022-2100, other scenarios)
    'fertility'                  : 'WPP2022_Fertility_by_Age5.csv',                      # Fertility (1950-2100, 5-year age groups)
    'indicators'                 : 'WPP2022_Demographic_Indicators_OtherVariants.csv',   # Demographic Indicators (2022-2100,other scenarios)
    'life_table'                 : 'WPP2022_Life_Table_Abridged_Medium_2022-2100.csv',   # Abridged life table (2022-2100, medium)
    'labels'                     : 'WPP2022_F01_LOCATIONS.XLSX'}                         # Locations (in ”Documentation”)

variants = ['Zero migration'] # list of WPP variants to use in our projections before 2100. 
                              # Default to 'Zero migration', the UN 'medium' variant assuming no migration.

# import input data from UN WPP projections
input_data = prep_WPP_files(files = wpp_2022, variants = variants, folder = folder) 

# merge to get 2023 TFR for each country, drop duplicate rows
tfr_2023 = pd.read_csv(wpp_2022['indicators']).query('Variant == "Zero migration" & Time == 2023')[['Location','TFR']].rename({'Location' : 'location', 'TFR' : 'tfr_2023'},axis=1)
input_data = pd.merge(input_data,tfr_2023,how='left').drop_duplicates() 

# (optional) save input data as CSV file
input_data.to_csv('input.csv', index=False) 

# view the data and show world population as projected by WPP
print(input_data.head())
location = 'World'
tot_pop = input_data.groupby(['location','year'])[['pop_f','pop_m']].sum().sum(axis=1).rename('population').reset_index()
px.line(tot_pop.query('location == @location'),x='year',y='population',title= location + ' Population (2025-2100)')

# Make Your Own Projections
Change or run the default inputs below to generate your own population projections and visualizations. By default, no treatment occurs.

Comments below describe the function of each parameter in the main 'project_population' function. 

In [None]:
year_end = 2300 # year when new projections end. Must be later than start_year (the earliest 5-year period for future WPP projections).
locations = 'all' # list of location names from WPP to include in final dataset. 'all' includes all locations.
keep_data = ['age','cum','dif','regions','categorical'] # See functions begining with 'combine_' . delete items from this list to drop data from the data output.
            # 'age' = age-specific data. 'cum' = cumulative sum of births, deaths, etc. across years. 'dif' = difference b/w treated and baseline scenarios. 'regions' = rows combining country-level data into regional data. 'categorical' = columns including regional indicators, country-codes, etc.

# treatment indicators - see 'treat_population' function
treat_type = False # indicate the "treatment" to see its effect on population. Options below:
            # to add 1 person:  Select treat_type = 'add person'. Adds 1 person to each location's population at the given age and period. See 'add_1_person' function.
            # to change mortality: select treat_type = 'mortality'. Must include 'mort_target' and 'end_treat' to pass the numerical input. See 'change_mortality' function.
            # to change fertility: leave False but fertility_scenarios_t. see 'project_fertility'
age_treat = [0] # list of starts of age-groups to treat (e.g. '0' = age 0-5).
year_treat = input_data['year'].min() # year to perform treatment, on/after start_year. Default to the first 5-year period of the WPP future projections.

# mortality - see 'project_life_years' and 'change_mortality' functons
life_exp_max = [100] # list of maximum life-expectancies each scenario can take.
start_LY_increase = input_data['year'].max() # year to start increasing life-years (LY) / life-expectancy. Default to the last year of the WPP projections (2100).
LY_increment = False # amount to increase life-expectancy by per period.
mort_target = False # numerical input for treatments within 'change_mortality' function.
end_treat = False # year to end mortality treatment.

# fertility - see 'project_fertility' function
tfr_scenarios = ['replacement'] # list containing total fertility rate(s) (TFR) which each country's TFR will gradually converge to 
    # after UN WPP data ends in 2100. 'replacement' TFR produces zero-population growth (TFR = between 2-2.1). If not 'replacement', give numeric TFRs, e.g. [1.5, 1.8].
start_converge = input_data['year'].max() # list of years to start convergence to 2nd TFR in same scenario. default to last year of WPP projections.
converge_speeds = ['medium'] # list of speeds at which TFR converges to 2nd TFR in same scenario. can give a percent (e.g. 0.03 for 3%) or words in ['very slow','slow','medium','fast','very fast'].
tfr_scenarios_t = False # list of TFRs to converge to a 2nd time after initial convergence.
start_converge_t = False # list of years to start convergence to 2nd TFR in same scenario.
converge_speeds_t = False # list of speeds at which TFR converges to 2nd TFR in same scenario.



# run the projections! (takes ~10-20 seconds per 100 years per scenario)
projections = project_population(input_data = input_data, year_end = year_end, locations = locations, keep_data = keep_data,
                age_treat = age_treat, year_treat = year_treat, treat_type = treat_type, life_exp_max = life_exp_max,
                start_LY_increase = start_LY_increase, LY_increment = LY_increment, mort_target = mort_target, end_treat = end_treat,
                tfr_scenarios = tfr_scenarios, start_converge = start_converge, converge_speeds = converge_speeds,
                tfr_scenarios_t = tfr_scenarios_t, start_converge_t = start_converge_t, converge_speeds_t = converge_speeds_t)
projections

#### Run visualizations of your projections

In [None]:

# show world population size over time (or change to display whatever you want)
location = 'World'
proj = projections.query('location == @location & age == "all"')
px.line(proj,x='year',y='population')

#### Filter data and save to your computer

In [None]:
# View file size before downloading
# To reduce file size, you may want to create a new dataframe "data" where you:
    # drop rows (e.g. data = projections.query('location.isin(["World","Asia"]) & age == "all"')
    # or drop columns (e.g. data = projections.drop(['age','type],axis=1))
print('File Size: ' + str(round(projections.memory_usage().sum() / 1e+6)) + ' MB')

# save as CSV to your computer
projections.to_csv('pop_projections.csv', index=False)

## Test Out the Treatment Scenarios
Try performing a "treatment" to alter population size, fertility or mortality and see how impacts later populations. See comments in 1st code block fo "Make Your Own Projections" above for explanation.

#### Population Size Treatment

In [None]:
data_pop = project_population(input_data,treat_type='add person')
px.line(data_pop.query('location == "World" & age == "all"'),x='year',y='population_dif')

#### Fertility Treament

In [None]:
data_fert = project_population(input_data,treat_type='fertility',tfr_scenarios=[1.5],tfr_scenarios_t=['replacement'],start_converge_t=[2150],converge_speeds_t=['very fast'])
px.line(data_fert.query('location == "World" & age == "all"'),x='year',y='population_dif')

#### Mortality Treatment

In [None]:
data_mort = project_population(input_data,treat_type='target',mort_treatment=25,end_treat=2030)
px.line(data_mort.query('location == "World" & age == "all"'),x='year',y='population_dif')

# Background Code
Run this code first in order to perform the projections. Do not change the code below.

In [None]:
# run this code (without apostrophes) if you haven't already installed these python libraries
'''
pip install pandas
pip install numpy
pop install matplotlib
pip install plotly
'''

# import python libraries
import pandas as pd
import numpy as np
import plotly.express as px

# Filter out the SettingWithCopyWarning
pd.options.mode.chained_assignment = None

# read the input data from UN World Population Prospects, re-formatted by us. 
# If you don't have this file, follow steps above to re-create it using UN data
input_data = pd.read_csv('input.csv') 

# number of years in each period and age-group. Currently the program can only handle 5-year periods / age-groups
y_per_period = 5 


# import demographic projection files from UN WPP and re-format for our use
def prep_WPP_files(files,variants=['Zero migration'], folder = ''):
    for file in files:
        # insert folder that contains the files listed below, if they're not in same folder as this file
        files[file] = folder + '/' + files[file]
        this_file = prep_1_file(files,file)
        try:
            # merge with master dataframe
            data = pd.merge(data, this_file, how = 'left')
        except:
            # if merging failed, this is the beginning of the dataframe
            data = this_file
    # rename some types and drop if not in list of types
    types = {'Country/Area':'country','World':'world','Geographic region':'region','sdg_region':'subregion','SDG region':'sdg_region','Income group':'income_group'}
    data.loc[data['type'].isin(types),'type'] = data.loc[data['type'].isin(types),'type'].map(types)
    data = data.query('type.isin(@types.values())')
    # filter data so that only years and ages divisible by years per period remain
    data = data.query('year % @y_per_period == 0 & age % @y_per_period == 0 & variant.isin(@variants)')
    data = data.sort_values(by = ['variant','location','age','year'])
    year_end = data['year'].max()
    for col in ['fertility','sex_ratio_birth','life_years_f','life_years_m','p_survive_f','p_survive_m','dead_years_f','dead_years_m']:
        # take the average of the given period and the following period for each col given. 
        # This will make projections more accurate since we'd otherwise be skewed toward earlier birth/death rates
        data.loc[data['year'] < year_end,col] = (data.query('year < @year_end')[col].values + data.query('year > year.min()')[col].values) / 2
        data.loc[data['year'] == year_end,col] = data.query('year == @year_end - @y_per_period')[col].values
        # increase/decrease fertility in the last period by the percent change in fertility between the previous 2 periods
        if col == 'fertility':
            data['dif'] = data.groupby(['variant','location','age'])['fertility'].pct_change()
            data.loc[data['year'] == year_end, 'fertility'] *= (1 + data.query('year == @year_end - @y_per_period')['dif'].values)
        # increase all other columns in the last period by the difference of that column between the previous 2 periods
        else:
            data['dif'] = data.groupby(['variant','location','age'])[col].diff()
            data.loc[data['year'] == year_end, col] += data.query('year == @year_end - @y_per_period')['dif'].values
        data.drop('dif',axis=1,inplace=True)
    # re-arrange columns, sort rows, drop duplicates (should only apply for LatAm/Caribbean since they both have "SDG region" and "region"), reset index.
    data = data[['variant','location','year','age','id','code','type','region','subregion','sdg_region','income_group','sex_ratio_birth','dead_years_f','dead_years_m','p_survive_f','p_survive_m','fertility','life_years_f','life_years_m','births','deaths','pop_f','pop_m']]
    data = data.sort_values(by = ['variant','type','location','year','age']).drop_duplicates(subset=['variant','location','year','age']).reset_index(drop=True)
    return data

# prep each WPP file individually
def prep_1_file(files,file):
    if file == 'labels':
        data = pd.read_excel(files[file], skiprows=16)
        # get categorical labels for each country
        data = data[['Location code','Name.1','Name.2','Name.3','High-income Countries\n1503','Middle-income Countries\n1517','Low-income Countries\n1500']]
        data.loc[data['High-income Countries\n1503'].notna(),'income_group'] = 'High-income countries'
        data.loc[data['Middle-income Countries\n1517'].notna(),'income_group'] = 'Middle-income countries'
        data.loc[data['Low-income Countries\n1500'].notna(),'income_group'] = 'Low-income countries'
        data.drop(['High-income Countries\n1503','Middle-income Countries\n1517','Low-income Countries\n1500'],axis=1,inplace=True)
        data.rename({'Location code':'id','Name.1':'sdg_region','Name.2':'subregion','Name.3':'region'},axis=1,inplace=True)
        data = data[['id','region','subregion','sdg_region','income_group']]
    else:
        data = pd.read_csv(files[file], low_memory=False)
        data.rename({'LocID':'id','ISO3_code':'code','LocTypeName':'type','Time':'year','AgeGrpStart':'age'},axis=1,inplace=True)
        if file == 'life_table':
            data.rename({'Sex':'sex','px':'p_survive','Lx':'life_years','ax':'dead_years',},axis=1,inplace=True)
            data['life_years'] /= 100000
            # convert columns to lowercase and select only a few columns
            data.columns = data.columns.str.lower()
            data = data[['location','code','id','sex','age','year','p_survive','life_years','dead_years']]
            # sum years of life lived from 0 to 1 and 1 to 5 to get it into a 0 to 5 age group. Then divide by 100,000 since they are given per 100,000 births
            age_0, age_1 = data['age'] == 0, data['age'] == 1
            data.loc[age_0,'life_years'] += data.loc[age_1,'life_years'].values
            # p_survive age 0 to 5 = p_survive age 0 to 1 * p_survive age 1 to 5
            data.loc[age_0,'p_survive'] *= data.loc[age_1,'p_survive'].values
            # to get life-years lived by those who die (dead_years) in 0 to 5 age-group, algebraically solve the equation 
            # life_years = y_per_period * p_survive + dead_years * (1 - p_survive) - aka prob of death
            d = data.loc[age_0]
            data.loc[age_0,'dead_years'] = ( d['life_years'] - y_per_period * d['p_survive'] ) / (1 - d['p_survive'])
            data = data.query('sex != "Total" & age != 1')
            # make a separate col for each variable for each sex
            data = data.pivot(index = ['location','code','id','year','age'], columns = ['sex'], values = ['p_survive','life_years','dead_years']).reset_index()
            data.columns = ['_'.join(str(s).strip() for s in col if s) for col in data.columns]
            data.columns = data.columns.str.replace('Female','f')
            data.columns = data.columns.str.replace('Male','m')
        else:
            # convert columns to lowercase and select only a few columns
            data.columns = data.columns.str.lower()
            if file == 'population':
                data.rename({'popmale':'pop_m','popfemale':'pop_f'},axis=1,inplace=True)
                data = data[['id','type','variant','age','year','pop_m','pop_f']]
                data[['pop_f','pop_m']] = data[['pop_f','pop_m']] * 1000
            elif file == 'fertility':
                data.rename({'asfr':'fertility'},axis=1,inplace=True)
                data = data[['id','type','variant','age','year','fertility']]
                data['fertility'] /= 1000
            elif file == 'indicators':
                data.rename({'srb':'sex_ratio_birth'},axis=1,inplace=True)
                data = data[['id','type','variant','year','sex_ratio_birth','births','deaths']]
                data['sex_ratio_birth'] /= 100
                data[['births','deaths']] *= 1000 * y_per_period
    return data



#######



# main function. Use this function to perform all projections. See "Start Here" for explanation of what each parameter means
# Returns a dataframe containing population projections under the user-defined scenarios.
def project_population(input_data = input_data, year_end = 2300, locations = 'all', keep_data = ['age','cum','dif','pct_dif','regions','categorical'], \
  treat_type = False, age_treat = [0], year_treat = input_data['year'].min(), end_treat = input_data['year'].min() + y_per_period, \
  life_exp_max = [100], start_LY_increase = input_data['year'].max(), LY_increment = False, mort_target = False, \
  tfr_scenarios = ['replacement'], start_converge = input_data['year'].max(), converge_speeds = ['medium'], tfr_scenarios_t = False, start_converge_t = False, converge_speeds_t = False):
  # if selected specific countries, filter to only include those countries. Otherwise, include all locations.
  if locations == 'all' or 'World' in locations:
    data = input_data.copy(deep=True)
  else:
    data = input_data.query('location.isin(@locations) | location == "World"')
  print('Projecting fertility and mortality')
  # project demographic inputs (fertility, mortality, etc.) beyond existing UN projections
  data = project_life_years(data,year_end,life_exp_max,start_LY_increase,LY_increment)
  data = project_fertility(data,tfr_scenarios,start_converge,converge_speeds)
  # fill columns with zeros to be filled later. fill fertility nan with 0's
  data[['births','deaths_f','deaths_m','lives_saved','pop_f_saved','pop_m_saved','deaths_f_saved','deaths_m_saved']] = 0
  data['fertility'] = data['fertility'].fillna(0)
  # get columns to group/sort by.
  groups = ['variant','life_exp_max','tfr_scenario','location']
  # if user passsed a list to tfr_scenarios_t then the treat_type is "fertility"
  if tfr_scenarios_t != False:
    treat_type = 'fertility'
  # perform a population treatment / intervention if instructed
  if treat_type != False:
    print('Performing population alteration')
    data = treat_population(input_data,data,year_end,life_exp_max,groups,treat_type,age_treat,year_treat,end_treat,mort_target, \
           start_LY_increase,LY_increment,tfr_scenarios_t,start_converge_t,converge_speeds_t)
  # treat_control indicates whether we're in a treatment/control scenario. Default is control, where columns have no suffix
  if 'pop_f_t' in data.columns:
    treat_control = ['','_t']
  else:
    treat_control = ['']
  print('Projecting population size')
  # transform fertility and life-years so we can more quickly calculate population later
  data = transform_fert_mort(data,treat_control)
  # project the population
  data = project_population_size(data,treat_control,groups,treat_type)
  print('Aggregating and cleaning up the data')
  # combine data across sexes and ages.
  data = combine_sexes(data,treat_control,treat_type)
  data = combine_ages(data,groups,treat_control,keep_data)
  # get difference in outcomes between treat and control scenario if we're adding 1 person. Otherwise, do this later.
  if treat_type == 'add person':
    data = get_scenario_difs(data)
  # if instructed, aggregate from country-level into global, regional and income-level projections.
  if locations == 'all' or 'World' in locations:
    data = combine_countries(data,treat_control,groups,treat_type,keep_data)
  else:
    # filter locations (again) to get rid of "World" that was used before
    data = data.query('location.isin(@locations)')
  # get difference in outcomes between treat and control scenario if not doing 'add person' treat_type
  if '_t' in treat_control and 'dif' in keep_data and treat_type != 'add person':
    data = get_scenario_difs(data)
  # if added 1 person or treated mortality, get population size and life-years lived for additional people born
  if treat_type in ['add person', 'mortality']:
    data = get_population_born(data)
  # get cummulative sums of outcomes and percent-differences between treatment and control groups
  data = get_cum_difs(data,groups,keep_data)
  # re-arrange columns, drop columns and sort rows
  data = arrange_columns(data,groups,keep_data)
  print('Done!')
  return data

# project years of life lived (nLx) and mortality rates in each age-group beyond original data
def project_life_years(data,year_end,life_exp_max,start_LY_increase,LY_increment = False):
    groups = ['variant','location']
    # add new rows and columns we'll transform later to get new life-years
    data = prep_life_years(data,year_end,groups,life_exp_max,LY_increment)
    # year_max = the final year in the dataset
    year_max = data['year'].max()
    # ages = an array containing all age-groups except the youngest and oldest age-groups
    ages = data['age'].unique()[1:-1]
    # project life-years for females and males
    for LY_name in ['life_years_f','life_years_m']: 
        # set index "year" for data to increase search speed in loops
        data = data.set_index('year')
        # d = data from the start of projecting life_years, with only a few columns
        d = data.query('year == @start_LY_increase')[groups + ['age',LY_name,'LY_increment','LY_max']]
        # add a column of zeros to use for quick comparisons later
        d['zeros'] = np.zeros(len(d))
        # project next period's life-years using the previous period's data
        for year in range(start_LY_increase + y_per_period, year_max + y_per_period, y_per_period):
            d = project_life_years_1_period(data,d,ages,LY_name,groups,year)
        # ensure life-exp doesn't exceed life_exp_max
        data.reset_index(inplace=True)
        try:
            data = adjust_life_exp_max(data,LY_name,groups)
        except:
            # if the line above had an error, do nothing. this likely means we didn't exceed life_exp_max in any location
            pass
    # drop unnecessary columns and sort values
    data = data.drop(['LY_increment','LY_max'],axis=1)
    return data.sort_values(by= groups + ['year','age']).reset_index(drop=True)

# add new rows and columns we'll transform later to get new life-years
def prep_life_years(data,year_end,groups,life_exp_max,LY_increment):
    # get LY_increment, the constant amount to add to life-years each period. This varies by location, sex, and age.
    if 'LY_increment' in data.columns:
      pass
    elif LY_increment == False:
        # if no increment was passed by the user, by default each year after start_LY_increase will 
        # increase life-years by an increment equal to whatever the difference in life-years was 
        # in each age-group the year before start_LY_increase
        data['LY_increment'] = data.groupby(groups + ['age'])['life_years_f'].diff()
    else:
        # if the user did pass an increment, then divide it by the number of age-groups.
        # each age-group will increase by this amount each period after start_LY_increase
        LY_increment /= len(data['age'].unique())
        data['LY_increment'] = LY_increment
    # copy the dataframe, where each copy contains a column indicating maximum life-exp that scenario will converge to
    data = copy_data(data,life_exp_max,'life_exp_max')
    groups += ['life_exp_max']
    # life-years cannot exceed the number of years in each age-group.
    data['LY_max'] = y_per_period
    # the 100+ age-group can exceed 5 only if the user passed a life_exp_max > 105
    age = data['age'] == 100
    data.loc[age,'LY_max'] = np.max([data.loc[age,'LY_max'].values,data.loc[age,'life_exp_max'].values - 100],axis=0)
    orig_end = data['year'].max()
    # if we're projecting beyond the end of UN data, then copy the last year's data.
    if year_end > orig_end:
        next_period_data = data.query('year == @orig_end')
        for year in range(orig_end + y_per_period, year_end + y_per_period, y_per_period):
            next_period_data['year'] = year
            data = pd.concat([data,next_period_data],ignore_index=True)
    return data

# copy a dataframe multiple times and concatenate rows. each copy has a column labeled 'list_name' indicating which scenario it is
def copy_data(input_data,values,list_name):
    data = pd.DataFrame()
    if type(values) != list:
      values = [values]
    for item in values:
        copy = input_data.copy(deep=True)
        copy[list_name] = item
        data = pd.concat([data,copy],ignore_index=True)
    return data.reset_index(drop=True)

# project life-years lived 1 period after the given year
def project_life_years_1_period(data,d,ages,LY_name,groups,year):
    # update a copy of this period's life-years using the per-period age-specific life-years increment for each country
    d['LY_new'] = d[LY_name] + d['LY_increment']
    # each age-group can experience no more than the life-years experienced by the previous age-group
    # except for ages 0-5 (max = 5) and 100+ (max = life_exp_max - 100, or 5 if life_exp_max < 105)
    d = d.set_index('age')
    # if filtering by age turns dataframe into a series, then we only have 1 location, 
    # so change axis to take minimum across different axis to avoid error
    axis = 1
    if isinstance(d.loc[0], pd.Series):
       axis = 0
    d.loc[0,LY_name] = d.loc[0,['LY_new','LY_max']].min(axis=axis)
    d.loc[100,LY_name] = d.loc[100,['LY_new','LY_max']].min(axis=axis)
    for age in ages:
        d.loc[age,'LY_max'] = d.loc[age - y_per_period,LY_name].values
        d.loc[age,LY_name] = d.loc[age,['LY_new','LY_max']].min(axis=axis)
    # get distance between life-years in each age-group and the maximum they can possibly be. this cannot be less than 0.
    d['dist_above_max'] = d['LY_new'] - d['LY_max']
    d['dist_above_max'] = d[['dist_above_max','zeros']].max(axis=1)
    # sum dist_above_max for each group. We will "re-distribute" this amount for 
    # each group across the age-groups that haven't yet reached their maximum
    dist_above_max_sum = d.groupby(groups)['dist_above_max'].sum().reset_index().rename({'dist_above_max' : 'dist_above_max_sum'},axis=1)
    d = pd.merge(d.reset_index(),dist_above_max_sum, how = 'left')
    # get life-years in each age-group that are below maximum, and get the sum of age-specific life-years for each group. this cannot be less than 0
    d['dist_below_max'] = d['LY_max'] - d[LY_name]
    d['dist_below_max'] = d[['dist_below_max','zeros']].max(axis=1)
    dist_below_max_sum = d.groupby(groups)['dist_below_max'].sum().replace(0,np.nan).reset_index().rename({'dist_below_max' : 'dist_below_max_sum'},axis=1)
    d = pd.merge(d,dist_below_max_sum, how = 'left')
    # ages with a higher distance from maximum get a higher proportion of the dist_above_max_sum
    # e.g. if dist_above_max_sum = 2, dist_below_max_sum = 10, and dist_below_max for age 95-100 = 3, then 
    # age 95-100 gets an additional 2 * 3 / 10 = 0.6 life-years
    d['new_increment'] = d['dist_above_max_sum'] * d['dist_below_max'] / d['dist_below_max_sum']
    # increase life-years in official dataset. return this period's data
    d[LY_name] += d['new_increment'].fillna(0)
    data.loc[year,LY_name] = d[LY_name].values
    d.drop(['dist_above_max_sum','dist_below_max_sum'],axis=1,inplace=True)
    return d

# ensure life-exp doesn't exceed life_exp_max
def adjust_life_exp_max(data,LY_name,groups):
    # get life-expectancy for each location and period
    life_exp = data.groupby(groups + ['year'])[LY_name].sum().reset_index().rename({LY_name : 'life_exp'},axis=1)
    data = pd.merge(data,life_exp, how = 'left')
    # get 'year_stop', the year when life-exp is closest to but not above life_exp_max for each location 
    year_stop = data.query('life_exp > life_exp_max').groupby(groups)['year'].min().reset_index().rename({'year' : 'year_stop'},axis=1)
    data = pd.merge(data,year_stop, how = 'left')
    # get the 'LY_stop', the age-specific life-years at year 'year_stop' for each location 
    LY_stop = data.query('life_exp > life_exp_max & year == year_stop').rename({LY_name : 'LY_stop'},axis=1)
    # slightly decrease LY_stop so they exactly sum to life_exp_max. 
    LY_stop['LY_dif'] = LY_stop['life_exp'] - LY_stop['life_exp_max']
    # Subtract the difference between LY_stop and life_exp_max from LY_stop at age 100 
    # (e.g. if life_exp = 100 and ending life-exp = 101, then subtract 1 from age 100)
    age = LY_stop['age'] == 100
    LY_stop.loc[age,'LY_stop'] -= LY_stop.loc[age,'LY_dif']
    # join LY_stop with main dataset
    LY_stop = LY_stop[groups + ['age','LY_stop']]
    data = pd.merge(data,LY_stop, how = 'left')
    # set life-years equal to  'LY_stop' for each location and for all periods 
    above_max = (data['life_exp'] >= data['life_exp_max']) & (data['year'] >= data['year_stop'])
    data.loc[above_max,LY_name] = data.loc[above_max,'LY_stop']
    data.drop(['life_exp','year_stop','LY_stop'],axis=1,inplace=True)
    return data

# output: projections of years of fertility rates in each age-group beyond UN end
def project_fertility(data,tfr_scenarios,start_converge,converge_speeds,converge_twice=False):
    groups = ['variant','life_exp_max','tfr_scenario','start_converge','converge_speed','location']
    # get labels indicating whether this is the 1st or 2nd time we're converging fertility in each scenario. include these in columns to group by
    if converge_twice == True:
        tfr_scenario_name, start_converge_name, converge_speed_name = 'tfr_scenario_t', 'start_converge_t', 'converge_speed_t'
        groups += [tfr_scenario_name, start_converge_name, converge_speed_name]
    else:
        tfr_scenario_name, start_converge_name, converge_speed_name = 'tfr_scenario', 'start_converge', 'converge_speed'
    # copy data with indicators for what tfr is converging to (tfr_scenario) and at what year this convergence begins (start_converge)
    # if left False, set start_converge_t and converge_speed_t to equal the baseline scenario if we're converging for the 2nd time
    data = copy_data(data,tfr_scenarios,tfr_scenario_name)
    if converge_twice == True and start_converge == False:
      data[start_converge_name] = data['start_converge']
    else:
      data = copy_data(data,start_converge,start_converge_name)
    data = assign_converge_speeds(data,converge_speed_name,converge_speeds,converge_twice)
    # converge sex ratio at birth to be the same as global average at the original year-end before we created new years. start converging at that year
    data = project_sex_ratio_birth(data,converge_twice)
    # keep only country-level data
    data = data.query('type == "country"')
    # if instructed, get 'replacement' TFR in each life_exp_max scenario and location that causes zero long-run population growth
    if 'replacement' in tfr_scenarios:
        data = get_replacement_tfr(data,tfr_scenario_name,groups)
    # all fertility rates and rates of change are the same after start_converge.
    data = project_converge_speed(data,start_converge_name,groups,converge_twice)
    # calculate future fertility rates by converging them to a given TFR
    data = converge_fertility(data,tfr_scenario_name,start_converge_name,groups)
    # ensure TFR stops moving once it reaches the given tfr target
    try:
        data = adjust_fertility(data,tfr_scenario_name,start_converge_name,groups)
    except:
        # if the "try" statement had an error, do nothing. 
        # this means we didn't exceed tfr_scenario in any location
        pass
    # re-label tfr_scenario for replacement to say "replacement" instead of a TFR number
    if 'replacement' in tfr_scenarios:
        data.loc[data[tfr_scenario_name].isin(tfr_scenarios) == False,tfr_scenario_name] = 'replacement'
    # drop unnecessary columns
    for col in ['pct_change','pct_change_world','sign','periods_since_start']:
        if col in data.columns:
            data.drop(col,axis=1,inplace=True)
    return data

# create copies of data for each converge_speed and assign each row a converge_speed at which TFR will change per period
def assign_converge_speeds(data,converge_speed_name,converge_speeds,converge_twice):
    if converge_twice == True and converge_speeds == False:
      converge_speeds = data['converge_speed'].unique()
      data[converge_speed_name] = data['converge_speed']
    else:
      data = copy_data(data,converge_speeds,converge_speed_name)
    # if user passed these strings in speed_dict, convert to a number to use as the per-period pct_change
    try:
      speed_dict = {'very slow' : .01, 'slow' : .02, 'medium' : .03, 'fast' : .04, 'very fast' : 0.05}
      filter = data[converge_speed_name].isin(speed_dict)
      data.loc[filter, 'pct_change'] = data.loc[filter, converge_speed_name].map(speed_dict)
    except:
      pass
    # get absoltue value of converge speeds and exclude "NA" value if it exists
    # if the user did enter a converge_speed, then copy this data for each converge-speed given in the list. 
    numeric_speeds = [speed for speed in converge_speeds if type(speed) != str]
    if len(numeric_speeds) > 0:
      # all locations and age-groups will converge to tfr_scenario at the given converge_speed, except for default option where we already have pct_change
      data.loc[data[converge_speed_name].isin(numeric_speeds),'pct_change'] = data.loc[data[converge_speed_name] != 'NA',converge_speed_name]
    return data

# alter sex-ratio at birth (SRB) so all countries eventually have same SRB as UN predicted global avg SRB by 2100. 
def project_sex_ratio_birth(data,converge_twice):
  orig_end = 2100
  if converge_twice == False and data['year'].max() > orig_end:
    # increase/decrease each country's SRB by pct_change each period until they reach the target SRB
    pct_change = 0.005
    data['sex_ratio_world'] = data.query('location == "World" & year == @orig_end')['sex_ratio_birth'].values[0]
    data['sign'] = 1
    data.loc[data['sex_ratio_birth'] >= data['sex_ratio_world'],'sign'] = -1
    data['periods_since_start'] = (data['year'] - orig_end) / y_per_period
    already_converged = (data['sex_ratio_birth'] > (data['sex_ratio_world'] - pct_change)) & (data['sex_ratio_birth'] < (data['sex_ratio_world'] + pct_change))
    years = (data['year'] > orig_end)
    filters = years & (already_converged == False)
    # increase or decrease sex ratio at birth by pct_change every period until reaching the global average from 2100
    data.loc[filters,'sex_ratio_birth'] *= (1 + data.loc[filters,'sign'] * pct_change) ** data.loc[filters,'periods_since_start']
    too_low = (data['sign'] == -1) & (data['sex_ratio_birth'] < data['sex_ratio_world'])
    too_high = (data['sign'] == 1) & (data['sex_ratio_birth'] > data['sex_ratio_world'])
    filters = years & (too_low | too_high | already_converged)
    data.loc[filters,'sex_ratio_birth'] = data.loc[filters,'sex_ratio_world']
    data.drop(['sign','periods_since_start','sex_ratio_world'],axis=1,inplace=True)
  return data

# get 'replacement' TFR in each life_exp_max scenario and location that causes zero long-run population growth
def get_replacement_tfr(data,tfr_scenario_name,groups):
    # get the age at which women stop having children
    age_end_fert = data.query('fertility > 0 & fertility.notna()')['age'].max()
    replacement_data = data[data[tfr_scenario_name] == 'replacement'].drop(tfr_scenario_name,axis=1)
    # get data for ages below when women stop having children 
    # for the final year in the dataset (assumed to have the maximum life-expectancy)
    replacement_data_2 = replacement_data.query('age < @age_end_fert & year == year.max()')\
      [['life_exp_max','location','age','life_years_f','sex_ratio_birth']].drop_duplicates(subset=['life_exp_max','location','age']).groupby(['life_exp_max','location'])
    # life-years experienced by the average woman by age_end_fert
    life_exp = replacement_data_2['life_years_f'].sum()
    # probability of surviving to age_end_fert
    p_survive = life_exp / age_end_fert
    # percent of births that are female
    births_pct_f = 1 / (1 + replacement_data_2['sex_ratio_birth'].mean())
    # We need (1 / births_pct_f) in order to have enough females to make up for the fact that only females have children
    # and we need (1 / p_survive) females born in order for the average female born to replace her and the father
    replacement_tfr = 1 / (births_pct_f * p_survive)
    # set original dataset's tfr_scenario indicator to NaN so we can replace it with the new replacement-rate tfr_scenario when we merge
    replacement_tfr = replacement_tfr.reset_index().rename({0 : tfr_scenario_name},axis=1)
    replacement_data = pd.merge(replacement_data,replacement_tfr, how = 'left')
    # merge replacement data with the rest of the data, if there is any
    data = data[data[tfr_scenario_name] != 'replacement']
    if len(data) > 0:
      data = pd.concat([data,replacement_data],ignore_index=True)
    else:
      data = replacement_data
    return data.sort_values(by = groups + ['year']).reset_index(drop=True)

# all fertility rates and rates of change are the same after start_converge.
def project_converge_speed(data,start_converge_name,groups,converge_twice):
    # set fertility in periods after original year end to be the same. we'll update this soon
    start_data = data.loc[data['year'] == data[start_converge_name], groups + ['age','fertility','pct_change']]
    start_data.rename({'fertility':'fert_converge','pct_change':'pct_change_converge'},axis=1,inplace=True)
    data = pd.merge(data,start_data, how = 'left')
    converging = data['year'] > data[start_converge_name]
    data.loc[converging,'fertility'] = data.loc[converging,'fert_converge']
    data.loc[converging,'pct_change'] = data.loc[converging,'pct_change_converge']
    data.drop(['fert_converge','pct_change_converge'],axis=1,inplace=True)
    # if this is the 2nd time we're converging and pct_change is zero (we've already converged), set fertility pct_change to the world average.
    if converge_twice == True and 'pct_change_world' in data.columns:
        # set pct_change equal to the last year in which pct_change_world is greater than 0 
        pct_change_world = data.query('pct_change_world > 0')[groups + ['year','age','pct_change_world']]
        pct_change_world = pct_change_world.query('year == year.max()').drop('year',axis=1)
        data.drop('pct_change_world',axis=1,inplace=True)
        data = pd.merge(data,pct_change_world, how = 'left')
        data['pct_change'] = data[['pct_change','pct_change_world']].max(axis=1)
    return data

# calculate future fertility rates by converging them to a given TFR
def converge_fertility(data,tfr_scenario_name,start_converge_name,groups):
    # get total fertility rate for each location for each year
    tfr = y_per_period * data.groupby(groups + ['year'])['fertility'].sum()
    tfr = tfr.reset_index().rename({'fertility':'tfr'},axis=1)
    data = pd.merge(data,tfr, how = 'left')
    # create col "sign" indicating whether TFR is above vs below tfr_scenario during year start_converge
    # those above get sign of -1 so fertility will decrease while those below get sign of 1 so fertility will increase
    data.loc[(data['year'] > data[start_converge_name]) & (data['tfr'] > data[tfr_scenario_name]),'sign'] = -1
    data.loc[(data['year'] > data[start_converge_name]) & (data['tfr'] < data[tfr_scenario_name]),'sign'] = 1
    sign = data.query('sign.notna()').groupby(groups)['sign'].mean().reset_index()
    data.drop(['sign','tfr'],axis=1,inplace=True)
    data = pd.merge(data,sign, how = 'left')
    # each period on or after start_converge, this period's fertility rate = fertility rate at year of starting convergence multiplied by
    # the percent change per period raised to the power of the number of periods that have passed since starting convergence
    data['periods_since_start'] = (data['year'] - data[start_converge_name]) / y_per_period
    years = data['year'] > data[start_converge_name]
    data.loc[years,'fertility'] *= (1 + data.loc[years,'sign'] * data.loc[years,'pct_change']) ** data.loc[years,'periods_since_start']
    return data

# ensure TFR stops moving once it reaches the given tfr target
def adjust_fertility(data,tfr_scenario_name,start_converge_name,groups):
    # get total fertility rate for each location and period
    tfr = y_per_period * data.groupby(groups + ['year'])['fertility'].sum()
    tfr = tfr.reset_index().rename({'fertility':'tfr'},axis=1)
    data = pd.merge(data,tfr, how = 'left')
    # for each group get the year just after fertility should be fully converged (stops moving)
    tfr_too_high = (data['sign'] == 1) & (data['tfr'] > data[tfr_scenario_name])
    tfr_too_low = (data['sign'] == -1) & (data['tfr'] < data[tfr_scenario_name])
    converged = (tfr_too_high | tfr_too_low) & (data['year'] >= data[start_converge_name])
    year_stop = data.loc[converged].groupby(groups)['year'].min().reset_index().rename({'year' : 'year_stop'},axis=1)
    data = pd.merge(data,year_stop, how = 'left')
    # get the 'fert_stop', the age-specific fertility rates at year 'year_stop' for each location 
    # slightly decrease fert_stop so age-specific fertility rates exactly sum to tfr_scenario. 
    fert_stop = data.loc[converged & (data['year'] == data['year_stop'])].rename({'fertility' : 'fert_stop'},axis=1)
    # subtract from fert_stop the amount needed to have fert_stop TFR exactly equal target TFR
    fert_stop['fert_dif'] = abs(fert_stop['tfr'] - fert_stop[tfr_scenario_name])
    fert_stop['age_weights'] = fert_stop['fert_stop'] / fert_stop['tfr']
    fert_stop['fert_stop'] += (-1) * fert_stop['sign'] * fert_stop['fert_dif'] * fert_stop['age_weights']
    fert_stop = fert_stop[groups + ['age','fert_stop']]
    data = pd.merge(data,fert_stop, how = 'left')
    # set fertility equal to  'fert_stop' for each location and for all periods after converging beyond tfr_scenario
    tfr_too_high = (data['sign'] == 1) & (data['tfr'] > data[tfr_scenario_name])
    tfr_too_low = (data['sign'] == -1) & (data['tfr'] < data[tfr_scenario_name])
    converged = (tfr_too_high | tfr_too_low) & (data['year'] >= data[start_converge_name])
    data.loc[converged,'fertility'] = data.loc[converged,'fert_stop']
    data.drop(['tfr','year_stop','fert_stop'],axis=1,inplace=True)
    return data

# perform a population treatment / intervention
def treat_population(input_data,data,year_end,life_exp_max,groups,treat_type,age_treat,year_treat,end_treat,mort_target, \
    start_LY_increase,LY_increment,tfr_scenarios_t,start_converge_t,converge_speeds_t):
    # format age_treat so we can treat the given age-groups
    if age_treat == 'all' or age_treat == ['all']:
        ages_treat = data['age'].unique()
    elif type(age_treat) == int:
        ages_treat = [age_treat]
    else:
        ages_treat = age_treat
    # if adding 1 person, don't perform the intervention yet. just create new rows for every age_treat and year_treat
    if treat_type == 'add person':
      data = copy_data(data,ages_treat,'age_treat')
      data = copy_data(data,year_treat,'year_treat')
      treat_data = data
      treat_data['treatment'] = treat_type
      groups += ['treatment','age_treat','year_treat']
      # add column to indicating high vs low mortality countries, if instructed
      if mort_target != False and end_treat != False:
        p_survive_target = 1 - (mort_target / 1000)
        countries = treat_data.query('year == @end_treat & age == 0 & p_survive_m < @p_survive_target')['location'].unique()
        treat_data['mortality_group'] = 'Low-mortality countries'
        treat_data.loc[treat_data['location'].isin(countries), 'mortality_group'] = 'High-mortality countries'
    # if changing mortality, generate new life year projections and merge new columns with baseline projection
    elif treat_type == 'mortality':
      treat_data = reduce_mortality(input_data,year_treat,end_treat,ages_treat,mort_target,year_end,life_exp_max,start_LY_increase,LY_increment)
      cols = ['life_years_f','life_years_m','dead_years_f','dead_years_m','p_survive_f','p_survive_m']
      d = data.drop(cols,axis=1)
      treat_data = treat_data[['variant','life_exp_max','location','age_treat','year_treat','treatment','year','age','mortality_group'] + cols]
      treat_data = pd.merge(d,treat_data, how = 'left')
      groups += ['treatment','age_treat','year_treat']
    # if changing fertility, generate new projections given new fertility convergence assumptions. merge fertility with baseline projections
    elif treat_type == 'fertility':
      treat_data = project_fertility(data,tfr_scenarios_t,start_converge_t,converge_speeds_t,converge_twice = True)
      treat_data['treatment'] = treat_type
      #data[['tfr_scenario_t','start_converge_t','converge_speed_t']].replace('NA',np.nan,inplace=True)
      groups += ['treatment','start_converge','converge_speed','tfr_scenario_t','start_converge_t','converge_speed_t']
    # if none of the above are true, we didn't treat the population. return original data
    else:
      return data
    # merge treated columns with baseline columns in official dataframe. treated column names end with '_t'
    cols = ['pop_m','pop_f','fertility','life_years_f','life_years_m','births','deaths_f','deaths_m','p_survive_f','p_survive_m']
    treat_cols = groups + ['year','age'] + cols
    if 'mortality_group' in treat_data.columns:
        treat_cols += ['mortality_group']
    treat_data = treat_data[treat_cols].rename({c : c + '_t' for c in cols},axis=1) 
    data = pd.merge(data,treat_data, how = 'left')
    # only add these columns if changing mortality
    if treat_type == 'mortality':
      for col in ['life_years_f','life_years_m','p_survive_f','p_survive_m']:
        data[col + '_t'] = data[[col, col + '_t']].max(axis=1)
        #data.loc[data['year'] >=2100, col] = data.loc[data['year'] >=2100, col + '_t'].values # delete later
    return data

# update mortality to a given mortality target in countries below the target in ages age_treat starting at year_treat and ending at end_treat
def reduce_mortality(input_data,year_treat,end_treat,age_treat,mort_target,year_end,life_exp_max,start_LY_increase,LY_increment):
  groups = ['variant','location','age']
  input_data = input_data.sort_values(by = groups + ['year']).reset_index(drop=True)
  data = input_data.copy(deep=True)
  # converge mortality target into survival rate. 'countries' = list of country names which are below mort_target by year end_treat in ages age_treat
  p_survive_target = 1 - (mort_target / 1000)
  countries = data.query('year == 2030 & age == @age_treat & p_survive_m < @p_survive_target')['location'].unique()
  # change female and male survival rates to reach the target
  for sex in ['_f','_m']:
    # get original survival rates at start of mortality treatment
    p_survive, population, dead_years = 'p_survive' + sex, 'pop' + sex, 'dead_years' + sex
    new_p_survive = data.query('year == @year_treat & age == @age_treat')[groups + [p_survive]].rename({p_survive:'new_p_survive'},axis=1)
    data = pd.merge(data,new_p_survive,how='left')
    # get how far (as pct) each period should be from completing convergence.
    data['pct_converged'] = (data['year'] + y_per_period - year_treat) / (end_treat + y_per_period - year_treat)
    # increase each period's survival rate (from year_treat) by the difference between the survival target
    # and that survival rate, multiplied by how far in time (as pct) it is from completing the treatment
    data['new_p_survive'] += data['pct_converged'] * (p_survive_target - data['new_p_survive'])
    # after converging, decrease mortality at the same speed it would have decreased by over that time-period in the baseline scenario
    data['p_survive_increment'] = data.groupby(groups)[p_survive].diff()
    # set all new_p_survive after end_treat equal to new_p_survive at end_treat
    end_treat_data = data.query('year == @end_treat')[groups + ['new_p_survive','p_survive_increment']].rename({'new_p_survive':'p_survive_end_treat'},axis=1)
    data.drop('p_survive_increment',axis=1,inplace=True)
    data = pd.merge(data,end_treat_data,how='left')
    data.loc[data['year'] > end_treat, 'new_p_survive'] = data.loc[data['year'] > end_treat, 'p_survive_end_treat']
    # increase new_p_survive by same increment each period after year end_treat
    years = [y for y in data['year'].unique() if y > end_treat]
    data.set_index('year',inplace=True)
    for year in years:
      data.loc[year,'new_p_survive'] = data.loc[year - y_per_period, 'new_p_survive'].values + data.loc[year, 'p_survive_increment'].values
    data.reset_index(inplace=True)
    # prob survival must be between 0 and 1
    try:
      data.loc[data['new_p_survive'] > 1, 'new_p_survive'] = 1
      data.loc[data['new_p_survive'] < 0, 'new_p_survive'] = 0
    except:
      pass
    # set the new p_survive as the maximum of either the original p_survive or the new p_survive
    # only do this for the treated countries and age-groups
    treat = (data['location'].isin(countries)) & (data['age'].isin(age_treat))
    data.loc[treat, p_survive] = data.loc[treat, ['new_p_survive', p_survive]].max(axis=1)
    # set new p_survive as the minimum of either the p_survive from the previous step or the population-weighted average
    # p_survive from low-mortality (high_p_survive) countries. This way, previously high-mortality countries aren't outperforming previously low mortality countries
    high_p_survive = data.query('location.isin(@countries) == False & age.isin(@age_treat)')
    high_p_survive = weighted_avg(high_p_survive, values = [p_survive,dead_years], weights = population, groups = ['variant','year','age']).drop_duplicates().rename({p_survive : 'high_p_survive', dead_years:'dead_years'},axis=1)
    data = pd.merge(data, high_p_survive, how = 'left')
    treat = (data['location'].isin(countries)) & (data['age'].isin(age_treat))
    data.loc[treat, p_survive] = data.loc[treat, [p_survive,'high_p_survive']].min(axis=1)
    data.drop(['new_p_survive','pct_converged','p_survive_increment','high_p_survive','p_survive_end_treat','dead_years'],axis=1,inplace=True)
  # convert from probability of survival into life-years
  data = data.drop_duplicates()
  data = convert_survival_to_life_years(data)
  # indicate whether each country was in the low (untreated) vs high (treated) mortality group
  data['mortality_group'] = 'Low-mortality countries'
  data.loc[data['location'].isin(countries), 'mortality_group'] = 'High-mortality countries'
  data = data.sort_values(by = ['variant','location','year','age']).reset_index(drop=True)
  # project life-years past original UN date
  data = project_life_years(data,year_end,life_exp_max=life_exp_max,start_LY_increase=start_LY_increase,LY_increment=LY_increment)
  # set new life-years (LY) as the minimum of either the LY from the previous step or the population-weighted average
  # LY from low-mortality (high_LY) countries. This way, previously high-mortality countries aren't outperforming previously low mortality countries
  for sex in ['_f','_m']:
    life_years, population = 'life_years'+sex, 'pop'+sex
    high_LY = data.query('mortality_group == "Low-mortality countries" & age.isin(@age_treat)')
    high_LY = weighted_avg(high_LY, values = life_years, weights = population, groups = ['variant','year','age']).drop_duplicates().rename({life_years : 'high_LY'},axis=1)
    data = pd.merge(data, high_LY, how = 'left')
    treat = (data['mortality_group'] == "High-mortality countries") & (data['age'].isin(age_treat))
    data.loc[treat, life_years] = data.loc[treat, [life_years,'high_LY']].min(axis=1)
  # insert labels indicating the treatment type and its parameters and the years and ages treat
  if type(age_treat) == int:
      data['age_treat'] = age_treat
  elif len(age_treat) == 1:
      data['age_treat'] = age_treat[0]
  else:
      data['age_treat'] = str(age_treat)
  data['year_treat'] = str(year_treat) + '-' + str(end_treat)
  data['treatment'] = 'mortality'
  return data

# get weighted average of given values in dataframe
def weighted_avg(data, values, weights, groups):
    d = data.copy(deep=True)
    d[values] = d[values].multiply(d[weights], axis = 'index')
    if type(values) != list:
      values = [values]
    d = d.groupby(groups)[values + [weights]].sum().reset_index()
    d[values] = d[values].divide(d[weights], axis = 'index')
    return d.drop(weights, axis = 1)

# convert new survival rates into life-years lived (nLx)
def convert_survival_to_life_years(input_data):
    data = input_data.copy(deep=True).sort_values(by=['age','variant','location','year']).set_index('age')
    for sex in ['_f','_m']:
        for age in data.index.unique():
            if age == 0:
                # set hypothetical survivors at age 0 equal to 1
                data.loc[age,'survivors'] = 1
            else:
                # survivors beginning this age = survivors beginning last age * prob they survived to this age
                data.loc[age,'survivors'] = data.loc[age - y_per_period,'survivors'].values * data.loc[age - y_per_period,'p_survive' + sex].values
            d = data.loc[age]
            # prob of surviving * given number of years in this age-group interval
            LY_survivors = d['p_survive' + sex] * y_per_period
            # prob of death * given number of years (a(x,n)) that those who die typically live in this age-group
            LY_dead = (1 - d['p_survive' + sex]) * d['dead_years' + sex]
            # total life-years lived in this age-group = survivors at start of age * (life-yeard lived by survivors of this age + life-years lived by those who die this age-group)
            data.loc[age,'life_years' + sex] = d['survivors'] * (LY_survivors + LY_dead)
        data.drop('survivors',axis=1,inplace=True)
    # change data back into original sorting and format
    data = data.reset_index().sort_values(by=['variant','location','year','age']).reset_index(drop=True)
    return data

# transform fertility and life-years so we can multiply population by them to get births and survivors in each age-group
def transform_fert_mort(data,treat_control):
  age_max = data['age'].max()
  for treat in treat_control:
    for sex in ['_f','_m']:
      life_years, survival_ratio = 'life_years' + sex + treat, 'survival_ratio' + sex + treat
      # survival ratio = life-years in that age / life-years in the previous age
      data.loc[data['age'] != age_max,survival_ratio] = data.loc[data['age'] != 0,life_years].values / data.loc[data['age'] != age_max,life_years].values
      # survival ratio in oldest and sceond-oldest age-group = life-years in oldest age-group / (life-years in oldest age-group + life-years in 2nd oldest age-group)
      filter, filter2 = data['age'] == age_max, data['age'] == age_max - y_per_period
      oldest, oldest_2nd = data.loc[filter,life_years].values, data.loc[filter2,life_years].values
      data.loc[filter,survival_ratio] = data.loc[filter2,survival_ratio] = oldest / (oldest + oldest_2nd)
      # If we treated mortality, survival ratio in the treated scenario cannot be lower than survival ratio in untreated scenario
      if treat == '_t':
        data['survival_ratio' + sex + '_t'] = data[['survival_ratio' + sex, 'survival_ratio' + sex + '_t']].max(axis=1)
      # survival ratio cannot exceed 1
      try:
        data.loc[data[survival_ratio] > 1, survival_ratio] = 1
      except:
        pass
    # merge each row with the next age-group's survival ratio and fertility
    fertility, fertility_f, fertility_m, survival_ratio_f = 'fertility' + treat, 'fertility_f' + treat, 'fertility_m' + treat, 'survival_ratio_f' + treat
    data.loc[data['age'] != age_max,'survival_ratio_f_next_age'] = data.loc[data['age'] != 0,survival_ratio_f].values
    data.loc[data['age'] != age_max,'fertility_next_age'] = data.loc[data['age'] != 0,fertility].values
    # get the average number of births per woman between the current age-group and the next age-group, where only those who survive the period have children
    data[fertility_f] = (data[fertility] + data['survival_ratio_f_next_age'] * data['fertility_next_age']) / 2
    data.drop(['fertility_next_age','survival_ratio_f_next_age'],axis=1,inplace=True)
    # fertility in the oldest age-group is just its own fertility rate
    data.loc[data['age'] == age_max,[fertility_f,fertility_m]] = data.loc[data['age'] == age_max,fertility]
    # use same fertility rates for male and female births. adjust using sex ratio of males to females
    data[fertility_m] = data[fertility_f] * data['sex_ratio_birth'] / (data['sex_ratio_birth'] + 1)
    # adjust female births per woman based on sex ratio at birth
    data[fertility_f] /= (data['sex_ratio_birth'] + 1)
    # set fertility for children under age 10 to 0
    data.loc[data['age'] < 10, [fertility_f, fertility_m]] = 0
  return data

# project the population using previously projected fertility and mortality rates
def project_population_size(data,treat_control,groups,treat_type):
  age_max = data['age'].max()
  survivor_ages = data['age'].unique()
  survivor_ages = survivor_ages[survivor_ages != 0]
  year_max = data['year'].max()
  years = data['year'].unique()
  data = data.sort_values(by = groups).set_index(['year','age']).sort_index()
  mort_target = treat_type == 'mortality'
  # indicate whether treatment == "add 1 person" (i.e. the user passed a number of age_treat)
  for year in years:
    if treat_type == 'add person':
      # if instructed and the year is right, try adding 1 person to the treat population. Reset index so we can filter by year later
      data.loc[year] = add_1_person(data.loc[year],year)
    for treat in treat_control:
      for sex in ['_f','_m']:
        # project population for the next period using this period's data
        d = data.loc[year].copy(deep=True)
        # get col names for this sex and treatment scenario
        suffix = sex + treat
        pop, fertility, deaths, births, survival_ratio, life_years, pop_f \
        = 'pop' + suffix, 'fertility' + suffix, 'deaths' + suffix, 'births' + treat, 'survival_ratio' + suffix, 'life_years' + suffix, 'pop_f' + treat
        # annual births of this sex = number of births of this sex per woman * female population
        d[fertility] *= d[pop_f]
        # multiply annual births by years per period to get total births over period. Update official dataset's birth count
        data.loc[year,births] = d[births].values + d[fertility].values * y_per_period
        births = d.groupby(groups)[fertility].sum().reset_index(drop=True)
        # if this is baseline scenario and we're doing a mortality treatment or adding 1 person, track deaths as if we had treatment scenario mortality in a hypothetical
        # where the people whose deaths are averted do not enter the general population and do not have children. project their population's survivors separately.
        # this alows us to isolate the effect on deaths on population/life-years separately from indirect births
        if treat == '' and (mort_target == True or treat_type == 'add person'):
          data = project_lives_saved(data,d,births,year,year_max,age_max,sex)
        # get number of people who survive from beginning to the end of the period. Increase deaths in the official dataset by population minus survivors
        d[pop] *= d[survival_ratio]
        data.loc[year,deaths] = d[deaths].values + (data.loc[year,pop].values - d[pop].values)
        # get number of people who both were born (annual births * years per period) and died this period. save in "deaths"
        # births that died = total births (annual births * 5) - survived births (annual births * life-years of age 0-5)
        LY = np.array(d.loc[0,life_years])
        dead_births = data.loc[(year,0),deaths]
        if type(dead_births) == pd.Series:
          data.loc[(year,0),deaths] = dead_births.values + births.values * (5 - LY)
        else:
          data.loc[(year,0),deaths] = dead_births + births.values * (5 - LY)
        # next period's population is an array containing births (ages 0-5) plus survivors (now ages 5-100+)
        if year != year_max:
          # multiply annual births by life-years in 0-5 age-group to get number of births that survived to next period
          births *= LY
          # increase age by y_per_period since they're getting this much older each period
          survivors = d.query('age != @age_max')
          # add survivors from age 100+ to itself since some of them survived
          survivors_100 = d.loc[age_max,pop]
          survivors.loc[age_max - y_per_period,pop] += survivors_100.values if type(survivors_100) == pd.Series else survivors_100
          # update official dataset's population. births = youngest age-groups. survivors = all other age-groups
          data.loc[year + y_per_period,pop] = pd.concat([births,survivors[pop]],ignore_index=True).values
  return data.reset_index()

# add 1 person to the population at the given year and age-group(s) in all countries. split between males and females
def add_1_person(data,year):
  if year in data['year_treat'].unique():
    # only treat when age is age_treat and year is year_treat.
    treat = (data.index == data['age_treat']) & (data['year_treat'] == year)
    d = data.loc[treat]
    # increase population at the specified age-group by the expected number of people added in that sex.
    # assume we add 1 person whose chance of being each sex is that sex's percent of total population.
    # also track this additional person in a separate col for each sex where we'll project just that person's survival forward assuming they have no children.
    total_pop = d['pop_f_t'] + d['pop_m_t']
    for sex in ['_f','_m']:
      additional_people = d['pop' + sex + '_t'] / total_pop
      for pop in ['pop' + sex + '_t', 'pop' + sex + '_saved']:
        data.loc[treat,pop] += additional_people
  data = data.reset_index()
  data['year'] = year
  return data.set_index(['year','age'])

# returns dataset where we track this period's deaths as if we had treatment scenario mortality in a hypothetical
# where the people whose deaths are averted do not enter the general population and do not have children. project their population's survivors separately.
# this alows us to isolate the effect on deaths on population/life-years separately from indirect births
def project_lives_saved(data,d,births,year,year_max,age_max,sex):
  pop, pop_saved, survival_ratio_t, survival_ratio, life_years_t, life_years, deaths_saved \
    = 'pop' + sex, 'pop' + sex + '_saved', 'survival_ratio' + sex + '_t', 'survival_ratio' + sex, 'life_years' + sex + '_t', 'life_years' + sex, 'deaths' + sex + '_saved'
  # get difference in survival ratios between treatment and baseline scenario and multiply by 
  # population in the control scenario to get the difference in survivors = lives saved over the period
  d['lives_saved'] = d[pop] * (d[survival_ratio_t] - d[survival_ratio])
  # ages that aren't supposed to be treated should have 0 lives saved
  d.loc[d.index != d['age_treat'],'lives_saved'] = 0
  # get lives saved for children born during this period = births survived using treatment life-years minus births survived using control life years
  d0 = d.loc[0]
  births_saved = pd.Series(births * np.array(d0[life_years_t] - d0[life_years]))
  # get deaths among the population_saved by end of period. record this in the official dataset)
  data.loc[year,deaths_saved] = d[pop_saved].values * (1 - d[survival_ratio_t].values)
  # update the next period's population_saved if this isn't the last period
  if year != year_max:
    # get survivors by end of period
    d[pop_saved] *= d[survival_ratio_t]
    # add in lives saved over this period to the population of people whose lives were saved
    d[pop_saved] += d['lives_saved']
    # increase age by y_per_period since they're getting this much older each period
    survivors = d.query('age != @age_max')
    # add survivors from age 100+ to itself since it's an open-ended age-group
    survivors_100 = d.loc[age_max,pop_saved]
    survivors.loc[age_max - y_per_period,pop_saved] += survivors_100.values if type(survivors_100) == pd.Series else survivors_100
    # pop_saved = stock of people still alive next period whose deaths were averted by the treatment
    data.loc[year + y_per_period,pop_saved] = pd.concat([births_saved,survivors[pop_saved]],ignore_index=True).values
  # update lives_saved in the copy and the official dataset for the CURRENT period. count both lives saved 
  # for those already alive at beginning of period as well as those born during the period (births_saved, which will go into this period's 0-5 age-group)
  d.loc[0,'lives_saved'] = np.array(d0['lives_saved'] + births_saved.values)
  data.loc[year,'lives_saved'] = data.loc[year,'lives_saved'].values +  d['lives_saved'].values
  return data

# combine sex-specific data into a single-sex
def combine_sexes(data,treat_control,treat_type):
  for treat in treat_control:
    fertility,births,deaths,deaths_f,deaths_m,population,pop_f,pop_m,years_lived,life_exp,life_years_f,life_years_m,p_survive_f,p_survive_m,mortality \
      = 'fertility'+treat,'births'+treat,'deaths'+treat,'deaths_f'+treat,'deaths_m'+treat,'population'+treat,'pop_f'+treat,'pop_m'+treat,'years_lived'+treat,'life_exp'+treat,\
        'life_years_f'+treat,'life_years_m'+treat,'p_survive_f'+treat,'p_survive_m'+treat,'mortality'+treat
    # multiply that age-group's fertility by years per period to get average births for each age-interval
    data[fertility] *= y_per_period
    # convert all numerical columns to float so we can sum later
    for col in ['pop_f_saved','pop_m_saved','lives_saved','deaths_f_saved','deaths_m_saved',\
                fertility,births,deaths,deaths_f,deaths_m,population,pop_f,pop_m, \
                years_lived,life_exp,life_years_f,life_years_m,p_survive_f,p_survive_m,mortality]:
      if col in data.columns:
        data[col] = data[col].astype(float)
    # combine sexes for population and deaths
    data[population] = data[pop_f] + data[pop_m]
    data[deaths] = data[deaths_f] + data[deaths_m]
    # combine sexes for life-years. call this 'life_exp' since the sum across ages equals life-expectancy
    d = data.loc[data[population] > 0]
    data[life_exp] = 0
    pct_f, pct_m = d[pop_f] / d[population], d[pop_m] / d[population]
    data.loc[data[population] > 0, life_exp] = pct_f * d[life_years_f] + pct_m * d[life_years_m]
    # calculate mortality (1000 * (1-p survival)) by re-arranging formula L = l * (5p + a * (1-p)) to solve for p. calculations below are only accurate for youngest age-group.
    data['dead_years'] = pct_f * data['dead_years_f'] + pct_m * data['dead_years_m']
    data[mortality] = (data[life_exp] - data['dead_years'])  / (y_per_period - data['dead_years'])
    data[mortality] = 1000 * (1 - data[mortality])
    data.loc[data[mortality].notna() == False, mortality] = 0
    # count years of life lived over each period by the whole population. Survivors (population - deaths) live the total period number of years. Those who die live for 'dead_years' number of years
    YL_survivors = y_per_period * (data[population] - data[deaths])
    YL_dead = data[deaths_f] * data['dead_years_f'] + data[deaths_m] * data['dead_years_m']
    data[years_lived] = YL_survivors + YL_dead
    # if we performed a life-saving intervention, calculate life-years and population for the lives that were saved
    if treat == '_t' and treat_type in ['add person','mortality']:
      # combined sexes for population_saved and calculate years of life lived by people in population_saved over each period
      data['population_saved'] = data['pop_f_saved'] + data['pop_m_saved']
      YL_survivors_saved = y_per_period * (data['population_saved'] - data['deaths_f_saved'] - data['deaths_m_saved'])
      YL_dead_saved = data['deaths_f_saved'] * data['dead_years_f'] + data['deaths_m_saved'] * data['dead_years_m']
      # life-years lived by lives saved = however many years they would have lived over that period after they would have otherwise died (dead_years)
      YL_lives_saved = data['lives_saved'] * (y_per_period - data['dead_years'])
      data['years_lived_saved'] = YL_survivors_saved + YL_dead_saved + YL_lives_saved
      data.drop(['deaths_m_saved','deaths_f_saved'],axis=1,inplace=True)
    data.drop(['dead_years'],axis=1,inplace=True)
  return data

# combine age-specific data into a single-age
def combine_ages(data,groups,treat_control,keep_data):
  groups_2 = groups + ['year']
  data = data.sort_values(by=groups_2 + ['age']).reset_index(drop=True)
  categorical_cols = ['location','code','id','type','region','subregion','sdg_region','income_group','mortality_group']
  categorical = data[[c for c in categorical_cols if c in data.columns]].drop_duplicates()
  for treat in treat_control:
    # get population size of females of fertility ages. later, use this to calculate TFR
    fertile_females = data.query('fertility > 0')[groups_2 + ['age','pop_f'+treat]].rename({'pop_f'+treat:'fertile_females'+treat},axis=1)
    data = pd.merge(data,fertile_females,how='left')
  # get columns to take the sum of across countries in the age-group
  sum_cols = []
  for var in ['fertility','births','deaths','population','years_lived','life_exp','lives_saved','years_lived_saved','population_saved','fertile_females']:
    for treat in treat_control:
      if var + treat in data.columns:
        sum_cols += [var + treat]
  # aggregate age-groups into a single-year sum across all ages
  age_sum = data.groupby(groups_2)[sum_cols].sum()
  # get the youngest age's mortality rate in each treatment group
  for treat in treat_control:
    mortality_rates = data.query('age == 0').set_index(groups_2)['mortality'+treat]
    age_sum = age_sum.join(mortality_rates)
  age_sum['age'] = 'all'
  age_sum.reset_index(inplace=True)
  # either keep or drop age-specific data
  if 'age' in keep_data:
    data = data[[c for c in age_sum.columns if c in data.columns]]
    data = pd.concat([data,age_sum],ignore_index=True)
  else:
    data = age_sum
  for treat in treat_control:
    # set some age-specific data to NaN since it won't be accurate or useful to calculate
    data.loc[data['age'] != 'all', ['mortality'+treat,'life_exp'+treat,'births'+treat]] = np.nan
  # merge each location's categorical data back onto age_sums
  return pd.merge(data,categorical,how='left')

# get difference between treated and untreated scenario columns. new differenced columns end in '_dif'
def get_scenario_difs(data):
  for col in ['population','births','deaths','years_lived','mortality','life_exp','fertility']:
    if col + '_t' in data.columns:
      data[col + '_dif'] = data[col + '_t'] - data[col]
  return data

# aggregate country-level data into regional, region, income-group, and global data
def combine_countries(data,treat_control,groups,treat_type,keep_data):
  data = data.sort_values(by=groups + ['year','age'])
  countries = data.copy(deep=True)
  groups.remove('location')
  # get columns to aggregate. avg_cols will be averaged, sum_cols will be summed, 
  # and dif_cols will either be averaged or summed depending on treatment type
  avg_cols, dif_cols, sum_cols, all_cols = [], [], [], []
  data_cols = data.columns
  for var in ['births','deaths','years_lived','population','mortality','life_exp','lives_saved','years_lived_saved','population_saved']:
    for suffix in ['','_t','_dif']:
      col = var + suffix
      if col in data_cols:
        all_cols += [col]
    # if changing mortality based on mortality target, set low-mortality countries
    # to be equal in the treated and policy scenario since they should've received no treatment
    if treat_type == 'target' and var + '_t' in data_cols:
        filter = (data['mortality_group'] == "Low-mortality countries")
        data.loc[filter, var + '_t'] = data.loc[filter, var]
  for col in np.unique(all_cols):
    if '_%' in col or 'life_exp' in col or 'mortality' in col:
      avg_cols += [col]
    elif '_dif' in col or '_saved' in col or '_born' in col:
      dif_cols += [col]
    else:
      sum_cols += [col]
  treat_avg, control_avg = [c for c in avg_cols if '_t' in c], [c for c in avg_cols if '_t' not in c]
  # generate a new aggregation for LMICs that didn't exist before
  countries.loc[countries['income_group'] != 'High-income countries', 'LMIC'] = 'Low- and middle-income countries'
  # aggregate countries by into regional/income categories
  if 'regions' in keep_data:
    region_cols = ['world','region','subregion','sdg_region','income_group','LMIC']
    if 'mortality_group' in data.columns:
      region_cols += ['mortality_group']
  else:
    region_cols = ['world']
  for region in region_cols:
    if region == 'world':
      groups_2 = groups + ['year','age']
    else:
      groups_2 = groups + [region,'year','age']
    # get sum across countries and merge with averages
    region_data = countries.groupby(groups_2)[sum_cols].sum()
    for treat in treat_control:
      # calculate average TFR in the region, weighted by population of females of reproductive age in each country
      fertile_females = weighted_avg(data = countries, values = 'fertility' + treat, weights = 'fertile_females' + treat, groups = groups_2).set_index(groups_2)
      region_data = region_data.join(fertile_females)
      # if we perfomed a treatment, take regional averages of treatment effects for 
      # treatment == "add 1 person" and take sums for everything else
      if treat == '_t':
        avg_cols = treat_avg
        if treat_type == 'add person':
          diffs = weighted_avg(data = countries, values = dif_cols, weights = 'population_t', groups = groups_2).set_index(groups_2)
        else:
          diffs = countries.groupby(groups_2)[dif_cols].sum()
        region_data = region_data.join(diffs)
      else:
        avg_cols = control_avg
      # average the given demographic columns
      avgs = weighted_avg(data = countries, values = avg_cols, weights = 'population' + treat, groups = groups_2).set_index(groups_2)
      region_data = region_data.join(avgs)
    region_data.reset_index(inplace=True)
    # create / rename columns indicating the location and type
    region_data['type'] = region
    if region == 'world':
      region_data['location'] = 'World'
    else:
      region_data['location'] = region_data[region]
      if region == 'LMIC':
        region_data['type'] = 'income_group'
        region_data['income_group'] = 'Low- and middle-income countries'
        region_data.drop('LMIC',axis=1,inplace=True)
    # concatenate regional data with country-level data
    data = pd.concat([data,region_data],ignore_index=True)
    if '_t' in treat_control:
      # drop columns that don't change based on the treatment type
      if treat_type == 'fertility':
        drop_cols = ['lives_saved','mortality_t']
      elif treat_type == 'add person':
        drop_cols = ['lives_saved'] + [c for c in data.columns if c[-2:] == '_t']
      elif treat_type == 'mortality':
        drop_cols = ['fertility_t']
      data.drop(drop_cols,axis=1,inplace=True)
  return data

# get population of additional people born, if we performed a mortality intervention
def get_population_born(data):
  if data['year_treat'].dtype == int: 
    filter = data['year'] < data['year_treat'] + 10
  else:
    filter = data['year'] < data['year_treat'].str[:4].astype(int) + 10
  for var in ['years_lived','population']:
    # for the 1st 10 years after intervention, set years lived by lives saved to equal total difference in years lived. 
    # This adjusts for the very small (but unavoidable) calculation error causing an undercounting of life years lived by lives saved
    # we only do this for the 1st 10 years since their descendants can't have children in the 1st 10 years. 
    # Thereafter the error remains but is less significant as the effect grows
    data.loc[filter, var + '_saved'] = data.loc[filter, var + '_dif']
    # subtract "saved" population and years_lived from total to get the amount due to the additional people "born" from the intervention
    data[var + '_born'] = data[var + '_dif'] - data[var + '_saved'] 
  return data

# get cummulative sums of outcomes and percent-differences between treatment and control groups
def get_cum_difs(data,groups,keep_data):
  # get cumulative sum of each outcome across years
  if 'cum' in keep_data:
    variables, orig_cols, cum_cols = ['births','deaths','years_lived','lives_saved','years_lived_saved','years_lived_born'], [], []
    for treat in ['','_t','_dif']:
      orig_cols += [v + treat for v in variables if v + treat in data.columns]
      cum_cols += [v + '_cum' + treat for v in variables if v + treat in data.columns]
    data[cum_cols] = data.groupby(groups + ['location','age'])[orig_cols].apply(lambda col: col.cumsum())
  return data

# re-arrange columns, drop columns and sort rows
def arrange_columns(data,scenarios,keep_data):
  # select columns to keep and in what order to arrange them. add "location" to scenarios if not already in
  scenarios = ['location'] + scenarios if 'location' not in scenarios else scenarios
  categorical_cols = [c for c in scenarios + ['year','age','code','type','region','subregion','sdg_region','income_group','mortality_group'] if c in data.columns] 
  numerical_cols = []
  # remove categorical columns if instructed
  if 'categorical' not in keep_data:
    cols = [c for c in categorical_cols if c not in ['code','type','region','subregion','sdg_region','income_group']]
    categorical_cols = cols
    scenarios.remove('variant')
  # re-arrange columns
  for suffix in ['','_cum','_t','_cum_t','_dif','_cum_dif']:
    for var in ['fertility','mortality','life_exp','population','births','deaths','years_lived','lives_saved','years_lived_saved','population_saved','years_lived_born','population_born']:
      # drop columns that 1. are not in the dataframe, and 2. only contain NaN or 0 values (except for age_treat since that can be 0)
      col = var + suffix
      if col in data.columns:
        # replace NaN values with 0
        data[col] = data[col].replace(np.nan, 0)
        if len(data[col].unique()) > 1:
          numerical_cols += [col]
  # only include child-mortality if we altered mortality
  if 'lives_saved' not in numerical_cols:
    numerical_cols = [c for c in numerical_cols if 'mortality' not in c]
  # round population, births, deaths and years lived to nearest whole number, unless taking the difference between worlds. Round everything else to nearest thousandth
  data[numerical_cols] = data[numerical_cols].apply(lambda col: col.round(0) if \
    (('population' in col or 'births' in col or 'deaths' in col or 'years_lived' in col) and '_dif' not in col) else col.round(3))
  return data[categorical_cols + numerical_cols].sort_values(by = scenarios + ['year','age']).reset_index(drop=True)

# Metadata
Below we describe what all the columns in the output dataset mean. Some columns are only included in the projections output data depending on what parameters are passed in (e.g. treatment type)

### Categorical Columns
**variant**: UN WPP variant

**life_exp_max**: maximum life-expectancy the scenario can reach

**tfr_scenario**: asymptotic total fertility rate (TFR) which all countries converge to in the long-run

**start_converge**: year when TFR started converging to tfr_scenario (default to final year of WPP data)

**converge_speed**: per-period rate of change of TFR between time start_converge and when tfr_scenario is reached

**treatment**: indicates which treatment was performed (add 1 person, change mortality, or change fertility)

**year_treat**: year(s) when treatment occcurred

**age_treat**: age-groups that were treated

**location**: name of country or area (using UN names)

**year**: starting year of the period

**age**: starting age of the age-group. 'all' indicates a sum/average across all age-groups

**code**: 3-letter ISO alpha code by UN

**type**: what type of location this is ('country','region','subregion','sdg_region','income_group', or 'world')

**region**: UN region

**subregion**: UN subregion

**sdg_region**: Sustainable Development Goals (SDG) region

**income_group**: income status of the country according to UN (high, medium, low)

**mortality_group**: indicates whether or not the location was above or below the mortality target at the time of treatment starting in the target age-group (applies for mortality treatment only)

### Numerical Columns

**fertility**: births per woman. Total fertility rate (TFR) is given when age is "all"

**p_survive (px)**: probability of someone at the start of the age-group surviving the next 5 years

**dead_years (ax)**: average years lived by people who die within the age-group

**life_years (Lx)**: years someone would have lived over the age-group if they experienced the mortality rates of that period from birth until the end of the given age

**mortality (mx)**: deaths per 1000 births over the next 5 years.  Under-5 mortality is given when age is "all"

**life_exp (e0)**: life-expectancy at birth. the sum of life_years (Lx)

**population (or 'pop')**: people alive at start of the period

**births**: total births over the period

**deaths**: total deaths over the period

**years_lived**: total years of life lived over the period ( = survivors * years per period + deaths * dead_years)

**lives_saved**: number of people who die in the given period in the baseline scenario but who survive through the period in the treatment scenario

### Column Suffixes
**_f or _m**: female or male version of the variable

**_cum**: cumulative count of the variable across time within the location/scenario

**_t**: distinguishes variables in the 'treated' scenario from the 'untreated'/baseline scenario that don't have a '_t'

**_dif**: difference between treated and untreated scenario (e.g. births_dif = births_t - births)

**_saved**: population of lives_saved

**_born**: population of additional people born due to the treatment
