# Explore CWatM data

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

from pathlib import Path

from tqdm.notebook import tqdm
import pandas as pd
import xarray as xr

import plotly.express as px

import src.data.cwatm_data as cwatm_data

In [None]:
PROCESSED_DATA_FOLDER_PATH = Path("../data/processed")

## Load CWatM data

In [None]:
all_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all.nc")).to_dataframe()
forcings_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings.nc")).to_dataframe()
outputs_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs.nc")).to_dataframe()

In [None]:
all_df

In [None]:
forcings_df

In [None]:
outputs_df

### Process the data

In [None]:
all_land_df = cwatm_data.process_inputs_df(all_df)
all_land_df

In [None]:
forcings_land_df = forcings_df.loc[all_land_df.index]
forcings_land_df

In [None]:
outputs_land_df = outputs_df.loc[all_land_df.index]
outputs_land_df

In [None]:
all_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all_land.parquet"))
forcings_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings_land.parquet"))
outputs_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs_land.parquet"))

## Load CWatM `_land` data

In [None]:
all_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all_land.parquet"))
forcings_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings_land.parquet"))
outputs_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs_land.parquet"))

In [None]:
all_land_df.describe()

In [None]:
forcings_land_df.describe()

In [None]:
outputs_land_df.describe()

## Visualize data

In [None]:
data_df = pd.concat((all_land_df, forcings_land_df, outputs_land_df), axis=1)

In [None]:
INPUTS_COLUMNS = list(all_land_df.columns)
#     "porosity",
#     "firstStorDepth",
#     "percolationImp",
#     "tanslope",
#     "maxRootDepth_forest",
#     "maxRootDepth_grassland"
# ]
FORCINGS_COLUMNS = list(forcings_land_df.columns)
#     "pr",
#     "tas",
#     "tasmax",
#     "tasmin",
#     "ps",
#     "rlds",
#     "rsds",
#     "sfcwind",
#     "hurs",
#     "huss",
# ]
OUTPUTS_COLUMNS = list(outputs_land_df.columns)
#     "evap-total",
#     "potevap",
#     "qr",
#     "qtot"
# ]

# data_df = data_df.iloc[:1000]

data_df.shape

In [None]:
from itertools import product
from tqdm import tqdm

import matplotlib.pyplot as plt
from matplotlib.markers import MarkerStyle

def display_individual_scatterplots(df: pd.DataFrame,
                                    dst_path: Path,
                                    valid_x,
                                    valid_y,
                                    regions_df = None,
                                    regions_2x2 = True
                                    ):

    marker_style = MarkerStyle(marker=".",
                               fillstyle="full")

    combinations = product(valid_x, valid_y)

    for input_col, output_col in tqdm(list(combinations), desc="Computing input-output combinations"):

        # Assuming the DataFrame has 'x' and 'y' columns for the scatter plot
        if regions_df is None:

            fig, axis = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
            fig.suptitle(f"Input '{input_col}' - Output '{output_col}'")

            axis.scatter(x=df[input_col], y=df[output_col],
                         marker=marker_style,
                         edgecolor="none",
                         s=30,
                         alpha=0.25,
                         )
            
            axis.set_xlabel(input_col)
            axis.set_ylabel(output_col)
        else:
            
            palette = {"wet warm": "#018571",
                    "dry warm": "#a6611a",
                    "wet cold": "#80cdc1",
                    "dry cold": "#dfc27d"}
            
            regions = regions_df["region"].unique()

            if regions_2x2:
                fig, axis = plt.subplots(nrows=2, ncols=2, figsize=(8, 6),
                                        sharex=True, sharey=True, constrained_layout=False)
                axis = axis.flatten()
                fig.suptitle(f"Input '{input_col}' - Output '{output_col}'")

                for i, region in enumerate(regions):

                    # ensure regions_df has same indexes as data_df
                    region_indices = regions_df[regions_df["region"] == region].index
                    region_indices = set(region_indices).intersection(data_df.index)

                    region_data_df = data_df.loc[list(region_indices)]
                
                    axis[i].scatter(x=region_data_df[input_col], y=region_data_df[output_col],
                                    label=region,
                                    c=palette[region],
                                    marker=marker_style,
                                    s=10,
                                    edgecolor="none",
                                    alpha=0.25,
                                    )
                    
                    
                    axis[i].set_xlabel(input_col)
                    axis[i].set_ylabel(output_col)
                    axis[i].label_outer()
            else:
                fig, axis = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
                fig.suptitle(f"Input '{input_col}' - Output '{output_col}'")

                for i, region in enumerate(regions):

                    # ensure regions_df has same indexes as data_df
                    region_indices = regions_df[regions_df["region"] == region].index
                    region_indices = set(region_indices).intersection(data_df.index)

                    region_data_df = data_df.loc[list(region_indices)]
                
                    axis.scatter(x=region_data_df[input_col], y=region_data_df[output_col],
                                 label=region,
                                 c=palette[region],
                                 marker=marker_style,
                                 s=10,
                                 edgecolor="none",
                                 alpha=0.25,
                                 )
                    
                    region_data_df.sort_values(input_col, ascending=True, inplace=True)

                    import matplotlib.patheffects as mpe
                    outline = mpe.withStroke(linewidth=4, foreground='white')

                    axis.plot(region_data_df[input_col], region_data_df[output_col].rolling(window=3000,
                                                                                            # win_type="gaussian",
                                                                                            center=True,
                                                                                            ).mean(
                                                                                                # std=2000
                                                                                            ),
                              c=palette[region],
                              path_effects=[outline],
                              label=f"_{region}"
                              )
                
                axis.set_xlabel(input_col)
                axis.set_ylabel(output_col)
                axis.label_outer()

        # Adjust layout and display the plots
        fig.tight_layout()
        if regions_df is not None:
            fig.subplots_adjust(bottom=0.13)
            legend = fig.legend(#labels=regions,
                                markerscale=3,
                                loc="lower center",
                                ncol=4)
            
            for legobj in legend.legend_handles:
                legobj.set_alpha(1)

        fig.savefig(dst_path.joinpath(f"{input_col}_{output_col}.png"), dpi=300)

        plt.close()


### Global

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS
                                )

### Gnann regions

In [None]:
RAW_DATA_FOLDER_PATH = Path("../data/raw")

domains_df = pd.read_csv(RAW_DATA_FOLDER_PATH.joinpath("ISIMIP_2b_aggregated_variables", "domains.csv"))
domains_df = domains_df[["lon", "lat", "domain_days_below_1_0.08_aridity_netrad"]]
regions_df = domains_df.rename(columns={"domain_days_below_1_0.08_aridity_netrad": "region"})
regions_df = regions_df.set_index(["lon", "lat"])

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots_regions_2x2_Gnann"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS,
                                regions_df=regions_df,
                                regions_2x2=True
                                )

In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots_regions_Gnann"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS,
                                regions_df=regions_df,
                                regions_2x2=False
                                )

## Interactive

In [None]:
from src.visualization import visualize


data_df = data_df.sample(frac=0.2)

visualize.plot_scatter_with_dropdown(df=data_df,
                                     default_x="pr",
                                     default_y="potevap",
                                     valid_x=sorted(INPUTS_COLUMNS + FORCINGS_COLUMNS),
                                     valid_y=sorted(OUTPUTS_COLUMNS))