### This is for the Analysis for Bias Adjusted Precipitation data of CMIP6 GCMs

In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import torch
import pandas as pd
import seaborn as sns
import numpy as np
from mpl_toolkits.basemap import Basemap
import geopandas as gpd
from esda.moran import Moran
from libpysal.weights import KNN
from eval.metrics import *
from eval.plot import *
import data.valid_crd as valid_crd

In [None]:
# Define start and end date
clim = 'ensemble'
start_date = "1981-01-01"
end_date = "1990-12-31"
period = [1981, 1990]
degree = 2
testep = 50

cmip6_dir = '/pscratch/sd/k/kas7897/cmip6'
ref_path = '/pscratch/sd/k/kas7897/Livneh/unsplit/'
if clim ==  'ensemble':
    y_clim = 'access_cm2'
else:
    y_clim = clim
ds_sample = xr.open_dataset(f"{ref_path}/precipitation/{y_clim}/prec.1980.nc")
valid_coords = valid_crd.valid_lat_lon(ds_sample)

emph_quantile = 0.5
x = torch.load(f'/pscratch/sd/k/kas7897/diffDownscale/jobs/{clim}-livneh/QM_ANN_layers4_degree{degree}_quantile{emph_quantile}/all/1950_1980/{period[0]}_{period[1]}/x.pt', weights_only = False).to('cpu').squeeze(-1).numpy()
y = torch.load(f'/pscratch/sd/k/kas7897/diffDownscale/jobs/{clim}-livneh/QM_ANN_layers4_degree{degree}_quantile{emph_quantile}/all/1950_1980/{period[0]}_{period[1]}/y.pt', weights_only = False).to('cpu').squeeze(-1).numpy()
xt = torch.load(f'/pscratch/sd/k/kas7897/diffDownscale/jobs/{clim}-livneh/QM_ANN_layers4_degree{degree}_quantile{emph_quantile}/all/1950_1980/{period[0]}_{period[1]}/ep{testep}/xt.pt', weights_only = False)

time = torch.load(f'/pscratch/sd/k/kas7897/diffDownscale/jobs/{clim}-livneh/QM_ANN_layers4_degree{degree}_quantile{emph_quantile}/all/1950_1980/{period[0]}_{period[1]}/time.pt', weights_only = False)


loca = xr.open_dataset(f'{cmip6_dir}/{clim}/historical/precipitation/loca/coarse_USclip.nc')
loca = loca['pr'].sel(lat=xr.DataArray(valid_coords[:, 0], dims='points'),
                                    lon=xr.DataArray(valid_coords[:, 1], dims='points'),
                                    method='nearest')
loca = loca.sel(time =slice(f'{period[0]}', f'{period[1]}')).values

#unit conversion
loca = loca*86400

In [None]:
## this block filters 'y' based on 'x' calender

x_time_np = np.array([pd.Timestamp(str(t)) for t in time])
x_time_np = np.array([pd.Timestamp(t).replace(hour=0, minute=0, second=0) for t in x_time_np], dtype='datetime64[D]')
# Generate a daily time array following the standard Gregorian calendar
y_time = pd.date_range(start=start_date, end=end_date, freq="D")

# Convert to NumPy array for indexing and comparison
y_time_np = y_time.to_numpy()

# Find indices where observed time matches model time
matched_indices = np.where(np.isin(y_time_np, x_time_np))[0]

y = y[matched_indices,:]

### Season Filtering

In [None]:
y_s = load_seasonal_data(x_time_np, y)
x_s = load_seasonal_data(x_time_np, x)
xt_s = load_seasonal_data(x_time_np, xt)

loca_s = load_seasonal_data(x_time_np, loca)

t_s = load_seasonal_data(x_time_np, x_time_np)

In [None]:

# Initialize climate indices manager
climate_indices = ClimateIndices()

day_bias_percentages = get_day_bias_percentages(x, y, xt, climate_indices)
mean_bias_percentages = get_mean_bias_percentages(x, y, xt, x_time_np, climate_indices)

loca_day_bias_percentages = get_day_bias_percentages(x, y, loca, climate_indices)
loca_mean_bias_percentages = get_mean_bias_percentages(x, y, loca, x_time_np, climate_indices)


## Arranging delCLIMAD and LOCA in one dictionary
for key in day_bias_percentages.keys():
    day_bias_percentages[key] = day_bias_percentages[key] + (loca_day_bias_percentages[key][1],)


for key in mean_bias_percentages.keys():
    mean_bias_percentages[key] = mean_bias_percentages[key] + (loca_mean_bias_percentages[key][1],)


## Temporal Analysis

In [None]:
keys = ['SDII (Monthly)','CDD (Yearly)', 'CWD (Yearly)', "Rx1day", "Rx5day", "R10mm",  "R20mm", "R95pTOT", "R99pTOT"]
d = dict(filter(lambda item: item[0] in keys , mean_bias_percentages.items()))

keys = ["Dry Days", "Wet Days >1mm", "Very Wet Days >10mm", "Very Very Wet Days >20mm"]
d4 = dict(filter(lambda item: item[0] in keys , day_bias_percentages.items()))

# Create a 2x2 subplot figure
fig, axes = plt.subplots(2, 1, figsize=(24, 12), sharey=True)

# Ensure axes is flattened for easy indexing
axes = axes.flatten()
method_names = [f"delCLIMAD-BA(degree{degree}_quantile{emph_quantile})", "LOCA"]


# Call the function for each dataset
plot_violin_bias(axes[0], d, "Bias(%)", "Mean Bias(%) for Different Precipitation Indices",  method_names=method_names, remove_outlier=True)

plot_violin_bias(axes[1], d4, "Bias(%)", "Day Bias(%) for Different Precipitation Indices",  method_names=method_names, remove_outlier=True)
fig.suptitle(f'{clim}', fontsize=20, fontweight="bold", y=1.02)


plt.tight_layout()
plt.show()

## Spatial Analysis

In [None]:
threshold_types = ["Dry Days", "Wet Days >1mm", "Very Wet Days >10mm", "Very Very Wet Days >20mm"]
method_names = ["delCLIMAD-BA", "LOCA"]

plot_spatial_bias(valid_coords=valid_coords,
                         bias_data_dict=day_bias_percentages,
                         threshold_types=threshold_types,
                         label="Bias Days %",
                         method_names=method_names,
                         vmin=-100,
                         vmax=100,
                         cmap="coolwarm")

In [None]:
threshold_types = ['SDII (Monthly)','CDD (Yearly)', 'CWD (Yearly)', "Rx1day", "Rx5day", "R10mm",  "R20mm", "R95pTOT", "R99pTOT"]
method_names = ["delCLIMAD-BA", "LOCA"]

plot_spatial_bias(valid_coords=valid_coords,
                         bias_data_dict=mean_bias_percentages,
                         threshold_types=threshold_types,
                         label="Mean Bias %",
                         method_names=method_names,
                         vmin=-100,
                         vmax=100,
                         cmap="coolwarm")

In [None]:
# morans_i, p_value = compute_morans_i(mean_biases['Very Wet Days >10mm'][1], valid_coords)
# plot_moran_scatter(mean_biases['Very Wet Days >10mm'][1], valid_coords)


## Seasonal Analysis

In [None]:
for season in ['Spring', 'Winter', 'Autumn', 'Summer']:
    x_s_s = x_s[season]
    y_s_s = y_s[season]
    xt_s_s = xt_s[season]
    loca_s_s = loca_s[season]
    t_s_s = t_s[season]

    seasonal_day_bias_percentages = get_day_bias_percentages(x_s_s, y_s_s, xt_s_s, climate_indices)
    seasonal_mean_bias_percentages = get_mean_bias_percentages(x_s_s, y_s_s, xt_s_s, t_s_s, climate_indices)
    seasonal_loca_day_bias_percentages = get_day_bias_percentages(x_s_s, y_s_s, loca_s_s, climate_indices)
    seasonal_loca_mean_bias_percentages = get_mean_bias_percentages(x_s_s, y_s_s, loca_s_s, t_s_s, climate_indices)

    ## Arranging delCLIMAD and LOCA in one dictionary
    for key in day_bias_percentages.keys():
        seasonal_day_bias_percentages[key] = seasonal_day_bias_percentages[key] + (seasonal_loca_day_bias_percentages[key][1],)


    for key in mean_bias_percentages.keys():
        seasonal_mean_bias_percentages[key] = seasonal_mean_bias_percentages[key] + (seasonal_loca_mean_bias_percentages[key][1],)



    keys = ['SDII (Monthly)','CDD (Yearly)', 'CWD (Yearly)', "Rx1day", "Rx5day", "R10mm",  "R20mm", "R95pTOT", "R99pTOT"]
    d = dict(filter(lambda item: item[0] in keys , seasonal_mean_bias_percentages.items()))

    keys = ["Dry Days", "Wet Days >1mm", "Very Wet Days >10mm", "Very Very Wet Days >20mm"]
    d2 = dict(filter(lambda item: item[0] in keys , seasonal_day_bias_percentages.items()))

    # Create a 2x2 subplot figure
    fig, axes = plt.subplots(2, 1, figsize=(24, 12), sharey=True)

    # Ensure axes is flattened for easy indexing
    axes = axes.flatten()
    method_names = [f"delCLIMAD-BA(degree{degree}_quantile{emph_quantile})", "LOCA"]


    # Call the function for each dataset
    plot_violin_bias(axes[0], d, "Bias(%)", f"Mean Bias(%) for Different Precipitation Indices",  method_names=method_names, remove_outlier=True)

    plot_violin_bias(axes[1], d2, "Bias(%)", f"Day Bias(%) for Different Precipitation Indices",  method_names=method_names, remove_outlier=True)
    fig.suptitle(f'{clim}_{season}', fontsize=20, fontweight="bold", y=1.02)


    plt.tight_layout()
    plt.show()

In [None]:
# Spring
season = 'Spring'
x_s_s = x_s[season]
y_s_s = y_s[season]
xt_s_s = xt_s[season]
loca_s_s = loca_s[season]
t_s_s = t_s[season]

seasonal_day_bias_percentages = get_day_bias_percentages(x_s_s, y_s_s, xt_s_s, climate_indices)
seasonal_mean_bias_percentages = get_mean_bias_percentages(x_s_s, y_s_s, xt_s_s, t_s_s, climate_indices)
seasonal_loca_day_bias_percentages = get_day_bias_percentages(x_s_s, y_s_s, loca_s_s, climate_indices)
seasonal_loca_mean_bias_percentages = get_mean_bias_percentages(x_s_s, y_s_s, loca_s_s, t_s_s, climate_indices)


## Arranging delCLIMAD and LOCA in one dictionary
for key in day_bias_percentages.keys():
    seasonal_day_bias_percentages[key] = seasonal_day_bias_percentages[key] + (seasonal_loca_day_bias_percentages[key][1],)


for key in mean_bias_percentages.keys():
    seasonal_mean_bias_percentages[key] = seasonal_mean_bias_percentages[key] + (seasonal_loca_mean_bias_percentages[key][1],)



keys = ['SDII (Monthly)','CDD (Yearly)', 'CWD (Yearly)', "Rx1day", "Rx5day", "R10mm",  "R20mm", "R95pTOT", "R99pTOT"]
d = dict(filter(lambda item: item[0] in keys , seasonal_mean_bias_percentages.items()))

keys = ["Dry Days", "Wet Days >1mm", "Very Wet Days >10mm", "Very Very Wet Days >20mm"]
d2 = dict(filter(lambda item: item[0] in keys , seasonal_day_bias_percentages.items()))

# Create a 2x2 subplot figure
fig, axes = plt.subplots(2, 1, figsize=(24, 12), sharey=True)

# Ensure axes is flattened for easy indexing
axes = axes.flatten()
method_names = [f"delCLIMAD-BA(degree{degree}_quantile{emph_quantile})", "LOCA"]


# Call the function for each dataset
plot_violin_bias(axes[0], d, "Bias(%)", f"Mean Bias(%) for Different Precipitation Indices",  method_names=method_names, remove_outlier=True)

plot_violin_bias(axes[1], d2, "Bias(%)", f"Day Bias(%) for Different Precipitation Indices",  method_names=method_names, remove_outlier=True)
fig.suptitle(f'{clim}_{season}', fontsize=20, fontweight="bold", y=1.02)


plt.tight_layout()
plt.show()