In [None]:
#--(a) estimate the population weighted PM2.5 exposure for each city 
# author: Shiyu Deng
# date: January 2026
# email: shiyu.deng.23@ucl.ac.uk
# affiliation: University College London
# description: This script calculates the population-weighted PM2.5 exposure for each city in China using gridded population and PM2.5 data.

import os
import numpy as np
import xarray as xr
from shapely.geometry import mapping
import geopandas as gpd
import rioxarray
import logging
import pandas as pd

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

shp_file = 'shi_en.shp' # Path to the city shapefile
cities = gpd.read_file(shp_file) 

population_file_path = '/gridded_population/' # Path to the gridded population data
pm2_5_file_path = '/resampled_pm/' # Path to the resampled PM2.5 data

def process_population_data(input_tif, cities):
    try:
        ds = rioxarray.open_rasterio(input_tif)
        pop_data = ds.squeeze()


        pop_data = pop_data.fillna(0)
        pop_data = xr.where(pop_data == -9999, 0, pop_data)
        pop_data = pop_data.rio.set_spatial_dims(x_dim='x', y_dim='y', inplace=True)
        pop_data.rio.write_crs("EPSG:4326", inplace=True)

        logging.info(f"Population dataset bounds: {pop_data.rio.bounds()}")
        logging.info(f"Cities CRS: {cities.crs}")

        try:
            clipped_pop = pop_data.rio.clip(cities.geometry.apply(mapping), cities.crs, drop=True)
        except Exception as e:
            logging.error(f"Error clipping population data: {e}")
            return None

        return clipped_pop
    except Exception as e:
        logging.error(f"Error processing population data: {e}")
        return None

def process_pm25_data(input_tif, cities):
    try:
        ds = rioxarray.open_rasterio(input_tif)
        pm25_data = ds.squeeze()
        pm25_data = pm25_data.fillna(0)
        pm25_data = pm25_data.rio.set_spatial_dims(x_dim='x', y_dim='y', inplace=True)
        pm25_data.rio.write_crs("EPSG:4326", inplace=True)

        logging.info(f"PM2.5 dataset bounds: {pm25_data.rio.bounds()}")

        try:
            clipped_pm25 = pm25_data.rio.clip(cities.geometry.apply(mapping), cities.crs, drop=True)
        except Exception as e:
            logging.error(f"Error clipping PM2.5 data: {e}")
            return None

        return clipped_pm25
    except Exception as e:
        logging.error(f"Error processing PM2.5 data: {e}")
        return None

def calculate_weighted_pm25(pop_data, pm_data, cities):
    weighted_pm25 = []
    for city in cities.itertuples():
        city_geom = [city.geometry]
        city_id = city.city_code if 'city_code' in cities.columns else 'Unknown'
        city_name = city.English if 'English' in cities.columns else 'Unknown'
        logging.info(f"Processing city: {city_name}")
        logging.info(f"City bounds: {city.geometry.bounds}")

        # Check if city's geometry overlaps with the population data bounds
        city_bounds = gpd.GeoSeries(city.geometry).set_crs(cities.crs).total_bounds
        pop_bounds = pop_data.rio.bounds()
        
        if not (city_bounds[2] < pop_bounds[0] or city_bounds[0] > pop_bounds[2] or 
                city_bounds[3] < pop_bounds[1] or city_bounds[1] > pop_bounds[3]):
            try:
                city_pop = pop_data.rio.clip(city_geom, cities.crs, drop=True)
                city_pm25 = pm_data.rio.clip(city_geom, cities.crs, drop=True)
            except Exception as e:
                logging.error(f"Error clipping data for city {city_name}: {e}")
                continue

            city_pop = city_pop.fillna(0)
            city_pm25 = city_pm25.fillna(0)

            pop_array = city_pop.values
            pm25_array = city_pm25.values
            
            min_shape = (min(pop_array.shape[0], pm25_array.shape[0]), min(pop_array.shape[1], pm25_array.shape[1]))
            pop_array = pop_array[:min_shape[0], :min_shape[1]]
            pm25_array = pm25_array[:min_shape[0], :min_shape[1]]

            total_population = np.sum(pop_array)
            logging.info(f"Total population for {city_name}: {total_population}")

            if total_population > 0:
                pm25_weighted_sum = np.sum(pm25_array * pop_array)
                logging.info(f"PM2.5 weighted sum for {city_name}: {pm25_weighted_sum}")

                weighted_pm25_value = pm25_weighted_sum / total_population
            else:
                weighted_pm25_value = np.nan
            logging.info(f"PM2.5 pop_weighted for {city_name}: {weighted_pm25_value}")
            
            weighted_pm25.append({
                
                'city_id':city_id,
                'city':city_name,
                'weighted_pm25': weighted_pm25_value
            })
        else:
            logging.warning(f"City {city_name} is out of population data bounds")
            weighted_pm25.append({
               
                'city_id':city_id,
                'city': city_name,
                'weighted_pm25': np.nan
            })

    return weighted_pm25


0       Anqing
1       Bengbu
2       Bozhou
3      Chizhou
4      Chuzhou
        ...   
365       None
366       None
367       None
368       None
369       None
Name: English, Length: 370, dtype: object


In [None]:
air_scenarios = ['ref','cleanair','earlypeak', 'ontimepeak_CL', 'ontimepeak_NZ_CL']
fer_scenarios = ['high','mid','low']
RCP_scenarios = ['RCP2_6','RCP4_5','RCP8_5']

for year in range(2030, 2061, 10):
    for scenario1 in air_scenarios:
        pm25_file = os.path.join(pm2_5_file_path, f'{scenario1}_{year}.tif')
        if not os.path.exists(pm25_file):
            logging.error(f"PM2.5 file not found: {pm25_file}")
            continue

        pm_data = process_pm25_data(pm25_file, cities)
        if pm_data is None:
            logging.error(f"Failed to process PM2.5 data for {scenario1} {year}")
            continue
        
        for scenario2 in fer_scenarios:
            for scenario3 in RCP_scenarios:
                population_file = os.path.join(population_file_path, f'grid_pop_count{year}_{scenario2}_{scenario3}_new.tif') 
                if not os.path.exists(population_file):
                    logging.error(f"Population file not found: {population_file}")
                    continue

                pop_data = process_population_data(population_file, cities)
                if pop_data is not None:
                    weighted_pm25 = calculate_weighted_pm25(pop_data, pm_data, cities)
                    output_csv = f'/{scenario1}_{scenario2}_{scenario3}_{year}.csv'
                    pd.DataFrame(weighted_pm25).to_csv(output_csv, index=False)
                    logging.info(f"Successfully saved weighted PM2.5 data to {output_csv}")