In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [9]:
# start from csv
X_train = pd.read_csv('X_train.csv', index_col=0)
X_test = pd.read_csv('X_test.csv', index_col=0)

# convert index to datetime
X_train["ds"] = pd.to_datetime(X_train.index).astype('datetime64[ns]')
X_test["ds"] = pd.to_datetime(X_test.index).astype('datetime64[ns]')

# transform location_A, location_B, location_C onehots to location column, and drop location_A, location_B, location_C
X_train["location"] = X_train.apply(lambda x: "A" if x["location_A"] == 1 else ("B" if x["location_B"] == 1 else "C"), axis=1)
X_test["location"] = X_test.apply(lambda x: "A" if x["location_A"] == 1 else ("B" if x["location_B"] == 1 else "C"), axis=1)
X_train = X_train.drop(["location_A", "location_B", "location_C"], axis=1)
X_test = X_test.drop(["location_A", "location_B", "location_C"], axis=1)

In [12]:
FEATURES = set(X_train.columns.tolist())
to_remove = ["ds", "y", "location"]
FEATURES = list(FEATURES - set(to_remove))
FEATURES


['dew_or_rime:idx',
 'snow_density:kgm3',
 'is_day:idx',
 'is_in_shadow:idx',
 'relative_humidity_1000hPa:p',
 'dew_point_2m:K',
 'quarter_of_year',
 'sfc_pressure:hPa',
 'clear_sky_energy_1h:J',
 'cloud_base_agl:m',
 'direct_rad_1h:J',
 'is_weekend',
 'day_of_week',
 'fresh_snow_1h:cm',
 'wind_speed_v_10m:ms',
 'fresh_snow_3h:cm',
 'elevation:m',
 'visibility:m',
 'wind_speed_w_1000hPa:ms',
 'msl_pressure:hPa',
 'month_of_year',
 'effective_cloud_cover:p',
 'pressure_100m:hPa',
 'diffuse_rad:W',
 'sun_elevation:d',
 'wind_speed_10m:ms',
 'direct_rad:W',
 'diffuse_rad_1h:J',
 'fresh_snow_24h:cm',
 'ceiling_height_agl:m',
 'pressure_50m:hPa',
 't_1000hPa:K',
 'fresh_snow_6h:cm',
 'hour_of_day',
 'super_cooled_liquid_water:kgm2',
 'wind_speed_u_10m:ms',
 'clear_sky_rad:W',
 'snow_drift:idx',
 'precip_5min:mm',
 'air_density_2m:kgm3',
 'absolute_humidity_2m:gm3',
 'prob_rime:p',
 'total_cloud_cover:p',
 'time_diff',
 'fresh_snow_12h:cm',
 'snow_melt_10min:mm',
 'rain_water:kgm2',
 'precip

In [8]:
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.models.nhits import NHiTS

In [None]:
# define dataset
max_encoder_length = 24
max_prediction_length = 12
training_cutoff = X_train.index.max() - pd.DateOffset(months=1)
validation_cutoff = training_cutoff - pd.DateOffset(months=1)

training = TimeSeriesDataSet(
    data=X_train,
    time_idx="ds",
    target="y",
    group_ids=["location"],
    time_varying_known_reals=FEATURES,
    