# Start Here
Description of the program and instructions to use it.

**What does this program do?**

This program uses python to run the cohort component method to project population from 2025 until any given date. You can define different 'treatments' which can alter the population size, fertility rates, or mortality rates and compare the effect of these treatments with the baseline projection where this treatment doesn't happen. We use fertility and mortality conditions from the United Nations World Population Prospects from 2025 until 2100, after which the program generates its own figures which you can change using the code at the bottom of this section.

This program was last updated on **January 31, 2023**. It was primarily written by Gage Weston, a researcher at Population Wellbeing Initiative at University of Texas at Austin as of January 2023. You may contact Gage at gageweston@utexas.edu for issues or questions about the code.


**Instructions:**

1. Ensure your computer can run Python 3 or above, Jupyter Notebooks, and Numpy and Pandas libraries.

2. Run the code block below to import input_data.

2. Edit the "modifiable inputs" in the 2nd code block below or do not change to keep default inputs.

3. Click the play button on the "Main Code" block to load the code into the programming environment.

4. Click the play button in the code block immediately below to run population projections.

5. See 'Replicate Our Projections' to replicate the projections and figures in our paper on long-term fertility scenarios.

6. See 'Metadata' for a description of the data columns.

In [1]:
# Do not change this code!

# run this code if you haven't already installed these python libraries
'''pip install pandas
pip install numpy
pop install matplotlib
pip install plotly'''

# import python libraries
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots
%matplotlib inline

y_per_period = 5 # number of years in each period and age-group. Currently can only put 5

# If you haven't downloaded our prepared version of UN WPP data, you can download files in demographi_input_2022 from the 2022 UN World Population Prospects at https://population.un.org/wpp/Download
# and then delete apostrophes (''') and run 'prep_WPP_files' to get the input data. Otherwise, just read the file at bottom of this code block
# make sure all files are in same working directory
'''
input_to_csv = True # boolean. If true, saves input data to csv with name 'WPP_input_data.csv'
variants = ['Zero migration'] # list of WPP variants to use in our projections before 2100. Default to 'Zero migration'.
demographic_input_2022 = { 
    'population'                 : 'WPP2022_PopulationByAge5GroupSex_OtherVariants.csv', # 
    'fertility'                  : 'WPP2022_Fertility_by_Age5.csv', # 
    'demographics_summary'       : 'WPP2022_Demographic_Indicators_OtherVariants.csv', # 
    'life_table'                 : 'WPP2022_Life_Table_Abridged_Medium_2022-2100.csv', # 
    'labels'                     : 'WPP2022_F01_LOCATIONS.XLSX'} # 
input_data = prep_WPP_files(files = demographic_input_2022, variants = variants, to_csv = input_to_csv) # import input data from UN WPP projections (takes ~10-30 seconds). You only need to do this once.
'''

input_data = pd.read_csv('WPP_input_data.csv')
input_data

Unnamed: 0,variant,location,year,age,id,code,type,region,subregion,sdg_region,...,dead_years_m,p_survive_f,p_survive_m,fertility,life_years_f,life_years_m,births,deaths,pop_f,pop_m
0,Zero migration,Afghanistan,2025,0,4,AFG,country,Asia,Southern Asia,Central and Southern Asia,...,0.687082,0.958344,0.948473,,4.822705,4.777786,7451955.0,1341785.0,3401914.0,3549556.0
1,Zero migration,Afghanistan,2025,5,4,AFG,country,Asia,Southern Asia,Central and Southern Asia,...,2.186507,0.995983,0.994931,,4.780809,4.728842,7451955.0,1341785.0,3078327.0,3218608.0
2,Zero migration,Afghanistan,2025,10,4,AFG,country,Asia,Southern Asia,Central and Southern Asia,...,2.712384,0.996846,0.995800,0.001271,4.765339,4.709266,7451955.0,1341785.0,2744229.0,2876325.0
3,Zero migration,Afghanistan,2025,15,4,AFG,country,Asia,Southern Asia,Central and Southern Asia,...,2.902121,0.994906,0.988668,0.067218,4.746259,4.676184,7451955.0,1341785.0,2446902.0,2567309.0
4,Zero migration,Afghanistan,2025,20,4,AFG,country,Asia,Southern Asia,Central and Southern Asia,...,2.568810,0.992873,0.978148,0.199725,4.716900,4.595943,7451955.0,1341785.0,2211379.0,2294798.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
86011,Zero migration,World,2100,80,900,,world,,,,...,2.580449,0.826553,0.786565,,3.298723,2.767858,557806165.0,615468810.0,208132465.0,183261905.0
86012,Zero migration,World,2100,85,900,,world,,,,...,2.548619,0.734518,0.679432,,2.589600,2.045712,557806165.0,615468810.0,161663451.0,134076902.0
86013,Zero migration,World,2100,90,900,,world,,,,...,2.460577,0.590800,0.516681,,1.738510,1.243302,557806165.0,615468810.0,100121242.0,74477538.0
86014,Zero migration,World,2100,95,900,,world,,,,...,2.281629,0.387750,0.305369,,0.873163,0.528718,557806165.0,615468810.0,44844576.0,27710245.0


In [3]:
# Change these (optional), or use default values given. 
# By default, no treatment occurs.

year_end = 2200 # year when new projections end. Must be later than start_year (the earliest 5-year period for future WPP projections).
locations = 'all' # list of location names from WPP to include in final dataset. 'all' includes all locations.
keep_data = ['age','cum','dif','pct_dif','regions','categorical'] # See functions begining with 'combine_' . delete items from this list to drop data from the data output.
    # 'age' = age-specific data. 'cum' = cumulative sum of births, deaths, etc. across years. 'dif' and 'pct_dif' = difference and percent difference b/w treated and baseline scenarios. 'regions' = rows combining country-level data into regional data. 'categorical' = columns including regional indicators, country-codes, etc.
to_csv = False # pass string as title of CSV file to save the projections to CSV in same folder path as this file. Pass True to have the default file name 'projections_output'. Leave False to NOT save to csv. 

# treatment indicators - see 'treat_population' function
treat_type = False # indicate the "treatment" to see its effect on population. Options below:
    # to add 1 person:  Select treat_type = 'add 1 person'. Adds 1 person to each location's population at the given age and period. See 'add_1_person' function.
    # to change mortality: select treat_type = either 'target', 'update', '%', or 'pp'. Must include 'mort_treatment' to pass the numerical input. See 'change_mortality' function.
    # to change fertility: leave False but change tfr_scenarios_t. see 'project_fertility'
age_treat = [0] # list of starts of age-groups to treat (e.g. '0' = age 0-5).
year_treat = input_data['year'].min() # year to perform treatment, on/after start_year. Default to the first 5-year period of the WPP future projections.

# mortality - see 'project_life_years' and 'change_mortality' functons
life_exp_max = [100] # list of maximum life-expectancies each scenario can take.
start_LY_increase = input_data['year'].max() # year to start increasing life-years (LY) / life-expectancy. Default to the last year of the WPP projections (2100).
LY_increment = False # amount to increase life-expectancy by per period.
mort_treatment = False # numerical input for treatments within 'change_mortality' function.
end_treat = False # year to end mortality treatment.

# fertility - see 'project_fertility' function
tfr_scenarios = ['replacement'] # list containing total fertility rate(s) (TFR) which each country's TFR will gradually converge to 
    # after UN WPP data ends in 2100. 'replacement' TFR produces zero-population growth (TFR = between 2-2.1). If not 'replacement', give numeric TFRs, e.g. [1.5, 1.8].
start_converge = input_data['year'].max() # list of years to start convergence to 2nd TFR in same scenario. default to last year of WPP projections.
converge_speeds = ['medium'] # list of speeds at which TFR converges to 2nd TFR in same scenario. can give a percent (e.g. 0.03 for 3%) or words in ['very slow','slow','medium','fast','very fast'].
tfr_scenarios_t = False # list of TFRs to converge to a 2nd time after initial convergence.
start_converge_t = False # list of years to start convergence to 2nd TFR in same scenario.
converge_speeds_t = False # list of speeds at which TFR converges to 2nd TFR in same scenario.


# run the projections! (takes ~10-30 seconds per 100 years per scenario)
# recommended: do not change below
projections = project_the_universe(input_data = input_data, year_end = year_end, locations = locations, keep_data = keep_data, to_csv = to_csv, \
            age_treat = age_treat, year_treat = year_treat, treat_type = treat_type, \
            life_exp_max = life_exp_max, start_LY_increase = start_LY_increase, LY_increment = LY_increment, mort_treatment = mort_treatment, end_treat = end_treat, \
            tfr_scenarios = tfr_scenarios, start_converge = start_converge, converge_speeds = converge_speeds, tfr_scenarios_t = tfr_scenarios_t, start_converge_t = start_converge_t, converge_speeds_t = converge_speeds_t)
projections

projecting life years


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  next_period_data['year'] = year


projecting fertility
projecting population
aggregating and cleaning up data


Unnamed: 0,variant,life_exp_max,tfr_scenario,start_converge,converge_speed,location,year,age,code,type,...,fertility,life_exp,life_years,population,births,deaths,years_lived,births_cum,deaths_cum,years_lived_cum
0,Zero migration,100,replacement,2100,medium,Afghanistan,2025,0,AFG,country,...,,,4.800,6.951470e+06,,371427.069,3.316445e+07,,3.714271e+05,3.316445e+07
1,Zero migration,100,replacement,2100,medium,Afghanistan,2025,5,AFG,country,...,,,4.754,6.296935e+06,,23285.025,3.141894e+07,,2.328503e+04,3.141894e+07
2,Zero migration,100,replacement,2100,medium,Afghanistan,2025,10,AFG,country,...,0.006,,4.737,5.620554e+06,,31193.672,2.803047e+07,,3.119367e+04,2.803047e+07
3,Zero migration,100,replacement,2100,medium,Afghanistan,2025,15,AFG,country,...,0.336,,4.710,5.014211e+06,,59189.641,2.494372e+07,,5.918964e+04,2.494372e+07
4,Zero migration,100,replacement,2100,medium,Afghanistan,2025,20,AFG,country,...,0.999,,4.655,4.506177e+06,,69085.853,2.236314e+07,,6.908585e+04,2.236314e+07
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
218587,Zero migration,100,replacement,2100,medium,Zimbabwe,2200,85,ZWE,country,...,,,2.243,8.921622e+05,,435638.911,3.315260e+06,,8.042021e+06,5.337027e+07
218588,Zero migration,100,replacement,2100,medium,Zimbabwe,2200,90,ZWE,country,...,,,1.167,4.281594e+05,,279010.800,1.358069e+06,,4.439873e+06,1.961308e+07
218589,Zero migration,100,replacement,2100,medium,Zimbabwe,2200,95,ZWE,country,...,,,0.419,1.367223e+05,,75373.326,4.534825e+05,,1.364634e+06,4.982619e+06
218590,Zero migration,100,replacement,2100,medium,Zimbabwe,2200,100,ZWE,country,...,,,0.726,9.938674e+04,,42813.745,3.800193e+05,,3.406415e+05,1.803464e+06


# Main Code
The code below will perform the population projections when executed above. See 'Start Here' for explanation.

In [2]:
# import demographic projection files from UN WPP and re-format for our use
def prep_WPP_files(files,variants=['Zero migration'], to_csv = True):
    for file in files:
        this_file = prep_1_file(files,file)
        try:
            # merge with master dataframe
            data = pd.merge(data, this_file, how = 'left')
        except:
            # if merging failed, this is the beginning of the dataframe
            data = this_file
    # rename some types and drop if not in list of types
    types = {'Country/Area':'country','World':'world','Geographic region':'region','sdg_region':'subregion','SDG region':'sdg_region','Income group':'income_group'}
    data.loc[data['type'].isin(types),'type'] = data.loc[data['type'].isin(types),'type'].map(types)
    data = data.query('type.isin(@types.values())')
    # filter data so that only years and ages divisible by years per period remain
    data = data.query('year % @y_per_period == 0 & age % @y_per_period == 0 & variant.isin(@variants)')
    data = data.sort_values(by = ['variant','location','age','year'])
    year_end = data['year'].max()
    for col in ['fertility','sex_ratio_birth','life_years_f','life_years_m','p_survive_f','p_survive_m','dead_years_f','dead_years_m']:
        # take the average of the given period and the following period for each col given. 
        # This will make projections more accurate since we'd otherwise be skewed toward earlier birth/death rates
        data.loc[data['year'] < year_end,col] = (data.query('year < @year_end')[col].array + data.query('year > year.min()')[col].array) / 2
        data.loc[data['year'] == year_end,col] = data.query('year == @year_end - @y_per_period')[col].array
        # increase/decrease fertility in the last period by the percent change in fertility between the previous 2 periods
        if col == 'fertility':
            data['dif'] = data.groupby(['variant','location','age'])['fertility'].pct_change()
            data.loc[data['year'] == year_end, 'fertility'] *= (1 + data.query('year == @year_end - @y_per_period')['dif'].array)
        # increase all other columns in the last period by the difference of that column between the previous 2 periods
        else:
            data['dif'] = data.groupby(['variant','location','age'])[col].diff()
            data.loc[data['year'] == year_end, col] += data.query('year == @year_end - @y_per_period')['dif'].array
        data.drop('dif',axis=1,inplace=True)
    # re-arrange columns, sort rows, drop duplicates (should only apply for LatAm/Caribbean since they both have "SDG region" and "region"), reset index.
    data = data[['variant','location','year','age','id','code','type','region','subregion','sdg_region','income_group','sex_ratio_birth','dead_years_f','dead_years_m','p_survive_f','p_survive_m','fertility','life_years_f','life_years_m','births','deaths','pop_f','pop_m']]
    data = data.sort_values(by = ['variant','type','location','year','age']).drop_duplicates(subset=['variant','location','year','age']).reset_index(drop=True)
    # if instructed, save as CSV file
    if to_csv == True:
      input_data.to_csv('WPP_input_data.csv', index = False)
    return data

# prep each WPP file individually
def prep_1_file(files,file):
    if file == 'labels':
        data = pd.read_excel(files[file], skiprows=16)
        # get categorical labels for each country
        data = data[['Location code','Name.1','Name.2','Name.3','High-income Countries\n1503','Middle-income Countries\n1517','Low-income Countries\n1500']]
        data.loc[data['High-income Countries\n1503'].notna(),'income_group'] = 'High-income countries'
        data.loc[data['Middle-income Countries\n1517'].notna(),'income_group'] = 'Middle-income countries'
        data.loc[data['Low-income Countries\n1500'].notna(),'income_group'] = 'Low-income countries'
        data.drop(['High-income Countries\n1503','Middle-income Countries\n1517','Low-income Countries\n1500'],axis=1,inplace=True)
        data.rename({'Location code':'id','Name.1':'sdg_region','Name.2':'subregion','Name.3':'region'},axis=1,inplace=True)
        data = data[['id','region','subregion','sdg_region','income_group']]
    else:
        data = pd.read_csv(files[file])
        data.rename({'LocID':'id','ISO3_code':'code','LocTypeName':'type','Time':'year','AgeGrpStart':'age'},axis=1,inplace=True)
        if file == 'life_table':
            data.rename({'Sex':'sex','px':'p_survive','Lx':'life_years','ax':'dead_years',},axis=1,inplace=True)
            data['life_years'] /= 100000
            # convert columns to lowercase and select only a few columns
            data.columns = data.columns.str.lower()
            data = data[['location','code','id','sex','age','year','p_survive','life_years','dead_years']]
            # sum years of life lived from 0 to 1 and 1 to 5 to get it into a 0 to 5 age group. Then divide by 100,000 since they are given per 100,000 births
            age_0, age_1 = data['age'] == 0, data['age'] == 1
            data.loc[age_0,'life_years'] += data.loc[age_1,'life_years'].array
            # p_survive age 0 to 5 = p_survive age 0 to 1 * p_survive age 1 to 5
            data.loc[age_0,'p_survive'] *= data.loc[age_1,'p_survive'].array
            # to get life-years lived by those who die (dead_years) in 0 to 5 age-group, algebraically solve the equation 
            # life_years = y_per_period * p_survive + dead_years * (1 - p_survive) - aka prob of death
            d = data.loc[age_0]
            data.loc[age_0,'dead_years'] = ( d['life_years'] - y_per_period * d['p_survive'] ) / (1 - d['p_survive'])
            data = data.query('sex != "Total" & age != 1')
            # make a separate col for each variable for each sex
            data = data.pivot(index = ['location','code','id','year','age'], columns = ['sex'], values = ['p_survive','life_years','dead_years']).reset_index()
            data.columns = ['_'.join(str(s).strip() for s in col if s) for col in data.columns]
            data.columns = data.columns.str.replace('Female','f')
            data.columns = data.columns.str.replace('Male','m')
        else:
            # convert columns to lowercase and select only a few columns
            data.columns = data.columns.str.lower()
            if file == 'population':
                data.rename({'popmale':'pop_m','popfemale':'pop_f'},axis=1,inplace=True)
                data = data[['id','type','variant','age','year','pop_m','pop_f']]
                data[['pop_f','pop_m']] = data[['pop_f','pop_m']] * 1000
            elif file == 'fertility':
                data.rename({'asfr':'fertility'},axis=1,inplace=True)
                data = data[['id','type','variant','age','year','fertility']]
                data['fertility'] /= 1000
            elif file == 'demographics_summary':
                data.rename({'srb':'sex_ratio_birth'},axis=1,inplace=True)
                data = data[['id','type','variant','year','sex_ratio_birth','births','deaths']]
                data['sex_ratio_birth'] /= 100
                data[['births','deaths']] *= 1000 * y_per_period
    return data

# main function. Use this function to perform all projections. See "Start Here" for explanation of what each parameter means
# Returns a dataframe containing population projections under the user-defined scenarios.
def project_the_universe(input_data = input_data, year_end = 2300, locations = 'all', keep_data = ['age','cum','dif','pct_dif','regions','categorical'], to_csv = False, \
  age_treat = [0], year_treat = input_data['year'].min(), treat_type = False, \
  life_exp_max = [100], start_LY_increase = input_data['year'].max(), LY_increment = False, mort_treatment = False, end_treat = False, \
  tfr_scenarios = ['replacement'], start_converge = input_data['year'].max(), converge_speeds = ['medium'], tfr_scenarios_t = False, start_converge_t = False, converge_speeds_t = False):
  # if selected specific countries, filter to only include those countries. Otherwise, include all locations.
  if locations == 'all' or 'World' in locations:
    data = input_data.copy(deep=True)
  else:
    data = input_data.query('location.isin(@locations)')
  # project demographic inputs (fertility, mortality, etc.) beyond existing UN projections
  print('projecting life years')
  data = project_life_years(data,year_end,life_exp_max,start_LY_increase,LY_increment,treat_type = False)
  print('projecting fertility')
  data = project_fertility(data,tfr_scenarios,start_converge,converge_speeds)
  # fill columns with zeros to be filled later. fill fertility nan with 0's
  data[['births','deaths_f','deaths_m','deaths_averted','pop_f_saved','pop_m_saved']] = 0
  data['fertility'] = data['fertility'].fillna(0)
  # get columns to group/sort by.
  groups = ['variant','life_exp_max','tfr_scenario','start_converge','converge_speed','location']
  # if user passsed a list to tfr_scenarios_t then the treat_type is "change TFR"
  if tfr_scenarios_t != False:
    treat_type = 'change TFR'
  # perform a population treatment / intervention if instructed
  if treat_type != False:
    print('treating population')
    data = treat_population(input_data,data,year_end,groups,treat_type,age_treat,year_treat,end_treat,mort_treatment, \
           start_LY_increase,LY_increment,tfr_scenarios_t,start_converge_t,converge_speeds_t)
  # treat_control indicates whether we're in a treatment/control scenario. Default is control, where columns have no suffix
  if 'pop_f_t' in data.columns:
    treat_control = ['','_t']
  else:
    treat_control = ['']
  print('projecting population')
  # transform fertility and life-years so we can more quickly calculate population later
  data = transform_fert_mort(data,treat_control)
  # project the population
  data = project_population(data,treat_control,groups,treat_type)
  # combine data across sexes and ages.
  print('aggregating and cleaning up data')
  data = combine_sexes(data,treat_control)
  data = combine_ages(data,groups,treat_control,keep_data)
  # get difference in outcomes between treat and control scenario if we're adding 1 person. Otherwise, do this later.
  if treat_type == 'add 1 person':
    data = get_scenario_difs(data)
  # if instructed, aggregate from country-level into global, regional and income-level projections.
  if locations == 'all' or 'World' in locations:
    data = combine_countries(data,treat_control,groups,treat_type,tfr_scenarios_t,keep_data)
    # if user selected "World", then include only global-level data
    if 'World' in locations:
      data = data.query('location == "World"')
  if '_t' in treat_control:
    # get difference in outcomes between treat and control scenario if not doing 'add 1 person' treat_type
    if 'dif' in keep_data and treat_type != 'add 1 person':
      data = get_scenario_difs(data)
  # get cummulative sums of outcomes and percent-differences between treatment and control groups
  data = get_cum_pct_difs(data,treat_type,groups,keep_data)
  # re-arrange columns, drop columns and sort rows
  data = arrange_columns(data,groups,keep_data)
  # if instructed, save file as CSV
  if to_csv != False:
    if type(to_csv) == str:
      file_name = to_csv
    else:
      file_name = 'projections_output'
    data.to_csv(file_name + '.csv', index = False)
  return data

# project years of life lived (nLx) and mortality rates in each age-group beyond original data
def project_life_years(data,year_end,life_exp_max,start_LY_increase,LY_increment = False, treat_type = False):
    groups = ['variant','location']
    # add new rows and columns we'll transform later to get new life-years
    data = prep_life_years(data,year_end,groups,life_exp_max,LY_increment,treat_type)
    # year_max = the final year in the dataset
    year_max = data['year'].max()
    # ages = an array containing all age-groups except the youngest and oldest age-groups
    ages = data['age'].unique()[1:-1]
    # project life-years for females and males
    for LY_name in ['life_years_f','life_years_m']:
        # set index "year" for data to increase search speed in loops
        data = data.set_index('year')
        # d = data from the start of projecting life_years, with only a few columns
        d = data.query('year == @start_LY_increase')[groups + ['age',LY_name,'LY_increment','LY_max']]
        # add a column of zeros to use for quick comparisons later
        d['zeros'] = np.zeros(len(d))
        # project next period's life-years using the previous period's data
        for year in range(start_LY_increase + y_per_period, year_max + y_per_period, y_per_period):
            d = project_life_years_1_period(data,d,ages,LY_name,groups,year)
        # ensure life-exp doesn't exceed life_exp_max
        data.reset_index(inplace=True)
        try:
            data = adjust_life_exp_max(data,LY_name,groups)
        except:
            # if the line above had an error, do nothing. this likely means we didn't exceed life_exp_max in any location
            pass
    # drop unnecessary columns and sort values
    data = data.drop(['LY_increment','LY_max'],axis=1)
    return data.sort_values(by= groups + ['year','age']).reset_index(drop=True)

# add new rows and columns we'll transform later to get new life-years
def prep_life_years(data,year_end,groups,life_exp_max,LY_increment,treat_type):
    # get LY_increment, the constant amount to add to life-years each period. This varies by location, sex, and age.
    if LY_increment == False:
        # if no increment was passed by the user, by default each year after start_LY_increase will 
        # increase life-years by an increment equal to whatever the difference in life-years was 
        # in each age-group the year before start_LY_increase
        data['LY_increment'] = data.groupby(groups + ['age'])['life_years_f'].diff()
    else:
        # if the user did pass an increment, then divide it by the number of age-groups.
        # each age-group will increase by this amount each period after start_LY_increase
        LY_increment /= len(data['age'].unique())
        data['LY_increment'] = LY_increment
    # copy the dataframe, where each copy contains a column indicating maximum life-exp that scenario will converge to
    data = copy_data(data,life_exp_max,'life_exp_max')
    groups += ['life_exp_max']
    # life-years cannot exceed the number of years in each age-group.
    data['LY_max'] = y_per_period
    # the 100+ age-group can exceed 5 only if the user passed a life_exp_max > 105
    age = data['age'] == 100
    data.loc[age,'LY_max'] = np.max([data.loc[age,'LY_max'].array,data.loc[age,'life_exp_max'].array - 100],axis=0)
    orig_end = data['year'].max()
    # if we're projecting beyond the end of UN data, then copy the last year's data.
    if year_end > orig_end:
        next_period_data = data.query('year == @orig_end')
        for year in range(orig_end + y_per_period, year_end + y_per_period, y_per_period):
            next_period_data['year'] = year
            data = pd.concat([data,next_period_data],ignore_index=True)
    return data

# copy a dataframe multiple times and concatenate rows. each copy has a column labeled 'list_name' indicating which scenario it is
def copy_data(input_data,values,list_name):
    data = pd.DataFrame()
    if type(values) != list:
      values = [values]
    for item in values:
        copy = input_data.copy(deep=True)
        copy[list_name] = item
        data = pd.concat([data,copy],ignore_index=True)
    return data.reset_index(drop=True)

# project life-years lived 1 period after the given year
def project_life_years_1_period(data,d,ages,LY_name,groups,year):
    # update a copy of this period's life-years using the per-period age-specific life-years increment for each country
    d['LY_new'] = d[LY_name] + d['LY_increment']
    # each age-group can experience no more than the life-years experienced by the previous age-group
    # except for ages 0-5 (max = 5) and 100+ (max = life_exp_max - 100, or 5 if life_exp_max < 105)
    d = d.set_index('age')
    d.loc[0,LY_name] = d.loc[0,['LY_new','LY_max']].min(axis=1)
    d.loc[100,LY_name] = d.loc[100,['LY_new','LY_max']].min(axis=1)
    for age in ages:
        d.loc[age,'LY_max'] = d.loc[age - y_per_period,LY_name].array
        d.loc[age,LY_name] = d.loc[age,['LY_new','LY_max']].min(axis=1)
    # get distance between life-years in each age-group and the maximum they can possibly be. this cannot be less than 0.
    d['dist_above_max'] = d['LY_new'] - d['LY_max']
    d['dist_above_max'] = d[['dist_above_max','zeros']].max(axis=1)
    # sum dist_above_max for each group. We will "re-distribute" this amount for 
    # each group across the age-groups that haven't yet reached their maximum
    dist_above_max_sum = d.groupby(groups)['dist_above_max'].sum().reset_index().rename({'dist_above_max' : 'dist_above_max_sum'},axis=1)
    d = pd.merge(d.reset_index(),dist_above_max_sum, how = 'left')
    # get life-years in each age-group that are below maximum, and get the sum of age-specific life-years for each group. this cannot be less than 0
    d['dist_below_max'] = d['LY_max'] - d[LY_name]
    d['dist_below_max'] = d[['dist_below_max','zeros']].max(axis=1)
    dist_below_max_sum = d.groupby(groups)['dist_below_max'].sum().replace(0,np.nan).reset_index().rename({'dist_below_max' : 'dist_below_max_sum'},axis=1)
    d = pd.merge(d,dist_below_max_sum, how = 'left')
    # ages with a higher distance from maximum get a higher proportion of the dist_above_max_sum
    # e.g. if dist_above_max_sum = 2, dist_below_max_sum = 10, and dist_below_max for age 95-100 = 3, then 
    # age 95-100 gets an additional 2 * 3 / 10 = 0.6 life-years
    d['new_increment'] = d['dist_above_max_sum'] * d['dist_below_max'] / d['dist_below_max_sum']
    # increase life-years in official dataset. return this period's data
    d[LY_name] += d['new_increment'].fillna(0)
    data.loc[year,LY_name] = d[LY_name].array
    d.drop(['dist_above_max_sum','dist_below_max_sum'],axis=1,inplace=True)
    return d

# ensure life-exp doesn't exceed life_exp_max
def adjust_life_exp_max(data,LY_name,groups):
    # get life-expectancy for each location and period
    life_exp = data.groupby(groups + ['year'])[LY_name].sum().reset_index().rename({LY_name : 'life_exp'},axis=1)
    data = pd.merge(data,life_exp, how = 'left')
    # get 'year_stop', the year when life-exp is closest to but not above life_exp_max for each location 
    year_stop = data.query('life_exp > life_exp_max').groupby(groups)['year'].min().reset_index().rename({'year' : 'year_stop'},axis=1)
    data = pd.merge(data,year_stop, how = 'left')
    # get the 'LY_stop', the age-specific life-years at year 'year_stop' for each location 
    LY_stop = data.query('life_exp > life_exp_max & year == year_stop').rename({LY_name : 'LY_stop'},axis=1)
    # slightly decrease LY_stop so they exactly sum to life_exp_max. 
    LY_stop['LY_dif'] = LY_stop['life_exp'] - LY_stop['life_exp_max']
    # Subtract the difference between LY_stop and life_exp_max from LY_stop at age 100 
    # (e.g. if life_exp = 100 and ending life-exp = 101, then subtract 1 from age 100)
    age = LY_stop['age'] == 100
    LY_stop.loc[age,'LY_stop'] -= LY_stop.loc[age,'LY_dif']
    # join LY_stop with main dataset
    LY_stop = LY_stop[groups + ['age','LY_stop']]
    data = pd.merge(data,LY_stop, how = 'left')
    # set life-years equal to  'LY_stop' for each location and for all periods 
    above_max = (data['life_exp'] >= data['life_exp_max']) & (data['year'] >= data['year_stop'])
    data.loc[above_max,LY_name] = data.loc[above_max,'LY_stop']
    data.drop(['life_exp','year_stop','LY_stop'],axis=1,inplace=True)
    return data

# output: projections of years of fertility rates in each age-group beyond UN end
def project_fertility(data,tfr_scenarios,start_converge,converge_speeds,converge_twice=False):
    groups = ['variant','life_exp_max','tfr_scenario','start_converge','converge_speed','location']
    # get labels indicating whether this is the 1st or 2nd time we're converging fertility in each scenario. include these in columns to group by
    if converge_twice == True:
        tfr_scenario_name, start_converge_name, converge_speed_name = 'tfr_scenario_t', 'start_converge_t', 'converge_speed_t'
        groups += [tfr_scenario_name, start_converge_name, converge_speed_name]
    else:
        tfr_scenario_name, start_converge_name, converge_speed_name = 'tfr_scenario', 'start_converge', 'converge_speed'
    # copy data with indicators for what tfr is converging to (tfr_scenario) and at what year this convergence begins (start_converge)
    # if left False, set start_converge_t and converge_speed_t to equal the baseline scenario if we're converging for the 2nd time
    data = copy_data(data,tfr_scenarios,tfr_scenario_name)
    if converge_twice == True and start_converge == False:
      data[start_converge_name] = data['start_converge']
    else:
      data = copy_data(data,start_converge,start_converge_name)
    data = assign_converge_speeds(data,converge_speed_name,converge_speeds,groups,converge_twice)
    # converge sex ratio at birth to be the same as global average at the original year-end before we created new years. start converging at that year
    data = project_sex_ratio_birth(data,converge_twice)
    # keep only country-level data
    data = data.query('type == "country"')
    # if instructed, get 'replacement' TFR in each life_exp_max scenario and location that causes zero long-run population growth
    if 'replacement' in tfr_scenarios:
        data = get_replacement_tfr(data,tfr_scenario_name,groups)
    # all fertility rates and rates of change are the same after start_converge.
    data = project_converge_speed(data,start_converge_name,groups,converge_twice)
    # calculate future fertility rates by converging them to a given TFR
    data = converge_fertility(data,tfr_scenario_name,start_converge_name,groups)
    # ensure TFR stops moving once it reaches the given tfr target
    try:
        data = adjust_fertility(data,tfr_scenario_name,start_converge_name,groups)
    except:
        # if the "try" statement had an error, do nothing. 
        # this means we didn't exceed tfr_scenario in any location
        pass
    # re-label tfr_scenario for replacement to say "replacement" instead of a TFR number
    if 'replacement' in tfr_scenarios:
        data.loc[data[tfr_scenario_name].isin(tfr_scenarios) == False,tfr_scenario_name] = 'replacement'
    # drop unnecessary columns
    for col in ['pct_change','pct_change_world','sign','periods_since_start']:
        if col in data.columns:
            data.drop(col,axis=1,inplace=True)
    return data

# create copies of data for each converge_speed and assign each row a converge_speed at which TFR will change per period
def assign_converge_speeds(data,converge_speed_name,converge_speeds,groups,converge_twice):
    if converge_twice == True and converge_speeds == False:
      converge_speeds = data['converge_speed'].unique()
      data[converge_speed_name] = data['converge_speed']
    else:
      data = copy_data(data,converge_speeds,converge_speed_name)
    # if user passed these strings in speed_dict, convert to a number to use as the per-period pct_change
    try:
      speed_dict = {'very slow' : .01, 'slow' : .02, 'medium' : .03, 'fast' : .04, 'very fast' : 0.05}
      filter = data[converge_speed_name].isin(speed_dict)
      data.loc[filter, 'pct_change'] = data.loc[filter, converge_speed_name].map(speed_dict)
    except:
      pass
    # get absoltue value of converge speeds and exclude "NA" value if it exists
    # if the user did enter a converge_speed, then copy this data for each converge-speed given in the list. 
    numeric_speeds = [speed for speed in converge_speeds if type(speed) != str]
    if len(numeric_speeds) > 0:
      # all locations and age-groups will converge to tfr_scenario at the given converge_speed, except for default option where we already have pct_change
      data.loc[data[converge_speed_name].isin(numeric_speeds),'pct_change'] = data.loc[data[converge_speed_name] != 'NA',converge_speed_name]
    return data

# alter sex-ratio at birth (SRB) so all countries eventually have same SRB as UN predicted global avg SRB by 2100. 
def project_sex_ratio_birth(data,converge_twice):
  orig_end = 2100
  if converge_twice == False and data['year'].max() > orig_end:
    # increase/decrease each country's SRB by pct_change each period until they reach the target SRB
    pct_change = 0.005
    data['sex_ratio_world'] = data.query('location == "World" & year == @orig_end')['sex_ratio_birth'].array[0]
    data['sign'] = 1
    data.loc[data['sex_ratio_birth'] >= data['sex_ratio_world'],'sign'] = -1
    data['periods_since_start'] = (data['year'] - orig_end) / y_per_period
    already_converged = (data['sex_ratio_birth'] > (data['sex_ratio_world'] - pct_change)) & (data['sex_ratio_birth'] < (data['sex_ratio_world'] + pct_change))
    years = (data['year'] > orig_end)
    filters = years & (already_converged == False)
    # increase or decrease sex ratio at birth by pct_change every period until reaching the global average from 2100
    data.loc[filters,'sex_ratio_birth'] *= (1 + data.loc[filters,'sign'] * pct_change) ** data.loc[filters,'periods_since_start']
    too_low = (data['sign'] == -1) & (data['sex_ratio_birth'] < data['sex_ratio_world'])
    too_high = (data['sign'] == 1) & (data['sex_ratio_birth'] > data['sex_ratio_world'])
    filters = years & (too_low | too_high | already_converged)
    data.loc[filters,'sex_ratio_birth'] = data.loc[filters,'sex_ratio_world']
    data.drop(['sign','periods_since_start','sex_ratio_world'],axis=1,inplace=True)
  return data

# get 'replacement' TFR in each life_exp_max scenario and location that causes zero long-run population growth
def get_replacement_tfr(data,tfr_scenario_name,groups):
    # get the age at which women stop having children
    age_end_fert = data.query('fertility > 0 & fertility.notna()')['age'].max()
    replacement_data = data[data[tfr_scenario_name] == 'replacement'].drop(tfr_scenario_name,axis=1)
    # get data for ages below when women stop having children 
    # for the final year in the dataset (assumed to have the maximum life-expectancy)
    replacement_data_2 = replacement_data.query('age < @age_end_fert & year == year.max()')\
      [['life_exp_max','location','age','life_years_f','sex_ratio_birth']].drop_duplicates(subset=['life_exp_max','location','age']).groupby(['life_exp_max','location'])
    # life-years experienced by the average woman by age_end_fert
    life_exp = replacement_data_2['life_years_f'].sum()
    # probability of surviving to age_end_fert
    p_survive = life_exp / age_end_fert
    # percent of births that are female
    births_pct_f = 1 / (1 + replacement_data_2['sex_ratio_birth'].mean())
    # We need (1 / births_pct_f) in order to have enough females to make up for the fact that only females have children
    # and we need (1 / p_survive) females born in order for the average female born to replace her and the father
    replacement_tfr = 1 / (births_pct_f * p_survive)
    # set original dataset's tfr_scenario indicator to NaN so we can replace it with the new replacement-rate tfr_scenario when we merge
    replacement_tfr = replacement_tfr.reset_index().rename({0 : tfr_scenario_name},axis=1)
    replacement_data = pd.merge(replacement_data,replacement_tfr, how = 'left')
    # merge replacement data with the rest of the data, if there is any
    data = data[data[tfr_scenario_name] != 'replacement']
    if len(data) > 0:
      data = pd.concat([data,replacement_data],ignore_index=True)
    else:
      data = replacement_data
    return data.sort_values(by = groups + ['year']).reset_index(drop=True)

# all fertility rates and rates of change are the same after start_converge.
def project_converge_speed(data,start_converge_name,groups,converge_twice):
    # set fertility in periods after original year end to be the same. we'll update this soon
    start_data = data.loc[data['year'] == data[start_converge_name], groups + ['age','fertility','pct_change']]
    start_data.rename({'fertility':'fert_converge','pct_change':'pct_change_converge'},axis=1,inplace=True)
    data = pd.merge(data,start_data, how = 'left')
    converging = data['year'] > data[start_converge_name]
    data.loc[converging,'fertility'] = data.loc[converging,'fert_converge']
    data.loc[converging,'pct_change'] = data.loc[converging,'pct_change_converge']
    data.drop(['fert_converge','pct_change_converge'],axis=1,inplace=True)
    # if this is the 2nd time we're converging and pct_change is zero (we've already converged), set fertility pct_change to the world average.
    if converge_twice == True and 'pct_change_world' in data.columns:
        # set pct_change equal to the last year in which pct_change_world is greater than 0 
        pct_change_world = data.query('pct_change_world > 0')[groups + ['year','age','pct_change_world']]
        pct_change_world = pct_change_world.query('year == year.max()').drop('year',axis=1)
        data.drop('pct_change_world',axis=1,inplace=True)
        data = pd.merge(data,pct_change_world, how = 'left')
        data['pct_change'] = data[['pct_change','pct_change_world']].max(axis=1)
    return data

# calculate future fertility rates by converging them to a given TFR
def converge_fertility(data,tfr_scenario_name,start_converge_name,groups):
    # get total fertility rate for each location for each year
    tfr = y_per_period * data.groupby(groups + ['year'])['fertility'].sum()
    tfr = tfr.reset_index().rename({'fertility':'tfr'},axis=1)
    data = pd.merge(data,tfr, how = 'left')
    # create col "sign" indicating whether TFR is above vs below tfr_scenario during year start_converge
    # those above get sign of -1 so fertility will decrease while those below get sign of 1 so fertility will increase
    data.loc[(data['year'] > data[start_converge_name]) & (data['tfr'] > data[tfr_scenario_name]),'sign'] = -1
    data.loc[(data['year'] > data[start_converge_name]) & (data['tfr'] < data[tfr_scenario_name]),'sign'] = 1
    sign = data.query('sign.notna()').groupby(groups)['sign'].mean().reset_index()
    data.drop(['sign','tfr'],axis=1,inplace=True)
    data = pd.merge(data,sign, how = 'left')
    # each period on or after start_converge, this period's fertility rate = fertility rate at year of starting convergence multiplied by
    # the percent change per period raised to the power of the number of periods that have passed since starting convergence
    data['periods_since_start'] = (data['year'] - data[start_converge_name]) / y_per_period
    years = data['year'] > data[start_converge_name]
    data.loc[years,'fertility'] *= (1 + data.loc[years,'sign'] * data.loc[years,'pct_change']) ** data.loc[years,'periods_since_start']
    return data

# ensure TFR stops moving once it reaches the given tfr target
def adjust_fertility(data,tfr_scenario_name,start_converge_name,groups):
    # get total fertility rate for each location and period
    tfr = y_per_period * data.groupby(groups + ['year'])['fertility'].sum()
    tfr = tfr.reset_index().rename({'fertility':'tfr'},axis=1)
    data = pd.merge(data,tfr, how = 'left')
    # for each group get the year just after fertility should be fully converged (stops moving)
    tfr_too_high = (data['sign'] == 1) & (data['tfr'] > data[tfr_scenario_name])
    tfr_too_low = (data['sign'] == -1) & (data['tfr'] < data[tfr_scenario_name])
    converged = (tfr_too_high | tfr_too_low) & (data['year'] >= data[start_converge_name])
    year_stop = data.loc[converged].groupby(groups)['year'].min().reset_index().rename({'year' : 'year_stop'},axis=1)
    data = pd.merge(data,year_stop, how = 'left')
    # get the 'fert_stop', the age-specific fertility rates at year 'year_stop' for each location 
    # slightly decrease fert_stop so age-specific fertility rates exactly sum to tfr_scenario. 
    fert_stop = data.loc[converged & (data['year'] == data['year_stop'])].rename({'fertility' : 'fert_stop'},axis=1)
    # subtract from fert_stop the amount needed to have fert_stop TFR exactly equal target TFR
    fert_stop['fert_dif'] = abs(fert_stop['tfr'] - fert_stop[tfr_scenario_name])
    fert_stop['age_weights'] = fert_stop['fert_stop'] / fert_stop['tfr']
    fert_stop['fert_stop'] += (-1) * fert_stop['sign'] * fert_stop['fert_dif'] * fert_stop['age_weights']
    fert_stop = fert_stop[groups + ['age','fert_stop']]
    data = pd.merge(data,fert_stop, how = 'left')
    # set fertility equal to  'fert_stop' for each location and for all periods after converging beyond tfr_scenario
    tfr_too_high = (data['sign'] == 1) & (data['tfr'] > data[tfr_scenario_name])
    tfr_too_low = (data['sign'] == -1) & (data['tfr'] < data[tfr_scenario_name])
    converged = (tfr_too_high | tfr_too_low) & (data['year'] >= data[start_converge_name])
    data.loc[converged,'fertility'] = data.loc[converged,'fert_stop']
    data.drop(['tfr','year_stop','fert_stop'],axis=1,inplace=True)
    return data

# perform a population treatment / intervention
def treat_population(input_data,data,year_end,groups,treat_type,age_treat,year_treat,end_treat,mort_treatment, \
    start_LY_increase,LY_increment,tfr_scenarios_t,start_converge_t,converge_speeds_t):
    # format age_treat so we can treat the given age-groups
    if age_treat == 'all' or age_treat == ['all']:
        ages_treat = data['age'].unique()
    elif type(age_treat) == int:
        ages_treat = [age_treat]
    else:
        ages_treat = age_treat
    # if adding 1 person, don't perform the intervention yet. just create new rows for every age_treat and year_treat
    if treat_type == 'add 1 person':
      data = copy_data(data,ages_treat,'age_treat')
      data = copy_data(data,year_treat,'year_treat')
      treat_data = data
      treat_data['treatment'] = treat_type
      groups += ['treatment','age_treat','year_treat']
    # if changing mortality, generate new life year projections and merge new columns with baseline projection
    elif treat_type in ['target','update','pp','%']:
      treat_data = change_mortality(input_data,ages_treat,year_treat,end_treat,treat_type,mort_treatment)
      treat_data = project_life_years(treat_data,year_end,life_exp_max,start_LY_increase,LY_increment,treat_type)
      cols = ['life_years_f','life_years_m','dead_years_f','dead_years_m','p_survive_f','p_survive_m']
      d = data.drop(cols,axis=1)
      if 'mortality_group' in treat_data.columns:
        cols += ['mortality_group']
      treat_data = treat_data[['variant','life_exp_max','location','age_treat','year_treat','treatment','year','age'] + cols]
      treat_data = pd.merge(d,treat_data, how = 'left')
      groups += ['treatment','age_treat','year_treat']
    # if changing fertility, generate new projections given new fertility convergence assumptions. merge fertility with baseline projections
    elif treat_type == 'change TFR':
      treat_data = project_fertility(data,tfr_scenarios_t,start_converge_t,converge_speeds_t,converge_twice = True)
      treat_data['treatment'] = treat_type
      #data[['tfr_scenario_t','start_converge_t','converge_speed_t']].replace('NA',np.nan,inplace=True)
      groups += ['treatment','tfr_scenario_t','start_converge_t','converge_speed_t']
    # if none of the above are true, we didn't treat the population. return original data
    else:
      return data
    # merge treated columns with baseline columns in official dataframe. treated column names end with '_t'
    columns = ['pop_m','pop_f','fertility','life_years_f','life_years_m','births','deaths_f','deaths_m','p_survive_f','p_survive_m']
    all_cols = groups + ['year','age'] + columns
    if 'mortality_group' in treat_data.columns:
        all_cols += ['mortality_group']
    treat_data = treat_data[all_cols].rename({c : c + '_t' for c in columns},axis=1) 
    data = pd.merge(data,treat_data, how = 'left')
    # only add these columns if changing mortality
    if treat_type in ['target','update','pp','%']:
      for col in ['life_years_f','life_years_m','p_survive_f','p_survive_m']:
        data[col + '_t'] = data[[col, col + '_t']].max(axis=1)
    return data

# change mortality rates and get the resulting life-years lived (nLx) 
# select treat_type = either 'target' (decrease mortality to reach a given threshold), 'update' (immediately change mortality in given years and ages), 
    # '%' (decrease mortality in given years and ages by given percentage), or 'pp' (decrease mortality in given years and ages by given percentage points). 
def change_mortality(input_data,age_treat,start_treat,end_treat,treat_type,mort_treatment=False):
    data = input_data.copy(deep=True)
    # update mortality to a given mortality target in countries below the target in ages age_treat starting at start_treat and ending at end_treat
    if treat_type == 'target':
      data = target_mortality(data,start_treat,end_treat,age_treat,mort_treatment)
      treatment_label = 'decrease mortality to ' + str(mort_treatment)
    # interventions below change probability of survival in different ways and then recalculates life-years lived for each sex
    else:
        # get booleans to filter data
        years = data['year'].isin(list(range(start_treat, end_treat + y_per_period, y_per_period)))
        if age_treat == 'all' or age_treat == ['all']:
            ages = data['age'].isin(data['age'].unique())
        else:
            ages = data['age'].isin(age_treat)
        for sex in ['_f','_m']:
            p_survive = 'p_survive' + sex
            if treat_type == 'update':
                # assign the new survival rate as 1 - the given mortality rate
                # e.g. new_mortality = 0.01 would mean a 99% probability of survival
                data.loc[ages & years,p_survive] = (1 - mort_treatment)
                treatment_label = 'new mortality = ' + str(mort_treatment * 100) + '%'
            elif treat_type == 'pp':
                # increase probability of surviving by adding the percentage points to reduce mortality by
                # e.g. pct_pt_reduction = 5 would increase probability of survival by 5 percentage points (max of 100% total)
                data.loc[ages & years,p_survive] += (mort_treatment / 100)
                treatment_label = 'reduce mortality by ' + str(mort_treatment) + 'pp'
            elif treat_type == '%':
                # multiply mortality by a given percentage of current mortality, then turn new mortality into prob of survival (1- mortality)
                # e.g. reduction_factor = 0.5 would cut mortality in half.
                mortality = mort_treatment * (1 - data.loc[ages & years,p_survive])
                data.loc[ages & years,p_survive] = 1 - mortality
                treatment_label = 'reduce mortality by ' + str(mort_treatment * 100) + '%'
            # prob of survival is between 0 and 1
            data.loc[data[p_survive] > 1,p_survive] = 1
            data.loc[data[p_survive] < 0,p_survive] = 0
        # given the new survival rates, convert these into life-years lived
        data = convert_survival_to_life_years(data)
    # insert labels indicating the treatment type and its parameters and the years and ages treat
    if type(age_treat) == int:
        age_treat = age_treat
    elif len(age_treat) == 1:
        age_treat = age_treat[0]
    else:
        age_treat = str(age_treat)
    data['year_treat'], data['age_treat'], data['treatment'] = str(start_treat) + '-' + str(end_treat), age_treat, treatment_label
    return data

# update mortality to a given mortality target in countries below the target in ages age_treat starting at start_treat and ending at end_treat
def target_mortality(input_data,start_treat,end_treat,age_treat,mort_target):
  data = input_data.copy(deep=True).sort_values(by=['variant','location','year','age']).reset_index(drop=True)
  groups = ['variant','location','age']
  data, input_data = data.sort_values(by = groups + ['year']).reset_index(drop=True), input_data.sort_values(by = groups + ['year']).reset_index(drop=True)
  # converge mortality target into survival rate. 'countries' = list of country names which are below mort_target by year end_treat in ages age_treat
  p_survive_target = 1 - (mort_target / 1000)
  countries = data.query('year == @end_treat & age == @age_treat & p_survive_m < @p_survive_target')['location'].unique()
  # change female and male survival rates to reach the target
  for sex in ['_f','_m']:
    p_survive, population = 'p_survive' + sex, 'pop' + sex
    # get how far (as pct) each period should be from completing convergence.
    pct_converged = (data['year'] + y_per_period - start_treat) / (end_treat + y_per_period - start_treat)
    data['new_p_survive'] = data[p_survive] + pct_converged * (p_survive_target - data[p_survive])
    # new_p_survive is always at least as high as old p_survive
    data['new_p_survive'] = data[['new_p_survive',p_survive]].max(axis=1)
    # after converging, decrease mortality at speed of high mortality countries
    high_p_survive = data.query('location.isin(@countries) & age.isin(@age_treat)')
    high_p_survive = weighted_avg(high_p_survive, values = p_survive, weights = population, groups = ['variant','year','age']).drop_duplicates().rename({p_survive : 'p_survive_increment'},axis=1)
    high_p_survive['p_survive_increment'] = high_p_survive.groupby(['variant','age'])['p_survive_increment'].diff()
    data = pd.merge(data, high_p_survive, how = 'left')
    # increase new_p_survive by same increment each period until reaching target by year end_treat
    years = [y for y in data['year'].unique() if y > end_treat]
    data.set_index('year',inplace=True)
    for year in years:
      data.loc[year,'new_p_survive'] = data.loc[year - y_per_period, 'new_p_survive'].array + data.loc[year, 'p_survive_increment'].array
    data.reset_index(inplace=True)
    # prob survival must be between 0 and 1
    try:
      data.loc[data['new_p_survive'] > 1, 'new_p_survive'] = 1
      data.loc[data['new_p_survive'] < 0, 'new_p_survive'] = 0
    except:
      pass
    # set the new p_survive as the maximum of either the original p_survive or the new p_survive
    treat = (data['location'].isin(countries)) & (data['age'].isin(age_treat))
    data.loc[treat, p_survive] = data.loc[treat, ['new_p_survive', p_survive]].max(axis=1)
    # drop unnecessary columns. 
    data.drop(['new_p_survive','p_survive_increment'],axis=1,inplace=True)
  # convert from probability of survival into life-years
  data = convert_survival_to_life_years(data).drop_duplicates()
  # indicate whether each country was in the low (untreated) vs high (treated) mortality group
  data['mortality_group'] = 'Low-mortality countries'
  data.loc[data['location'].isin(countries), 'mortality_group'] = 'High-mortality countries'
  return data.sort_values(by = ['variant','location','year','age']).reset_index(drop=True)

# get weighted average of given values in dataframe
def weighted_avg(data, values, weights, groups):
    d = data.copy(deep=True)
    d[values] = d[values].multiply(d[weights], axis = 'index')
    if type(values) != list:
      values = [values]
    d = d.groupby(groups)[values + [weights]].sum().reset_index()
    d[values] = d[values].divide(d[weights], axis = 'index')
    return d.drop(weights, axis = 1)

# convert new survival rates into life-years lived (nLx)
def convert_survival_to_life_years(input_data):
    data = input_data.copy(deep=True).sort_values(by=['age','variant','location','year']).set_index('age')
    for sex in ['_f','_m']:
        for age in data.index.unique():
            if age == 0:
                # set hypothetical survivors at age 0 equal to 1
                data.loc[age,'survivors'] = 1
            else:
                # survivors beginning this age = survivors beginning last age * prob they survived to this age
                data.loc[age,'survivors'] = data.loc[age - y_per_period,'survivors'].array * data.loc[age - y_per_period,'p_survive' + sex].array
            d = data.loc[age]
            # prob of surviving * given number of years in this age-group interval
            LY_survivors = d['p_survive' + sex] * y_per_period
            # prob of death * given number of years (a(x,n)) that those who die typically live in this age-group
            LY_dead = (1 - d['p_survive' + sex]) * d['dead_years' + sex]
            # total life-years lived in this age-group = survivors at start of age * (life-yeard lived by survivors of this age + life-years lived by those who die this age-group)
            data.loc[age,'life_years' + sex] = d['survivors'] * (LY_survivors + LY_dead)
        data.drop('survivors',axis=1,inplace=True)
    # change data back into original sorting and format
    data = data.reset_index().sort_values(by=['variant','location','year','age']).reset_index(drop=True)
    return data

# transform fertility and life-years so we can multiply population by them to get births and survivors in each age-group
def transform_fert_mort(data,treat_control):
  age_max = data['age'].max()
  for treat in treat_control:
    for sex in ['_f','_m']:
      life_years, survival_ratio = 'life_years' + sex + treat, 'survival_ratio' + sex + treat
      # survival ratio = life-years in that age / life-years in the previous age
      data.loc[data['age'] != age_max,survival_ratio] = data.loc[data['age'] != 0,life_years].array / data.loc[data['age'] != age_max,life_years].array
      # survival ratio in oldest and sceond-oldest age-group = life-years in oldest age-group / (life-years in oldest age-group + life-years in 2nd oldest age-group)
      filter, filter2 = data['age'] == age_max, data['age'] == age_max - y_per_period
      oldest, oldest_2nd = data.loc[filter,life_years].array, data.loc[filter2,life_years].array
      data.loc[filter,survival_ratio] = data.loc[filter2,survival_ratio] = oldest / (oldest + oldest_2nd)
    # merge each row with the next age-group's survival ratio and fertility
    fertility, fertility_f, fertility_m, survival_ratio_f = 'fertility' + treat, 'fertility_f' + treat, 'fertility_m' + treat, 'survival_ratio_f' + treat
    data.loc[data['age'] != age_max,'survival_ratio_f_next_age'] = data.loc[data['age'] != 0,survival_ratio_f].array
    data.loc[data['age'] != age_max,'fertility_next_age'] = data.loc[data['age'] != 0,fertility].array
    # get the average number of births per woman between the current age-group and the next age-group, where only those who survive the period have children
    data[fertility_f] = (data[fertility] + data['survival_ratio_f_next_age'] * data['fertility_next_age']) / 2
    data.drop(['fertility_next_age','survival_ratio_f_next_age'],axis=1,inplace=True)
    # fertility in the oldest age-group is just its own fertility rate
    data.loc[data['age'] == age_max,[fertility_f,fertility_m]] = data.loc[data['age'] == age_max,fertility]
    # use same fertility rates for male and female births. adjust using sex ratio of males to females
    data[fertility_m] = data[fertility_f] * data['sex_ratio_birth'] / (data['sex_ratio_birth'] + 1)
    # adjust female births per woman based on sex ratio at birth
    data[fertility_f] /= (data['sex_ratio_birth'] + 1)
  return data

# project the population using previously projected fertility and mortality rates
def project_population(data,treat_control,groups,treat_type):
  age_max = data['age'].max()
  survivor_ages = data['age'].unique()
  survivor_ages = survivor_ages[survivor_ages != 0]
  year_max = data['year'].max()
  years = data['year'].unique()
  data = data.sort_values(by = groups).set_index(['year','age']).sort_index()
  mort_treatment = treat_type in ['target','update','pp','%']
  # indicate whether treatment == "add 1 person" (i.e. the user passed a number of age_treat)
  for year in years:
    if treat_type == 'add 1 person':
      # if instructed and the year is right, try adding 1 person to the treat population. Reset index so we can filter by year later
      data.loc[year] = add_1_person(data.loc[year],year)
    for treat in treat_control:
      for sex in ['_f','_m']:
        # project population for the next period using this period's data
        d = data.loc[year].copy(deep=True)
        # get col names for this sex and treatment scenario
        suffix = sex + treat
        pop, fertility, pop_saved, deaths, births, survival_ratio, life_years, pop_f \
        = 'pop' + suffix, 'fertility' + suffix, 'pop' + sex + '_saved', 'deaths' + suffix, 'births' + treat, 'survival_ratio' + suffix, 'life_years' + suffix, 'pop_f' + treat
        # annual births of this sex = number of births of this sex per woman * female population
        d[fertility] *= d[pop_f]
        # multiply annual births by years per period to get total births over period. Update official dataset's birth count
        data.loc[year,births] = d[births].array + d[fertility].array * y_per_period
        births = d.groupby(groups)[fertility].sum().reset_index(drop=True)
        # if this is baseline scenario and we're doing a mortality treatment or adding 1 person, track deaths as if we had treatment scenario mortality in a hypothetical
        # where the people whose deaths are averted do not enter the general population and do not have children. project their population's survivors separately.
        # this alows us to isolate the effect on deaths on population/life-years separately from indirect births
        if treat == '' and (mort_treatment == True or treat_type == 'add 1 person'):
          data = project_deaths_averted(data,d,births,year,year_max,age_max,groups,sex)
        # get number of people who survive from beginning to the end of the period. Increase deaths in the official dataset by population minus survivors
        d[pop] *= d[survival_ratio]
        data.loc[year,deaths] = d[deaths].array + (data.loc[year,pop].array - d[pop].array)
        # get number of people who both were born (annual births * years per period) and died this period. save in "deaths"
        # births that died = total births (annual births * years in each age-group) - survived births (annual births * life-years of age 0-5)
        LY = d.loc[0,life_years].array
        data.loc[(year,0),deaths] = data.loc[(year,0),deaths] + births.array * (y_per_period - LY)
        # next period's population is an array containing births (ages 0-5) plus survivors (now ages 5-100+)
        if year != year_max:
          # multiply annual births by life-years in 0-5 age-group to get number of births that survived to next period
          births *= LY
          # increase age by y_per_period since they're getting this much older each period
          survivors = d.query('age != @age_max')
          # add survivors from age 100+ to itself since some of them survived
          survivors.loc[age_max - y_per_period,pop] += d.loc[age_max,pop].array
          # update official dataset's population. births = youngest age-groups. survivors = all other age-groups
          data.loc[year + y_per_period,pop] = pd.concat([births,survivors[pop]],ignore_index=True).array
  return data.reset_index()

# add 1 person to the population at the given year and age-group(s) in all countries. split between males and females
def add_1_person(data,year):
  if year in data['year_treat'].unique():
    # only treat when age is age_treat and year is year_treat.
    treat = (data.index == data['age_treat'])
    d = data.loc[treat]
    # increase population at the specified age-group by the expected number of people added in that sex.
    # assume we add 1 person whose chance of being each sex is that sex's percent of total population.
    # also track this additional person in a separate col for each sex where we'll project just that person's survival forward assuming they have no children.
    total_pop = d['pop_f_t'] + d['pop_m_t']
    for sex in ['_f','_m']:
      additional_people = d['pop' + sex + '_t'] / total_pop
      for pop in ['pop' + sex + '_t', 'pop' + sex + '_saved']:
        data.loc[treat,pop] += additional_people
  data = data.reset_index()
  data['year'] = year
  return data.set_index(['year','age'])
    
# returns dataset where we track this period's deaths as if we had treatment scenario mortality in a hypothetical
# where the people whose deaths are averted do not enter the general population and do not have children. project their population's survivors separately.
# this alows us to isolate the effect on deaths on population/life-years separately from indirect births
def project_deaths_averted(data,d,births,year,year_max,age_max,groups,sex):
  pop, pop_saved, survival_ratio_t, survival_ratio_c, life_years_t, life_years_c = 'pop' + sex, 'pop' + sex + '_saved', 'survival_ratio' + sex + '_t', 'survival_ratio' + sex, 'life_years' + sex + '_t', 'life_years' + sex
  # get difference in survival ratios between treatment and baseline scenario and multiply by 
  # population in the control scenario to get the difference in survivors = deaths averted over the period
  d['deaths_averted'] = d[pop] * (d[survival_ratio_t] - d[survival_ratio_c])
  # ages that aren't supposed to be treat should have 0 deaths averted
  d.loc[d.index != d['age_treat'],'deaths_averted'] = 0
  # get deaths averted for children born during this period = births survived using treatment life-years minus births survived using control life years
  d0 = d.loc[0]
  births_saved = births * (d0[life_years_t].array - d0[life_years_c].array)
  if year != year_max:
    # get number of people who survive from beginning to the end of the period. We won't count the deaths averted here to avoid double-counting.
    d[pop_saved] *= d[survival_ratio_t]
    # add in deaths averted over this period to the population of people whose deaths were averted
    d[pop_saved] += d['deaths_averted']
    # increase age by y_per_period since they're getting this much older each period
    survivors = d.query('age != @age_max')
    # add survivors from age 100+ to itself since some of them survived
    survivors.loc[age_max - y_per_period,pop_saved] += d.loc[age_max,pop_saved].array
    # pop_saved = stock of people still alive next period whose deaths were averted by the treatment
    data.loc[year + y_per_period,pop_saved] = pd.concat([births_saved,survivors[pop_saved]],ignore_index=True).array
  # update deaths_averted in the official dataset for the CURRENT period. count both deaths averted 
  # for those already alive at beginning of period as well as those born during the period (births_saved, which will go into this period's 0-5 age-group)
  d.loc[0,'deaths_averted'] = d0['deaths_averted'].array + births_saved.array
  data.loc[year,'deaths_averted'] = data.loc[year,'deaths_averted'].array +  d['deaths_averted'].array
  return data

# combine sex-specific data into a single-sex
def combine_sexes(data,treat_control):
  for treat in treat_control:
    fertility,births,deaths,deaths_f,deaths_m,population,pop_f,pop_m,years_lived,life_years,life_years_f,life_years_m,p_survive_f,p_survive_m,mortality \
      = 'fertility'+treat,'births'+treat,'deaths'+treat,'deaths_f'+treat,'deaths_m'+treat,'population'+treat,'pop_f'+treat,'pop_m'+treat,'years_lived'+treat,'life_years'+treat,\
        'life_years_f'+treat,'life_years_m'+treat,'p_survive_f'+treat,'p_survive_m'+treat,'mortality'+treat
    # multiply that age-group's fertility by years per period to get average births for each age-interval
    data[fertility] *= y_per_period
    # convert all numerical columns to float so we can sum later
    for col in ['pop_f_saved','pop_m_saved','deaths_averted',fertility,births,deaths,deaths_f,deaths_m,population,pop_f,pop_m, \
      years_lived,life_years,life_years_f,life_years_m,p_survive_f,p_survive_m,mortality]:
      if col in data.columns:
        data[col] = data[col].astype(float)
    # combine sex data
    data[population] = data[pop_f] + data[pop_m]
    data['population_saved'] = data['pop_f_saved'] + data['pop_m_saved']
    data[deaths] = data[deaths_f] + data[deaths_m]
    # count years of life lived over each period by the whole population. 
    # Survivors live the total period # of years. Those who die live for 'dead_years' number of years
    YL_survivors = y_per_period * (data[population] - data[deaths])
    YL_dead = data[deaths_f] * data['dead_years_f'] + data[deaths_m] * data['dead_years_m']
    data[years_lived] = YL_survivors + YL_dead
    # get annual mortality rate per 1000 people and life-expectancy by dividing total deaths over the period by years per period * population at period's start
    # we cannot divide by 0, so only apply this to non-zero populations
    d = data.loc[data[population] > 0]
    data[life_years] = 0
    pct_f, pct_m = d[pop_f] / d[population], d[pop_m] / d[population]
    # combine life-years lived (nLx) and mortality rates for both sexes using percent of pop in each sex.
    data.loc[data[population] > 0, life_years] = pct_f * d[life_years_f] + pct_m * d[life_years_m]
    data[mortality] = pct_f * data[p_survive_f] + pct_m * data[p_survive_m]
    # calculate mortality for years after 2100 by re-arranging formula L = l * (5p + aq) to solve for 1 - "p". 
    # calculations below only accurate for youngest age-group.
    data['dead_years'] = pct_f * data['dead_years_f'] + pct_m * data['dead_years_m']
    data['new_mort'] = (data[life_years] - data['dead_years'])  / (y_per_period - data['dead_years'])
    data.loc[data['year'] > 2100, mortality] = data.loc[data['year'] > 2100, 'new_mort']
    data.drop(['dead_years','new_mort'],axis=1,inplace=True)
    data[mortality] = 1000 * (1 - data[mortality])
    data.loc[data[mortality].notna() == False, mortality] = 0
  return data

# combine age-specific data into a single-age
def combine_ages(data,groups,treat_control,keep_data):
  groups_2 = groups + ['year']
  data = data.sort_values(by=groups_2 + ['age']).reset_index(drop=True)
  categorical_cols = ['location','code','id','type','region','subregion','sdg_region','income_group','mortality_group']
  categorical = data[[c for c in categorical_cols if c in data.columns]].drop_duplicates()
  for treat in treat_control:
    # TBD
    fertile_females = data.query('fertility > 0')[groups_2 + ['age','pop_f'+treat]].rename({'pop_f'+treat:'fertile_females'+treat},axis=1)
    data = pd.merge(data,fertile_females,how='left')
  # TBD
  sum_cols = []
  for var in ['fertility','births','deaths','population','years_lived','life_years','deaths_averted','population_saved','fertile_females']:
    for treat in treat_control:
      if var + treat in data.columns:
        sum_cols += [var + treat]
  # aggregate age-groups into a single-year sum across all ages
  age_sum = data.groupby(groups_2)[sum_cols].sum()
  # get the youngest age's mortality rate in each treatment group
  for treat in treat_control:
    mortality_rates = data.query('age == 0').set_index(groups_2)['mortality'+treat]
    age_sum = age_sum.join(mortality_rates)
  age_sum['age'] = 'all'
  age_sum.reset_index(inplace=True)
  # either keep or drop age-specific data
  if 'age' in keep_data:
    data = data[[c for c in age_sum.columns if c in data.columns]]
    data = pd.concat([data,age_sum],ignore_index=True)
  else:
    data = age_sum
  for treat in treat_control:
    # create new column "life_exp" that gives life-expectancy at birth (sum of nLx across age-groups), separate from life_years
    data['life_exp'+treat] = data['life_years'+treat]
    # set some age-specific data to NaN since it won't be accurate or useful to calculate
    data.loc[data['age'] != 'all', ['mortality'+treat,'life_exp'+treat,'births'+treat]] = np.nan
    data.loc[data['age'] == 'all', 'life_years'+treat] = np.nan
  # merge each location's categorical data back onto age_sums
  return pd.merge(data,categorical,how='left')

# get difference between treated and untreated scenario columns. new differenced columns end in '_dif'
def get_scenario_difs(data):
  for col in ['births','deaths','years_lived','population','mortality','life_exp','fertility']:
    if col + '_t' in data.columns:
      data[col + '_dif'] = data[col + '_t'] - data[col]
  return data

# aggregate country-level data into regional, region, income-group, and global data
def combine_countries(data,treat_control,groups,treat_type,tfr_scenarios_t,keep_data):
  data = data.sort_values(by=groups + ['year','age'])
  countries = data.copy(deep=True)
  groups.remove('location')
  # drop columns that don't change based on the treatment type
  if '_t' in treat_control:
    if treat_type == 'change TFR':
      drop_cols = ['deaths_averted','population_saved','mortality_t']
    elif treat_type == 'add 1 person':
      drop_cols = ['deaths_averted'] + [c for c in data.columns if c[-2:] == '_t']
    elif treat_type in ['target','update','pp','%']:
      drop_cols = ['fertility_t']
    data.drop(drop_cols,axis=1,inplace=True)
  # get columns to aggregate. avg_cols will be averaged, sum_cols will be summed, 
  # and dif_cols will either be averaged or summed depending on treatment type
  avg_cols, dif_cols, sum_cols, all_cols = [], [], [], []
  data_cols = data.columns
  for var in ['births','deaths','deaths_averted','population_saved','years_lived','population','mortality','life_exp','life_years']:
    for suffix in ['','_t','_dif']:
      col = var + suffix
      if col in data_cols:
        all_cols += [col]
  for col in np.unique(all_cols):
    if '_%' in col or 'life_exp' in col or 'mortality' in col or 'life_years' in col:
      avg_cols += [col]
    elif col == 'deaths_averted' or col == 'population_born' or col == 'population_saved' or '_dif' in col:
      dif_cols += [col]
    else:
      sum_cols += [col]
  treat_avg, control_avg = [c for c in avg_cols if '_t' in c], [c for c in avg_cols if '_t' not in c]
  # generate a new aggregation for LMICs that didn't exist before
  countries.loc[countries['income_group'] != 'High-income countries','LMIC'] = 'Low- and middle-income countries'
  # aggregate countries by into regional/income categories
  if 'regions' in keep_data:
    region_cols = ['world','region','subregion','sdg_region','income_group','LMIC']
    if 'mortality_group' in data.columns:
      region_cols += ['mortality_group']
  else:
    region_cols = ['world']
  for region in region_cols:
    if region == 'world':
      groups_2 = groups + ['year','age']
    else:
      groups_2 = groups + [region,'year','age']
    # get sum across countries and merge with averages
    region_data = countries.groupby(groups_2)[sum_cols].sum()
    for treat in treat_control:
      # calculate average TFR in the region, weighted by population of females of reproductive age in each country
      fertile_females = weighted_avg(data = countries, values = 'fertility' + treat, weights = 'fertile_females' + treat, groups = groups_2).set_index(groups_2)
      region_data = region_data.join(fertile_females)
      # if we perfomed a treatment, take regional averages of treatment effects for 
      # treatment == "add 1 person" and take sums for everything else
      if treat == '_t':
        avg_cols = treat_avg
        if treat_type == 'add 1 person':
          diffs = weighted_avg(data = countries, values = dif_cols, weights = 'population_t', groups = groups_2).set_index(groups_2)
        else:
          diffs = countries.groupby(groups_2)[dif_cols].sum()
        region_data = region_data.join(diffs)
      else:
        avg_cols = control_avg
      # average the given demographic columns
      avgs = weighted_avg(data = countries, values = avg_cols, weights = 'population' + treat, groups = groups_2).set_index(groups_2)
      region_data = region_data.join(avgs)
    region_data.reset_index(inplace=True)
    # create / rename columns indicating the location and type
    region_data['type'] = region
    if region == 'world':
      region_data['location'] = 'World'
    else:
      region_data['location'] = region_data[region]
      if region == 'LMIC':
        region_data['type'] = 'income_group'
        region_data['income_group'] = 'Low- and middle-income countries'
        region_data.drop('LMIC',axis=1,inplace=True)
    # concatenate regional data with country-level data
    data = pd.concat([data,region_data],ignore_index=True)
  return data

# get cummulative sums of outcomes and percent-differences between treatment and control groups
def get_cum_pct_difs(data,treat_type,groups,keep_data):
  # get cumulative sum of each outcome across years
  if 'cum' in keep_data:
    variables, orig_cols, cum_cols = ['births','deaths','years_lived','deaths_averted'], [], []
    for treat in ['','_t','_dif']:
      orig_cols += [v + treat for v in variables if v + treat in data.columns]
      cum_cols += [v + '_cum' + treat for v in variables if v + treat in data.columns]
    data[cum_cols] = data.groupby(groups + ['location','age'])[orig_cols].apply(lambda col: col.cumsum())
  # get percent difference between treated and baseline scenario
  if 'dif' in keep_data and 'pct_dif' in keep_data and treat_type != 'add 1 person' and treat_type != False:
    for col in ['births','deaths','years_lived','population']:
      for cum in ['','_cum']:
        dif, pct_dif, treat, control = col + cum + '_dif', col + cum + '_%_dif', col + cum + '_t', col + cum
        if treat in data.columns:
          data[pct_dif] = 100 * data[dif] / data[control]
  # if added 1 person or treated mortality, add column indicating the population of additional births caused by the intervention who are still alive in each period
  if treat_type == 'add 1 person' or treat_type in ['converge','target','update','pp','%']:
    data['population_born'] = data['population_dif'] - data['population_saved']
  return data

# re-arrange columns, drop columns and sort rows
def arrange_columns(data,scenarios,keep_data):
  # replace 0 values with NaN
  data.replace(0, np.nan, inplace = True)
  for col in ['age','age_treat']:
    if col in data.columns:
      data[col].replace(np.nan, 0, inplace = True)
  # select columns to keep and in what order to arrange them
  numerical_cols, categorical_cols = [], scenarios + ['location','year','age','code','type','region','subregion','sdg_region','income_group']
  # remove categorical columns if instructed
  if 'categorical' not in keep_data:
    cols = [c for c in categorical_cols if c not in ['variant','code','type','region','subregion','sdg_region','income_group']]
    categorical_cols = cols
    scenarios.remove('variant')
  # re-arrange columns
  for suffix in ['','_cum','_t','_cum_t','_dif','_cum_dif','_%_dif','_cum_%_dif']:
    for var in ['fertility','mortality','life_exp','life_years','population','births','deaths','years_lived','deaths_averted','population_saved','population_born']:
      # drop columns that 1. are not in the dataframe, and 2. only contain NaN or 0 values (except for age_treat since that can be 0)
      col = var + suffix
      if col in data.columns and len(data[col].unique()) > 1:
        numerical_cols += [col]
  # only include child-mortality if we altered mortality
  if 'deaths_averted' not in numerical_cols:
    numerical_cols.remove('mortality')
  # round population, births, deaths and years lived to nearest whole number, unless taking the difference between worlds. Round everything else to nearest thousandth
  data[numerical_cols] = data[numerical_cols].apply(lambda col: col.round(0) if \
    (('population' in col or 'births' in col or 'deaths' in col or 'years_lived' in col) and '_dif' not in col) else col.round(3))
  return data[categorical_cols + numerical_cols].sort_values(by = scenarios + ['location','year','age']).reset_index(drop=True)

# Replicate Our Projections
Replicate the projections and figures in our paper on long-term fertility scenarios.

## Main Paper
Replicate the projections from the main portion of our paper. Figures are generated separately using these files run through a stata "do" file in this program's GitHub folder.

In [None]:
tfr_scenarios = [i / 10 for i in range(10,20,1)] + ['replacement',1.66,1.85]
year_end = 3000
main_output = project_the_universe(input_data,year_end=year_end,tfr_scenarios=tfr_scenarios,keep_data=['age','cum'])
main_output.query('location == "World"')[['tfr_scenario','year','age','fertility','life_exp','population','births','births_cum']].to_csv('main_output.csv', index=False)
rebound_output = project_the_universe(input_data,year_end=year_end,tfr_scenarios=tfr_scenarios,tfr_scenarios_t=['replacement'],start_converge_t=[2125,2150,2175],keep_data=['age'])
rebound_output.query('location == "World" & age == "all"')[['tfr_scenario','start_converge_t','year','population_t','births_t']].to_csv('rebound_output.csv', index=False)

In [None]:
tfr_scenarios = [i / 10 for i in range(10,20,1)] + ['replacement',1.66,1.85]
rebound_output = project_the_universe(input_data,year_end=year_end,tfr_scenarios=tfr_scenarios,life_exp_max=life_exp_max,tfr_scenarios_t=['replacement'],start_converge_t=[2125,2150,2175],keep_data=['age'])
rebound_output.query('location == "World" & age == "all"')[['tfr_scenario','start_converge_t','year','population_t','births_t']].to_csv('rebound_output_2.csv', index=False)

## Appendix
Replicate the figures from the appendix of our paper.

In [None]:
year_end = 3000
keep_data = []
tfr_scenarios = [i / 10 for i in range(10,20,1)]

# TFR converges to various asymptotic TFRs at different speeds from 2025-3000. 
print('main')
main = project_the_universe(input_data,year_end=year_end,tfr_scenarios=tfr_scenarios + ['replacement'],converge_speeds=['very slow','medium','very fast'],keep_data=keep_data)
main['figure'] = 'main'



In [None]:
# TFR "rebounds" to replacement after falling below various tfr_scenarios
print('rebound')
rebound = project_the_universe(input_data,year_end=2400,tfr_scenarios=tfr_scenarios,tfr_scenarios_t=['replacement'],start_converge_t=[2125,2150,2175],keep_data=keep_data)
rebound['figure'] = 'rebound'

In [None]:
# life-expectancy scenarios
print('life_exp')
life_exp = project_the_universe(input_data,year_end=year_end,tfr_scenarios=[1.2,1.66,1.8],life_exp_max=[100,120],keep_data=keep_data)
life_exp['figure'] = 'life_exp'

In [None]:
# age-specific mortality and fertility
print('ages')
ages_all = project_the_universe(input_data,year_end=2400,tfr_scenarios=[1.66],keep_data=['age'])
ages = ages_all.query('location.isin(["World","India","Nigeria","Republic of Korea"])')
ages['figure'] = 'ages'


# clean up and combine the data
# i.e. concatenate dataframes together, filter to only include global data for some dataframes, convert TFR scenarios to float, arrange columns, save as CSV
print('putting it together')
appendix_all = pd.concat([main,rebound,life_exp],ignore_index=True)
appendix = appendix_all.query('location == "World" & age == "all"')
appendix = pd.concat([appendix,ages],ignore_index=True)
appendix.loc[appendix['tfr_scenario'] == 'replacement', 'tfr_scenario'] = 2.05
appendix.tfr_scenario = appendix.tfr_scenario.astype(float)
appendix = appendix[['figure','life_exp_max','tfr_scenario', 'converge_speed', 'start_converge_t',
       'year', 'age', 'life_exp', 'life_years', 'fertility', 'fertility_t',
       'population', 'births', 'deaths', 'years_lived', 'population_t',
       'births_t', 'deaths_t', 'years_lived_t']]
appendix.to_csv('appendix_output.csv', index=False)

# Metadata
Below we describe what all the columns in the output dataset mean. Some columns are only included in the projections output data depending on what parameters are passed in (e.g. treatment type)

### Categorical Columns
**variant**: UN WPP variant (we use "Zero migration" by default)

**life_exp_max**: maximum life-expectancy the scenario can reach

**tfr_scenario**: asymptotic total fertility rate (TFR) which all countries converge to in the long-run

**start_converge**: year when TFR started converging to tfr_scenario (default to final year of WPP data)

**converge_speed**: per-period rate of change of TFR between time start_converge and when tfr_scenario is reached

**treatment**: indicates which treatment was performed (add 1 person, change mortality, or change fertility)

**year_treat**: year(s) when treatment occcurred

**age_treat**: age-groups that were treated

**location**: name of country or area (using UN names)

**year**: starting year of the period

**age**: starting age of the age-group. 'all' indicates a sum/average across all age-groups

**code**: 3-letter ISO alpha code by UN

**type**: what type of location this is ('country','region','subregion','sdg_region','income_group', or 'world')

**region**: UN region

**subregion**: UN subregion

**sdg_region**: Sustainable Development Goals (SDG) region

**income_group**: income status of the country according to UN (high, medium, low)

**mortality_group**: indicates whether or not the location was above or below the mortality target at the time of treatment starting in the target age-group (applies for mortality treatment only)

### Numerical Columns

**fertility**: births per woman

**p_survive (p)**: probability of someone at the start of the age-group surviving the next 5 years

**dead_years (a)**: average years lived by people who die within the age-group

**life_years (L)**: years someone would have lived over the age-group if they experienced the mortality rates of that period from birth until the end of the given age

**mortality**: deaths per 1000 births over the next 5 years

**life_exp (e0)**: life-expectancy at birth. the sum of life_years (L)

**population (or 'pop')**: people alive at start of the period

**births**: total births over the period

**deaths**: total deaths over the period

**years_lived**: total years of life lived oevr the period ( = roughly population * years per period)

**deaths_averted**: number of deaths that would have been prevented in the baseline scenario had the population been exposed to the treated scenario's mortality rates

**population_saved**: population of deaths_averted who are still alive at start of the period

**population_born**: population of additional people born due to the treatment who are still alive at start of the period


### Column Suffixes
**_f or _m**: female or male version of the variable

**_cum**: cumulative count of the variable across time within the location/scenario

**_t**: distinguishes variables in the 'treated' scenario from the 'untreated'/baseline scenario that don't have a '_t'

**_dif**: difference between treated and untreated scenario (e.g. births_dif = births_t - births)

**_%_dif**: difference between treated and untreated scenario (e.g. births_dif = (births_t - births) / births)

**_cum_dif**: difference between cumulative version of treated and untreated scenario

**_cum_%_dif**: percent difference between cumulative version of treated and untreated scenario
