# Explore CWatM data

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

from pathlib import Path

from tqdm.notebook import tqdm
import pandas as pd
import xarray as xr

import plotly.express as px

import src.data.cwatm_data as cwatm_data

In [9]:
PROCESSED_DATA_FOLDER_PATH = Path("../data/processed")

## Load CWatM data

In [16]:
all_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all.nc")).to_dataframe()
forcings_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings.nc")).to_dataframe()
outputs_df = xr.open_dataset(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs.nc")).to_dataframe()

In [None]:
all_df

In [None]:
forcings_df

In [None]:
outputs_df

### Process the data

In [None]:
all_land_df = cwatm_data.process_inputs_df(all_df)
all_land_df

In [None]:
forcings_land_df = forcings_df.loc[all_land_df.index]
forcings_land_df

In [None]:
outputs_land_df = outputs_df.loc[all_land_df.index]
outputs_land_df

In [29]:
all_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all_land.parquet"))
forcings_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings_land.parquet"))
outputs_land_df.to_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs_land.parquet"))

## Load CWatM `_land` data

In [10]:
all_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "all_land.parquet"))
forcings_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "forcings_land.parquet"))
outputs_land_df = pd.read_parquet(PROCESSED_DATA_FOLDER_PATH.joinpath("CWatM_data", "outputs_land.parquet"))

In [None]:
all_land_df.describe()

In [None]:
forcings_land_df.describe()

In [None]:
outputs_land_df.describe()

## Visualize data

In [11]:
data_df = pd.concat((all_land_df, forcings_land_df, outputs_land_df), axis=1)

In [None]:
INPUTS_COLUMNS = list(all_land_df.columns)
#     "porosity",
#     "firstStorDepth",
#     "percolationImp",
#     "tanslope",
#     "maxRootDepth_forest",
#     "maxRootDepth_grassland"
# ]
FORCINGS_COLUMNS = [  # list(forcings_land_df.columns)
    "pr",
    "tas",
    "tasmax",
    "tasmin",
    "ps",
    "rlds",
    "rsds",
    "sfcwind",
    "hurs",
    "huss",
]
OUTPUTS_COLUMNS = [  # list(outputs_land_df.columns)
    "evap-total",
    "potevap",
    "qr",
    "qtot"
]

# data_df = data_df.iloc[:1000]

data_df.shape

In [17]:
from itertools import product
from tqdm import tqdm

import matplotlib.pyplot as plt

def display_individual_scatterplots(df,
                                    dst_path: Path,
                                    valid_x,
                                    valid_y,
                                    ):

    combinations = product(valid_x, valid_y)

    for input_col, output_col in tqdm(list(combinations), desc="Computing input-output combinations"):

        fig, axis = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))

        # Assuming the DataFrame has 'x' and 'y' columns for the scatter plot
        axis.scatter(df[input_col], df[output_col],
                     alpha=0.2)
        axis.set_title(f"Input '{input_col}' - Output '{output_col}'")
        axis.set_xlabel(input_col)
        axis.set_ylabel(output_col)

        # Adjust layout and display the plots
        plt.tight_layout()

        fig.savefig(dst_path.joinpath(f"{input_col}_{output_col}.png"), dpi=300)

        plt.close()


In [None]:
display_individual_scatterplots(df=data_df,
                                dst_path=Path("../reports/figures/CWatM_data/scatterplots"),
                                valid_x=INPUTS_COLUMNS + FORCINGS_COLUMNS,
                                valid_y=OUTPUTS_COLUMNS
                                )

In [None]:
from src.visualization import visualize


data_df = data_df.sample(frac=0.2)

visualize.plot_scatter_with_dropdown(df=data_df,
                                     default_x="pr",
                                     default_y="potevap",
                                     valid_x=sorted(INPUTS_COLUMNS + FORCINGS_COLUMNS),
                                     valid_y=sorted(OUTPUTS_COLUMNS))