In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import warnings

import sys
from pyprojroot import here
sys.path.append("../..") 
# from laos_gggi.statistics import get_distance_to_rivers

import pandas as pd
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import matplotlib.pyplot as plt
import arviz as az
import geopandas as gpd
import xarray as xr

from laos_gggi.replication_data import create_replication_data
from laos_gggi.plotting import plotting_function_damages 
from laos_gggi.model import add_data
from laos_gggi.plotting import configure_plot_style

from laos_gggi.sample import sample_or_load
from sklearn.preprocessing import StandardScaler as Standardize

from laos_gggi.data_functions import load_emdat_data, load_shapefile, load_rivers_data

In [3]:
# Set random seed
SEED = sum(list(map(ord, 'climate_bayes')))
rng = np.random.default_rng(SEED)

# Loading and preparing data 

In [4]:
# Select SEA countries
laos_neighboors = [
    "KHM",  # Cambodia
    "THA",  # Thailand
    "LAO",  # Laos
    "VNM",  # Vietnam
]

In [5]:
# Configure pytensir floats
floatX = pytensor.config.floatX


# Load idatas
event_idata = az.from_netcdf("model_closest_full_long.idata")
damage_idata = az.from_netcdf("damages_model.idata")

#Rename ISO dimension to avoid conflicts merging
damage_idata =damage_idata.rename({"ISO": "ISO_damage",
                                                   "country_effect": "country_effect_damage"})

#Merge posteriors
merged_posteriors = xr.merge([event_idata.sel(ISO = laos_neighboors).posterior,
                              damage_idata.posterior])

In [7]:
damage_idata

In [6]:
merged_posteriors.posterior

In [35]:
# Load maps
world = load_shapefile('world')

# Define maps
sea_map = world.query('ISO_A3 in @laos_neighboors')
laos_map = world.query('ISO_A3 == "LAO"')

### Events data

In [36]:
#Load data sets
sea_point_grid = pd.read_csv(here("data/sea_point_grid.csv"),index_col=0  ).rename(columns = {"ISO_A3": "ISO"})
laos_point_grid = pd.read_csv(here("data/laos_point_grid.csv"),index_col=0 ).rename(columns = {"ISO_A3": "ISO"})
sea_df = pd.read_csv(here("data/sea.csv"),index_col=0 )
lao_df = pd.read_csv(here("data/lao.csv"),index_col=0 )
sea_df_stand = pd.read_csv(here("data/sea_df_stand.csv"),index_col=0 )
lao_df_stand = pd.read_csv(here("data/lao_df_stand.csv"),index_col=0 )

predictions = pd.read_csv(here("data/climate_forecast.csv")).rename(columns = {"time": "Start_Year"})

In [37]:
# create predictions_ISO
predictions["ISO"] = "LAO"

# Select years of interest
prediction_years = ["2026-01-01", "2030-01-01", "2040-01-01", "2050-01-01", "2070-01-01"]
predictions_short = predictions.query('Start_Year in @prediction_years')

In [38]:
# create grid point number
laos_point_grid = laos_point_grid.reset_index().rename(columns = {"index": "point_number" })

# merge grid and predictions
pred_df = pd.merge(laos_point_grid ,predictions_short, left_on= "ISO", right_on= "ISO", how = "left" )

In [39]:
# Create population density
pred_df["population_density"] = pred_df["Population"] / 236800 # Laos area in squared km hardcoded

# Calculate logs
pred_df["log_population_density"] = np.log(pred_df["population_density"])
pred_df["log_gdp_per_cap"] = np.log(pred_df["gdp_per_cap"])

In [40]:
# Define features
event_features = ['lat', 'long', 'log_distance_to_river__standardized', "log_distance_to_coastline__standardized",
"log_distance_to_river__standardized_squared", "log_distance_to_coastline__standardized_squared",
 "Population__standardized", "co2__standardized", "precip_deviation__standardized", "dev_ocean_temp__standardized", 
'log_population_density__standardized','log_gdp_per_cap__standardized', 
"log_gdp_per_cap__standardized_squared", "log_population_density__standardized_squared" ]

distance_features = event_features[0:6]
time_varying_features = event_features[6:]
time_varying_features_base =[x[ :-14] for x in event_features[6:12]]

cols_to_stand = ['Population','log_distance_to_river', 'dev_ocean_temp',
                 'co2', 'log_population_density', 'log_gdp_per_cap',
       'precip_deviation', 'log_distance_to_coastline',]

cols_to_stand_stand = [x + "__standardized" for x in  cols_to_stand]

cols_not_stand = ['ISO' , 'Start_Year', 'lat', 'long', 'geometry', 'point_number']

In [41]:
#Train transformers
transformer_stand_ =  Standardize().fit(sea_df.query('is_disaster == 1')[cols_to_stand] )

In [42]:
# standardize pred_df
pred_df_stand = transformer_stand_.transform(pred_df[cols_to_stand] )
pred_df_stand = pd.DataFrame(pred_df_stand, columns =cols_to_stand_stand )
pred_df_stand = pd.merge(pred_df_stand, pred_df[cols_not_stand], left_index=True, right_index= True, how = "left" )

In [43]:
# Calculate logs squared
pred_df_stand["log_gdp_per_cap__standardized_squared"] = pred_df_stand["log_gdp_per_cap__standardized"] **2
pred_df_stand["log_population_density__standardized_squared"] = pred_df_stand["log_population_density__standardized"] **2

pred_df_stand["log_distance_to_river__standardized_squared"] = pred_df_stand["log_distance_to_river__standardized"] **2
pred_df_stand["log_distance_to_coastline__standardized_squared"] = pred_df_stand["log_distance_to_coastline__standardized"] **2


### Center Laos lat-long using SEA data

In [44]:
# Compute sea center
sea_center =  {}

# center Laos values
for x in ["lat", "long"]:
    sea_center[x] = (sea_df[x].max() + sea_df[x].min()) /  2
    pred_df_stand[x + "_centered"] = pred_df_stand[x] - sea_center[x]

In [45]:
### Create pred_df_stand_dict for several years
pred_df_stand_dict = {}

for year in pred_df_stand["Start_Year"].unique():
    pred_df_stand_dict[year] = pred_df_stand.query('Start_Year == @year')

# Set coords

In [46]:
# Load damage df
damage_df_stand = pd.read_csv(here("data/damage_df_stand.csv"),index_col =0 )

In [47]:
# gp features
gp_features = ["lat", "long"]

# damage features
damage_features = ['Population__standardized',
 'log_population_density__standardized',
 'log_population_density__standardized_squared',
 'log_gdp_per_cap__standardized',
 'log_gdp_per_cap__standardized_squared',
 'dev_ocean_temp__standardized',
 'co2__standardized',
 'precip_deviation__standardized']

# Event features
event_features = ['log_distance_to_river__standardized',
 'log_distance_to_coastline__standardized',
 'Population__standardized',
 'co2__standardized',
 'precip_deviation__standardized',
 'dev_ocean_temp__standardized',
 'log_population_density__standardized',
 'log_gdp_per_cap__standardized',
 'log_gdp_per_cap__standardized_squared',
 'log_population_density__standardized_squared',
 'log_distance_to_river__standardized_squared',
 'log_distance_to_coastline__standardized_squared']

In [48]:
# Define event coords
is_disaster_idx_events , is_disaster_events  = pd.factorize(sea_df_stand["is_disaster"])
ISO_idx_events , ISO_events  = pd.factorize(sea_df_stand["ISO"]) 
obs_idx_events  = sea_df_stand.index

#Creating idx
xr_idx_events  = xr.Coordinates.from_pandas_multiindex(sea_df_stand.set_index(['ISO', 'Start_Year']).index, 'obs_idx')

#Set coords
event_coords = {"is_disaster" : is_disaster_events,
        "obs_idx": obs_idx_events,
        "ISO": ISO_events,
        "feature": event_features,
        "gp_feature":gp_features }


In [49]:
# Define damage coords
ISO_idx_damage, ISO_damage =  pd.factorize(damage_df_stand["ISO"])
coords_damage = {
    'ISO':ISO_damage,
    'obs_idx':damage_df_stand.index,
    'feature': damage_features
        }    

xr_idx_damage = xr.Coordinates.from_pandas_multiindex(damage_df_stand.set_index(['ISO', 'year']).index, 'obs_idx')


In [50]:
# Prediction coords
obs_idx = pred_df_stand_dict["2026-01-01"]["point_number"]


gp_features = ["lat", "long"]


# Rebuild ISO_idx_sea
ISO_idx, ISO = pd.factorize(sea_df["ISO"]) 

ISO_to_idx = {name: idx for idx, name in enumerate(ISO)}
ISO_idx_laos =  pred_df_stand_dict["2026-01-01"].ISO.map(ISO_to_idx.get)

years = pred_df_stand["Start_Year"].unique()

# ISO damage
ISO_damage_idx, ISO_damage =  pd.factorize(damage_df_stand["ISO"])
ISO_to_idx_2 = {name: idx for idx, name in enumerate(ISO_damage)}
ISO_damage_idx_laos =  pred_df_stand_dict["2026-01-01"].ISO.map(ISO_to_idx_2.get)

# 


#Set coords_predictions
coords_predictions = {
    # "is_disaster" : is_disaster,
        "obs_idx": obs_idx,
        "ISO": ISO,
        "distance_features": distance_features,
        "time_varying_features" : time_varying_features,
        "gp_feature":gp_features,
        "ISO_damage": ISO_damage ,
        "damage_features": damage_features,
        "event_features": event_features,
        "year":years
                     }

# Compute the curves

In [51]:
# Compute center function
def compute_center(X):
    return (pt.max(X, axis=0) + pt.min(X, axis=0)).eval() / 2

In [52]:
from copy import deepcopy
from pymc.model.transform.optimization import freeze_dims_and_data

damage_curves_plot = {}

for year in years:
    
    with pm.Model(coords=coords_predictions) as damage_curves_plot[year]:
        ####################################Events model####################################
        # Set data
        event_features_data = add_data(features= event_features ,  target = None, df =  pred_df_stand_dict[year],
                                   dims=['obs_idx', 'features'])
        X_gp = pm.Data("X_gp",  pred_df_stand_dict[year][["lat_centered", "long_centered"]].astype(floatX), dims=['obs_idx', 'gp_feature'])
    
        #Flat variables
        country_effect = pm.Flat("country_effect", dims = ["ISO"])    
        beta = pm.Flat("beta", dims = ["event_features"])
        
        # HSGP component
        eta = pm.Flat("eta")
        ell = pm.Flat("ell", dims=["gp_feature"])
        cov_func = eta**2 * pm.gp.cov.Matern52(input_dim=2, ls=ell)
    
        m0, m1, c = 35, 35, 1.5
        gp = pm.gp.HSGP(m=[m0, m1], c=c, cov_func=cov_func)
        gp._X_center = compute_center( pred_df_stand_dict[year][["lat_centered", "long_centered"]].values.astype(floatX))
    
        phi, sqrt_psd = gp.prior_linearized(X=X_gp)
    
        basis_coeffs = pm.Flat("basis_coeffs", size=gp.n_basis_vectors)
    
        HSGP_component = pm.Deterministic('HSGP_component', phi @ (basis_coeffs * sqrt_psd), dims=['obs_idx'])
    
        # Event model components
        event_features_component = pm.Deterministic('event_features_component', 
                                                    event_features_data @ beta, dims=['obs_idx'])
        
        logit_p = pm.Deterministic('logit_p', country_effect[ISO_idx_laos] + event_features_component 
                                   + HSGP_component, dims=['obs_idx'])
        event_prob_y_hat = pm.Deterministic('event_prob_y_hat', pm.math.invlogit(logit_p), dims=['obs_idx'])
    
        ####################################Damages model####################################
        # Set data
        damage_x_data = pm.Data("damage_x_data",   pred_df_stand_dict[year][damage_features], dims=['obs_idx', 'damage_features'])
    
        # Set flats
        country_effect_damage = pm.Flat("country_effect_damage", dims = ["ISO_damage"])
        betas_damage = pm.Flat("betas_damage", dims = ["damage_features"])
        sigma_damage = pm.Flat("sigma_damage")
    
        #Damage model components
        mu =  country_effect_damage[ISO_damage_idx_laos] + damage_x_data @ betas_damage
       
        ln_damage_millions = pm.Normal("ln_damage_millions", mu = mu, sigma = sigma_damage , dims = ["obs_idx"])
    
        damage_millions = pm.Deterministic("damage_millions", pm.math.exp(ln_damage_millions), dims=['obs_idx'])
    
        # # #################################### Damage curves ####################################
        damages_curves = pm.Deterministic("damages_curves", damage_millions * event_prob_y_hat, dims=['obs_idx'])

In [None]:
# Sample predictions
idata_plot_point = {}
for year in years:
    with freeze_dims_and_data(damage_curves_plot[year] ):
        idata_plot_point[year] = pm.sample_posterior_predictive(merged_posteriors , extend_inferencedata=False, 
                                                          compile_kwargs={'mode':'JAX'},
                                                    var_names=[
                                                               'HSGP_component', 
                                                               'event_features_component', 
                                                               'logit_p', 
                                                               'event_prob_y_hat', 
                                                               "ln_damage_millions",
                                                                "damage_millions", 
                                                               "damages_curves"
                                                              ])

In [None]:
idata_plot_point["2026-01-01"].posterior_predictive["damages_curves"]