## Post-Processing

After training and forecasting, we need to do a post-processing in order to guarantee that our output forecast makes sense.
For example, our forecast must never go below zero, as our target is number of dengue cases, which is either zero or larger.
Another point is that our quantiles must be monotonic, i.e. if qunatile `0.5` forecasts a value of `20`, then quantile `0.6` must forecast
a value larger than `20`.

In [83]:
import polars as pl
import numpy as np

### Getting the intervals required by the sprint

In [84]:
intervals = [0.5,0.8,0.9,0.95]

quantiles = [[np.round(0.5 - i/2,decimals=3), np.round(0.5 + i/2,decimals=3)]for i in intervals]

def estimate_quantile(predictions,target_quantile):
    """
    Estimate the value of a given quantile based on the predictions.

    Parameters:
    predictions (DataFrame): A DataFrame containing quantile predictions.
                                     Columns should represent quantile levels (e.g., '0.1', '0.2', ..., '0.9').
    target_quantile (float): The quantile level to estimate (e.g., 0.25, 0.75).

    Returns:
    float: The estimated value for the target quantile, interpolated if necessary.

    """
    quantile_values = np.arange(0.1,1.0, 0.1)
    if target_quantile in quantile_values:
        return predictions[str(target_quantile)]
    if target_quantile < 0.1:
        return predictions['0.1'] - (0.1 - target_quantile)*(predictions['0.2'] - predictions['0.1'])/0.1
    if target_quantile > 0.9:
        return predictions['0.9'] + (target_quantile - 0.9)*(predictions['0.9'] - predictions['0.8'])/0.1

    lower_bound = np.round(max(q for q in quantile_values if q < target_quantile),decimals=2)
    upper_bound = np.round(min(q for q in quantile_values if q > target_quantile),decimals=2)
    lower_values = predictions[str(lower_bound)]
    upper_values = predictions[str(upper_bound)]
    slope = (upper_values - lower_values) / (upper_bound - lower_bound)
    return lower_values + slope * (target_quantile - lower_bound)



Importing our predictions. 

In [85]:
# import predictions
predictions = pl.read_parquet('../predictions/autogluon_baseline.parquet')

## Checking Monotone Condition

Note that the quantile values must be monotone, i.e. if quantile `0.5` forecasts a value of `20`, then quantile `0.6` must forecast
a value larger.  Depending on how the quantiles are estimated, a forecasting model might give
inconsistent predictions. The following function checks for monotonicity.

In [86]:
def check_monotonicity(row, qs):
    vals = [row[q] for q in qs]
    return np.all(np.diff(vals) >= 0)

In [87]:
qs = [str(i) for i in np.sort([float(i) for i in predictions.columns[1:10]])]
predictions = predictions.with_columns(
    pl.struct(predictions.columns)
        .map_elements(lambda row: check_monotonicity(row,qs),return_dtype=bool).alias('mono')
)

Let us check if there are non monotonic cases. 

In [88]:
predictions.filter(~pl.col('mono'))

mean,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,item_id,timestamp,mono
f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,str,datetime[ms],bool
-32.769592,-69.29425,-42.997375,-35.186707,-33.124023,-32.769592,-35.926514,-41.837769,-50.733459,-58.776917,"""ES""",2023-09-17 00:00:00,false
-16.645691,-61.632385,-28.760376,-20.447327,-18.525391,-16.645691,-18.636841,-24.714905,-31.975647,-22.383911,"""ES""",2023-09-24 00:00:00,false
356.431641,183.191406,283.436768,338.314453,340.759766,356.431641,355.583008,377.563477,512.13208,1317.305298,"""MG""",2023-07-02 00:00:00,false
344.495361,143.489258,262.353271,328.012695,340.793457,344.495361,328.40918,314.793457,399.235107,1066.433105,"""MG""",2023-07-09 00:00:00,false
200.110596,-10.65332,124.420654,201.552979,209.347168,200.110596,177.342041,142.052734,208.162354,901.991455,"""MG""",2023-07-16 00:00:00,false
…,…,…,…,…,…,…,…,…,…,…,…,…
-349.505615,-498.082581,-398.725708,-372.924072,-377.567078,-349.505615,-267.616943,-147.181427,138.80719,843.253174,"""SC""",2022-08-28 00:00:00,false
-276.413422,-453.482666,-324.03949,-290.936462,-307.608612,-276.413422,-199.183502,-48.1772,243.201797,975.922302,"""SC""",2022-09-04 00:00:00,false
216.537109,-447.092773,-164.146973,29.614258,150.112305,216.537109,247.182129,234.11377,318.377441,937.294678,"""SP""",2023-08-13 00:00:00,false
178.131836,-460.749512,-175.565918,-5.862793,114.979492,178.131836,179.61377,162.405762,220.858398,829.817871,"""SP""",2023-08-20 00:00:00,false


In this case, we do. Therefore, we must adjust our predictions.

### Method 1: Sorting
In this method, we sort the values, and re-attribute the quantiles based on this sorting.
For the already monotonic cases, nothing will change. For non-monotonic, we will get a proper quantile collection.

In [199]:
def sort_quantiles(row: dict, prepend='s') -> dict:
    quantile_cols =  ["0.1","0.2","0.3","0.4","0.5","0.6","0.7","0.8","0.9"]
    # extract quantile values
    sorted_vals = sorted(row[q] for q in quantile_cols)
    # return a dict mapping back to the same columns
    return {prepend+col: val for col, val in zip(quantile_cols, sorted_vals)}

pred = predictions.with_columns(
    pl.struct(qs).map_elements(sort_quantiles).alias("sorted_struct")
).unnest("sorted_struct")

sqs = ['s'+q for q in qs]
pred = pred.with_columns(
    pl.struct(pred.columns)
        .map_elements(lambda row: check_monotonicity(row,sqs),return_dtype=bool).alias('mono')
)

  pred = predictions.with_columns(


### Check if the sorting worked 

In [200]:
# Check if the sum of the orignal quantiles matches with the sorted
print(len(pred.with_columns(
    pl.sum_horizontal([pl.col(q) for q in qs]).alias('row_sum'),
    pl.sum_horizontal([pl.col('s'+q) for q in qs]).alias('row_sum2'),
).filter(
    np.abs(pl.col('row_sum') - pl.col('row_sum2')) > 0.1
)) == 0)


# Check if all are monotonic
pred.filter(~pl.col('mono'))

True


mean,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,item_id,timestamp,mono,0.25,0.75,0.05,0.95,0.025,0.975,s0.1,s0.2,s0.3,s0.4,s0.5,s0.6,s0.7,s0.8,s0.9
f32,f64,f32,f32,f32,f32,f32,f32,f32,f64,str,datetime[ms],bool,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64


#### Renaming sorted quantiles to the original.

In [201]:
predictions = pred.drop(qs).rename({f's{q}': q for q in qs})

Autogluon computes quantiles 0.1,0.2,...,0.9, so we need to convert them to the intervals required by the sprint.
We use simple linear interpolation/extrapolation.

In [205]:
for q in np.hstack(quantiles):
    predictions = predictions.with_columns(
        pl.struct(predictions.columns).map_elements(lambda row: estimate_quantile(row, target_quantile=q),return_dtype=float).alias(str(q))
    )

In [207]:
predictions

mean,item_id,timestamp,mono,0.25,0.75,0.05,0.95,0.025,0.975,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9
f32,str,datetime[ms],bool,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
32.646072,"""AC""",2022-06-26 00:00:00,true,19.80246,44.014046,3.60585,64.790176,1.427971,67.74604,7.961609,16.673126,22.931793,28.160172,32.646072,36.753967,40.973099,47.054993,58.878448
31.444885,"""AC""",2022-07-03 00:00:00,true,18.135551,43.152397,0.296234,65.393326,-2.152435,68.670807,5.193573,14.988251,21.282852,26.381317,31.444885,35.775024,40.576355,45.728439,58.838364
29.168777,"""AC""",2022-07-10 00:00:00,true,15.024055,41.199455,-3.119896,66.7976,-5.603642,70.57188,1.847595,11.782578,18.265533,24.15863,29.168777,34.275879,38.246994,44.151917,59.249039
28.488434,"""AC""",2022-07-17 00:00:00,true,14.088104,39.98996,-2.910408,66.651703,-5.170055,70.648285,1.608887,10.647476,17.528732,23.252213,28.488434,33.140259,37.307709,42.672211,58.658539
27.669281,"""AC""",2022-07-24 00:00:00,true,13.28801,37.143944,-4.84227,63.391685,-7.314262,67.424953,0.101715,9.989685,16.586334,22.762085,27.669281,31.898636,35.09581,39.192078,55.32515
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
27.878174,"""TO""",2023-09-10 00:00:00,true,6.007065,64.427376,-22.681458,163.52887,-26.607674,177.967712,-14.829025,0.875839,11.13829,19.444809,27.878174,37.506454,51.958939,76.895813,134.651184
24.473618,"""TO""",2023-09-17 00:00:00,true,13.824799,32.799904,0.371727,38.278351,-1.440365,38.927322,3.995911,11.244278,16.405319,20.701416,24.473618,27.866333,31.215286,34.384521,36.980408
26.049774,"""TO""",2023-09-24 00:00:00,true,14.813858,33.688805,1.061951,40.015953,-0.781769,40.886822,4.74939,12.124268,17.503448,21.985672,26.049774,29.365356,32.586868,34.790741,38.274216
26.676178,"""TO""",2023-10-01 00:00:00,true,15.061493,35.624428,0.843765,44.597694,-1.075714,45.814091,4.682724,12.360641,17.762344,22.460037,26.676178,30.686005,33.949539,37.299316,42.164902


In [221]:
from lets_plot import *
LetsPlot.setup_html()

i = 3
pred = predictions.drop(['mono','mean','item_id','timestamp'])[i]
pred = pred.unpivot().rename({'variable':'q','value':'y'})
pred = pred.with_columns(
    (~pl.col('q').is_in(qs)).alias('new_q'),
    pl.col('q').cast(pl.Float64)
)

(
    ggplot(data=pred)
    + geom_line(aes(x='q', y='y'), color='blue')
    + geom_point(aes(x='q', y='y',color='new_q'))
)