## Libraries

In [1]:
import numpy as np
import pandas as pd
import os
import pickle

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss

## Preprocessing Data

In [2]:
path = './gludata/data'
with open(path+"/train_data_pyforecast.pkl", 'rb') as f:
      train_data_raw = pickle.load(f)
with open(path+"/val_data_pyforecast.pkl", 'rb') as f:
      val_data_raw = pickle.load(f)

In [4]:
subset = [0, 1, 2, 3]
val_data_raw = [x for x in val_data_raw if x[0] in subset]
train_data_raw = [x for x in train_data_raw if x[0] in subset]

In [5]:
def read_data(data, id_start):
    data_len = sum([len(data[i][1]) for i in range(len(data))])
    data_pd = pd.DataFrame(index = range(data_len),
                           columns = ["timeidx", "id", "subject", "CGM", 
                                      "dayofyear", "dayofmonth", "dayofweek", "hour", 
                                      "minute", "date"])
    start = 0
    for i in range(len(data)):
        block_len = len(data[i][1]) 
        data_pd["timeidx"][start:(start+block_len)] = range(block_len)
        data_pd["id"][start:(start+block_len)] = [id_start + i] * block_len
        data_pd["subject"][start:(start+block_len)] = [data[i][0]] * block_len
        data_pd["CGM"][start:(start+block_len)] = data[i][1].flatten() 
        data_pd["date"][start:(start+block_len)] = data[i][3]
        start += block_len
    
    # set format
    data_pd["id"] = data_pd["id"].astype(str).astype("string").astype("category")
    data_pd["subject"] = data_pd["subject"].astype(str).astype("string").astype("category")
    data_pd["CGM"] = data_pd["CGM"].astype("float")
    data_pd["timeidx"] = data_pd["timeidx"].astype("int")
    
    #extract time features
    data_pd["date"] = pd.to_datetime(data_pd["date"])
    data_pd["dayofyear"] = data_pd["date"].dt.dayofyear.astype("string").astype("category")
    data_pd["dayofmonth"] = data_pd["date"].dt.day.astype("string").astype("category")
    data_pd["dayofweek"] = data_pd["date"].dt.dayofweek.astype("string").astype("category")
    data_pd["hour"] = data_pd["date"].dt.hour.astype("string").astype("category")
    data_pd["minute"] = data_pd["date"].dt.minute.astype("string").astype("category")
    
    # reset index
    data_pd = data_pd.reset_index()
    data_pd = data_pd.drop(columns=["index"])
    return data_pd

train_data_pd = read_data(train_data_raw, 0)
val_data_pd = read_data(val_data_raw, len(train_data_raw))
test_data_pd = read_data(val_data_raw, len(train_data_raw)+len(val_data_raw))

In [6]:
train_data_pd

Unnamed: 0,timeidx,id,subject,CGM,dayofyear,dayofmonth,dayofweek,hour,minute,date
0,0,0,0,-2.664835,274,1,4,17,17,2010-10-01 17:17:00
1,1,0,0,-2.637363,274,1,4,17,22,2010-10-01 17:22:00
2,2,0,0,-2.692308,274,1,4,17,27,2010-10-01 17:27:00
3,3,0,0,-2.747253,274,1,4,17,32,2010-10-01 17:32:00
4,4,0,0,-3.214286,274,1,4,17,37,2010-10-01 17:37:00
...,...,...,...,...,...,...,...,...,...,...
39273,736,48,3,0.164835,81,22,4,2,15,2013-03-22 02:15:00
39274,737,48,3,0.247253,81,22,4,2,20,2013-03-22 02:20:00
39275,738,48,3,0.302198,81,22,4,2,25,2013-03-22 02:25:00
39276,739,48,3,0.247253,81,22,4,2,30,2013-03-22 02:30:00


In [7]:
val_data_pd

Unnamed: 0,timeidx,id,subject,CGM,dayofyear,dayofmonth,dayofweek,hour,minute,date
0,0,49,0,-2.637363,258,15,6,14,7,2013-09-15 14:07:00
1,1,49,0,-2.637363,258,15,6,14,12,2013-09-15 14:12:00
2,2,49,0,-2.692308,258,15,6,14,17,2013-09-15 14:17:00
3,3,49,0,-2.829670,258,15,6,14,22,2013-09-15 14:22:00
4,4,49,0,-2.994505,258,15,6,14,27,2013-09-15 14:27:00
...,...,...,...,...,...,...,...,...,...,...
2712,253,54,3,-0.219780,259,16,0,12,5,2013-09-16 12:05:00
2713,254,54,3,0.219780,259,16,0,12,10,2013-09-16 12:10:00
2714,255,54,3,0.741758,259,16,0,12,15,2013-09-16 12:15:00
2715,256,54,3,0.714286,259,16,0,12,20,2013-09-16 12:20:00


In [8]:
train_data_pd.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 39278 entries, 0 to 39277
Data columns (total 10 columns):
 #   Column      Non-Null Count  Dtype         
---  ------      --------------  -----         
 0   timeidx     39278 non-null  int64         
 1   id          39278 non-null  category      
 2   subject     39278 non-null  category      
 3   CGM         39278 non-null  float64       
 4   dayofyear   39278 non-null  category      
 5   dayofmonth  39278 non-null  category      
 6   dayofweek   39278 non-null  category      
 7   hour        39278 non-null  category      
 8   minute      39278 non-null  category      
 9   date        39278 non-null  datetime64[ns]
dtypes: category(7), datetime64[ns](1), float64(1), int64(1)
memory usage: 1.2 MB


In [9]:
train_data = TimeSeriesDataSet(
    train_data_pd,
    time_idx="timeidx",
    target="CGM",
    group_ids=["id"],
    max_encoder_length=180,
    max_prediction_length=12,
    static_categoricals=["subject"],
    time_varying_known_categoricals= ["dayofyear", 
                                      "dayofmonth", 
                                      "dayofweek", 
                                      "hour",
                                      "minute"],
    time_varying_known_reals=["timeidx"],
    time_varying_unknown_reals = ["CGM"],
    scalers=[],
    add_relative_time_idx=True,
    add_encoder_length=True,
)
train_dataloader = train_data.to_dataloader(train=True, batch_size=64, num_workers=24)


val_data = TimeSeriesDataSet(
    val_data_pd,
    time_idx="timeidx",
    target="CGM",
    group_ids=["id"],
    max_encoder_length=180,
    max_prediction_length=12,
    static_categoricals=["subject"],
    time_varying_known_categoricals= ["dayofyear", 
                                      "dayofmonth", 
                                      "dayofweek", 
                                      "hour",
                                      "minute"],
    time_varying_known_reals=["timeidx"],
    time_varying_unknown_reals = ["CGM"],
    scalers=[],
    add_relative_time_idx=True,
    add_encoder_length=True,
)
val_dataloader = val_data.to_dataloader(train=False, batch_size=64, num_workers=24)

## Training the model

In [10]:
# configure network and trainer
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath='./saved_models',
    filename="tft-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min")
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard

trainer = pl.Trainer(
    max_epochs=1000,
    gpus=1,
    weights_summary="top",
    callbacks=[checkpoint_callback],
    logger=logger,
)


tft = TemporalFusionTransformer.from_dataset(
    train_data,
    learning_rate=0.01,
    hidden_size=160,
    attention_head_size=4,
    dropout=0.1,
    output_size=7, 
    loss=QuantileLoss(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Number of parameters in network: 1313.1k


In [11]:
# fit network
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 1.9 K 
3  | prescalers                         | ModuleDict                      | 64    
4  | static_variable_selection          | VariableSelectionNetwork        | 4.4 K 
5  | encoder_variable_selection         | VariableSelectionNetwork        | 16.1 K
6  | decoder_variable_selection         | VariableSelectionNetwork        | 11.9 K
7  | static_context_variable_selection  | GatedResidualNetwork            | 103 K 
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 103 K 
9  | static_context_initial_cell

Validation sanity check:   0%| | 0/2 [00:00<?, ?it/

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Epoch 0:  95%|▉| 467/492 [05:22<00:17,  1.45it/s, loss=0.162, v_num=9, train_los
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                        | 0/25 [00:00<?, ?it/s][A
Epoch 0:  95%|▉| 469/492 [05:24<00:15,  1.45it/s, loss=0.162, v_num=9, train_los[A
Validating:   8%|██▌                             | 2/25 [00:02<00:27,  1.20s/it][A
Epoch 0:  96%|▉| 471/492 [05:25<00:14,  1.45it/s, loss=0.162, v_num=9, train_los[A
Validating:  16%|█████                           | 4/25 [00:03<00:15,  1.40it/s][A
Epoch 0:  96%|▉| 473/492 [05:26<00:13,  1.45it/s, loss=0.162, v_num=9, train_los[A
Validating:  24%|███████▋                        | 6/25 [00:04<00:10,  1.74it/s][A
Epoch 0:  97%|▉| 475/492 [05:27<00:11,  1.45it/s, loss=0.162, v_num=9, train_los[A
Validating:  32%|██████████▏                     | 8/25 [00:05<00:08,  1.92it/s][A
Epoch 0:  97%|▉| 477/492 [05:28<00:10,  1.45it/s, loss=0.162, v_num=9, train_los[A
Validating:  40%|████████████▍               

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Epoch 0: 100%|█| 492/492 [05:36<00:00,  1.46it/s, loss=0.162, v_num=9, train_los
Epoch 1:  74%|▋| 365/492 [04:12<01:27,  1.45it/s, loss=0.156, v_num=9, train_loss_step=0.145, val_loss=0.194, train_loss_epoch

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
