In [1]:
import torch
import numpy as np
import pandas as pd
from pytorch_forecasting import TimeSeriesDataSet


In [2]:
from torch.utils.data import DataLoader
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.metrics import MAE
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

In [3]:
df_TFT = pd.read_csv('../data/processed/merged_dataset_cleaned.csv', sep = ';', index_col = 'utc_timestamp', parse_dates = True)
df_TFT = df_TFT.loc['2019-01-01':'2019-12-31']

print(df_TFT.head())
print(df_TFT.tail())
print(df_TFT.shape)
print(df_TFT.dtypes)
print(df_TFT.isnull().sum())

                                 cet_cest_timestamp  \
utc_timestamp                                         
2019-01-01 00:00:00+00:00  2019-01-01T01:00:00+0100   
2019-01-01 00:15:00+00:00  2019-01-01T01:15:00+0100   
2019-01-01 00:30:00+00:00  2019-01-01T01:30:00+0100   
2019-01-01 00:45:00+00:00  2019-01-01T01:45:00+0100   
2019-01-01 01:00:00+00:00  2019-01-01T02:00:00+0100   

                           DE_load_actual_entsoe_transparency  \
utc_timestamp                                                   
2019-01-01 00:00:00+00:00                            42254.95   
2019-01-01 00:15:00+00:00                            41718.84   
2019-01-01 00:30:00+00:00                            41349.07   
2019-01-01 00:45:00+00:00                            40924.25   
2019-01-01 01:00:00+00:00                            40984.90   

                           DE_load_forecast_entsoe_transparency  \
utc_timestamp                                                     
2019-01-01 00:00:00+00:0

In [4]:
# Drop forecasting, radiation, cet timestamp, and all profile columns
cols_to_drop = [
    "cet_cest_timestamp",
    "DE_load_forecast_entsoe_transparency",
    "DE_radiation_direct_horizontal",
    "DE_radiation_diffuse_horizontal",
    "DE_solar_profile",
    "DE_wind_profile",
    "DE_wind_offshore_profile",
    "DE_wind_onshore_profile"
]

df_TFT = df_TFT.drop(columns=cols_to_drop)

df_TFT.shape

(34940, 13)

In [5]:
print(df_TFT.dtypes)
df_TFT.isnull().sum()

DE_load_actual_entsoe_transparency    float64
DE_solar_capacity                       int64
DE_solar_generation_actual            float64
DE_wind_capacity                        int64
DE_wind_generation_actual             float64
DE_wind_offshore_capacity               int64
DE_wind_offshore_generation_actual    float64
DE_wind_onshore_capacity                int64
DE_wind_onshore_generation_actual     float64
DE_temperature                        float64
hour                                  float64
is_daylight                             int64
month                                   int64
dtype: object


DE_load_actual_entsoe_transparency    0
DE_solar_capacity                     0
DE_solar_generation_actual            0
DE_wind_capacity                      0
DE_wind_generation_actual             0
DE_wind_offshore_capacity             0
DE_wind_offshore_generation_actual    0
DE_wind_onshore_capacity              0
DE_wind_onshore_generation_actual     0
DE_temperature                        0
hour                                  0
is_daylight                           0
month                                 0
dtype: int64

In [6]:
df_TFT=df_TFT.copy()
df_TFT["time_idx"] = range(len(df_TFT)) # sequential time index
df_TFT["group_id"] = "DE" # only one group
df_TFT["target"] = df_TFT["DE_load_actual_entsoe_transparency"] # target variable
print(df_TFT.head())

                           DE_load_actual_entsoe_transparency  \
utc_timestamp                                                   
2019-01-01 00:00:00+00:00                            42254.95   
2019-01-01 00:15:00+00:00                            41718.84   
2019-01-01 00:30:00+00:00                            41349.07   
2019-01-01 00:45:00+00:00                            40924.25   
2019-01-01 01:00:00+00:00                            40984.90   

                           DE_solar_capacity  DE_solar_generation_actual  \
utc_timestamp                                                              
2019-01-01 00:00:00+00:00              47480                         0.0   
2019-01-01 00:15:00+00:00              47480                         0.0   
2019-01-01 00:30:00+00:00              47480                         0.0   
2019-01-01 00:45:00+00:00              47480                         0.0   
2019-01-01 01:00:00+00:00              47480                         0.0   

           

In [7]:
max_encoder_length = 96      # past 1 day (15-min intervals)
max_prediction_length = 2880 # future 30 days

categorical_columns = ["month", "hour", "is_daylight"]
continuous_columns = [
    "DE_solar_capacity", "DE_solar_generation_actual",
    "DE_wind_capacity", "DE_wind_generation_actual",
    "DE_wind_offshore_capacity", "DE_wind_offshore_generation_actual",
    "DE_wind_onshore_capacity", "DE_wind_onshore_generation_actual",
    "DE_temperature",
]

training_dataset = TimeSeriesDataSet(
    df_TFT,
    time_idx="time_idx",
    target="target",
    group_ids=["group_id"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    time_varying_known_reals=["time_idx"] + categorical_columns,
    time_varying_unknown_reals=["target"] + continuous_columns,
    static_categoricals=["group_id"],
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)


Moving on...

In [8]:
# Wrap in data loader
batch_size = 64  # can tune later
train_dataloader = training_dataset.to_dataloader(train=True, batch_size=batch_size, num_workers=0)

In [9]:
# intialize the model
tft = TemporalFusionTransformer.from_dataset(
    training_dataset,
    learning_rate=0.03,
    hidden_size=16, 
    attention_head_size=1,
    dropout=0.1,
    loss=MAE(),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

  rank_zero_warn(
  rank_zero_warn(


In [10]:
# set up trainer and train
early_stop_callback = EarlyStopping(monitor="val_loss", patience=10, verbose=True, mode="min")

trainer = Trainer(
    max_epochs=30,
    gradient_clip_val=0.1,
    callbacks=[early_stop_callback],
    accelerator="auto"
)

# Now call fit
trainer.fit(model=tft, train_dataloaders=train_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /Users/mariarumpf/thesis-electricity-demand-forecasting/notebooks/lightning_logs

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | MAE                             | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 1     
3  | prescalers                         | ModuleDict                      | 288   
4  | static_variable_selection          | VariableSelectionNetwork        | 1.8 K 
5  | encoder_variable_selection         | VariableSelectionNetwork        | 10.8 K
6  | decoder_variable_selection         | VariableSelectionN

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [2]:
import pandas as pd

# Load the test dataset
df = pd.read_csv('../data/processed/merged_dataset_cleaned.csv', sep=';', index_col='utc_timestamp', parse_dates=True)
df.info()

<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 175167 entries, 2015-01-01 07:15:00+00:00 to 2019-12-30 22:45:00+00:00
Data columns (total 21 columns):
 #   Column                                Non-Null Count   Dtype  
---  ------                                --------------   -----  
 0   cet_cest_timestamp                    175167 non-null  object 
 1   DE_load_actual_entsoe_transparency    175167 non-null  float64
 2   DE_load_forecast_entsoe_transparency  175065 non-null  float64
 3   DE_solar_capacity                     175167 non-null  int64  
 4   DE_solar_generation_actual            175167 non-null  float64
 5   DE_solar_profile                      174783 non-null  float64
 6   DE_wind_capacity                      175167 non-null  int64  
 7   DE_wind_generation_actual             175167 non-null  float64
 8   DE_wind_profile                       174869 non-null  float64
 9   DE_wind_offshore_capacity             175167 non-null  int64  
 10  DE_wind_offshore_gener