## Post-Processing

After training and forecasting, we need to do a post-processing in order to guarantee that our output forecast makes sense.
For example, our forecast must never go below zero, as our target is number of dengue cases, which is either zero or larger.
Another point is that our quantiles must be monotonic, i.e. if qunatile `0.5` forecasts a value of `20`, then quantile `0.6` must forecast
a value larger than `20`.

In [1]:
import polars as pl
import numpy as np

### Getting the intervals required by the sprint

In [2]:
intervals = [0.5,0.8,0.9,0.95]

quantiles = [[np.round(0.5 - i/2,decimals=3), np.round(0.5 + i/2,decimals=3)]for i in intervals]
qs = [str(i) for i in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]]

def estimate_quantile(predictions,target_quantile):
    """
    Estimate the value of a given quantile based on the predictions.

    Parameters:
    predictions (DataFrame): A DataFrame containing quantile predictions.
                                     Columns should represent quantile levels (e.g., '0.1', '0.2', ..., '0.9').
    target_quantile (float): The quantile level to estimate (e.g., 0.25, 0.75).

    Returns:
    float: The estimated value for the target quantile, interpolated if necessary.

    """
    quantile_values = np.arange(0.1,1.0, 0.1)
    if target_quantile in quantile_values:
        return predictions[str(target_quantile)]
    if target_quantile < 0.1:
        return predictions['0.1'] - (0.1 - target_quantile)*(predictions['0.2'] - predictions['0.1'])/0.1
    if target_quantile > 0.9:
        return predictions['0.9'] + (target_quantile - 0.9)*(predictions['0.9'] - predictions['0.8'])/0.1

    lower_bound = np.round(max(q for q in quantile_values if q < target_quantile),decimals=2)
    upper_bound = np.round(min(q for q in quantile_values if q > target_quantile),decimals=2)
    lower_values = predictions[str(lower_bound)]
    upper_values = predictions[str(upper_bound)]
    slope = (upper_values - lower_values) / (upper_bound - lower_bound)
    return lower_values + slope * (target_quantile - lower_bound)



In [None]:
# sort quantiles
def sort_quantiles(row: dict, prepend='s') -> dict:
    quantile_cols =  ["0.1","0.2","0.3","0.4","0.5","0.6","0.7","0.8","0.9"]
    # extract quantile values
    sorted_vals = sorted(row[q] for q in quantile_cols)
    # return a dict mapping back to the same columns
    return {prepend+col: val for col, val in zip(quantile_cols, sorted_vals)}


def apply_sort_quantiles(predictions, qs=['0.1','0.2','0.3','0.4','0.5','0.6','0.7','0.8','0.9']):
    pred = predictions.with_columns(
        pl.struct(qs).map_elements(sort_quantiles).alias("sorted_struct")
    ).unnest("sorted_struct")

    sqs = ['s'+q for q in qs]
    pred = pred.with_columns(
        pl.struct(pred.columns)
            .map_elements(lambda row: check_monotonicity(row,sqs),return_dtype=bool).alias('mono')
    )
    return pred


# Check if the sum of the orignal quantiles matches with the sorted
def check_sum_quantiles(validation, qs=['0.1','0.2','0.3','0.4','0.5','0.6','0.7','0.8','0.9']):
    print(len(validation.with_columns(
        pl.sum_horizontal([pl.col(q) for q in qs]).alias('row_sum'),
        pl.sum_horizontal([pl.col('s'+q) for q in qs]).alias('row_sum2'),
    ).filter(
        np.abs(pl.col('row_sum') - pl.col('row_sum2')) > 0.1
    )) == 0)


    # Check if all are monotonic
    return validation.filter(~pl.col('mono'))



In [4]:
def post_process_validation(validation):
    validation = apply_sort_quantiles(validation)
    validation = validation.drop(qs).rename({f's{q}': q for q in qs})

    # compute quantiles for the desired quantiles
    for q in np.hstack(quantiles):
        validation = validation.with_columns(
            pl.struct(validation.columns).map_elements(lambda row: estimate_quantile(row, target_quantile=q),return_dtype=float).alias(str(q))
        )

    # set negative values to 0
    qcols = validation.columns[4:]
    validation=validation.with_columns([
        pl.when(pl.col(q) < 0).then(0).otherwise(pl.col(q)).alias(q)
        for q in qcols
    ])


    # rename columns to match the submission format
    # lower_95	lower_90	lower_80	lower_50	pred	upper_50	upper_80	upper_90	upper_95	date
    # [2.5, 5, 10, 25, 50, 75, 90, 95, 97.5]
    columns_submission = ['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95','date']
    validation = validation.rename(
        {
            '0.025':'lower_95',
            '0.05' :'lower_90',
            '0.1'  :'lower_80',
            '0.25' :'lower_50',
            '0.5':'pred',
            '0.75' :'upper_50',
            '0.9'  :'upper_80',
            '0.95' :'upper_90',
            '0.975':'upper_95',
        })[columns_submission]
    return validation

In [5]:
# import predictions
validation1 = pl.read_parquet('../data/4_model_output/validation_sprint_1.parquet').with_columns(
    pl.col("date").cast(pl.Date)
)
validation2 = pl.read_parquet('../data/4_model_output/validation_sprint_2.parquet').with_columns(
    pl.col("date").cast(pl.Date)
)
validation3 = pl.read_parquet('../data/4_model_output/validation_sprint_3.parquet').with_columns(
    pl.col("date").cast(pl.Date)
)

In [6]:
submission1 = post_process_validation(validation1)
submission2 = post_process_validation(validation2)
submission3 = post_process_validation(validation3)

  pred = predictions.with_columns(
  pred = predictions.with_columns(
  pred = predictions.with_columns(


In [28]:
def check_monotonicity(row, qs):
    vals = [row[q] for q in qs]
    return np.all(np.diff(vals) >= 0)

def apply_check_monotonicity(predictions, qs=['0.1','0.2','0.3','0.4','0.5','0.6','0.7','0.8','0.9']):
    predictions = predictions.with_columns(
        pl.struct(predictions.columns)
            .map_elements(lambda row: check_monotonicity(row,qs),return_dtype=bool).alias('mono')
    )
    return predictions.filter(~pl.col('mono')).shape[0] == 0


def apply_check_nonnegative(submission, qs=['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95']):
    return [submission.filter(pl.col(q) < 0).shape[0] == 0 for q in qs]

print(apply_check_monotonicity(submission1, qs=['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95']))
print(apply_check_monotonicity(submission2, qs=['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95']))
print(apply_check_monotonicity(submission3, qs=['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95']))


print(apply_check_nonnegative(submission1, qs=['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95']))
print(apply_check_nonnegative(submission2, qs=['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95']))
print(apply_check_nonnegative(submission3, qs=['lower_95','lower_90','lower_80','lower_50','pred','upper_50','upper_80','upper_90','upper_95']))

True
True
True
[True, True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True, True]
[True, True, True, True, True, True, True, True, True]
