## Convert from deaths calculated for each grid to deaths by country

In [1]:
import geopandas as gpd
import xarray as xr
from cartopy import crs as ccrs
import seaborn as sns; sns.set_theme()
import os
import fiona
import country_converter as coco
import dask
import dask.array as da
import netCDF4 as nc
import regionmask
from matplotlib import cm
import numpy as np
from matplotlib import pyplot as plt
import country_converter as coco
import pyogrio
#pyogrio.set_gdal_config_options({"SHAPE_RESTORE_SHX": "YES"})
import pandas as pd
from cartopy.util import add_cyclic_point
import nc_time_axis
import glob
import cdo
import pandas as pd
import cartopy.feature as cfeature


In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
def global_mean_xarray(ds_XXLL):
    """ 
    Compute the global mean value of the data.
    The data has to have the lat and lon in its dimensions.
    Should not include NaN in Inputs.
    
    Parameters
    ----------
    ds_XXLL   : xarray with lat and lon. ds_XXLL.lat will be 
                used for area weight.

    Returns
    ----------
    tmp_XX    : xarray without lat and lon.
    
    """
    lat = ds_XXLL.coords['lat']        # readin lat
    # global mean
    # compute cos(lat) as a weight function
    weight_lat = np.cos(np.deg2rad(lat))/np.mean(np.cos(np.deg2rad(lat)))
    tmp_XXL = ds_XXLL.mean(dim=['lon'])*weight_lat
    tmp_XX  = tmp_XXL.mean(dim=['lat'])
    return tmp_XX

def weighted_temporal_mean_l(ds, var=None):
    """
    weight by days in each month
    """
    #ds = xr.decode_cf(ds)
    # Determine the month length
    month_length = ds.time.dt.days_in_month

    # Calculate the weights
    wgts = month_length.groupby("time.year") / month_length.groupby("time.year").sum()

    # Make sure the weights in each year add up to 1
    np.testing.assert_allclose(wgts.groupby("time.year").sum(xr.ALL_DIMS), 1.0)

    # Subset our dataset for our variable
    obs = ds if var is None else ds[var]

    # Setup our masking for nan values
    cond = obs.isnull()
    ones = xr.where(cond, 0.0, 1.0)

    # Calculate the numerator
    obs_sum = (obs * wgts).resample(time="AS").sum(dim="time")

    # Calculate the denominator
    ones_out = (ones * wgts).resample(time="AS").sum(dim="time")

    # Return the weighted average
    return obs_sum / ones_out

In [4]:
# Gridded Population of the World, Version 4 (GPWv4): National Identifier Grid, Revision 11
# can be downloaded from: https://www.earthdata.nasa.gov/data/catalog/sedac-ciesin-sedac-gpwv4-natiden-r11-4.11#:~:text=Description,use%20in%20aggregating%20population%20data.
# cited as: Center For International Earth Science Information Network-CIESIN-Columbia University. (2018). Gridded Population of the World, Version 4 (GPWv4): National Identifier Grid, Revision 11 (Version 4.11) [Data set]. Palisades, NY: NASA Socioeconomic Data and Applications Center (SEDAC). https://doi.org/10.7927/H4TD9VDP
shapefile_path = "./gpw-v4-national-identifier-grid-rev11_15_min_asc/gpw_v4_national_identifier_grid_rev11_15_min.shp"

# Read the shapefile
gdf = gpd.read_file(shapefile_path)
#gdf_list = pd.read_csv(f'./Mortality_data/country_borders/gpw_new.csv')

In [6]:
mortality = pd.read_csv(f'./IFs/mortality_all_new.csv')

#mortality = xr.open_mfdataset(f'/glade/derecho/scratch/cindywang625/BMR/*.nc', parallel=True, combine='nested', concat_dim='time')['mortality']
mortality_cardio_female = xr.open_mfdataset(f'./bmr/Cardio/female/*.nc', parallel=True, combine='nested', concat_dim='time')['mortality']
mortality_cardio_male = xr.open_mfdataset(f'./bmr/Cardio/male/*.nc', parallel=True, combine='nested', concat_dim='time')['mortality']

mortality_resp_female = xr.open_mfdataset(f'./bmr/Resp/female/*.nc', parallel=True, combine='nested', concat_dim='time')['mortality']
mortality_resp_male = xr.open_mfdataset(f'./bmr/Resp/male/*.nc', parallel=True, combine='nested', concat_dim='time')['mortality']

#mortality_cardio = xr.open_mfdataset(f'/glade/derecho/scratch/cindywang625/BMR/for_o3/cardio/*.nc', parallel=True, combine='nested', concat_dim='time')['mortality']
#mortality_resp = xr.open_mfdataset(f'/glade/derecho/scratch/cindywang625/BMR/for_o3/resp/*.nc', parallel=True, combine='nested', concat_dim='time')['mortality']

#mortality_og = pd.read_csv(f'./Mortality_data/IFs/mortality_all_new.csv')

#mortality per thousand
demographic_female = xr.open_mfdataset(f'./demographic/countries/female_2020_demo_frac.nc')['demo_pop']
demographic_male = xr.open_mfdataset(f'./demographic/countries/male_2020_demo_frac.nc')['demo_pop']

#demographic = demographic.sum("age")

#used to make a mask
population = xr.open_dataset(f'./population/SSP2/Total/NetCDF/ssp2_2020.nc')


In [7]:
#get country mask
lon = population['lon']
lat = population['lat']
print(lat.shape)
mask = regionmask.mask_3D_geopandas(gdf,lon,lat)

(1117,)


In [50]:
mortality_cardio['time'] = deaths_val['time'].values
mortality_resp['time'] = deaths_val['time'].values

### Calculating health impacts associated with air pollution (PM2.5)

Calculate mortalities attributed to PM2.5 from noncommunicable diseases and lower respiratory infections (combined) based on epidemiological cohort studies of long-term exposure to PM2.5

\begin{equation}
M = BMR \times P \times AF
\end{equation}

- M = mortality due to PM2.5
- BMR = baseline mortality rate
- P = exposed population
- AF = attributable factor

In [36]:
#ensemble = '003'
ensemble = '010'
deaths_val = xr.open_mfdataset(f'./deaths_ozone/arise-sai-1.0/by_grid/{ensemble}.nc')['__xarray_dataarray_variable__']

#arise_2035
#arise_ssp
#ssp_2035

In [37]:
mortality_value = xr.DataArray(
    np.full((35, len(gdf['ISOCODE'])), np.nan),
    dims=["time", "country"],
    coords={"time": deaths_val.time, "country": gdf['ISOCODE']}
)

In [38]:
for j in range(35):
    mortality_years = np.zeros((240))
    print(f"\rCurrent j: {j}", end="", flush=True)
    #year = 2035 + j
    for i, country_code in enumerate(gdf['ISOCODE']):
        #print(f"\rCurrent i: {i}", end="", flush=True)
        if country_code not in ((set(mortality['ISO3']))):
            mortality_years[i] = np.nan
        else:
            mortality_years[i] = deaths_val[j].where(mask[i] != 0 ).sum(dim=["lat", "lon"], skipna=True)
    mortality_value[j,:] = mortality_years

Current j: 34

In [39]:
mortality_value.to_netcdf(f'./deaths_ozone/arise-sai-1.0/by_country/{ensemble}.nc')



In [None]:
iso_set = set(mortality['ISO3'])  # Cache once

for j in range(35):
    print(f"\rCurrent j: {j}", end="", flush=True)
    
    # Preallocate with NaNs directly
    mortality_years = np.full(240, np.nan, dtype=np.float32)

    for i, country_code in enumerate(gdf['ISOCODE']):
        if country_code in iso_set:
            mortality_years[i] = deaths_val[j].where(mask[i] != 0).sum(dim=["lat", "lon"], skipna=True)

    mortality_value[j, :] = mortality_years
