In [None]:
from pathlib import Path
import pandas as pd
import xarray as xr
from ombs_senegal.benchmark_model import BenchmarkScores

DATA_PATH = Path("../testing_data")
OBS_COL = 'Q_obs'

In [None]:
data = pd.read_csv(
    DATA_PATH/'hydro_example.csv', 
    usecols=['time', OBS_COL], 
    index_col='time',
    converters={"time": pd.to_datetime}
    )
data = data['2012-01-01':]

## Seassonal model

First we will create a seassonal model as benchmark

In [None]:
class SeasonalityHandler:
    """Class to handle seasonality operations in time series data.
    
    This class provides methods to:
    - Compute seasonal patterns based on week of year
    - Remove seasonality from data
    - Add seasonality back to data
    
    Attributes:
        seasonal_pattern: pd.DataFrame
            The computed seasonal pattern, indexed by week of year
    """
    
    def __init__(self):
        """Initialize the SeasonalityHandler."""
        self.seasonal_pattern = None
    
    def compute_seasonal_pattern(self, data: pd.DataFrame) -> pd.DataFrame:
        """Compute mean values for each week of the year to capture seasonal patterns."""
        self.seasonal_pattern = data.groupby(data.index.isocalendar().week).mean()
        self.original_columns = data.columns
        return self.seasonal_pattern
    
    def remove_seasonality(self, data: pd.DataFrame) -> pd.DataFrame:
        """Remove seasonality from the data."""
        if self.seasonal_pattern is None:
            raise ValueError("Seasonal pattern not computed. Call compute_seasonal_pattern first.")
            
        if not all(col in self.original_columns for col in data.columns):
            raise ValueError("Input data contains columns not present in seasonal pattern")
        
        deseasonalized = data.copy()
        deseasonalized = self.add_week_index(deseasonalized, True)
        deseasonalized = deseasonalized - self.seasonal_pattern
        deseasonalized = self.remove_week_index(deseasonalized, True)
        return deseasonalized
    
    def add_seasonality(self, data: pd.DataFrame) -> pd.DataFrame:
        """Add seasonality back to the data."""
        if self.seasonal_pattern is None:
            raise ValueError("Seasonal pattern not computed. Call compute_seasonal_pattern first.")
            
        if not all(col in self.original_columns for col in data.columns):
            raise ValueError("Input data contains columns not present in seasonal pattern")
        
        reseasonalized = data.copy()
        reseasonalized = self.add_week_index(reseasonalized, True)
        reseasonalized = reseasonalized + self.seasonal_pattern
        reseasonalized = self.remove_week_index(reseasonalized, True)
        
        return reseasonalized
    
    def add_week_index(self, data: pd.DataFrame, inplace: bool = False) -> pd.DataFrame:
        """Add week as index level to the data."""
        if not inplace:
            data = data.copy()
        data['week'] = data.index.isocalendar().week
        data.set_index('week', inplace=True, append=True)
        return data
    
    def remove_week_index(self, data: pd.DataFrame, inplace: bool = False) -> pd.DataFrame:
        """Remove week as index level from the data."""
        if not inplace:
            data = data.copy()
        data.reset_index(level='week', drop=True, inplace=True)
        return data

We will now compute the seasonal variations and show the resutls

In [None]:
# Split data into train and test sets (80-20 split)
train_mask = data.index < '2019-01-01'
train_data = data[train_mask]
test_data = data[~train_mask]

In [None]:
seasonality_handler = SeasonalityHandler()
season = seasonality_handler.compute_seasonal_pattern(train_data)

In [None]:
results = seasonality_handler.add_week_index(test_data)
results = results.join(season.rename(columns={OBS_COL:"season"}), on='week')
results = seasonality_handler.remove_week_index(results)

In [None]:
benchmark_scores = BenchmarkScores()
scores = benchmark_scores.compute_scores(
    results[["season"]].to_xarray(), results[OBS_COL].to_xarray(), ["rmse", "mae", "nse", "kge"]
    )
scores_df = scores.to_dataframe()

In [None]:
scores_str = " ".join(f"{c.upper()}: {scores_df[c].values[0]:.2f}" for c in scores_df.columns)
ax = results.plot(title=f"Seasonal Model\n{scores_str}", ylabel="Flow (m³/s)")

In [None]:
import xarray as xr

season_pred = results[["season"]].to_xarray()
season_pred = xr.merge([
    season_pred,
    scores.to_array("score", name="scores").sel(forecast_horizon="season", drop=True)
     ])
season_pred = season_pred.expand_dims({"model": ["Season"]})
season_pred