 # Inference

In [1]:
# %%
from pathlib import Path
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "xgboost_inference.ipynb"  # Manually set the notebook name

import pandas as pd
import polars as pl
import xgboost as xgb
import wandb
from tqdm.notebook import tqdm
import pickle
import numpy as np

import utils
import yaml


In [2]:
# %%
DEBUG = True
# Load the inference config from the YAML file

with open('configs/direct_inference_config_11_10_24.yaml', 'r') as f:
    train_config = yaml.safe_load(f)

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

train_config = dotdict(train_config)

inference_config = {
    'prediction_length': 168,
    'create_submission_csv': False
}

# Update the checkpoints directory to 'checkpoints_final'
checkpoints_dir = 'checkpoints_final'
xgboost_models_dir = Path(checkpoints_dir)


In [3]:

# Load all models
models = {}
for file_name in tqdm(os.listdir(xgboost_models_dir)):
    if file_name.startswith('forward_shift_'):
        shift = int(file_name.split('_')[-1])
        model_path = xgboost_models_dir / file_name
        with open(model_path, 'rb') as f:
            models[shift] = pickle.load(f)

  0%|          | 0/168 [00:00<?, ?it/s]

 ## Load and Prepare Data

In [4]:
# %%
# Read the CSV files
data_dir = Path('input-data')
target_dataframes = {
    'thp_vol': pl.read_csv(data_dir / 'traffic_DLThpVol.csv'),  # This is the target variable
    'prb': pl.read_csv(data_dir / 'traffic_DLPRB.csv'),
    'thp_time': pl.read_csv(data_dir / 'traffic_DLThpTime.csv'),
    'mr_number': pl.read_csv(data_dir / 'traffic_MR_number.csv')
}

idx_hour_series = target_dataframes['thp_vol']['']

for k, v in target_dataframes.items():
    target_dataframes[k] = v.drop('')

template_df = target_dataframes['thp_vol']

predict_hour = 840

null_row = pl.DataFrame({beam_id: [None] for beam_id in template_df.columns})

target_dataframes = {k: pl.concat([v, null_row], how='vertical_relaxed') for k, v in target_dataframes.items()}

target_names = list(target_dataframes.keys())
feature_dfs = utils.create_all_feature_dfs(target_dataframes, idx_hour_series, train_config)
feature_dfs = {k: v.tail(1) for k, v in feature_dfs.items()}  # maybe turn in to lazyframe for efficiency?
X_predict = utils.convert_to_long_format(feature_dfs)

cat_types = utils.make_id_cat_type(template_df.columns)
X_predict = X_predict.to_pandas()
for col in ['beam_id', 'cell_id', 'station_id']:
    if col in X_predict.columns:
        X_predict[col] = X_predict[col].astype(cat_types[col])

In [11]:
# %%

ys_predicted_wide = []

for shift in tqdm(range(168)):
    y_predicted = models[shift].predict(X_predict)

    idx_hour = pl.DataFrame({'idx_hour': [840 + shift] * len(template_df.columns)})

    y_predicted_long_df = pl.DataFrame({'idx_hour': idx_hour, 'beam_id': X_predict['beam_id'], 'thp_vol': y_predicted})

    y_predicted_wide = utils.convert_to_wide_format(y_predicted_long_df, ['thp_vol'])['thp_vol']

    ys_predicted_wide.append(y_predicted_wide)


  0%|          | 0/168 [00:00<?, ?it/s]

In [14]:
# %%
predictions_wide = pl.concat(ys_predicted_wide, how='vertical')
predictions_wide = predictions_wide.with_columns(idx_hour=pl.Series(range(840, 1008)))


In [23]:
predictions_wide = predictions_wide.select(pl.col('idx_hour'), pl.exclude('idx_hour'))
predictions_wide

idx_hour,0_0_0,0_0_1,0_0_2,0_0_3,0_0_4,0_0_5,0_0_6,0_0_7,0_0_8,0_0_9,0_0_10,0_0_11,0_0_12,0_0_13,0_0_14,0_0_15,0_0_16,0_0_17,0_0_18,0_0_19,0_0_20,0_0_21,0_0_22,0_0_23,0_0_24,0_0_25,0_0_26,0_0_27,0_0_28,0_0_29,0_0_30,0_0_31,0_1_0,0_1_1,0_1_2,0_1_3,…,29_1_27,29_1_28,29_1_29,29_1_30,29_1_31,29_2_0,29_2_1,29_2_2,29_2_3,29_2_4,29_2_5,29_2_6,29_2_7,29_2_8,29_2_9,29_2_10,29_2_11,29_2_12,29_2_13,29_2_14,29_2_15,29_2_16,29_2_17,29_2_18,29_2_19,29_2_20,29_2_21,29_2_22,29_2_23,29_2_24,29_2_25,29_2_26,29_2_27,29_2_28,29_2_29,29_2_30,29_2_31
i64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
840,0.466081,0.105483,0.149764,0.07013,1.138302,0.539551,0.09427,0.07013,0.635353,0.617956,0.125508,0.151468,0.660757,0.16226,0.14043,0.116173,0.101051,0.303548,1.393848,0.714513,0.130742,0.114484,0.130271,0.341998,0.087457,0.10228,1.384943,0.093579,0.105328,0.167898,0.094431,0.077898,0.070685,0.159843,0.134162,0.495922,…,0.366006,0.563202,0.877411,0.248561,0.068914,0.094377,0.103158,0.098499,0.080331,0.070698,0.098352,0.179785,0.100533,0.081107,0.213927,0.102621,0.096659,0.096707,0.323838,0.10493,0.100565,0.071091,0.104655,0.098098,0.283177,0.17434,0.087915,0.126967,0.069353,0.093761,0.082774,0.082669,0.084514,0.069541,0.089541,0.102618,0.077246
841,0.349067,0.114122,0.365288,0.070777,0.75924,0.514641,0.088071,0.070777,0.541569,0.395453,0.134796,0.140268,0.417919,0.171833,0.132208,0.135961,0.101817,0.186731,1.192911,0.485041,0.144062,0.140639,0.135782,0.325621,0.090982,0.096628,0.169835,0.086482,0.095266,0.164192,0.086003,0.07569,0.072087,0.174577,0.128622,0.460806,…,0.264044,0.658557,0.544527,0.184946,0.069083,0.091339,0.111099,0.096764,0.078599,0.072625,0.095315,0.199519,0.097169,0.077253,0.190637,0.107674,0.098724,0.097073,0.214651,0.103337,0.094676,0.071227,0.092499,0.096107,0.202352,0.100795,0.082301,0.119546,0.071383,0.086623,0.078347,0.079846,0.081462,0.069849,0.087333,0.095188,0.074722
842,0.321606,0.142163,0.956922,0.070082,0.541167,0.503571,0.084412,0.070082,0.382485,0.0518,0.13033,0.140306,0.290956,0.224402,0.147279,0.168332,0.110168,0.220572,1.123015,0.402481,0.194486,0.201382,0.184022,0.278421,0.080271,0.098321,0.147988,0.077482,0.097235,0.166244,0.076884,0.070996,0.071204,0.172753,0.120458,0.270621,…,0.253255,0.461135,0.391571,0.152435,0.06837,0.090716,0.134313,0.102475,0.07446,0.073411,0.095694,0.272838,0.105211,0.07735,0.229896,0.178292,0.1085,0.124598,0.189833,0.125484,0.097922,0.071203,0.09154,0.101167,0.1674,0.08274,0.079931,0.098272,0.071901,0.084584,0.07976,0.077499,0.077674,0.068924,0.085145,0.095514,0.085771
843,0.256772,0.15254,0.144487,0.071674,0.652168,0.377,0.084175,0.071674,0.284285,0.178115,0.147477,0.126126,0.261665,0.194638,0.14049,0.144841,0.105706,0.214634,0.934179,0.228563,0.171008,0.177292,0.172015,0.25701,0.081144,0.097251,0.10717,0.088367,0.093451,0.150312,0.079478,0.072043,0.071444,0.164914,0.108961,0.208146,…,0.24916,0.440034,0.309431,0.12958,0.069836,0.089148,0.129243,0.100106,0.075436,0.074549,0.100461,0.228939,0.104852,0.078283,0.204459,0.1571,0.111729,0.121196,0.159668,0.121469,0.100191,0.071333,0.089208,0.102576,0.156645,0.077818,0.07884,0.089243,0.072304,0.086647,0.08159,0.080385,0.077865,0.069559,0.085092,0.09751,0.07592
844,0.201087,0.151786,0.16632,0.072165,0.517022,0.354426,0.084704,0.07248,0.195417,0.151439,0.145536,0.128983,0.180761,0.189154,0.16001,0.166237,0.116049,0.181516,0.702928,0.208299,0.152736,0.176887,0.162114,0.223751,0.08242,0.097157,0.096651,0.081549,0.094134,0.134083,0.081877,0.07248,0.072126,0.136798,0.100516,0.228552,…,0.208336,0.405753,0.245721,0.115604,0.071689,0.086226,0.129328,0.097724,0.07636,0.075975,0.096125,0.215622,0.105862,0.080775,0.18451,0.13566,0.108815,0.116357,0.120611,0.113921,0.09997,0.072282,0.094287,0.10619,0.140795,0.079859,0.08153,0.085939,0.073337,0.088561,0.082715,0.083364,0.080187,0.07105,0.084931,0.096068,0.078594
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1003,0.785802,0.19292,0.919532,0.080161,3.369255,2.517078,0.097956,0.080161,0.777475,0.934418,0.287254,0.366046,1.002874,0.475969,0.357358,0.296671,0.180528,1.0067,1.327262,0.712475,0.565407,0.56673,0.846098,0.446173,0.109263,0.139686,0.168376,0.093968,0.141049,0.246284,0.108248,0.080161,0.077338,0.438352,0.163978,0.204689,…,0.556077,0.981051,0.65788,0.191846,0.078082,0.119162,0.350075,0.255251,0.07954,0.078536,0.143547,0.617534,0.187067,0.083148,1.059229,0.593434,0.187851,0.264418,0.307057,0.307736,0.19721,0.078536,0.098454,0.200153,0.454371,0.107821,0.097007,0.122682,0.078536,0.099069,0.113965,0.110439,0.081632,0.078082,0.102592,0.157153,0.083148
1004,0.905722,0.19546,0.287623,0.078535,4.182957,3.069052,0.100289,0.078535,0.609168,0.895572,0.29844,0.366478,1.153975,0.485891,0.27166,0.244047,0.17527,0.875149,1.451405,0.780759,0.459926,0.486236,0.521021,0.421355,0.108757,0.143696,0.253706,0.093125,0.136748,0.238314,0.108382,0.078535,0.076399,0.612883,0.154773,0.216868,…,0.523206,0.969965,0.667413,0.194858,0.076593,0.115763,0.304643,0.253785,0.080951,0.077748,0.147053,0.480636,0.180538,0.081536,0.667378,0.395275,0.186872,0.219356,0.3722,0.26605,0.178965,0.077939,0.100132,0.189851,0.383976,0.123886,0.088115,0.129192,0.07749,0.100139,0.10887,0.105514,0.080583,0.07749,0.104698,0.172554,0.086122
1005,0.817579,0.169659,0.364458,0.080052,3.789816,2.408399,0.105673,0.080052,0.480036,0.479029,0.366126,0.340522,1.256107,0.35766,0.196935,0.188492,0.156621,0.669894,2.004391,0.820667,0.32756,0.338983,0.382474,0.505162,0.102963,0.134139,0.242892,0.096014,0.141536,0.223325,0.090878,0.080052,0.07771,0.565465,0.142905,0.228065,…,0.421394,0.938256,0.63229,0.210789,0.07996,0.119394,0.193897,0.179114,0.084058,0.07892,0.140954,0.266571,0.156548,0.085071,0.440643,0.212289,0.142347,0.172707,0.453927,0.220654,0.137879,0.079373,0.104431,0.139706,0.328955,0.15288,0.093049,0.147525,0.07892,0.1038,0.101524,0.097681,0.084479,0.07892,0.109229,0.150083,0.088547
1006,0.685358,0.158854,0.357996,0.08148,2.939298,1.409624,0.107322,0.08148,0.507386,0.47262,0.28434,0.301383,0.843075,0.313909,0.177207,0.230126,0.134573,0.549925,1.6049,0.764628,0.256479,0.247576,0.246518,0.397863,0.097203,0.128326,0.19369,0.102027,0.138786,0.181333,0.092976,0.08148,0.077999,0.460672,0.160305,0.258814,…,0.385403,0.851494,0.731206,0.243895,0.078996,0.110277,0.173739,0.150539,0.086866,0.08301,0.138789,0.250428,0.140676,0.083958,0.358872,0.168565,0.138082,0.15802,0.364444,0.192167,0.132767,0.079766,0.098759,0.13722,0.270502,0.163475,0.087615,0.161042,0.079766,0.110148,0.096945,0.095598,0.086474,0.079766,0.113014,0.169646,0.092821


In [16]:
def create_half_submission_df(input_df: pl.DataFrame, weeks: str) -> pl.DataFrame:
    """
    Create a submission CSV file from a Polars DataFrame of thp_vol.
    """
    if weeks == '5w-6w':
        range = [840, 1007]
    elif weeks == '10w-11w':
        range = [1680, 1847]

    # Choose rows with first column 'idx_hour' having the values 671-840.
    input_df = input_df.filter(pl.col('idx_hour') >= range[0], pl.col('idx_hour') <= range[1])

    # Some checks on the input_df
    assert input_df.shape == (168, 2881), f"Expected shape (168, 2881), got {input_df.shape}"
    assert input_df.select(pl.any_horizontal(pl.all().is_null().any())).item() == False, "Submission dataframe contains null values"
    assert input_df['idx_hour'].head(1)[0] <= range[0] and input_df['idx_hour'].tail(1)[0] >= range[1], "Submission dataframe does seemingly not contain the correct idx_hour values"

    # Stack the dataframe with f'traffic_DLThpVol_test_5w-6w_{hour}_{beam_id}' as index
    # where it cycles through the values 671-840 for hour and then the beam_ids, which are colnames of input_df
    # return input_df.unpivot(index='idx_hour')
    return input_df.unpivot(index='idx_hour', variable_name='beam_id').with_columns(
        pl.concat_str([pl.lit('traffic_DLThpVol_test'), pl.lit(weeks), pl.col('idx_hour') - range[0], pl.col('beam_id')], separator='_').alias('ID')
    ).select(['ID', 'value']).rename({'value': 'Target'})


def create_submission_csv(input_df: pl.DataFrame, output_filename='traffic_forecast.csv', archiving_dir='submission-csvs-archive') -> pl.DataFrame:
    """
    Create a submission CSV file from data in input format that's been extended to cover weeks 5-6 and 10-11.
    """

    # Create half submission dataframes
    half_submission_5w_6w = create_half_submission_df(input_df, '5w-6w')
    half_submission_10w_11w = create_half_submission_df(input_df, '10w-11w')

    # Concatenate the two half submission dataframes
    submission_df = pl.concat([half_submission_5w_6w, half_submission_10w_11w], how='vertical')

    # Save the submission dataframe to a CSV file for submission, and to wandb
    submission_df.write_csv(output_filename)
    wandb.save(output_filename)

    # Save the submission dataframe to a CSV file for archiving
    if archiving_dir:
        archiving_dir = Path(archiving_dir)
        archiving_dir.mkdir(parents=True, exist_ok=True)
        submission_df.write_csv(archiving_dir / f'{wandb.run.name}_{output_filename}')

    return submission_df

NameError: name 'ys_predicted_long' is not defined

In [26]:
if DEBUG:
    predictions_wide = predictions_wide.with_columns(idx_hour=pl.Series(range(840, 1008)))
    # idxs from for 11th week
    dummy_w11 = pl.DataFrame({'idx_hour': list(range(1680, 1848))} | {beam_id: [0] * 168 for beam_id in template_df.columns})
    ys_final = pl.concat([predictions_wide, dummy_w11], how='vertical_relaxed')

shape: (168, 2_881)
┌──────────┬──────────┬──────────┬──────────┬───┬──────────┬──────────┬──────────┬──────────┐
│ idx_hour ┆ 0_0_0    ┆ 0_0_1    ┆ 0_0_2    ┆ … ┆ 29_2_28  ┆ 29_2_29  ┆ 29_2_30  ┆ 29_2_31  │
│ ---      ┆ ---      ┆ ---      ┆ ---      ┆   ┆ ---      ┆ ---      ┆ ---      ┆ ---      │
│ i64      ┆ f32      ┆ f32      ┆ f32      ┆   ┆ f32      ┆ f32      ┆ f32      ┆ f32      │
╞══════════╪══════════╪══════════╪══════════╪═══╪══════════╪══════════╪══════════╪══════════╡
│ 840      ┆ 0.466081 ┆ 0.105483 ┆ 0.149764 ┆ … ┆ 0.069541 ┆ 0.089541 ┆ 0.102618 ┆ 0.077246 │
│ 841      ┆ 0.349067 ┆ 0.114122 ┆ 0.365288 ┆ … ┆ 0.069849 ┆ 0.087333 ┆ 0.095188 ┆ 0.074722 │
│ 842      ┆ 0.321606 ┆ 0.142163 ┆ 0.956922 ┆ … ┆ 0.068924 ┆ 0.085145 ┆ 0.095514 ┆ 0.085771 │
│ 843      ┆ 0.256772 ┆ 0.15254  ┆ 0.144487 ┆ … ┆ 0.069559 ┆ 0.085092 ┆ 0.09751  ┆ 0.07592  │
│ 844      ┆ 0.201087 ┆ 0.151786 ┆ 0.16632  ┆ … ┆ 0.07105  ┆ 0.084931 ┆ 0.096068 ┆ 0.078594 │
│ …        ┆ …        ┆ …        ┆ …    

In [27]:
if inference_config['create_submission_csv']:
    submission_df = create_submission_csv(ys_final)