In [None]:
from pathlib import Path
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "xgboost_inference.ipynb"  # Manually set the notebook name

import polars as pl
import xgboost as xgb


import wandb
from tqdm.notebook import tqdm

In [None]:
DEBUG = True

In [None]:
# Read the CSV files
data_dir = Path('input-data')
thp_vol = pl.read_csv(data_dir / 'traffic_DLThpVol.csv')  # This is the target variable
prb = pl.read_csv(data_dir / 'traffic_DLPRB.csv')
thp_time = pl.read_csv(data_dir / 'traffic_DLThpTime.csv')
mr_number = pl.read_csv(data_dir / 'traffic_MR_number.csv')

target_dataframes = {
    'thp_vol': thp_vol,
    'prb': prb,
    'thp_time': thp_time,
    'mr_number': mr_number
}

# Rename first col to 'hour'
for k, v in target_dataframes.items():
    target_dataframes[k] = v.rename({'': "idx_hour"})

## Multi-Step Inference

In [None]:
def predict_one_step(target_dataframes: dict[pl.DataFrame], model: xgb.Booster, config: wandb.Config) -> dict[pl.DataFrame]:
    """
    Predict one step into the future using a trained model.
    Takes DataFrames of len n, returns DataFrames of len n + 1.
    """
    template_df = next(iter(target_dataframes.values()))
    predict_hour = template_df['idx_hour'][-1] + 1

    null_row = pl.DataFrame({col: [None] if not col == 'idx_hour' else predict_hour for col in template_df.columns})
    target_dataframes = {k: pl.concat([v, null_row], how='vertical_relaxed') for k, v in target_dataframes.items()}

    target_names = list(target_dataframes.keys())
    feature_dfs = create_all_feature_dfs(target_dataframes, config)
    feature_dfs = {k: v.tail(1) for k, v in feature_dfs.items()}  # maybe turn in to lazyframe for efficiency?
    X_predict = convert_to_long_format(feature_dfs)
    idx_hour, beam_id = X_predict['idx_hour'], X_predict['beam_id']
    X_predict = X_predict.drop(dropped_cols)

    # We predict only the idx immediately folling the last idx in the input, ie a single row
    y_predicted_long = model.predict(X_predict)
    y_predicted_long = pl.DataFrame(y_predicted_long, schema=target_names).with_columns([idx_hour, beam_id])
    y_predicted_wide = convert_to_wide_format(y_predicted_long, output_df_names=target_names)

    return {target_name: pl.concat([target_dataframes[target_name].head(-1), y_predicted_wide[target_name]], how='vertical_relaxed') for target_name in target_names}

In [None]:
model.predict(X_test).shape
X_test.shape

(38352, 54)

In [None]:
test_dataframes['thp_vol']

idx_hour,0_0_0,0_0_1,0_0_2,0_0_3,0_0_4,0_0_5,0_0_6,0_0_7,0_0_8,0_0_9,0_0_10,0_0_11,0_0_12,0_0_13,0_0_14,0_0_15,0_0_16,0_0_17,0_0_18,0_0_19,0_0_20,0_0_21,0_0_22,0_0_23,0_0_24,0_0_25,0_0_26,0_0_27,0_0_28,0_0_29,0_0_30,0_0_31,0_1_0,0_1_1,0_1_2,0_1_3,…,7_2_26,7_2_27,7_2_28,7_2_29,7_2_30,7_2_31,8_0_0,8_0_1,8_0_2,8_0_3,8_0_4,8_0_5,8_0_6,8_0_7,8_0_8,8_0_9,8_0_10,8_0_11,8_0_12,8_0_13,8_0_14,8_0_15,8_0_16,8_0_17,8_0_18,8_0_19,8_0_20,8_0_21,8_0_22,8_0_23,8_0_24,8_0_25,8_0_26,8_0_27,8_0_28,8_0_29,8_0_30
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
280,1.394905,0.429558,0.094461,0.0,1.226118,0.374851,0.555201,0.066917,2.082938,2.359717,0.183934,0.185242,0.0,0.711513,0.882289,0.447663,0.033392,0.423404,2.290486,0.061054,0.791223,2.932756,5.758433,0.154,0.149582,0.376519,0.0,0.05051,0.023011,0.041854,1.381302,0.051797,0.058078,0.0,0.0,0.314952,…,0.354536,0.0,0.266968,0.78298,0.202991,0.149535,0.0,0.116287,0.079129,0.544606,1.011335,0.508489,0.104174,0.218241,0.040294,0.0,0.0,0.140568,1.183508,0.967467,0.190622,0.0,0.217498,0.123826,0.271984,0.54442,4.646362,6.901463,0.502357,0.138504,0.0,0.0,0.119672,0.222028,0.0,0.279824,0.601916
281,1.361382,0.164671,0.0,0.157446,1.818427,0.890517,0.207697,0.0,0.414001,1.562626,0.513849,0.164002,1.50566,0.268562,0.544356,0.157376,0.344504,0.618929,3.570099,0.375449,0.77161,1.762884,1.803496,0.163807,0.007883,0.0,0.064631,0.02056,0.113126,0.0,0.096561,0.0,0.0,0.262083,0.017426,0.092695,…,0.205197,0.097805,0.106249,0.899108,0.0,0.231301,0.204437,0.0,0.214736,0.296146,0.604106,0.003663,0.635056,0.203999,0.0,0.287216,0.220086,0.0,0.746307,0.327765,0.0,0.370205,0.0,0.0,0.055574,0.591691,6.484226,7.718435,0.858236,0.0,0.0,0.107582,0.021942,0.187172,0.122306,0.955285,0.905826
282,0.690832,0.0,0.0,0.060399,4.306769,0.930586,0.0,0.0,0.804336,0.341009,0.030314,0.0,1.035338,0.574537,0.176991,0.153407,0.0,0.519305,0.757881,0.201867,0.108562,0.0,0.455723,0.811248,0.139837,0.114758,0.249496,0.0,0.297342,0.167366,0.154341,0.0,0.0,0.286578,0.0,0.0,…,0.019013,5.218378,1.22582,0.551091,0.209292,0.002189,0.012969,0.0,0.030771,0.285711,0.456578,0.463025,0.0,0.0,0.0,0.173658,0.0,0.0,0.475697,1.644528,0.426923,0.085647,0.035213,0.181422,0.740102,2.088093,7.475278,5.392865,1.588921,0.213785,0.381893,0.241738,0.059294,0.066321,0.0,1.270753,0.0
283,0.241227,0.0,0.0,0.133241,5.183051,0.795806,0.09242,0.094404,1.618754,0.383314,0.145809,0.435745,1.53758,0.992989,0.243869,0.185276,0.131322,0.555589,1.302331,0.563051,0.238022,0.25776,0.425951,0.0,0.007544,0.0,0.290367,0.0,0.241948,0.331998,0.0,0.0,0.067894,0.137073,0.193574,0.0,…,0.054799,0.336743,0.457812,0.36802,0.505496,0.105446,0.215994,0.0,0.107215,0.0,0.445634,0.216285,0.036876,0.0,0.0,0.0,0.0,0.054128,0.340284,1.491914,0.56239,0.106991,0.0,0.0,0.220424,1.927696,3.698893,5.228314,2.10877,0.162849,0.843928,0.074823,0.0,0.0,0.049861,0.482616,0.0
284,0.116194,0.0,0.0,0.0,5.620047,1.554181,0.0,0.018737,2.086414,0.0,0.20496,0.06057,2.114513,0.799929,0.046056,0.271809,0.087906,1.025342,3.105536,0.143985,0.082518,0.333745,0.237563,0.133263,0.033766,0.0,0.055464,0.236226,0.0,0.0,0.0,0.0,0.153475,0.603833,0.0,0.376693,…,0.029933,0.0,0.105368,0.050666,0.18248,0.0,0.038179,0.013657,0.0,0.136815,0.405815,0.080497,0.288377,0.189874,0.0,0.216828,0.0,0.090629,0.609325,0.794011,0.60714,0.066821,0.18579,0.163424,0.0,2.436909,3.36954,5.008002,1.325702,0.322979,0.038541,0.185916,0.188993,0.005235,0.0,1.104151,0.165653
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
395,0.866853,0.060494,23.840542,0.0,2.71631,0.696229,0.169646,0.0,0.197882,1.606409,0.824158,0.206217,0.811744,0.439725,1.190046,1.169239,0.010567,12.296133,28.877837,1.066428,0.797706,1.408293,10.95912,1.4947,0.041537,0.295054,76.899842,0.0,0.635374,0.308932,0.154891,0.238477,0.057884,0.693554,0.0,0.91502,…,0.36705,0.253666,0.404308,2.746635,0.243393,0.0,0.093298,0.0,0.0,0.470833,0.612456,0.118999,0.207755,0.054382,0.0,0.127565,0.010687,0.0,0.338718,1.223425,1.100356,0.513492,0.0,0.050423,1.013315,4.617613,1.913166,10.275319,3.561647,0.362016,0.0,0.0,0.0,0.308471,0.089815,3.03349,0.310813
396,0.180718,0.0,55.870284,0.0,0.280475,0.265168,1.026338,0.120535,1.432047,0.0,0.90414,0.75382,0.0,0.0,0.70288,0.482404,0.0,6.603545,8.085353,1.632242,0.430794,0.829319,3.620469,0.734741,0.0,0.0,9.873692,0.0,0.680461,0.528039,0.0,0.226355,0.0,0.07923,0.0,1.003442,…,0.01573,0.686935,1.33025,0.820575,0.13542,0.0,0.0,0.0,0.053656,0.167889,0.436016,0.682557,0.291408,0.380079,0.0,0.211704,0.0,0.0,0.630729,0.719716,0.676325,0.05256,0.0,0.042565,0.914567,3.111639,2.827034,7.198893,1.928745,0.0,0.180342,0.0,0.165236,0.078116,0.305266,0.924199,0.980539
397,0.624104,0.101212,0.0,0.0,0.432237,0.285216,0.417385,0.0,0.793596,1.069535,0.177304,0.177004,0.192179,1.031682,0.437941,1.726503,0.10391,1.009184,2.438767,0.15185,0.777208,1.432838,3.528387,0.824026,0.0,0.164852,5.43519,0.0,0.0,0.195098,0.148888,0.0,0.195597,0.052875,0.134794,0.454587,…,0.192635,1.243714,0.15779,0.022784,0.0,0.0,0.154473,0.136831,0.0,0.367299,0.288833,0.047654,0.0,0.044122,0.206223,0.01134,0.023025,0.0,0.994231,1.405295,0.399594,0.050588,0.128999,0.0,1.249393,2.512841,2.848478,7.933981,2.136853,0.204602,0.236151,0.076319,0.0,0.577113,0.462353,0.945524,0.249167
398,0.279431,0.0,0.234371,0.0,0.099509,0.960125,0.0,0.014794,0.462177,1.654164,0.213959,0.180923,0.346395,4.818276,0.944466,0.819463,0.014055,1.459292,5.024005,0.1152,0.632807,0.76857,0.984666,0.970977,0.0,0.058807,0.38958,0.09802,0.0,0.280848,0.209541,0.0,0.0,0.0,0.0,0.414549,…,0.0,0.261838,0.17454,0.073404,0.0,0.0,0.0,0.077856,0.0,0.125863,0.059378,0.864292,0.153123,0.089032,0.204714,0.0,0.041575,0.199477,0.128244,1.020702,0.499593,0.100492,0.20931,0.127876,0.111976,1.873898,0.835592,6.956599,0.600975,0.135257,0.0,0.045188,0.049552,0.062312,0.016458,0.425627,0.851423


In [None]:
y = predict_one_step(test_dataframes, model, wandb.config)


The argument `columns` for `DataFrame.pivot` is deprecated. It has been renamed to `on`.



In [None]:
y['thp_vol']

idx_hour,0_0_0,0_0_1,0_0_2,0_0_3,0_0_4,0_0_5,0_0_6,0_0_7,0_0_8,0_0_9,0_0_10,0_0_11,0_0_12,0_0_13,0_0_14,0_0_15,0_0_16,0_0_17,0_0_18,0_0_19,0_0_20,0_0_21,0_0_22,0_0_23,0_0_24,0_0_25,0_0_26,0_0_27,0_0_28,0_0_29,0_0_30,0_0_31,0_1_0,0_1_1,0_1_2,0_1_3,…,7_2_26,7_2_27,7_2_28,7_2_29,7_2_30,7_2_31,8_0_0,8_0_1,8_0_2,8_0_3,8_0_4,8_0_5,8_0_6,8_0_7,8_0_8,8_0_9,8_0_10,8_0_11,8_0_12,8_0_13,8_0_14,8_0_15,8_0_16,8_0_17,8_0_18,8_0_19,8_0_20,8_0_21,8_0_22,8_0_23,8_0_24,8_0_25,8_0_26,8_0_27,8_0_28,8_0_29,8_0_30
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
280,1.394905,0.429558,0.094461,0.0,1.226118,0.374851,0.555201,0.066917,2.082938,2.359717,0.183934,0.185242,0.0,0.711513,0.882289,0.447663,0.033392,0.423404,2.290486,0.061054,0.791223,2.932756,5.758433,0.154,0.149582,0.376519,0.0,0.05051,0.023011,0.041854,1.381302,0.051797,0.058078,0.0,0.0,0.314952,…,0.354536,0.0,0.266968,0.78298,0.202991,0.149535,0.0,0.116287,0.079129,0.544606,1.011335,0.508489,0.104174,0.218241,0.040294,0.0,0.0,0.140568,1.183508,0.967467,0.190622,0.0,0.217498,0.123826,0.271984,0.54442,4.646362,6.901463,0.502357,0.138504,0.0,0.0,0.119672,0.222028,0.0,0.279824,0.601916
281,1.361382,0.164671,0.0,0.157446,1.818427,0.890517,0.207697,0.0,0.414001,1.562626,0.513849,0.164002,1.50566,0.268562,0.544356,0.157376,0.344504,0.618929,3.570099,0.375449,0.77161,1.762884,1.803496,0.163807,0.007883,0.0,0.064631,0.02056,0.113126,0.0,0.096561,0.0,0.0,0.262083,0.017426,0.092695,…,0.205197,0.097805,0.106249,0.899108,0.0,0.231301,0.204437,0.0,0.214736,0.296146,0.604106,0.003663,0.635056,0.203999,0.0,0.287216,0.220086,0.0,0.746307,0.327765,0.0,0.370205,0.0,0.0,0.055574,0.591691,6.484226,7.718435,0.858236,0.0,0.0,0.107582,0.021942,0.187172,0.122306,0.955285,0.905826
282,0.690832,0.0,0.0,0.060399,4.306769,0.930586,0.0,0.0,0.804336,0.341009,0.030314,0.0,1.035338,0.574537,0.176991,0.153407,0.0,0.519305,0.757881,0.201867,0.108562,0.0,0.455723,0.811248,0.139837,0.114758,0.249496,0.0,0.297342,0.167366,0.154341,0.0,0.0,0.286578,0.0,0.0,…,0.019013,5.218378,1.22582,0.551091,0.209292,0.002189,0.012969,0.0,0.030771,0.285711,0.456578,0.463025,0.0,0.0,0.0,0.173658,0.0,0.0,0.475697,1.644528,0.426923,0.085647,0.035213,0.181422,0.740102,2.088093,7.475278,5.392865,1.588921,0.213785,0.381893,0.241738,0.059294,0.066321,0.0,1.270753,0.0
283,0.241227,0.0,0.0,0.133241,5.183051,0.795806,0.09242,0.094404,1.618754,0.383314,0.145809,0.435745,1.53758,0.992989,0.243869,0.185276,0.131322,0.555589,1.302331,0.563051,0.238022,0.25776,0.425951,0.0,0.007544,0.0,0.290367,0.0,0.241948,0.331998,0.0,0.0,0.067894,0.137073,0.193574,0.0,…,0.054799,0.336743,0.457812,0.36802,0.505496,0.105446,0.215994,0.0,0.107215,0.0,0.445634,0.216285,0.036876,0.0,0.0,0.0,0.0,0.054128,0.340284,1.491914,0.56239,0.106991,0.0,0.0,0.220424,1.927696,3.698893,5.228314,2.10877,0.162849,0.843928,0.074823,0.0,0.0,0.049861,0.482616,0.0
284,0.116194,0.0,0.0,0.0,5.620047,1.554181,0.0,0.018737,2.086414,0.0,0.20496,0.06057,2.114513,0.799929,0.046056,0.271809,0.087906,1.025342,3.105536,0.143985,0.082518,0.333745,0.237563,0.133263,0.033766,0.0,0.055464,0.236226,0.0,0.0,0.0,0.0,0.153475,0.603833,0.0,0.376693,…,0.029933,0.0,0.105368,0.050666,0.18248,0.0,0.038179,0.013657,0.0,0.136815,0.405815,0.080497,0.288377,0.189874,0.0,0.216828,0.0,0.090629,0.609325,0.794011,0.60714,0.066821,0.18579,0.163424,0.0,2.436909,3.36954,5.008002,1.325702,0.322979,0.038541,0.185916,0.188993,0.005235,0.0,1.104151,0.165653
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
396,0.180718,0.0,55.870284,0.0,0.280475,0.265168,1.026338,0.120535,1.432047,0.0,0.90414,0.75382,0.0,0.0,0.70288,0.482404,0.0,6.603545,8.085353,1.632242,0.430794,0.829319,3.620469,0.734741,0.0,0.0,9.873692,0.0,0.680461,0.528039,0.0,0.226355,0.0,0.07923,0.0,1.003442,…,0.01573,0.686935,1.33025,0.820575,0.13542,0.0,0.0,0.0,0.053656,0.167889,0.436016,0.682557,0.291408,0.380079,0.0,0.211704,0.0,0.0,0.630729,0.719716,0.676325,0.05256,0.0,0.042565,0.914567,3.111639,2.827034,7.198893,1.928745,0.0,0.180342,0.0,0.165236,0.078116,0.305266,0.924199,0.980539
397,0.624104,0.101212,0.0,0.0,0.432237,0.285216,0.417385,0.0,0.793596,1.069535,0.177304,0.177004,0.192179,1.031682,0.437941,1.726503,0.10391,1.009184,2.438767,0.15185,0.777208,1.432838,3.528387,0.824026,0.0,0.164852,5.43519,0.0,0.0,0.195098,0.148888,0.0,0.195597,0.052875,0.134794,0.454587,…,0.192635,1.243714,0.15779,0.022784,0.0,0.0,0.154473,0.136831,0.0,0.367299,0.288833,0.047654,0.0,0.044122,0.206223,0.01134,0.023025,0.0,0.994231,1.405295,0.399594,0.050588,0.128999,0.0,1.249393,2.512841,2.848478,7.933981,2.136853,0.204602,0.236151,0.076319,0.0,0.577113,0.462353,0.945524,0.249167
398,0.279431,0.0,0.234371,0.0,0.099509,0.960125,0.0,0.014794,0.462177,1.654164,0.213959,0.180923,0.346395,4.818276,0.944466,0.819463,0.014055,1.459292,5.024005,0.1152,0.632807,0.76857,0.984666,0.970977,0.0,0.058807,0.38958,0.09802,0.0,0.280848,0.209541,0.0,0.0,0.0,0.0,0.414549,…,0.0,0.261838,0.17454,0.073404,0.0,0.0,0.0,0.077856,0.0,0.125863,0.059378,0.864292,0.153123,0.089032,0.204714,0.0,0.041575,0.199477,0.128244,1.020702,0.499593,0.100492,0.20931,0.127876,0.111976,1.873898,0.835592,6.956599,0.600975,0.135257,0.0,0.045188,0.049552,0.062312,0.016458,0.425627,0.851423
399,0.423216,0.096952,0.0,0.080912,0.326296,0.0,0.12406,0.096386,1.117871,1.205607,0.212986,0.0,0.585459,1.933617,1.269501,0.559622,0.107714,1.042602,5.44947,0.314388,0.0,0.616449,0.796853,1.108213,0.017554,0.161621,0.0,0.0,0.087075,0.178272,0.0,0.0,0.055621,0.138237,0.452213,0.258883,…,0.448979,0.376689,0.129059,0.355229,0.045102,0.0,0.0,0.0,0.050582,0.042599,0.394194,3.199615,1.231384,0.175372,0.224746,0.0,0.0,0.0,0.272897,0.822952,0.574984,0.088533,0.1806,0.0,0.628068,2.819004,1.012273,6.795228,1.552637,0.020124,0.0,0.0,0.061305,0.116182,0.705867,1.489329,0.681226


In [None]:
long_train_df = create_long_train_df(target_dataframes, wandb.config)
long_test_df = create_long_train_df(test_dataframes, wandb.config)

dropped_cols = ['idx_hour', 'beam_id']
target_cols = list(target_dataframes.keys())

X_train, y_train = long_train_df.drop(dropped_cols + target_cols), long_train_df.select(target_cols)
X_test, y_test = long_test_df.drop(dropped_cols + target_cols), long_test_df.select(target_cols)

wandb.log({'train shape': X_train.shape, 'test shape': X_test.shape, 'train_feats': X_train.columns,})

In [None]:
def predict_multi_step(target_dataframes: dict[pl.DataFrame], model: xgb.Booster, config: wandb.Config, num_steps: int, max_lag=None) -> dict[pl.DataFrame]:
    """
    Predict multiple steps into the future using a trained model.
    Takes DataFrames of len n, returns DataFrames of len n + num_steps.
    """
    if max_lag:
        target_dataframes = {k: v.tail(max_lag + 5) for k, v in target_dataframes.items()}

    for _ in tqdm(range(num_steps)):
        target_dataframes = predict_one_step(target_dataframes, model, config)

    return target_dataframes

In [None]:
y = predict_multi_step({k: v.tail(80) for k, v in train_dataframes.items()}, model, wandb.config, num_steps=len(test_dataframes['thp_vol'])).mean_horizontal()

  0%|          | 0/120 [00:00<?, ?it/s]


The argument `columns` for `DataFrame.pivot` is deprecated. It has been renamed to `on`.



AttributeError: 'dict' object has no attribute 'mean_horizontal'

In [None]:
def mean_absolute_error(Y_true: pl.DataFrame, Y_pred: pl.DataFrame) -> float:
    """
    Compute the mean absolute error between two DataFrames.
    """
    assert (Y_true['idx_hour'] == Y_pred['idx_hour']).all(), "DataFrames must be aligned"

    return (Y_true.drop('idx_hour') - Y_pred.drop('idx_hour')).select(pl.all().abs().mean()).mean_horizontal()[0]

In [None]:
mean_absolute_error(test_dataframes['thp_vol'], y['thp_vol'].tail(len(test_dataframes['thp_vol'])))

0.24911949991279875

## Create Submission CSV

* Hours in 5 weeks: 840
* Hours in 6 weeks: 1008
* We need period 841-1008 (841-1009 with Python list indexing)

* Hours in 10 weeks: 1680
* Hours in 11 weeks: 1848

In [None]:
def create_half_submission_df(input_df: pl.DataFrame, weeks: str) -> pl.DataFrame:
    """
    Create a submission CSV file from a Polars DataFrame.
    """
    if weeks == '5w-6w':
        range = [841, 1008]
    elif weeks == '10w-11w':
        range = [1681, 1848]

    # Choose rows with first column 'idx_hour' having the values 671-840.
    input_df = input_df.filter(pl.col('idx_hour').is_in(range)).with_row_index()

    # Check that shape of dataframe is (168, 2881)
    assert input_df.shape == (168, 2881), f"Expected shape (168, 2881), got {input_df.shape}"

    # Check that there is no null value in the dataframe
    assert input_df.is_null().any().any() == False, "Submission dataframe contains null values"

    # Stack the dataframe with f'traffic_DLThpVol_test_5w-6w_{hour}_{beam_id}' as index
    # where it cycles through the values 671-840 for hour and then the beam_ids, which are colnames of input_df
    return input_df.unpivot(index='idx_hour').with_columns(
        (pl.struct(pl.all()).map_elements(lambda row: f'traffic_DLThpVol_test_5w-6w_{row['row_index']}_{row["variable"]}', return_dtype=pl.String)).alias('ID')
    ).select(['ID', 'value']).rename({'value': 'Target'})


def create_submission_csv(input_df: pl.DataFrame, output_filename='traffic_forecast.csv', archiving_dir='submission-csvs-archive') -> pl.DataFrame:
    """
    Create a submission CSV file from data in input format that's been extended to cover weeks 5-6 and 10-11.
    """

    # Create half submission dataframes
    half_submission_5w_6w = create_half_submission_df(input_df, '5w-6w')
    half_submission_10w_11w = create_half_submission_df(input_df, '10w-11w')

    # Concatenate the two half submission dataframes
    submission_df = pl.concat([half_submission_5w_6w, half_submission_10w_11w], how='vertical')

    # Save the submission dataframe to a CSV file for submission
    submission_df.write_csv(output_filename)
    
    # Save the submission dataframe to a CSV file for archiving
    if archiving_dir:
        archiving_dir = Path(archiving_dir)
        archiving_dir.mkdir(parents=True, exist_ok=True)
        submission_df.write_csv(archiving_dir / f'{wandb.run.name}_{output_filename}')

    return submission_df