In [25]:
from datetime import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from config import *
from typing import Tuple

from data.util import split_weekdays_and_weekends

from prediction_models.dbn import *
from prediction_models.kelm import *

%reload_ext autoreload
%autoreload 2

In [26]:
DBN_HIDDEN_LAYER_SIZES = CONFIG['DBN_HIDDEN_LAYER_SIZES']
GIBBS_SAMPLING_STEPS = CONFIG['GIBBS_SAMPLING_STEPS']
READ_START_DATE = datetime.strptime(CONFIG['READ_START_DATE'], DATE_FORMAT)
READ_END_DATE = datetime.strptime(CONFIG['READ_END_DATE'], DATE_FORMAT)
TRAIN_START_DATE = datetime.strptime(CONFIG['TRAIN_START_DATE'], DATE_FORMAT)
TRAIN_END_DATE = datetime.strptime(CONFIG['TRAIN_END_DATE'], DATE_FORMAT)
TIME_WINDOW_LENGTH = CONFIG['TIME_WINDOW_LENGTH']

### Prepare training dataset

Load the matrix $C$ constructed from $Q$

In [27]:
mat_c = torch.load(out_path('mat_c.pt'))
TIME_DIM, SPACE_DIM = mat_c.shape
mat_c.shape

torch.Size([2880, 194])

Split $C$ into workdays and weekends data

In [28]:
mat_c_wd, mat_c_we = split_weekdays_and_weekends(mat_c, TRAIN_START_DATE, TRAIN_END_DATE)
assert mat_c_wd.shape[1] == mat_c_we.shape[1] == mat_c.shape[1]
assert mat_c_wd.shape[0] + mat_c_we.shape[0] == mat_c.shape[0]
mat_c_wd.shape, mat_c_we.shape

(torch.Size([2112, 194]), torch.Size([768, 194]))

Create a dataset for pre-training RBMs

In [29]:

class BinaryVectorDataset(Dataset):
    def __init__(self, n_bits: int):
        self.n_bits = n_bits
        self.format_str = f'{{0:0{self.n_bits}b}}'
    
    def __len__(self) -> int:
        return 200
    
    def __getitem__(self, index: int) \
            -> Tuple[torch.TensorType, torch.TensorType]:
        bin_str = self.format_str.format(index)
        return torch.tensor([float(c) for c in bin_str]), torch.tensor([index,])


class RandomBinaryVectorDataset(Dataset):
    def __init__(self, n_samples, bits_per_sample):
        self.n_samples = n_samples
        self.bits_per_sample = bits_per_sample

    def __len__(self) -> int:
        return self.n_samples

    def __getitem__(self, index: int) -> torch.TensorType:
        return (torch.rand(self.bits_per_sample) > 0.5).to(dtype=torch.float32)


dbn_pre_train_loader = DataLoader(RandomBinaryVectorDataset(10000, TIME_WINDOW_LENGTH), batch_size=256)

### Pre-train RBMs inside DBN

Train to reconstruct all possible combinations of 0's and 1's

In [30]:
dbn = DBN(TIME_WINDOW_LENGTH, DBN_HIDDEN_LAYER_SIZES, k=GIBBS_SAMPLING_STEPS)
pre_train_dbn(dbn, dbn_pre_train_loader, print_every=1)

Epoch 0:
	RBM 0: Loss=73.64987182617188
	RBM 1: Loss=8.299970626831055
	RBM 2: Loss=10.80990219116211
	RBM 0: Loss=68.73383331298828
	RBM 1: Loss=8.232904434204102
	RBM 2: Loss=10.564075469970703
	RBM 0: Loss=65.10677337646484
	RBM 1: Loss=8.373233795166016
	RBM 2: Loss=10.32261848449707
	RBM 0: Loss=62.56462478637695
	RBM 1: Loss=8.371500015258789
	RBM 2: Loss=10.13857650756836
	RBM 0: Loss=57.93785095214844
	RBM 1: Loss=8.284734725952148
	RBM 2: Loss=9.936759948730469
	RBM 0: Loss=54.35929489135742
	RBM 1: Loss=7.859210968017578
	RBM 2: Loss=9.712125778198242
	RBM 0: Loss=53.37949752807617
	RBM 1: Loss=7.796375274658203
	RBM 2: Loss=9.559505462646484
	RBM 0: Loss=50.37427520751953
	RBM 1: Loss=7.687204360961914
	RBM 2: Loss=9.353076934814453
	RBM 0: Loss=46.90052795410156
	RBM 1: Loss=7.513519287109375
	RBM 2: Loss=9.220352172851562
	RBM 0: Loss=44.63683319091797
	RBM 1: Loss=7.22064208984375
	RBM 2: Loss=9.053524017333984
	RBM 0: Loss=42.92268753051758
	RBM 1: Loss=6.84453010559082


DBN(
  (rbms): ModuleList(
    (0-2): 3 x RBM()
  )
)

#### Train prediction of the first section

Define the loss function

In [31]:
loss_fn = nn.MSELoss()
optim = torch.optim.Adam(dbn.parameters())
print([p.shape for p in dbn.parameters()])

[torch.Size([1, 60]), torch.Size([1, 14]), torch.Size([14, 60]), torch.Size([1, 14]), torch.Size([1, 13]), torch.Size([13, 14]), torch.Size([1, 13]), torch.Size([1, 12]), torch.Size([12, 13])]


In [35]:
kelm = KELM()
kelm.random_fit(SPACE_DIM, DBN_HIDDEN_LAYER_SIZES[-1], 1)
dbn_features = []
dbn_labels = []
torch.autograd.set_detect_anomaly(True)
dbn.train()
for start_time in range(TIME_DIM - TIME_WINDOW_LENGTH - 1):
    end_time = start_time + TIME_WINDOW_LENGTH
    train_window = mat_c[start_time:end_time].T
    labels = mat_c[end_time]
    features = dbn(train_window)
    
    kelm_X = torch.concat(dbn_features + [features,], dim=0)
    kelm_y = torch.concat(dbn_labels + [labels,])[:, None]

    optim.zero_grad()
    pred = kelm(kelm_X)
    loss = loss_fn(pred, kelm_y)
    loss.backward()
    optim.step()

    print(loss.item())

    dbn_features.append(features.detach())
    dbn_labels.append(labels.detach())

    kelm_X = torch.concat(dbn_features, dim=0)
    kelm_y = torch.concat(dbn_labels)[:, None]
    kelm.fit(kelm_X, kelm_y)

    # kelm.fit(features, label_window)
    # predictions = kelm(features) # [60, 60]
    # print(predictions.shape)
    # print(loss_fn(predictions, label_window))
    # USE THIS TO TRAIN DBN
dbn.train(False)

300.87164306640625
225.3878631591797
295.4187316894531
347.23193359375
195.1676025390625
143.13645935058594
184.87777709960938
122.7111587524414
154.31849670410156
194.88087463378906
142.9027099609375
145.71397399902344
122.37610626220703
91.82546997070312
145.36585998535156
81.17040252685547
113.90858459472656
70.92010498046875
89.09101867675781
94.96537780761719
59.87431716918945
69.29869079589844
126.6322021484375
62.317047119140625
110.71247100830078
46.54066848754883
211.67677307128906
74.9046859741211
184.08676147460938
91.52769470214844
71.27245330810547
119.71369934082031
224.0331268310547
966.8876342773438
127.86042022705078
108.13761138916016
68.83271789550781
56.2658576965332
54.23207092285156
47.234169006347656
39.88899612426758
43.80222702026367
34.466651916503906
78.5123291015625
62.14506149291992
43.747127532958984
61.448246002197266
29.750465393066406
59.040245056152344
123.2210464477539
137.98876953125
31.317447662353516
43.05730438232422
56.302825927734375
51.45318984

DBN(
  (rbms): ModuleList(
    (0-2): 3 x RBM()
  )
)