## Post-Processing

After training and forecasting, we need to do a post-processing in order to guarantee that our output forecast makes sense.
For example, our forecast must never go below zero, as our target is number of dengue cases, which is either zero or larger.
Another point is that our quantiles must be monotonic, i.e. if qunatile `0.5` forecasts a value of `20`, then quantile `0.6` must forecast
a value larger than `20`.

In [1]:
import polars as pl
import numpy as np

### Getting the intervals required by the sprint

In [36]:
intervals = [0.5,0.8,0.9,0.95]

quantiles = [[np.round(0.5 - i/2,decimals=3), np.round(0.5 + i/2,decimals=3)]for i in intervals]
qs = [str(i) for i in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]]

def estimate_quantile(predictions,target_quantile):
    """
    Estimate the value of a given quantile based on the predictions.

    Parameters:
    predictions (DataFrame): A DataFrame containing quantile predictions.
                                     Columns should represent quantile levels (e.g., '0.1', '0.2', ..., '0.9').
    target_quantile (float): The quantile level to estimate (e.g., 0.25, 0.75).

    Returns:
    float: The estimated value for the target quantile, interpolated if necessary.

    """
    quantile_values = np.arange(0.1,1.0, 0.1)
    if target_quantile in quantile_values:
        return predictions[str(target_quantile)]
    if target_quantile < 0.1:
        return predictions['0.1'] - (0.1 - target_quantile)*(predictions['0.2'] - predictions['0.1'])/0.1
    if target_quantile > 0.9:
        return predictions['0.9'] + (target_quantile - 0.9)*(predictions['0.9'] - predictions['0.8'])/0.1

    lower_bound = np.round(max(q for q in quantile_values if q < target_quantile),decimals=2)
    upper_bound = np.round(min(q for q in quantile_values if q > target_quantile),decimals=2)
    lower_values = predictions[str(lower_bound)]
    upper_values = predictions[str(upper_bound)]
    slope = (upper_values - lower_values) / (upper_bound - lower_bound)
    return lower_values + slope * (target_quantile - lower_bound)



Importing our predictions. 

In [48]:
# import predictions
validation1 = pl.read_parquet('../data/4_model_output/validation_sprint_1.parquet').with_columns(
    pl.col("date").cast(pl.Date)
)
validation2 = pl.read_parquet('../data/4_model_output/validation_sprint_2.parquet').with_columns(
    pl.col("date").cast(pl.Date)
)
validation3 = pl.read_parquet('../data/4_model_output/validation_sprint_3.parquet').with_columns(
    pl.col("date").cast(pl.Date)
)

## Checking Monotone Condition

Note that the quantile values must be monotone, i.e. if quantile `0.5` forecasts a value of `20`, then quantile `0.6` must forecast
a value larger.  Depending on how the quantiles are estimated, a forecasting model might give
inconsistent predictions. The following function checks for monotonicity.

In [49]:
def check_monotonicity(row, qs):
    vals = [row[q] for q in qs]
    return np.all(np.diff(vals) >= 0)

def apply_check_monotonicity(predictions, qs=['0.1','0.2','0.3','0.4','0.5','0.6','0.7','0.8','0.9']):
    predictions = predictions.with_columns(
        pl.struct(predictions.columns)
            .map_elements(lambda row: check_monotonicity(row,qs),return_dtype=bool).alias('mono')
    )
    return predictions

Let us check if there are non monotonic cases. 

In [50]:
validation1 = apply_check_monotonicity(validation1)
validation1.filter(~pl.col('mono'))

uf,date,mean,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,epiweek,target_1,mono
str,date,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i64,bool,bool
"""RS""",2022-11-27,-81.416016,-212.614136,-131.288422,-99.96698,-81.089691,-81.416016,-52.639679,28.986839,253.777161,1052.812622,202248,True,False
"""SP""",2022-11-06,613.699463,-138.401367,251.768066,444.932129,571.827881,613.699463,611.859375,688.529785,933.813965,2247.476807,202245,True,False
"""PI""",2023-09-03,60.015114,31.10347,30.667221,29.79837,37.710167,60.015114,105.979538,182.353897,321.660278,616.369019,202336,True,False
"""PI""",2023-09-10,46.822922,17.735245,14.239532,14.124344,22.992615,46.822922,95.299263,177.751862,316.703461,609.36084,202337,True,False


In this case, we do. Therefore, we must adjust our predictions.

### Method 1: Sorting
In this method, we sort the values, and re-attribute the quantiles based on this sorting.
For the already monotonic cases, nothing will change. For non-monotonic, we will get a proper quantile collection.

In [51]:
def sort_quantiles(row: dict, prepend='s') -> dict:
    quantile_cols =  ["0.1","0.2","0.3","0.4","0.5","0.6","0.7","0.8","0.9"]
    # extract quantile values
    sorted_vals = sorted(row[q] for q in quantile_cols)
    # return a dict mapping back to the same columns
    return {prepend+col: val for col, val in zip(quantile_cols, sorted_vals)}


def apply_sort_quantiles(predictions, qs=['0.1','0.2','0.3','0.4','0.5','0.6','0.7','0.8','0.9']):
    pred = predictions.with_columns(
        pl.struct(qs).map_elements(sort_quantiles).alias("sorted_struct")
    ).unnest("sorted_struct")

    sqs = ['s'+q for q in qs]
    pred = pred.with_columns(
        pl.struct(pred.columns)
            .map_elements(lambda row: check_monotonicity(row,sqs),return_dtype=bool).alias('mono')
    )
    return pred

In [53]:
validation1 = apply_sort_quantiles(validation1)

  pred = predictions.with_columns(


### Check if the sorting worked 

In [58]:
# Check if the sum of the orignal quantiles matches with the sorted
def check_sum_quantiles(validation, qs=['0.1','0.2','0.3','0.4','0.5','0.6','0.7','0.8','0.9']):
    print(len(validation1.with_columns(
        pl.sum_horizontal([pl.col(q) for q in qs]).alias('row_sum'),
        pl.sum_horizontal([pl.col('s'+q) for q in qs]).alias('row_sum2'),
    ).filter(
        np.abs(pl.col('row_sum') - pl.col('row_sum2')) > 0.1
    )) == 0)


    # Check if all are monotonic
    return validation.filter(~pl.col('mono'))

check_sum_quantiles(validation1)

True


uf,date,mean,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,epiweek,target_1,mono,s0.1,s0.2,s0.3,s0.4,s0.5,s0.6,s0.7,s0.8,s0.9
str,date,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i64,bool,bool,f64,f64,f64,f64,f64,f64,f64,f64,f64


#### Renaming sorted quantiles to the original.

In [59]:
validation1 = validation1.drop(qs).rename({f's{q}': q for q in qs})

Autogluon computes quantiles 0.1,0.2,...,0.9, so we need to convert them to the intervals required by the sprint.
We use simple linear interpolation/extrapolation.

In [60]:
for q in np.hstack(quantiles):
    validation1 = validation1.with_columns(
        pl.struct(validation1.columns).map_elements(lambda row: estimate_quantile(row, target_quantile=q),return_dtype=float).alias(str(q))
    )

In [62]:
validation1

uf,date,mean,epiweek,target_1,mono,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.25,0.75,0.05,0.95,0.025,0.975
str,date,f32,i64,bool,bool,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""RJ""",2022-10-09,20.754639,202241,true,true,-44.537598,-19.891479,-3.844604,8.201904,20.754639,39.189087,61.92627,88.855957,151.351807,-11.868042,75.391113,-56.860657,182.599731,-63.022186,198.223694
"""RJ""",2022-10-16,45.361938,202242,true,true,-40.073486,-12.325562,8.807251,27.176025,45.361938,64.186646,86.517944,114.416382,184.415283,-1.759155,100.467163,-53.947449,219.414734,-60.88443,236.914459
"""RJ""",2022-10-23,42.41626,202243,true,true,-43.28186,-15.07373,8.538452,26.965576,42.41626,64.560181,91.47937,124.641479,205.903564,-3.267639,108.060425,-57.385925,246.534607,-64.437958,266.850128
"""RJ""",2022-10-30,43.486084,202244,true,true,-45.106689,-18.593384,5.735229,25.024414,43.486084,67.587402,96.228638,135.265747,220.12915,-6.429077,115.747192,-58.363342,262.560852,-64.991669,283.776703
"""RJ""",2022-11-06,52.406006,202245,true,true,-43.868286,-13.547241,8.329102,31.474609,52.406006,74.221436,107.44519,149.462402,257.632568,-2.60907,128.453796,-59.028809,311.717651,-66.60907,338.760193
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""SE""",2023-09-03,45.724743,202336,true,true,5.682419,15.928577,22.071175,30.505674,45.724743,70.900909,107.585052,166.706589,270.222809,18.999876,137.145821,0.55934,321.980919,-2.0022,347.859974
"""SE""",2023-09-10,42.899181,202337,true,true,4.129044,12.537964,18.511459,26.938179,42.899181,69.379967,108.192032,167.329208,273.653961,15.524712,137.76062,-0.075417,326.816338,-2.177647,353.397526
"""SE""",2023-09-17,37.646095,202338,true,true,5.043617,10.878052,16.421257,22.115238,37.646095,64.161903,102.333832,160.041809,250.794937,13.649654,131.18782,2.1264,296.171501,0.667791,318.859783
"""SE""",2023-09-24,35.854889,202339,true,true,4.977779,9.557594,15.631424,22.161182,35.854889,62.490421,100.214928,155.457962,232.420456,12.594509,127.836445,2.687872,270.901703,1.542918,290.142326


In [71]:
from lets_plot import *
LetsPlot.setup_html()

i = 1
pred = validation1.drop(['mono','mean','uf','date','epiweek','target_1'])[i]
pred = pred.unpivot().rename({'variable':'q','value':'y'})
pred = pred.with_columns(
    (~pl.col('q').is_in(qs)).alias('new_q'),
    pl.col('q').cast(pl.Float64)
)

(
    ggplot(data=pred)
    + geom_line(aes(x='q', y='y'), color='blue')
    + geom_point(aes(x='q', y='y',color='new_q'))
)

### Cases > 0

The model needs to predict always cases > 0. If the model predicts a value < 0, we need to set it to 0.

In [72]:
qcols = validation1.columns[4:]
validation1=validation1.with_columns([
    pl.when(pl.col(q) < 0).then(0).otherwise(pl.col(q)).alias(q)
    for q in qcols
])

## Adjusting Format for Submission

In [21]:
# lower_95	lower_90	lower_80	lower_50	pred	upper_50	upper_80	upper_90	upper_95	date
# [2.5, 5, 10, 25, 50, 75, 90, 95, 97.5]
columns_submission = ['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95','date']
submission = predictions.rename(
    {
        'timestamp':'date',
        '0.025':'lower_95',
        '0.05' :'lower_90',
        '0.1'  :'lower_80',
        '0.25' :'lower_50',
        '0.5':'pred',
        '0.75' :'upper_50',
        '0.9'  :'upper_80',
        '0.95' :'upper_90',
        '0.975':'upper_95',
    })[columns_submission]
submission = submission.with_columns(
    pl.col('date').dt.date()
)

In [22]:
submission

lower_95,lower_90,lower_80,lower_50,pred,upper_50,upper_80,upper_90,upper_95,date
f64,f64,f64,f64,f64,f64,f64,f64,f64,date
1.427971,3.60585,7.961609,19.80246,32.646072,44.014046,58.878448,64.790176,67.74604,2022-06-26
0.0,0.296234,5.193573,18.135551,31.444885,43.152397,58.838364,65.393326,68.670807,2022-07-03
0.0,0.0,1.847595,15.024055,29.168777,41.199455,59.249039,66.7976,70.57188,2022-07-10
0.0,0.0,1.608887,14.088104,28.488434,39.98996,58.658539,66.651703,70.648285,2022-07-17
0.0,0.0,0.101715,13.28801,27.669281,37.143944,55.32515,63.391685,67.424953,2022-07-24
…,…,…,…,…,…,…,…,…,…
0.0,0.0,0.0,6.007065,27.878174,64.427376,134.651184,163.52887,177.967712,2023-09-10
0.0,0.371727,3.995911,13.824799,24.473618,32.799904,36.980408,38.278351,38.927322,2023-09-17
0.0,1.061951,4.74939,14.813858,26.049774,33.688805,38.274216,40.015953,40.886822,2023-09-24
0.0,0.843765,4.682724,15.061493,26.676178,35.624428,42.164902,44.597694,45.814091,2023-10-01
