## Post-Processing

After training and forecasting, we need to do a post-processing in order to guarantee that our output forecast makes sense.
For example, our forecast must never go below zero, as our target is number of dengue cases, which is either zero or larger.
Another point is that our quantiles must be monotonic, i.e. if qunatile `0.5` forecasts a value of `20`, then quantile `0.6` must forecast
a value larger than `20`.

In [6]:
import polars as pl
import numpy as np

### Getting the intervals required by the sprint

In [26]:
intervals = [0.5,0.8,0.9,0.95]

quantiles = [[np.round(0.5 - i/2,decimals=3), np.round(0.5 + i/2,decimals=3)]for i in intervals]

def estimate_quantile(predictions,target_quantile):
    """
    Estimate the value of a given quantile based on the predictions.

    Parameters:
    predictions (DataFrame): A DataFrame containing quantile predictions.
                                     Columns should represent quantile levels (e.g., '0.1', '0.2', ..., '0.9').
    target_quantile (float): The quantile level to estimate (e.g., 0.25, 0.75).

    Returns:
    float: The estimated value for the target quantile, interpolated if necessary.

    """
    quantile_values = np.arange(0.1,1.0, 0.1)
    if target_quantile in quantile_values:
        return predictions[str(target_quantile)]
    if target_quantile < 0.1:
        return predictions['0.1'] - (0.1 - target_quantile)*(predictions['0.2'] - predictions['0.1'])/0.1
    if target_quantile > 0.9:
        return predictions['0.9'] + (target_quantile - 0.9)*(predictions['0.9'] - predictions['0.8'])/0.1

    lower_bound = np.round(max(q for q in quantile_values if q < target_quantile),decimals=2)
    upper_bound = np.round(min(q for q in quantile_values if q > target_quantile),decimals=2)
    lower_values = predictions[str(lower_bound)]
    upper_values = predictions[str(upper_bound)]
    slope = (upper_values - lower_values) / (upper_bound - lower_bound)
    return lower_values + slope * (target_quantile - lower_bound)



Importing our predictions. 

In [27]:
# import predictions
predictions = pl.read_parquet('../predictions/autogluon_baseline.parquet')

## Checking Monotone Condition

Note that the quantile values must be monotone, i.e. if quantile `0.5` forecasts a value of `20`, then quantile `0.6` must forecast
a value larger.  Depending on how the quantiles are estimated, a forecasting model might give
inconsistent predictions. The following function checks for monotonicity.

In [45]:
def check_monotonicity(row, qs):
    vals = [row[q] for q in qs]
    return np.all(np.diff(vals) >= 0)

In [47]:
qs = [str(i) for i in np.sort([float(i) for i in predictions.columns[1:10]])]
predictions = predictions.with_columns(
    pl.struct(predictions.columns)
        .map_elements(lambda row: check_monotonicity(row,qs),return_dtype=bool).alias('mono')
)

Autogluon computes quantiles 0.1,0.2,...,0.9, so we need to convert them to the intervals required by the sprint.
We use simple linear interpolation/extrapolation.

In [15]:
for q in np.hstack(quantiles):
    predictions = predictions.with_columns(
        pl.struct(predictions.columns).map_elements(lambda row: estimate_quantile(row, target_quantile=q),return_dtype=float).alias(str(q))
    )


In [23]:
from lets_plot import *
LetsPlot.setup_html()
# pred = predictions.drop(['mono','mean'])
# i = 4
# sample = pl.DataFrame(
#     {'q':pred.iloc[i].index.values.astype(float),

#     'y':pred.iloc[i].values,}
# )

# (
#     ggplot(data=sample)
#     + geom_line(aes(x='q', y='y'), color='blue')
#     + geom_point(aes(x='q', y='y'), color='blue')
# )
# predictions.drop(['mono'])
predictions

mean,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,item_id,timestamp,0.25,0.75,0.05,0.95,0.025,0.975
f32,f64,f32,f32,f32,f32,f32,f32,f32,f64,str,datetime[ms],f64,f64,f64,f64,f64,f64
32.646072,7.961609,16.673126,22.931793,28.160172,32.646072,36.753967,40.973099,47.054993,58.878448,"""AC""",2022-06-26 00:00:00,19.80246,44.014046,3.60585,64.790176,1.427971,67.74604
31.444885,5.193573,14.988251,21.282852,26.381317,31.444885,35.775024,40.576355,45.728439,58.838364,"""AC""",2022-07-03 00:00:00,18.135551,43.152397,0.296234,65.393326,-2.152435,68.670807
29.168777,1.847595,11.782578,18.265533,24.15863,29.168777,34.275879,38.246994,44.151917,59.249039,"""AC""",2022-07-10 00:00:00,15.024055,41.199455,-3.119896,66.7976,-5.603642,70.57188
28.488434,1.608887,10.647476,17.528732,23.252213,28.488434,33.140259,37.307709,42.672211,58.658539,"""AC""",2022-07-17 00:00:00,14.088104,39.98996,-2.910408,66.651703,-5.170055,70.648285
27.669281,0.101715,9.989685,16.586334,22.762085,27.669281,31.898636,35.09581,39.192078,55.32515,"""AC""",2022-07-24 00:00:00,13.28801,37.143944,-4.84227,63.391685,-7.314262,67.424953
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
27.878174,-14.829025,0.875839,11.13829,19.444809,27.878174,37.506454,51.958939,76.895813,134.651184,"""TO""",2023-09-10 00:00:00,6.007065,64.427376,-22.681458,163.52887,-26.607674,177.967712
24.473618,3.995911,11.244278,16.405319,20.701416,24.473618,27.866333,31.215286,34.384521,36.980408,"""TO""",2023-09-17 00:00:00,13.824799,32.799904,0.371727,38.278351,-1.440365,38.927322
26.049774,4.74939,12.124268,17.503448,21.985672,26.049774,29.365356,32.586868,34.790741,38.274216,"""TO""",2023-09-24 00:00:00,14.813858,33.688805,1.061951,40.015953,-0.781769,40.886822
26.676178,4.682724,12.360641,17.762344,22.460037,26.676178,30.686005,33.949539,37.299316,42.164902,"""TO""",2023-10-01 00:00:00,15.061493,35.624428,0.843765,44.597694,-1.075714,45.814091
