# Inference

In [20]:
from pathlib import Path
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "xgboost_inference.ipynb"  # Manually set the notebook name

import polars as pl
import xgboost as xgb
import wandb
from tqdm.notebook import tqdm
import pickle

import utils

In [21]:
DEBUG = True

In [22]:
# Choose training run from which to load the model, etc.
train_run_name = 'prime-wind-87'
run_path = f'esedx12/traffic-forecasting-challenge/{train_run_name}'
api = wandb.Api()
train_run = api.runs(
    path="esedx12/traffic-forecasting-challenge",
    filters={"display_name": {"$eq": train_run_name}} 
)[0]
train_config = train_run.config

In [23]:
inference_config = {
    # 'prediction_start': train_config['train_shape'][0] + 1,
    'prediction_length': 1900 - train_config['num_train_rows'] if not DEBUG else 5,
}

In [24]:
run = wandb.init(project="traffic-forecasting-challenge", tags=[train_run_name], job_type='inference',
                 entity="esedx12", config=inference_config, save_code=True, mode=('dryrun' if DEBUG else 'online'))

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [25]:
checkpoints_dir = 'checkpoints'
xgboost_models_dir = Path(checkpoints_dir) / train_run_name

models = {}
for file_name in os.listdir(xgboost_models_dir):
    if file_name.endswith('.ubj'):
        target_name = file_name[:-4]
        model_path = xgboost_models_dir / file_name
        models[target_name] = pickle.load(open(model_path, 'rb'))

## Load and prepare data

In [26]:
# %%
# Read the CSV files
data_dir = Path('input-data')
target_dataframes = {
    'thp_vol': pl.read_csv(data_dir / 'traffic_DLThpVol.csv'),  # This is the target variable
    'prb': pl.read_csv(data_dir / 'traffic_DLPRB.csv'),
    'thp_time': pl.read_csv(data_dir / 'traffic_DLThpTime.csv'),
    'mr_number': pl.read_csv(data_dir / 'traffic_MR_number.csv')
}

# Filter target dataframes based on train_config
target_dataframes = {k: v for k, v in target_dataframes.items() if k in train_config['target_df_names']}

idx_hour_series = target_dataframes['thp_vol']['']

# Drop the first column (idx hour) from each dataframe
for k in target_dataframes:
    target_dataframes[k] = target_dataframes[k].rename({'': 'idx_hour'})

# A long format beam_id column to be used for converting to wide format
beam_id_col = utils.convert_to_long_format({'beam_id': pl.DataFrame({beam_id: [beam_id] * len(target_dataframes['thp_vol']) for beam_id in target_dataframes['thp_vol'].columns})})

In [51]:
num_rows = len(target_dataframes['thp_vol'])
num_train_rows = round(num_rows * train_config['train_percentage'])
# num_val_rows = round(num_rows * train_config['val_percentage'])

# Split data into train and test
input_dataframes = {k: v.drop('idx_hour').head(num_train_rows) for k, v in target_dataframes.items()}
input_idx_hour_series = idx_hour_series.head(num_train_rows)

comparison_dataframes = {k: v.slice(num_train_rows, inference_config['prediction_length']) for k, v in target_dataframes.items()}
# TODO add different df sets form idx of validation and holdout test

## Multi-Step Inference

In [52]:
def predict_one_step(target_dataframes: dict[pl.DataFrame], idx_hour_series: pl.Series ,models: xgb.Booster, train_config: wandb.Config) -> dict[pl.DataFrame]:
    """
    Predict one step into the future using a trained model.
    Takes DataFrames of len n, returns DataFrames of len n + 1.
    """
    template_df = target_dataframes['thp_vol']
    predict_hour = idx_hour_series[-1] + 1

    null_row = pl.DataFrame({beam_id: [None] for beam_id in template_df.columns})
    target_dataframes = {k: pl.concat([v, null_row], how='vertical_relaxed') for k, v in target_dataframes.items()}

    target_names = list(target_dataframes.keys())
    feature_dfs = utils.create_all_feature_dfs(target_dataframes, idx_hour_series, train_config)
    feature_dfs = {k: v.tail(1) for k, v in feature_dfs.items()}  # maybe turn in to lazyframe for efficiency?
    X_predict = utils.convert_to_long_format(feature_dfs)

    # We predict only the idx immediately folling the last idx in the input, ie a single row
    ys_predicted_long = pl.DataFrame()
    for target_name, model in models.items():
        y_predicted = model.predict(X_predict.to_numpy())
        ys_predicted_long = pl.concat([ys_predicted_long, pl.DataFrame({target_name: y_predicted})], how='horizontal')

    # We need these long-format columns to convert the predictions to wide format
    util_dfs = {}
    util_dfs['beam_id'] = pl.DataFrame({beam_id: [beam_id] for beam_id in template_df.columns})
    util_dfs['idx_hour'] = pl.DataFrame({beam_id: [predict_hour] for beam_id in template_df.columns})
    util_long_df = utils.convert_to_long_format(util_dfs)
    ys_predicted_long = pl.concat([util_long_df, ys_predicted_long], how='horizontal')

    y_predicted_wide = utils.convert_to_wide_format(ys_predicted_long, output_df_names=target_names)    

    return (
        {target_name: pl.concat([target_dataframes[target_name].head(-1), y_predicted_wide[target_name]], how='vertical_relaxed') for target_name in target_names},
        idx_hour_series.append(pl.Series([predict_hour]))
        )

In [53]:
def predict_multi_step(target_dataframes: dict[pl.DataFrame], idx_hour_series: pl.Series, models: xgb.Booster, train_config: wandb.Config, num_steps: int, max_lag=None) -> dict[pl.DataFrame]:
    """
    Predict multiple steps into the future using a trained model.
    Takes DataFrames of len n, returns DataFrames of len n + num_steps.
    
    Args:
        target_dataframes (dict): A dictionary of DataFrames representing the target data.
        idx_hour_series (Series): Index hours CORRESPONDING to target_dataframes.

    Returns:
        dict: A dictionary of DataFrames representing the predicted target dataframes.
    """
    if max_lag:
        target_dataframes = {k: v.tail(max_lag + 5) for k, v in target_dataframes.items()}
        idx_hour_series = idx_hour_series.tail(max_lag + 5)

    for _ in tqdm(range(num_steps), desc='Predicting steps...'):
        target_dataframes, idx_hour_series = predict_one_step(target_dataframes, idx_hour_series, models, train_config)

    return {k: pl.concat([pl.DataFrame({'idx_hour': idx_hour_series}), v], how='horizontal') for k, v in target_dataframes.items()}

In [54]:
ys_pred = predict_multi_step(input_dataframes, input_idx_hour_series, models, train_config=train_config, num_steps=inference_config['prediction_length'])

Predicting steps...:   0%|          | 0/10 [00:00<?, ?it/s]

In [55]:
def mean_absolute_error(Y_true: pl.DataFrame, Y_pred: pl.DataFrame) -> float:
    """
    Compute the mean absolute error between two DataFrames.
    """
    # TODO some kind of check here even though idx_hour is no longer normally part of dfs
    assert (Y_true['idx_hour'] == Y_pred['idx_hour']).all(), "DataFrames must be aligned"
    # assert Y_true.shape == Y_pred.shape, "DataFrames must have the same shape"

    return (Y_true - Y_pred).select(pl.all().abs().mean()).mean_horizontal()[0]

In [56]:
comparison_dataframes['thp_vol'] 
ys_pred['thp_vol'].slice(503, 515).head(15)

idx_hour,0_0_0,0_0_1,0_0_2,0_0_3,0_0_4,0_0_5,0_0_6,0_0_7,0_0_8,0_0_9,0_0_10,0_0_11,0_0_12,0_0_13,0_0_14,0_0_15,0_0_16,0_0_17,0_0_18,0_0_19,0_0_20,0_0_21,0_0_22,0_0_23,0_0_24,0_0_25,0_0_26,0_0_27,0_0_28,0_0_29,0_0_30,0_0_31,0_1_0,0_1_1,0_1_2,0_1_3,…,29_1_27,29_1_28,29_1_29,29_1_30,29_1_31,29_2_0,29_2_1,29_2_2,29_2_3,29_2_4,29_2_5,29_2_6,29_2_7,29_2_8,29_2_9,29_2_10,29_2_11,29_2_12,29_2_13,29_2_14,29_2_15,29_2_16,29_2_17,29_2_18,29_2_19,29_2_20,29_2_21,29_2_22,29_2_23,29_2_24,29_2_25,29_2_26,29_2_27,29_2_28,29_2_29,29_2_30,29_2_31
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
503,0.342484,0.17033,0.21847,0.0,0.321722,0.166766,0.0,0.0,1.003408,0.0,0.204035,0.981364,0.132587,0.0,0.0,0.162998,0.066661,0.078624,0.55093,0.300764,0.101321,0.07419,0.195417,0.12561,0.0,0.133651,0.186158,0.000732,0.236092,0.0,0.19081,0.0,0.0,0.0,0.002198,0.0,…,0.203027,1.473573,0.689921,0.127269,0.0,0.0,0.074132,0.0,0.003583,0.0,0.191806,0.144702,0.078267,0.0,0.146916,0.0,0.0,0.0,1.344579,0.269339,0.0875,0.018413,0.095665,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.128113,0.0,0.081418,0.0,0.156009,0.0,0.00809
504,0.431801,0.144597,0.126805,0.09827,0.477011,0.445733,0.112457,0.102793,0.667032,0.160219,0.202257,0.398203,0.353289,0.22105,0.150519,0.144597,0.111847,0.276468,0.60596,0.220566,0.17084,0.155375,0.154788,0.225817,0.156116,0.130746,0.202349,0.195328,0.122004,0.174934,0.104827,0.098666,0.101186,0.230874,0.164885,0.1156,…,0.211631,0.958214,0.534678,0.236536,0.104158,0.117715,0.150519,0.136829,0.097634,0.097634,0.163796,0.262018,0.21509,0.097634,0.313692,0.162413,0.130746,0.134151,0.504932,0.144472,0.142101,0.097634,0.097634,0.1156,0.228833,0.097634,0.099904,0.111964,0.097634,0.207696,0.105032,0.101005,0.097634,0.097634,0.106472,0.1379,0.106472
505,0.285405,0.117619,0.106743,0.09675,0.385073,0.268494,0.100035,0.098191,0.527826,0.131386,0.157878,0.199236,0.253426,0.156706,0.120937,0.117619,0.098195,0.216932,0.527099,0.180957,0.142642,0.1215,0.1215,0.16468,0.102616,0.112139,0.136031,0.119249,0.100668,0.124636,0.097547,0.09675,0.09675,0.205148,0.123926,0.102616,…,0.192335,0.635448,0.367571,0.158032,0.097989,0.101572,0.122175,0.114206,0.09675,0.09675,0.121343,0.184391,0.131029,0.09675,0.215374,0.131265,0.110901,0.110901,0.283154,0.118612,0.110901,0.09675,0.101186,0.102616,0.143473,0.09675,0.105911,0.100019,0.09675,0.127461,0.097547,0.10081,0.097782,0.09675,0.098987,0.114878,0.098987
506,0.255296,0.117868,0.106107,0.09675,0.418617,0.271414,0.100035,0.09675,0.392848,0.127516,0.128877,0.152732,0.226745,0.144703,0.120937,0.113648,0.09675,0.138903,0.753619,0.151232,0.13054,0.1215,0.1215,0.124786,0.105354,0.108168,0.123468,0.111474,0.105354,0.153874,0.097547,0.09675,0.09675,0.200592,0.126393,0.11266,…,0.179396,0.481833,0.26873,0.114085,0.09675,0.098987,0.122175,0.111474,0.09675,0.099489,0.108168,0.133287,0.116251,0.097989,0.241005,0.126676,0.10693,0.108168,0.183876,0.110235,0.10693,0.09675,0.09675,0.106032,0.159375,0.09675,0.103647,0.100226,0.09675,0.104977,0.100131,0.09675,0.099489,0.097989,0.098987,0.108168,0.101642
507,0.180511,0.114887,0.107982,0.09675,0.455054,0.525483,0.100684,0.100603,0.295442,0.12535,0.172612,0.127668,0.219926,0.141865,0.122175,0.114887,0.098625,0.182388,0.633682,0.175223,0.122175,0.122739,0.126475,0.193516,0.103854,0.108168,0.123337,0.111474,0.103854,0.15166,0.098183,0.099489,0.09675,0.18907,0.11929,0.137654,…,0.240767,0.398367,0.337502,0.114085,0.09675,0.101211,0.124409,0.111474,0.09675,0.100188,0.108168,0.130601,0.114085,0.099489,0.282774,0.122739,0.108168,0.114494,0.163784,0.111474,0.108168,0.09675,0.09675,0.106032,0.130792,0.09675,0.097989,0.100226,0.09675,0.110153,0.116398,0.097989,0.09675,0.09675,0.098987,0.110907,0.100925
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
509,0.137196,0.114887,0.10883,0.09675,0.826051,0.441458,0.101274,0.097387,0.265679,0.132303,0.221306,0.121503,0.220666,0.137953,0.122175,0.123048,0.098195,0.137433,0.482534,0.196176,0.122739,0.128654,0.136936,0.24584,0.103854,0.108168,0.119987,0.111474,0.1017,0.134533,0.097547,0.09675,0.09675,0.189227,0.142305,0.13661,…,0.130037,0.436334,0.186077,0.122985,0.09675,0.100226,0.122175,0.111474,0.09675,0.09675,0.108168,0.131388,0.12544,0.09675,0.189991,0.122739,0.108168,0.110907,0.124786,0.111474,0.108168,0.09675,0.09675,0.106032,0.130131,0.09675,0.097989,0.100226,0.097449,0.106215,0.10137,0.09675,0.09675,0.09675,0.101726,0.108168,0.104661
510,0.133431,0.122843,0.108347,0.09675,0.834369,0.286684,0.10594,0.097782,0.271414,0.189243,0.216313,0.125026,0.370032,0.204663,0.173593,0.187552,0.098195,0.279057,0.361392,0.150618,0.13054,0.159512,0.253944,0.290509,0.103854,0.121574,0.162499,0.111474,0.162641,0.126688,0.097547,0.09675,0.09675,0.237063,0.143172,0.104942,…,0.14237,0.440314,0.143389,0.111474,0.09675,0.101895,0.158948,0.111474,0.09675,0.09675,0.112106,0.366592,0.125534,0.097782,0.351812,0.149215,0.114205,0.108168,0.172166,0.117562,0.115805,0.09675,0.097387,0.106669,0.135581,0.09675,0.09675,0.100226,0.101186,0.218074,0.102068,0.096957,0.09675,0.09675,0.098987,0.108168,0.108148
511,0.296281,0.166898,0.116655,0.100942,1.025458,0.309343,0.126335,0.100942,0.42463,0.267416,0.187796,0.145368,0.38005,0.334134,0.595489,0.247853,0.110795,0.462407,0.64951,0.37155,0.724168,0.709376,0.502902,0.286322,0.112268,0.125461,0.358006,0.141934,0.119472,0.141989,0.103241,0.100942,0.100942,0.183699,0.134972,0.107639,…,0.166531,0.383427,0.245553,0.136896,0.100942,0.109226,0.341871,0.15335,0.100942,0.100942,0.126153,0.28516,0.161022,0.102181,0.706467,0.389007,0.131936,0.15141,0.188654,0.159685,0.141715,0.100942,0.100942,0.125662,0.173479,0.100942,0.102651,0.107212,0.100942,0.125491,0.11314,0.103868,0.100942,0.109737,0.104681,0.125461,0.112167
512,0.38991,0.561767,0.138311,0.103117,0.834946,0.366681,0.137325,0.101826,0.548946,1.451273,0.398294,0.182332,0.357878,0.353254,0.436824,0.263285,0.11893,0.772944,0.889756,0.356399,0.543026,0.651263,0.593158,0.450362,0.126889,0.158606,0.215681,0.157473,0.119096,0.16225,0.110841,0.102462,0.106685,0.190096,0.147521,0.114888,…,0.266183,0.407418,0.262767,0.158765,0.101826,0.116388,0.322659,0.173077,0.101826,0.106663,0.162814,0.346573,0.180093,0.122348,0.878663,0.411464,0.161097,0.196325,0.209403,0.174103,0.170256,0.103117,0.101826,0.139831,0.243756,0.101826,0.107938,0.112688,0.101826,0.142111,0.113379,0.106435,0.101826,0.101859,0.10873,0.158606,0.109968


In [58]:
mean_absolute_error(comparison_dataframes['thp_vol'], ys_pred['thp_vol'].tail(inference_config['prediction_length']))

0.2050350198807903

## ...on Validation Set

## ...on Test Set

## ...on Validation and Test Sets

## Create Submission CSV

* Hours in 5 weeks: 840
* Hours in 6 weeks: 1008
* We need period 841-1008 (841:1009 with Python list indexing)

* Hours in 10 weeks: 1680
* Hours in 11 weeks: 1848

In [103]:
def create_half_submission_df(input_df: pl.DataFrame, weeks: str) -> pl.DataFrame:
    """
    Create a submission CSV file from a Polars DataFrame of thp_vol.
    """
    if weeks == '5w-6w':
        range = [841, 1008]
    elif weeks == '10w-11w':
        range = [1681, 1848]

    # Choose rows with first column 'idx_hour' having the values 671-840.
    input_df = input_df.filter(pl.col('idx_hour') >= range[0], pl.col('idx_hour') <= range[1])

    # Some checks on the input_df
    assert input_df.shape == (168, 2881), f"Expected shape (168, 2881), got {input_df.shape}"
    assert input_df.select(pl.any_horizontal(pl.all().is_null().any())).item() == False, "Submission dataframe contains null values"
    assert input_df['idx_hour'].head(1)[0] <= range[0] and input_df['idx_hour'].tail(1)[0] >= range[1], "Submission dataframe does seemingly not contain the correct idx_hour values"

    # Stack the dataframe with f'traffic_DLThpVol_test_5w-6w_{hour}_{beam_id}' as index
    # where it cycles through the values 671-840 for hour and then the beam_ids, which are colnames of input_df
    # return input_df.unpivot(index='idx_hour')
    return input_df.unpivot(index='idx_hour', variable_name='beam_id').with_columns(
        pl.concat_str([pl.lit('traffic_DLThpVol_test'), pl.lit(weeks), pl.col('idx_hour') - range[0], pl.col('beam_id')], separator='_').alias('ID')
    ).select(['ID', 'value']).rename({'value': 'Target'})


def create_submission_csv(input_df: pl.DataFrame, output_filename='traffic_forecast.csv', archiving_dir='submission-csvs-archive') -> pl.DataFrame:
    """
    Create a submission CSV file from data in input format that's been extended to cover weeks 5-6 and 10-11.
    """

    # Create half submission dataframes
    half_submission_5w_6w = create_half_submission_df(input_df, '5w-6w')
    half_submission_10w_11w = create_half_submission_df(input_df, '10w-11w')

    # Concatenate the two half submission dataframes
    submission_df = pl.concat([half_submission_5w_6w, half_submission_10w_11w], how='vertical')

    # Save the submission dataframe to a CSV file for submission, and to wandb
    submission_df.write_csv(output_filename)
    wandb.save(output_filename)

    # Save the submission dataframe to a CSV file for archiving
    if archiving_dir:
        archiving_dir = Path(archiving_dir)
        archiving_dir.mkdir(parents=True, exist_ok=True)
        submission_df.write_csv(archiving_dir / f'{wandb.run.name}_{output_filename}')

    return submission_df

In [70]:
if inference_config['create_submission_csv']:
    submission_df = create_submission_csv(ys_pred['thp_vol'])

In [104]:
# debug_submission_df_5w_6w = pl.DataFrame(
#     {'idx_hour': pl.Series(range(1, 1901))} | {id: pl.Series(range(1, 1901)) for id in ys_pred['thp_vol'].columns})
# debug_filtered = debug_submission_df_5w_6w.filter(pl.col('idx_hour') >= 841, pl.col('idx_hour') <= 1848)
# df = create_half_submission_df(debug_filtered, '5w-6w')
# # df = create_submission_csv(ys_pred['thp_vol'], 'traffic_forecast.csv', 'submission-csvs-archive')
# create_submission_csv(debug_filtered)

ID,Target
str,i64
"""traffic_DLThpVol_test_5w-6w_0_…",841
"""traffic_DLThpVol_test_5w-6w_1_…",842
"""traffic_DLThpVol_test_5w-6w_2_…",843
"""traffic_DLThpVol_test_5w-6w_3_…",844
"""traffic_DLThpVol_test_5w-6w_4_…",845
…,…
"""traffic_DLThpVol_test_10w-11w_…",1844
"""traffic_DLThpVol_test_10w-11w_…",1845
"""traffic_DLThpVol_test_10w-11w_…",1846
"""traffic_DLThpVol_test_10w-11w_…",1847
