In [1]:
import pandas as pd
import numpy as np
import pyunpack
import math
import json

from data_formatters.m5 import M5Formatter
from data_formatters.base import DataTypes, InputTypes

from pytorch_dataset import TFTDataset
from models import GatedLinearUnit
from models import GateAddNormNetwork
from models import GatedResidualNetwork 
from models import ScaledDotProductAttention
from models import InterpretableMultiHeadAttention
from models import VariableSelectionNetwork

from quantile_loss import QuantileLossCalculator
from quantile_loss import NormalizedQuantileLossCalculator

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import nn

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from argparse import ArgumentParser

import matplotlib.pyplot as plt

## Creating Datasets

In [2]:
#m5_data = pd.read_pickle('/home/daniel/github/kaggle_competitions/M5_Forecasting_Accuracy/data/full_data.pkl')
m5_data = pd.read_pickle('/container/home/millenium/Storage/Daniel/github/kaggle_competitions/M5_Forecasting_Accuracy/data/full_data.pkl')

In [3]:
m5_data.shape

(46881677, 34)

In [4]:
m5_data.id.unique().shape

(30490,)

In [5]:
m5_data.groupby(['id']).apply(lambda x: len(x)).value_counts().to_frame()

Unnamed: 0,0
1941,10932
1934,1043
1927,544
1906,301
1920,280
...,...
191,1
177,1
135,1
289,1


### Formating data

In [None]:
data_formatter = M5Formatter()
train, valid, test = data_formatter.split_data(m5_data)

Formatting train-valid-test splits.
Setting scalers with training data...
359
1884
1884
1884
1884
1884
1884
1877
1884
1884
1884
1884
1884
1884
1884
1877
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1485
1485
1485
1485
1485
1485
1485
1485
1485
1485
1884
1884
1884
1884
1884
1884
1884
1884
1884
1863
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1625
1520
1331
1625
1604
1506
1457
1618
1597
1534
1485
1485
1485
1485
1478
1471
1478
1478
1478
1471
1401
1702
1702
1702
1793
1758
1793
1366
1366
1072
1884
1821
1884
1884
1884
1884
1870
1821
1884
1884
1702
1702
1702
1702
1793
1793
1793
1702
1702
1702
1884
1884
1884
1485
1877
1884
1884
1884
1884
1884
1044
1044
1044
750
1359
1366
1359
1345
1338
1345
1870
1814
1807
1814
1807
1814
1793
1814
1821
1814
1884
1884
1884
1807
1884
1793
1884
1884
1884
1884
1478
1485
1478
1478
1478
1478
1464
1478
1478
1471
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
188

1793
1793
1793
1702
1702
1702
1002
995
995
1002
1793
1793
1401
1401
995
1002
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1877
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1485
1485
1485
1478
1478
1478
1478
1478
1485
1478
1877
1814
1884
1814
1807
1814
1884
1821
1814
1807
1884
1884
1884
1877
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
477
477
477
379
463
477
463
470
449
456
1884
1884
1884
1884
1884
1884
1884
1884
1814
1842
1485
1128
1464
1121
1478
1457
1436
1135
1128
1478
1331
1331
1331
1324
1331
1331
1331
1331
1324
1331
1884
1884
1884
1884
1884
1884
1884
1856
1884
1884
1555
1555
1555
1562
1562
1555
1555
1562
1562
1562
1884
1884
1884
1807
1884
1814
1884
1884
1884
1877
1457
1478
1478
1121
1457
1471
1471
1471
1485
1478
1884
1877
1863
1884
1877
1884
1877
1884
1856
1870
666
764
764
757
764
666
645
666
764
666
1492
1485
1478
1478
1485
1492
1485
1478
1485
1492
1135
1135
1135
1135
1128
1135
1128
1135
112

1884
876
1884
1884
1884
1884
1884
1254
1884
1884
1835
715
1835
1835
1835
1842
1443
1254
1422
1821
1884
309
1884
1884
1884
1884
1884
1254
1646
1884
1436
309
1450
1450
1450
1443
1436
1240
1408
1450
708
295
708
701
708
708
701
708
708
708
1429
708
1429
1429
1429
1429
1429
1247
1429
1429
1450
295
1457
1443
1450
1443
834
1254
834
1436
1884
295
1884
1884
1884
1884
1772
1247
750
1884
1884
358
1884
1884
1884
1884
1884
1247
1709
1884
1884
1884
1884
1884
1884
1884
1884
1184
1877
1884
1436
302
1436
323
1884
1877
323
1254
1422
1884
1884
309
1884
1884
1884
1884
323
1247
1422
1884
1884
309
1884
1884
1884
1884
1884
1247
1429
1884
1884
1611
1884
1884
1884
1884
1884
1072
1807
1884
1485
155
1485
1387
1485
1387
1338
1114
876
1373
1884
302
1884
1884
1884
1884
1884
1247
1429
1884
1422
1135
1415
1415
1422
1422
1422
1184
1422
1422
1884
302
1884
1884
1884
1884
1450
1254
1429
1884
1863
211
1884
1884
1884
1884
1884
1254
1429
1884
617
309
610
554
617
617
617
547
617
617
1884
309
1884
1884
1884
1884
1884
1240
142

1779
1758
1884
1464
260
1457
1464
1464
1457
1464
1184
1429
1464
869
302
869
834
876
827
883
883
883
883
1751
155
1751
1751
1751
1751
1681
1114
1422
1730
1828
309
1884
1884
1870
1884
1884
1254
1373
1884
1884
1884
1884
1884
1884
1884
1884
1240
1884
1884
1884
421
1884
1877
1870
1884
1877
512
1429
1884
1821
477
1821
1821
1821
1821
1821
1254
1429
1597
1338
302
1296
1338
1338
1324
1331
1254
1338
1338
1884
309
1884
1786
1786
1884
1884
1254
1422
1786
1534
295
1534
1534
1534
1534
1527
1254
1429
1534
841
309
841
841
841
841
141
834
841
764
1884
309
1884
1884
1884
1884
1884
1247
1429
1884
1317
568
1324
1317
1324
1317
1317
1240
1324
1317
1667
302
1653
1674
1660
1681
1611
1247
1422
1667
1765
715
1884
1884
1884
1884
1842
1247
1429
1856
1884
1863
1884
1884
1884
1884
1884
1884
1877
1884
1765
365
1765
1758
1765
1765
1681
1100
1429
1765
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
309
1884
1884
1884
1884
1163
1163
1170
1884
1884
302
1884
1660
1779
1786
1268
1191
1422
1779
981
533
967
974
981
9

652
1884
1884
1821
1884
1884
1254
533
1884
876
876
876
876
876
876
813
876
876
876
1884
1051
1884
715
1884
1884
1835
1247
1870
715
1884
1821
1884
1821
1884
1814
1884
1814
1681
1821
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1639
1884
1884
1884
1884
1884
1884
1821
1884
1884
1884
1884
1842
1884
1415
1884
1884
1884
1884
1884
1884
1884
1884
1884
1877
1884
1884
806
827
834
813
813
827
827
813
834
806
750
575
750
568
743
750
750
743
750
575
1884
1345
1884
1828
1884
1884
1884
904
1884
1884
1051
1044
1051
1037
1051
1044
1051
1044
1044
1044
750
750
750
715
708
750
708
722
750
701
1884
1884
1884
1884
1884
1884
1527
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1870
1884
1660
1562
1765
1884
1394
1884
1884
1884
1884
1506
1884
1884
1884
1877
1870
1884
1884
1884
1884
1884
1247
1429
1884
1170
1170
1170
1163
1163
1163
1163
1170
1170
1170
1884
1884
1884
1884
1884
1870
1856
1870
1884
1884
1884
1884
1884
1884
1884
1884
1884
1254
1765
1884


1023
1884
1863
1884
1884
1884
1884
1884
1254
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1807
1807
1807
1807
1807
1807
1807
1030
1807
1807
1541
302
1632
1541
1541
1520
1541
1254
1408
1534
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1114
533
1121
1121
1114
1121
1121
1121
1072
1121
449
435
442
435
435
435
421
442
442
365
463
372
449
358
365
358
379
470
414
358
1884
1884
1884
1884
1884
1884
1884
1884
1849
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1870
1576
1576
1583
1576
1590
1583
1590
1583
1583
1562
1576
1576
1576
1569
1576
1576
1569
1583
1583
1590
463
456
456
456
442
463
428
456
456
456
1884
1884
1884
1884
1884
1884
1884
1247
1884
1835
638
519
645
631
652
652
645
652
652
659
1884
1884
1884
1884
1884
1884
1884
1884
1387
1051
204
533
617
1261
1884
1821
197
1254
1422
1856
1884
1884
1884
1884
1884
1884
1884
1884
1856
1884
1856
1849
1856
1849
1856
1856
1842
1856
1849
1807
1884
1884
1821
1884
1884
1884
1884
1884
1884
1884
1380
1597
1387
1373
1597
1380
1380
1254
1

1884
1884
1786
1870
1807
1884
1247
1247
1240
1240
1079
1247
1240
1240
1240
1240
1884
1877
1884
1884
1884
1884
1884
1884
1884
1884
1520
1520
1527
1520
1506
1506
1513
1513
1513
1520
1884
1884
1884
1884
1884
1884
1870
1884
1884
1884
1205
1212
1212
1205
1212
1212
1212
1205
1205
1198
1170
1170
1170
1170
1170
1170
1170
1170
1170
1156
1212
1198
1212
1212
1212
1212
1198
1205
1212
1205
1541
1527
1541
1534
1548
1548
1534
1065
1527
1534
820
610
820
428
813
778
813
813
813
813
1114
561
1121
561
1121
1107
1114
568
1121
491
1688
1688
1688
1688
1681
1688
1688
1695
1681
1695
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1695
1688
1688
1688
1688
1688
1688
1688
1695
1695
1807
1807
1856
1807
1800
1807
1807
1800
1807
1807
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
596
575
589
582
596
561
596
561
568
596
939
946
946
953
953
953
953
946
946
953
1681
1534
1695
1688
1688
1688
1681
1681
1688
1688
925
925
925
925
911
925
918
918
911
918
757
743
757
750
743
750
750
736
736
715
1485
1485
1485
1485
1485

1730
1730
1723
1674
1884
1884
1884
1884
1884
1884
1884
1254
1884
1884
43
533
43
932
918
939
960
918
953
904
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1065
925
1065
1051
1058
1065
1058
1065
1058
1058
1450
1450
1457
1450
1436
1443
1436
1436
1443
1443
617
610
610
610
610
624
624
631
631
617
1093
1079
1100
1093
1093
1093
1100
1093
1093
1100
736
533
582
806
575
589
505
981
974
540
575
204
204
211
568
211
582
568
547
526
519
596
463
505
771
750
771
778
771
778
925
925
925
925
939
925
932
939
939
939
1877
904
1884
1884
1884
1884
1884
1009
1429
1884
1072
1079
1072
1051
1079
1079
1072
1079
1079
1079
1520
1520
1520
1520
1520
1520
1520
1520
1520
1520
1849
1849
1849
1849
1849
1849
1849
1849
1849
1856
1331
1338
1338
1317
1331
1331
1338
1310
1331
1338
1884
1856
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1653
1639
1765
1646
1758
1751
1751
1723
1730
1884
449
358
449
449
435
449
456
449
449
442
1884
1884
1856
1884
1884
1863
1884
1884
1842
1884
1884

1884
1233
1233
1233
1233
1233
1233
1233
1233
1233
1233
1422
1415
1422
1422
1436
1429
1436
1436
1450
1450
1884
1884
1884
1422
1884
1884
1884
1254
1429
1884
722
722
729
715
708
729
694
729
722
729
1793
1807
1786
1800
1744
1793
1744
1800
1786
1765
1884
568
1884
1884
1884
1884
1884
1254
1422
1884
806
806
771
785
785
799
764
806
806
806
589
561
589
575
575
589
561
589
589
267
1884
239
1884
1884
1884
1884
1877
1247
1422
1884
1184
1156
1184
1184
1184
1184
1184
1184
1184
1177
631
526
631
533
512
498
491
631
631
631
1884
1884
1870
1884
1884
1884
1884
904
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1849
1884
1884
1884
1821
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1254
1233
1261
1884
1884
1884
1884
1884
1884
1884
1884
1765
1884
1520
1520
1520
1520
1520
1520
1520
1520
1485
1520
1856
1856
1856
1856
1856
1856
1849
1856
1856
1849
1072
1142
1142
1128
1142
1142
1142
1128
1149
1072
764
778
792
785
771
757
764
792
799
79

1345
1436
1345
1338
1352
1450
1450
1345
1345
1345
1884
1835
1884
1884
1877
1884
1877
1884
1884
1884
1870
1884
1884
1835
1884
1870
1884
1884
1884
1884
988
988
974
1002
988
981
974
974
981
974
1128
1114
1121
1121
1821
1828
1471
1247
1835
1828
1128
1093
1128
1128
1100
1128
1135
1100
1086
1107
988
981
995
995
988
981
988
974
981
974
1044
1044
1044
1037
1030
1037
1037
995
1037
1037
1163
1170
1163
1170
1156
1177
1170
1156
1177
1163
456
407
421
407
456
421
393
435
456
407
743
736
743
743
757
708
701
729
736
680
1884
1884
1870
1877
1884
1877
1884
1884
1884
1884
1877
1884
1884
1884
1884
1884
1884
1884
1814
1884
1352
1324
1352
1345
1352
1352
1352
1352
1331
1352
1884
1884
1884
1884
1884
1884
1884
1884
1884
1520
1226
1226
1219
1226
1219
1219
1226
1226
1226
1226
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1366
1345
1373
1366
1366
1373
1366
1366
1352
1352
1632
1632
1632
1625
1632
1632
1632
1632
1625
1632
988
988
743
743
946
974
953
743
386
967
1415
1422
1429
1422
1408
1429
1422
1422
1408
1387


1863
750
736
757
757
743
750
750
750
729
750
1716
1709
1709
1709
1716
1709
1709
1716
1709
1716
1142
1135
1142
1135
1149
1142
1142
1135
1142
1128
1884
1884
1884
1667
1765
1842
1884
1884
1884
1884
1884
1884
1877
1884
1884
1877
1884
1884
1877
1884
1884
1884
1884
1674
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1877
1884
1884
1884
1884
1884
1716
1716
1716
1695
1716
1716
1716
1695
1709
1716
1884
1884
1807
1884
1884
1884
1863
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1870
1877
1870
1863
1870
1513
1877
1870
1884
1884
1884
1870
1835
1737
1884
1884
1737
1884
1884
1884
1884
1884
1884
743
1870
1884
1884
1233
1884
1884
1870
1814
1884
1835
1870
1870
1884
1842
1884
1870
1128
1114
1114
1100
1303
1303
1303
1282
1303
1289
1415
1415
1520
1422
1401
1415
1415
1408
1415
1415
1331
1331
1331
1331
1331
1331
1338
1324
1324
1331
1884
1884
1807
1884
1884
1877
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1877
1884
1884
1884
722
1884
1884
1884
1884
1884
1884
1884
1842
1884
1884
18

680
659
666
673
680
680
673
680
1884
1884
1877
1863
1884
1884
1884
1884
1877
1884
708
708
659
701
687
673
715
715
652
624
694
841
834
351
673
820
806
344
694
673
1828
1842
1436
1422
1625
1849
1436
1835
1793
1800
680
666
680
673
680
666
680
687
645
680
1387
1429
1275
1282
1422
1191
1268
1429
1380
1044
1282
1317
1058
540
1338
1072
1058
1359
1065
1037
1877
1884
1877
1884
1884
1884
1884
1884
1884
1884
1079
1058
1037
1051
1079
1065
1072
1051
1079
1016
1870
1870
1856
1520
1870
1849
1828
1870
1870
1863
1877
1877
1884
1604
1877
1884
1877
1884
1884
1877
1877
1884
1884
1730
1849
1870
1884
1870
1877
1877
1884
1877
1877
1877
1884
1884
1884
1877
1884
1877
1618
1625
1625
1625
1625
1548
1618
1625
1618
1625
1849
1849
1856
1828
1842
1849
1842
1849
1842
1842
1072
1044
1058
1065
1065
1107
1065
1415
1058
1051
1877
1877
1884
1492
1863
1870
1870
1877
1863
1884
1884
1877
1884
1828
1849
1835
1884
1884
1856
1863
1877
1884
1884
1877
1884
1877
1884
1884
1849
1856
687
680
708
694
694
1884
715
708
701
673
1884
188

1835
1842
1842
1842
1842
1835
1835
1842
1835
1842
1492
1506
1527
981
1513
1492
1499
1534
1506
967
1485
1772
1765
1422
1485
1485
1499
1478
1716
1436
1422
414
659
1422
1429
1422
1429
848
1114
1429
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1870
1870
1870
1870
1863
1870
1870
1870
1870
1870
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1730
1723
1730
1716
1709
1716
1716
1716
1730
1702
1135
988
995
1128
1310
1310
1310
1310
1317
1310
1849
1856
1849
1849
1884
1849
1849
1884
1884
1884
1884
1884
1338
1884
1884
1884
1884
1884
1884
1884
1870
1870
1877
1863
1877
1870
1877
1870
1877
1870
1842
1842
1842
1842
1849
1849
1849
1849
1667
1849
925
918
400
925
897
925
918
918
764
911
1170
1177
988
981
1191
1191
1177
1177
1177
1177
1646
1646
1156
1646
1611
1646
1639
1646
1506
1646
1296
1296
1296
1296
1296
1289
1296
1296
1296
1296
1002
1002
1212
1002
995
1002
1002
1002
1002
1002
1170
1254
1191
631
1184
1191
1170
1254
1191
1247
1464
1464
1464
1464
1464
1464
1464
1464
1464
1464
1884
1870
1884
1870
1

1870
1478
1877
1870
1877
1870
1870
1870
1870
1870
1870
1877
708
316
533
680
750
708
708
722
701
715
1142
1142
1135
1135
1142
1142
1142
1135
1142
1135
743
764
771
757
771
757
771
764
757
757
1471
1478
1478
1478
1478
1478
1478
1478
1464
1471
330
358
442
421
260
281
253
358
239
358
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1065
1065
946
876
1058
1051
1065
687
323
1065
785
407
407
771
778
407
771
400
407
778
456
449
456
456
456
449
456
456
442
456
1884
1884
1884
1884
1884
1884
1884
1828
1884
1884
1485
1492
1485
1492
1492
1485
1499
1492
1492
1478
1149
1135
1149
1135
1135
1142
1142
1121
1135
1121
911
918
918
869
876
897
883
883
883
890
1884
1884
1884
1884
1884
1884
1884
1121
1128
1884
1506
1506
1513
1513
1513
1506
1513
1506
1513
1513
533
526
533
533
526
533
533
540
533
526
1177
1184
1177
981
1191
1184
1177
1177
1177
1177
1884
1884
1884
1877
1884
1884
1884
1884
1863
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1121
764
1121
778
771
1128
1128
1093
1128
1128
981
974
981
988
89

1506
1499
1485
1499
1485
1506
1723
743
1723
1709
1709
1716
1716
1723
1548
1723
1884
1884
1884
1884
1884
1884
1884
1639
1660
1884
1884
1884
1884
1884
1884
1884
1877
1884
1870
1884
1401
1401
1401
1401
1401
1401
1401
1394
1401
1401
1093
1121
1121
1100
1114
1121
1072
1086
1051
1121
981
981
981
981
981
981
974
883
883
981
1506
1499
407
1499
1506
1506
1506
1499
1506
1506
1884
1632
1492
1485
1877
1884
1884
1548
1506
1478
638
547
673
547
183
197
197
253
183
610
1828
1835
1828
1835
1828
1835
1828
1835
1835
1835
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1324
1324
1324
1324
1331
1331
1331
1317
1310
1331
1681
1667
1681
1674
1674
1674
1681
1667
1667
1667
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
750
750
750
750
750
750
750
750
750
750
1709
1709
1555
449
1709
1702
1709
1695
1569
1702
1093
1079
1093
1086
1093
1128
1121
1086
1079
1093
1877
1884
1786
1877
1884
1884
1884
1884
1751
1884
645
638
645
652
631
477
645
638
638
617
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
869
1170
820


1100
1128
1135
1128
1884
1884
1884
1863
1884
1884
1884
1884
1884
1884
1856
1877
1884
1884
1884
1884
1884
1884
1884
1884
1212
1191
1212
1212
1205
1205
1198
1198
1198
1184
1870
1870
1863
1226
1856
1870
1863
1856
1856
1751
1884
1884
1884
1884
1835
1884
1856
1863
1814
1835
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1541
1555
1541
1541
1548
1555
1555
1548
1527
1541
1884
1884
1884
1877
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1807
1877
1884
1884
1870
1884
1884
1884
1863
1877
1877
1884
1842
1842
456
456
456
456
456
456
449
456
449
456
890
883
897
820
890
876
862
883
876
890
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
764
764
820
813
820
820
820
820
813
820
1107
1100
1100
1044
1079
1114
1079
1107
1114
1114
1884
1884
1884
1884
1884
1884
1884
1884
1863
1884
1884
1884
1884
1884
1884
1863
1884
1863
1884
1884
1884
1884
1877
1541
1884
1884
1884
1877
1884
1884
1366
1366
1366
1065
1359
1373
1366
1352
1324
1359
477
491
491
512
491
484
491
484
484
491
372
743
750
743
407


1884
1884
1884
1884
1884
1884
1884
1884
1884
743
750
743
743
631
624
750
750
743
743
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1870
1884
1884
1877
1877
1884
1863
1856
1884
1884
1870
1863
1884
1884
1877
1863
736
729
827
813
729
827
729
827
729
729
1884
1884
1884
1884
1884
1884
1870
1884
1884
1884
1884
1884
1884
1877
1800
1870
1793
1884
1884
1821
1884
1884
1884
1884
1884
1877
1884
1835
1877
1884
1653
1653
1639
1625
1632
1653
1653
1646
1639
1632
1079
1135
1163
757
1121
1156
1009
1135
1107
1534
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1653
1653
1646
1646
1646
1653
1653
1653
1646
1653
1884
1884
1884
1884
1856
1884
1884
1884
1870
1870
1884
1884
1884
1884
1884
1884
1884
1884
1674
1884
1884
1884
1884
1884
1884
1884
1877
1884
1877
1884
1884
1870
1877
1884
1884
1884
1884
1877
1835
1877
1401
1401
1408
1401
1401
1408
1408
1387
1373
1387
1877
1884
1884
1877
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1870
1884
1520
1877
1

1877
1884
1884
1884
1884
1884
1884
1877
1884
1884
1884
1884
1884
1618
1548
1674
1576
1618
1667
1688
1667
1695
1681
1856
1884
1877
1863
1870
1877
1877
1884
1870
1877
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1863
1884
1884
1870
1877
1884
1877
1884
1751
1884
757
631
778
645
757
806
806
729
701
414
1114
1100
918
904
1100
911
1107
911
904
1093
1884
1884
1884
1884
1863
1863
1884
1884
1884
1884
666
694
687
694
687
694
687
687
687
687
904
897
904
904
897
911
904
897
897
897
1163
1170
1163
1156
1156
1142
1863
1177
932
1065
694
694
792
785
792
694
687
694
792
694
1863
1877
1870
1863
1863
1870
1863
1863
1849
1863
673
631
673
652
666
666
659
673
673
666
1569
1562
1513
1492
1555
1555
1569
1534
1569
1541
1884
1884
1884
1884
1877
1884
1877
1884
1884
1877
1884
1877
1877
1884
1877
1884
1884
1884
1884
1884
1877
1870
1863
1219
1856
1856
1849
1849
1856
1849
1807
1807
1814
1233
1779
1814
1807
1814
1814
1807
1884
1884
1884
1884
1884
1884
1884
1884
1884
1884
1849
1849
1842
1842
1842
1800
1842
1849
1

### Creation Datasets

In [None]:
train_dataset = TFTDataset(train)
valid_dataset = TFTDataset(valid)
test_dataset = TFTDataset(test)

### Temporal Fusion Transformer

In [None]:
class TemporalFusionTransformer(pl.LightningModule):
    def __init__(self, hparams):
        super(TemporalFusionTransformer, self).__init__()
        
        self.hparams = hparams
        
        self.name = self.__class__.__name__

        # Data parameters
        self.time_steps = int(hparams.total_time_steps)#int(params['total_time_steps'])
        self.input_size = int(hparams.input_size)#int(params['input_size'])
        self.output_size = int(hparams.output_size)#int(params['output_size'])
        self.category_counts = json.loads(str(hparams.category_counts))#json.loads(str(params['category_counts']))
        self.num_categorical_variables = len(self.category_counts)
        self.num_regular_variables = self.input_size - self.num_categorical_variables
        self.n_multiprocessing_workers = int(hparams.multiprocessing_workers) #int(params['multiprocessing_workers'])

        # Relevant indices for TFT
        self._input_obs_loc = json.loads(str(hparams.input_obs_loc))#json.loads(str(params['input_obs_loc']))
        self._static_input_loc = json.loads(str(hparams.static_input_loc))#json.loads(str(params['static_input_loc']))
        self._known_regular_input_idx = json.loads(str(hparams.known_regular_inputs))#json.loads(str(params['known_regular_inputs']))
        self._known_categorical_input_idx = json.loads(str(hparams.known_categorical_inputs))#json.loads(str(params['known_categorical_inputs']))
        
        self.num_non_static_historical_inputs = self.get_historical_num_inputs()
        self.num_non_static_future_inputs = self.get_future_num_inputs()
        
        self.column_definition = [
                                  ('id', DataTypes.REAL_VALUED, InputTypes.ID),
                                  ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.TIME),
                                  ('power_usage', DataTypes.REAL_VALUED, InputTypes.TARGET),
                                  ('hour', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
                                  ('day_of_week', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
                                  ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
                                  ('categorical_id', DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
                                ]

        # Network params
        self.quantiles = [0.1, 0.5, 0.9]
#         self.use_cudnn = use_cudnn  # Whether to use GPU optimised LSTM
        self.hidden_layer_size = int(hparams.hidden_layer_size)#int(params['hidden_layer_size'])
        self.dropout_rate = float(hparams.dropout_rate)#float(params['dropout_rate'])
        self.max_gradient_norm = float(hparams.max_gradient_norm)#float(params['max_gradient_norm'])
        self.learning_rate = float(hparams.learning_rate)#float(params['learning_rate'])
        self.minibatch_size = int(hparams.minibatch_size)#int(params['minibatch_size'])
        self.num_epochs = int(hparams.num_epochs)#int(params['num_epochs'])
        self.early_stopping_patience = int(hparams.early_stopping_patience)#int(params['early_stopping_patience'])

        self.num_encoder_steps = int(hparams.num_encoder_steps)#int(params['num_encoder_steps'])
        self.num_stacks = int(hparams.stack_size)#int(params['stack_size'])
        self.num_heads = int(hparams.num_heads)#int(params['num_heads'])

        # Serialisation options
#         self._temp_folder = os.path.join(params['model_folder'], 'tmp')
#         self.reset_temp_folder()

        # Extra components to store Tensorflow nodes for attention computations
        self._input_placeholder = None
        self._attention_components = None
        self._prediction_parts = None

        print('*** {} params ***'.format(self.name))
        for k in vars(hparams):
            print('# {} = {}'.format(k, vars(hparams)[k]))
            
        self.train_criterion = QuantileLossCalculator(self.quantiles, self.output_size)
        self.test_criterion = NormalizedQuantileLossCalculator(self.quantiles, self.output_size)

        # Build model
        ## Build embeddings
        self.build_embeddings()
        
        ## Build Static Contex Networks
        self.build_static_context_networks()
        
        ## Building Variable Selection Networks
        self.build_variable_selection_networks()
        
        ## Build Lstm
        self.build_lstm()
        
        ## Build GLU for after lstm encoder decoder and layernorm
        self.build_post_lstm_gate_add_norm()
        
        ## Build Static Enrichment Layer
        self.build_static_enrichment()
        
        ## Building decoder multihead attention
        self.build_temporal_self_attention()
        
        ## Building positionwise decoder
        self.build_position_wise_feed_forward()
        
        ## Build output feed forward
        self.build_output_feed_forward()
        
        ## Initializing remaining weights
        self.init_weights()
        
    def init_weights(self):
        for name, p in self.named_parameters():
            if ('lstm' in name and 'ih' in name) and 'bias' not in name:
                #print(name)
                #print(p.shape)
                torch.nn.init.xavier_uniform_(p)
#                 torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid')
            elif ('lstm' in name and 'hh' in name) and 'bias' not in name:
        
                 torch.nn.init.orthogonal_(p)
            
            elif 'lstm' in name and 'bias' in name:
                #print(name)
                #print(p.shape)
                torch.nn.init.zeros_(p)
        
    def get_historical_num_inputs(self):
        
        obs_inputs = [i for i in self._input_obs_loc]
        
        known_regular_inputs = [i for i in self._known_regular_input_idx
                                if i not in self._static_input_loc]
            
        known_categorical_inputs = [i for i in self._known_categorical_input_idx
                                    if i + self.num_regular_variables not in self._static_input_loc]
        
        wired_embeddings = [i for i in range(self.num_categorical_variables)
                            if i not in self._known_categorical_input_idx 
                            and i not in self._input_obs_loc]

        unknown_inputs = [i for i in range(self.num_regular_variables)
                          if i not in self._known_regular_input_idx
                          and i not in self._input_obs_loc]

        return len(obs_inputs+known_regular_inputs+known_categorical_inputs+wired_embeddings+unknown_inputs)
    
    def get_future_num_inputs(self):
            
        known_regular_inputs = [i for i in self._known_regular_input_idx
                                if i not in self._static_input_loc]
            
        known_categorical_inputs = [i for i in self._known_categorical_input_idx
                                    if i + self.num_regular_variables not in self._static_input_loc]

        return len(known_regular_inputs + known_categorical_inputs)
    
    def build_embeddings(self):
        self.categorical_var_embeddings = nn.ModuleList([nn.Embedding(self.category_counts[i], 
                                                                      self.hidden_layer_size) 
                                                     for i in range(self.num_categorical_variables)])

        self.regular_var_embeddings = nn.ModuleList([nn.Linear(1, 
                                                              self.hidden_layer_size) 
                                                  for i in range(self.num_regular_variables)])

    def build_variable_selection_networks(self):
        
        self.static_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                   input_size = self.hidden_layer_size * len(self._static_input_loc),
                                                   output_size = len(self._static_input_loc),
                                                   dropout_rate = self.dropout_rate)
        
        self.temporal_historical_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                                input_size = self.hidden_layer_size *
                                                                        self.num_non_static_historical_inputs,
                                                                output_size = self.num_non_static_historical_inputs,
                                                                dropout_rate = self.dropout_rate,
                                                                additional_context=self.hidden_layer_size)
        
        self.temporal_future_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                            input_size = self.hidden_layer_size *
                                                                        self.num_non_static_future_inputs,
                                                            output_size = self.num_non_static_future_inputs,
                                                            dropout_rate = self.dropout_rate,
                                                            additional_context=self.hidden_layer_size)
        
    def build_static_context_networks(self):
        
        self.static_context_variable_selection_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                                          dropout_rate=self.dropout_rate)
        
        self.static_context_enrichment_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                              dropout_rate=self.dropout_rate)

        self.static_context_state_h_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                           dropout_rate=self.dropout_rate)
        
        self.static_context_state_c_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                           dropout_rate=self.dropout_rate)
        
    def build_lstm(self):
        self.historical_lstm = nn.LSTM(input_size = self.hidden_layer_size,
                                       hidden_size = self.hidden_layer_size,
                                       batch_first = True)
        self.future_lstm = nn.LSTM(input_size = self.hidden_layer_size,
                                   hidden_size = self.hidden_layer_size,
                                   batch_first = True)
        
    def build_post_lstm_gate_add_norm(self):
        self.post_seq_encoder_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                                 self.hidden_layer_size,
                                                                 self.dropout_rate,
                                                                 activation = None)
        
    def build_static_enrichment(self):
        self.static_enrichment = GatedResidualNetwork(self.hidden_layer_size,
                                                      dropout_rate = self.dropout_rate,
                                                      additional_context=self.hidden_layer_size)
        
    def build_temporal_self_attention(self):
        self.self_attn_layer = InterpretableMultiHeadAttention(n_head = self.num_heads, 
                                                               d_model = self.hidden_layer_size,
                                                               dropout = self.dropout_rate)
        
        self.post_attn_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                           self.hidden_layer_size,
                                                           self.dropout_rate,
                                                           activation = None)
        
    def build_position_wise_feed_forward(self):
        self.GRN_positionwise = GatedResidualNetwork(self.hidden_layer_size,
                                                     dropout_rate = self.dropout_rate)
        
        self.post_tfd_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                         self.hidden_layer_size,
                                                         self.dropout_rate,
                                                         activation = None)
        
    def build_output_feed_forward(self):
        self.output_feed_forward = torch.nn.Linear(self.hidden_layer_size, 
                                                   self.output_size * len(self.quantiles))
         
    def get_decoder_mask(self, self_attn_inputs):
        """Returns causal mask to apply for self-attention layer.
        Args:
        self_attn_inputs: Inputs to self attention layer to determine mask shape
        """
        len_s = self_attn_inputs.shape[1]
        bs = self_attn_inputs.shape[0]
        mask = torch.cumsum(torch.eye(len_s), 0)
        mask = mask.repeat(bs,1,1).to(torch.float32)

        return mask.to(DEVICE)
    
    def get_tft_embeddings(self, regular_inputs, categorical_inputs):
        # Static input
        if self._static_input_loc:
            static_regular_inputs = [self.regular_var_embeddings[i](regular_inputs[:, 0, i:i + 1]) 
                                    for i in range(self.num_regular_variables)
                                    if i in self._static_input_loc]
            #print('static_regular_inputs')
            #print([print(emb.shape) for emb in static_regular_inputs])
            
            static_categorical_inputs = [self.categorical_var_embeddings[i](categorical_inputs[Ellipsis, i])[:,0,:] 
                                         for i in range(self.num_categorical_variables)
                                         if i + self.num_regular_variables in self._static_input_loc]
            #print('static_categorical_inputs')
            #print([print(emb.shape) for emb in static_categorical_inputs])
            static_inputs = torch.stack(static_regular_inputs + static_categorical_inputs, axis = 1)
        else:
            static_inputs = None
            
        # Target input
        obs_inputs = torch.stack([self.regular_var_embeddings[i](regular_inputs[Ellipsis, i:i + 1])
                                     for i in self._input_obs_loc], axis=-1)
        
        # Observed (a prioir unknown) inputs
        wired_embeddings = []
        for i in range(self.num_categorical_variables):
            if i not in self._known_categorical_input_idx \
            and i not in self._input_obs_loc:
                e = self.categorical_var_embeddings[i](categorical_inputs[:, :, i])
                wired_embeddings.append(e)

        unknown_inputs = []
        for i in range(self.num_regular_variables):
            if i not in self._known_regular_input_idx \
            and i not in self._input_obs_loc:
                e = self.regular_var_embeddings[i](regular_inputs[Ellipsis, i:i + 1])
                unknown_inputs.append(e)
                
        if unknown_inputs + wired_embeddings:
            unknown_inputs = torch.stack(unknown_inputs + wired_embeddings, axis=-1)
        else:
            unknown_inputs = None
            
        # A priori known inputs
        known_regular_inputs = [self.regular_var_embeddings[i](regular_inputs[Ellipsis, i:i + 1])
                                for i in self._known_regular_input_idx
                                if i not in self._static_input_loc]
        #print('known_regular_inputs')
        #print([print(emb.shape) for emb in known_regular_inputs])
        
        known_categorical_inputs = [self.categorical_var_embeddings[i](categorical_inputs[Ellipsis, i])
                                    for i in self._known_categorical_input_idx
                                    if i + self.num_regular_variables not in self._static_input_loc]
       #print('known_categorical_inputs')
       #print([print(emb.shape) for emb in known_categorical_inputs])

        known_combined_layer = torch.stack(known_regular_inputs + known_categorical_inputs, axis=-1)
        
        return unknown_inputs, known_combined_layer, obs_inputs, static_inputs
        
    def forward(self, all_inputs):

        regular_inputs = all_inputs[:, :, :self.num_regular_variables].to(torch.float)
        #print('regular_inputs')
        #print(regular_inputs.shape)
        categorical_inputs = all_inputs[:, :, self.num_regular_variables:].to(torch.long)
        #print('categorical_inputs')
        #print(categorical_inputs.shape)
        
        unknown_inputs, known_combined_layer, obs_inputs, static_inputs \
            = self.get_tft_embeddings(regular_inputs, categorical_inputs)
        
        # Isolate known and observed historical inputs.
        if unknown_inputs is not None:
              historical_inputs = torch.cat([
                  unknown_inputs[:, :self.num_encoder_steps, :],
                  known_combined_layer[:, :self.num_encoder_steps, :],
                  obs_inputs[:, :self.num_encoder_steps, :]
              ], axis=-1)
        else:
              historical_inputs = torch.cat([
                  known_combined_layer[:, :self.num_encoder_steps, :],
                  obs_inputs[:, :self.num_encoder_steps, :]
              ], axis=-1)
                
        #print('historical_inputs')
        #print(historical_inputs.shape)
        
        # Isolate only known future inputs.
        future_inputs = known_combined_layer[:, self.num_encoder_steps:, :]
        #print('future_inputs')
        #print(future_inputs.shape)
              
        #print('static_inputs')
        #print(static_inputs.shape)
        
        static_encoder, sparse_weights = self.static_vsn(static_inputs)
        
        #print('static_encoder')
        #print(static_encoder.shape)
        
        #print('sparse_weights')
        #print(sparse_weights.shape)
        
        static_context_variable_selection = self.static_context_variable_selection_grn(static_encoder)
        #print('static_context_variable_selection')
        #print(static_context_variable_selection.shape)
        static_context_enrichment = self.static_context_enrichment_grn(static_encoder)
        #print('static_context_enrichment')
        #print(static_context_enrichment.shape)
        static_context_state_h = self.static_context_state_h_grn(static_encoder)
        #print('static_context_state_h')
        #print(static_context_state_h.shape)
        static_context_state_c = self.static_context_state_c_grn(static_encoder)
        #print('static_context_state_c')
        #print(static_context_state_c.shape)
        
        historical_features, historical_flags \
        = self.temporal_historical_vsn((historical_inputs,
                                        static_context_variable_selection))
        #print('historical_features')
        #print(historical_features.shape)
        #print('historical_flags')
        #print(historical_flags.shape)
        
        future_features, future_flags \
        = self.temporal_future_vsn((future_inputs,
                                    static_context_variable_selection))
        #print('future_features')
        #print(future_features.shape)
        #print('future_flags')
        #print(future_flags.shape)
        
        history_lstm, (state_h, state_c) \
        = self.historical_lstm(historical_features,
                               (static_context_state_h.unsqueeze(0),
                                static_context_state_c.unsqueeze(0)))
        #print('history_lstm')
        #print(history_lstm.shape)
        #print('state_h')
        #print(state_h.shape)
        #print('state_c')
        #print(state_c.shape)
        
        future_lstm, _ = self.future_lstm(future_features,
                                          (state_h,
                                           state_c))
        #print('future_lstm')
        #print(future_lstm.shape)
        
        # Apply gated skip connection
        input_embeddings = torch.cat((historical_features, future_features), axis=1)
        #print('input_embeddings')
        #print(input_embeddings.shape) 
        
        lstm_layer = torch.cat((history_lstm, future_lstm), axis=1)
        #print('lstm_layer')
        #print(lstm_layer.shape) 
        
        temporal_feature_layer = self.post_seq_encoder_gate_add_norm(lstm_layer, input_embeddings)
        #print('temporal_feature_layer')
        #print(temporal_feature_layer.shape)  
        
        # Static enrichment layers
        expanded_static_context = static_context_enrichment.unsqueeze(1)
        
        enriched = self.static_enrichment((temporal_feature_layer, expanded_static_context))
        #print('enriched')
        #print(enriched.shape)    
        
        # Decoder self attention
        #self.mask = self.get_decoder_mask(enriched)
        #print('enriched')
        #print(enriched.shape)
        x, self_att = self.self_attn_layer(enriched, 
                                           enriched, 
                                           enriched,
                                           mask = self.get_decoder_mask(enriched))
        #print('x')
        #print(x.shape)
        #print('self_att')
        #print(self_att.shape)
        
        x = self.post_attn_gate_add_norm(x, enriched)
        #print('x')
        #print(x.shape)
        
        # Nonlinear processing on outputs
        decoder = self.GRN_positionwise(x)
        #print('decoder')
        #print(decoder.shape)
        
        # Final skip connection
        transformer_layer = self.post_tfd_gate_add_norm(decoder, temporal_feature_layer)
        #print('transformer_layer')
        #print(transformer_layer.shape)
        
        outputs = self.output_feed_forward(transformer_layer[Ellipsis, self.num_encoder_steps:, :])
        #print('outputs')
        #print(outputs.shape)
        
        #ipdb.set_trace()
        
        return outputs
    
    def loss(self, y_hat, y):
        return self.train_criterion.apply(y_hat, y)
    
    def test_loss(self, y_hat, y):
        return self.test_criterion.apply(y_hat, y, self.quantiles[1])
    
    def training_step(self, batch, batch_nb):
        x, y, _ = batch
        
        x = x.to(torch.float)
        y = y.to(torch.float)
#         print('y')
#         print(y.shape)
        y_hat = self.forward(x)
#         print('y_hat')
#         print(y_hat.shape)
        loss = self.loss(y_hat, torch.cat([y, y, y], dim = -1))
        #print(loss.shape)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}
    
    def validation_step(self, batch, batch_nb):
        x, y, _ = batch
        x = x.to(torch.float)
        y = y.to(torch.float)
        y_hat = self.forward(x)
        #print(y_hat.shape)
        #print(torch.cat([y, y, y], dim = -1).shape)
        loss = self.loss(y_hat, torch.cat([y, y, y], dim = -1))
        #print(loss)
        return {'val_loss': loss}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
    
    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y, _ = batch
        x = x.to(torch.float)
        y = y.to(torch.float)
        y_hat = self.forward(x)
        return {'test_loss': self.test_loss(y_hat[Ellipsis, 1], y[Ellipsis, 0])}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
    
    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return [torch.optim.Adam(self.parameters(), lr=self.learning_rate)]
    
    def plot_grad_flow(self, named_parameters):
        ave_grads = []
        layers = []
        for name, p in named_parameters:
            if p.grad is not None:
                if (p.requires_grad) and ("bias" not in name):
                    layers.append(name)
                    ave_grads.append(p.grad.abs().mean())
                    self.logger.experiment.add_histogram(tag=name, values=p.grad,
                                                         global_step=self.trainer.global_step)
            else:
                 print('{} - {}'.format(name, p.requires_grad))
            
        plt.plot(ave_grads, alpha=0.3, color="b")
        plt.hlines(0, 0, len(ave_grads), linewidth=1, color="k" )
        plt.xticks(list(range(0,len(ave_grads), 1)), layers, rotation='vertical')
        plt.xlim(left=0, right=len(ave_grads))
        plt.xlabel("Layers")
        plt.ylabel("average gradient")
        plt.title("Gradient flow")
        plt.grid(True)
        plt.rcParams["figure.figsize"] = (20, 5)
    
    def on_after_backward(self):
        # example to inspect gradient information in tensorboard
        if self.trainer.global_step % 25 == 0:  
            self.plot_grad_flow(self.named_parameters())
    
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(train_dataset, batch_size = self.minibatch_size, shuffle=True, drop_last=True, num_workers=1)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(valid_dataset, batch_size = self.minibatch_size, shuffle=True, drop_last=True, num_workers=1)
    
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(test_dataset, batch_size = self.minibatch_size, shuffle=True, drop_last=True, num_workers=1)