In [1]:
import logging
import sys
from pathlib import Path
import itertools

import click
import numpy as np
import pandas as pd
import rasterra as rt
from rra_tools import jobmon
from scipy.special import expit
import xarray as xr
import tqdm
import rasterra as rt
from rasterio.features import rasterize
import geopandas as gpd

from spatial_temp_cgf import paths
from spatial_temp_cgf.training_data_prep import income_funcs
from spatial_temp_cgf.training_data_prep.location_mapping import load_fhs_lsae_mapping
from spatial_temp_cgf import cli_options as clio
from spatial_temp_cgf.data import DEFAULT_ROOT, ClimateMalnutritionData

In [2]:
RASTER_TEMPLATE_PATH = Path('/mnt/team/rapidresponse/pub/population/data/01-raw-data/other-gridded-pop-projects/global-human-settlement-layer/1km_template.tif')
SHAPE_PATH = Path('/mnt/team/rapidresponse/pub/population/data/02-processed-data/ihme/lbd_admin2.parquet')
LDIPC_FILEPATH = Path('/share/resource_tracking/forecasting/poverty/GK_2024_income_distribution_forecasts/income_forecasting_through2100_admin2_final_nocoviddummy_intshift/admin2_ldipc_estimates.csv')

In [16]:
output_dir = Path(DEFAULT_ROOT)
measure = 'stunting'
model_version = "2024_07_01.04"
cmip6_scenario = 'ssp126'
year = 2025

In [4]:
cm_data = ClimateMalnutritionData(output_dir / measure)

In [5]:
raster_template = rt.load_raster(RASTER_TEMPLATE_PATH)
a2 = gpd.read_parquet(SHAPE_PATH)

In [6]:
ldi = pd.read_csv(LDIPC_FILEPATH)
national_mean = ldi.groupby(['year_id', 'national_ihme_loc_id', 'population_percentile']).ldipc.transform('mean')
null_mask = ldi.ldipc.isnull()
ldi.loc[null_mask, 'ldipc'] = national_mean.loc[null_mask]
ldi['ldi_pc_pd'] = ldi['ldipc'] / 365.25
ldi = ldi.groupby(['year_id', 'location_id']).ldi_pc_pd.mean().reset_index()
polys = a2.loc[a2.loc_id.isin(ldi.location_id.unique()), ['loc_id', 'geometry']].rename(columns={'loc_id': 'location_id'}).set_index('location_id').geometry

year_ldi = ldi[ldi.year_id == 2000].set_index('location_id').ldi_pc_pd
shapes = [(t.geometry, t.ldi_pc_pd) for t in pd.concat([year_ldi, polys.sort_index()], axis=1).itertuples()]
arr = rasterize(
    shapes, 
    out=np.zeros_like(raster_template), 
    transform=raster_template.transform, 
)
r = rt.RasterArray(arr, transform=raster_template.transform, crs=raster_template.crs, no_data_value=np.nan)

In [21]:
models = cm_data.load_model_family(model_version)

In [22]:
m = models[0]

In [23]:
m['model'].var_info

{'intercept': {'transformer': <spatial_temp_cgf.scaling.Scaler at 0x7f8f84772360>},
 'mean_temperature': {'transformer': <spatial_temp_cgf.binning.Binner at 0x7f8f37d92420>},
 'ldi_pc_pd': {'transformer': <spatial_temp_cgf.scaling.Scaler at 0x7f8f3df5a390>}}

In [None]:
t = m['model'].var_info['ldi_pc_pd']['transformer']
t._strategy.n_features_in_ = r.shape[1]

In [13]:
r_trans = rt.RasterArray(
    t(r),
    transform=r.transform,
    crs=r.crs,
    no_data_value=r.no_data_value,
)

In [54]:
import xarray as xr
subfolder = cmip6_scenario if year >= 2024 else 'historical'
path = paths.CLIMATE_PROJECTIONS_ROOT / subfolder / "mean_temperature" / f"{year}.nc"
climate_ds = xr.open_dataset(path)
climate_raster = xarray_to_raster(climate_ds.sel(year=year)['value'], nodata=np.nan).resample_to(raster_template)

RasterArray
dimensions    : 36000, 18000 (x, y)
resolution    : 0.01, -0.01 (x, y)
extent        : -180.0, 180.0, -90.0, 90.0 (xmin, xmax, ymin, ymax)
crs           : EPSG:4326
no_data_value : nan
size          : 4943.85 MB
dtype         : float64

In [50]:
def xarray_to_raster(ds: xr.DataArray, nodata: float | int) -> rt.RasterArray:
    from affine import Affine
    """Convert an xarray DataArray to a RasterArray."""
    lat, lon = ds["latitude"].data, ds["longitude"].data

    dlat = (lat[1:] - lat[:-1]).mean()
    dlon = (lon[1:] - lon[:-1]).mean()

    transform = Affine(
        a=dlon,
        b=0.0,
        c=lon[0],
        d=0.0,
        e=-dlat,
        f=lat[-1],
    )
    raster = rt.RasterArray(
        data=ds.data,
        transform=transform,
        crs="EPSG:4326",
        no_data_value=nodata,
    )
    return raster