In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import sys
import math
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary

import datetime
from sklearn import metrics
from sklearn.metrics import f1_score, precision_score, recall_score

In [5]:
sys.path.append( '.data/' )

In [7]:
sys.path

['/home/nesl/anaconda3/envs/iros24/lib/python312.zip',
 '/home/nesl/anaconda3/envs/iros24/lib/python3.12',
 '/home/nesl/anaconda3/envs/iros24/lib/python3.12/lib-dynload',
 '',
 '/home/nesl/anaconda3/envs/iros24/lib/python3.12/site-packages',
 '..',
 '.data/']

In [6]:
from ce_generator import CE5min
from utils import set_seeds, CEDataset, create_src_causal_mask, MultiTaskCEDataset
from loss import focal_loss
from train import train, test, test_iterative
from models import TSTransformer, MultiTaskTSTransformer

ModuleNotFoundError: No module named 'ce_generator'

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [4]:
ce_train_data = np.load('./CE_dataset/ce5min_train_data_10000.npy')
ce_train_labels = np.load('./CE_dataset/ce5min_train_labels_10000.npy')

ce_train_data.shape, ce_train_labels.shape

((10000, 60, 128), (10000, 60))

# Generate new training data
which contains states embbeding

In [5]:
n_data = 10000
ce5 = CE5min(n_data, 'train', simple_label=False)
action_data, event_labels, in_states, out_states, actions, action_labels, windows, t = ce5.generate_event(0)
# for a,l in zip(actions, labels):
#     print(a,l)

# print(action_data)
print(event_labels)
print(in_states)
print(out_states)
print(actions)
print(action_labels)

print(len(action_data))
print(len(event_labels))
print(len(in_states))
print(len(out_states))
print(len(actions))
print(len(action_labels))

print(windows)
print(t)

state_mapping = {'s1_0':0, 's1_1':1, 's2_0':2, 's2_1':3, 's2_2':4, 's3_0':5, 's3_1':6, 's3_2':7}
ce_data, ce_labels, in_states, out_states = ce5.generate_CE_dataset(state_mapping)
ce_train_data = ce_data
ce_train_labels = ce_labels
ce_train_in_states = in_states
ce_train_out_states = out_states

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_1'], ['s1_1'], ['s1_1'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0']]
[['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_0'], ['s1_1'], ['s1_1'], ['s1_1'], ['s1_0'], ['s1_0'

In [6]:
# Sanity check

print(ce_data.shape, ce_labels.shape, in_states.shape, out_states.shape, ce_train_data.shape)
id = np.random.randint(n_data)
print(ce_train_labels[id])
print(in_states[id])
print(ce_train_data[id])

(10000, 60, 128) (10000, 60) (10000, 60) (10000, 60) (10000, 60, 128)
[0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[2 2 2 2 2 2 2 2 2 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]
[[0.         0.7696283  0.         ... 0.12150168 0.10205833 0.2599745 ]
 [0.         1.0695908  0.         ... 0.3475209  0.18219297 0.4804398 ]
 [0.         0.         0.         ... 0.05184354 0.22168872 0.11783302]
 ...
 [0.         0.         0.         ... 0.42290998 0.         0.23560199]
 [0.         0.08531922 0.         ... 0.30346203 0.0619647  0.09413667]
 [0.         1.3432202  0.         ... 0.30655456 0.5112179  0.37887827]]


In [7]:
n_data = 1000
ce5 = CE5min(n_data, 'test', simple_label=False)
action_data, event_labels, in_states, out_states, actions, action_labels, windows, t = ce5.generate_event(1)
# for a,l in zip(actions, labels):
#     print(a,l)

# print(action_data)
print(event_labels)
print(in_states)
print(out_states)
print(actions)
print(action_labels)

print(len(action_data))
print(len(event_labels))
print(len(in_states))
print(len(out_states))
print(len(actions))
print(len(action_labels))

print(windows)
print(t)

state_mapping = {'s1_0':0, 's1_1':1, 's2_0':2, 's2_1':3, 's2_2':4, 's3_0':5, 's3_1':6, 's3_2':7}
ce_data, ce_labels, in_states, out_states = ce5.generate_CE_dataset(state_mapping)
ce_test_data = ce_data
ce_test_labels = ce_labels
ce_test_in_states = in_states
ce_test_out_states = out_states

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[['s2_0'], ['s2_0'], ['s2_0'], ['s2_0'], ['s2_1'], ['s2_1'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2']]
[['s2_0'], ['s2_0'], ['s2_0'], ['s2_1'], ['s2_1'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'], ['s2_2'

In [8]:
# Sanity check

print(ce_data.shape, ce_labels.shape, in_states.shape, out_states.shape, ce_test_data.shape)
id = np.random.randint(n_data)
print(ce_test_labels[id])
print(in_states[id])
print(ce_test_data[id])

(1000, 60, 128) (1000, 60) (1000, 60) (1000, 60) (1000, 60, 128)
[0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[2 2 2 2 2 2 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]
[[0.         0.1919747  0.01926938 ... 0.20723204 0.253298   0.3104067 ]
 [0.         0.         0.         ... 0.23939635 0.06111969 0.35643005]
 [0.         0.1956715  0.02046966 ... 0.20782621 0.2536357  0.31088582]
 ...
 [0.         0.965396   0.         ... 0.39777046 0.4011     0.05979404]
 [0.         0.         0.         ... 0.25798956 0.         0.16314007]
 [0.         0.74343675 0.         ... 0.42359447 0.40279055 0.1620849 ]]


In [9]:
print(in_states.max())

7


In [10]:
batch_size = 128

ce_train_dataset = MultiTaskCEDataset(ce_train_data, ce_train_labels, ce_train_in_states, ce_train_out_states)
ce_train_loader = DataLoader(ce_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
ce_test_dataset = MultiTaskCEDataset(ce_test_data, ce_test_labels, ce_test_in_states, ce_test_out_states)
# ce_test_dataset = CEDataset(ce_test_data, ce_test_labels)
ce_test_loader = DataLoader(ce_test_dataset, batch_size=batch_size, shuffle=False)

# ce_state_test_dataset = CEDataset(ce_test_data, ce_test_labels)
# ce_state_test_loader = DataLoader(ce_state_test_dataset, batch_size=batch_size, shuffle=False)

# Define Transformer model

In [11]:
in_dim_embed = ce_train_data.shape[-1]
out_dim_ce = 4
out_dim_state = 8
# state_dim = 32

src_causal_mask = create_src_causal_mask(ce_train_data.shape[1])

In [12]:
criterion = focal_loss(alpha=torch.tensor([.005, 0.45, 0.45, 0.45]),gamma=2)

In [13]:
n_epochs = 3000
learning_rate = 1e-3

model = MultiTaskTSTransformer(in_dim_embed=in_dim_embed, 
                               out_dim_ce=out_dim_ce, 
                               out_dim_state=out_dim_state, 
                            #    hidden_dim_state=state_dim, 
                               num_head=5, 
                               num_layers=6, 
                               pos_encoding=True)

train(
    model=model,
    train_loader=ce_train_loader,
    val_loader=ce_test_loader, # need to generate validation data
    n_epochs=n_epochs,
    lr=learning_rate,
    criterion=criterion,
    src_mask=src_causal_mask,
    multi_task=True,
    device=device
    )

  0%|          | 1/3000 [00:03<3:02:49,  3.66s/it]

pred labels tensor([[3, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [3, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
pred states tensor([[5, 5, 5,  ..., 5, 5, 5],
        [5, 5, 5,  ..., 5, 5, 5],
        [5, 5, 5,  ..., 5, 5, 5],
        ...,
        [5, 5, 5,  ..., 5, 5, 5],
        [5, 5, 5,  ..., 5, 5, 5],
        [5, 5, 5,  ..., 5, 5, 5]])
Training Loss: 1.780275267890737, Training Accuracy - Label: 0.3885482800535009, Training Accuracy - State: 0.6400992519257567, 
Validation Loss: 1.6148353219032288, Validation Accuracy - Label: 0.3520520420279354, Validation Accuracy - State: 0.7803297787904739
Early-stop counter: 0


  0%|          | 2/3000 [00:06<2:29:17,  2.99s/it]

pred labels tensor([[3, 3, 3,  ..., 0, 0, 0],
        [3, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 3, 0, 0],
        ...,
        [3, 0, 3,  ..., 0, 0, 3],
        [3, 3, 0,  ..., 0, 0, 0],
        [3, 3, 0,  ..., 0, 0, 0]])
pred states tensor([[5, 5, 5,  ..., 5, 5, 5],
        [5, 5, 5,  ..., 5, 5, 5],
        [5, 5, 5,  ..., 5, 5, 5],
        ...,
        [5, 5, 5,  ..., 5, 5, 5],
        [5, 5, 5,  ..., 5, 5, 5],
        [5, 5, 5,  ..., 5, 5, 5]])
Training Loss: 1.5463440025909037, Training Accuracy - Label: 0.36693534405925604, Training Accuracy - State: 0.7948065430303163, 
Validation Loss: 1.4996301233768463, Validation Accuracy - Label: 0.32492989010279416, Validation Accuracy - State: 0.7978978827595711
Early-stop counter: 0


  0%|          | 2/3000 [00:07<3:15:12,  3.91s/it]


KeyboardInterrupt: 

In [None]:
test(
    model=model,
    data_loader=ce_test_loader,
    criterion=criterion,
    src_mask=src_causal_mask,
    multi_task=True
    )

 12%|█▎        | 1/8 [00:00<00:01,  4.59it/s]

Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True state: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

 38%|███▊      | 3/8 [00:00<00:00,  5.29it/s]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True state: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,

 50%|█████     | 4/8 [00:00<00:00,  5.21it/s]

Pred label: tensor([0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
True state: tensor([2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
Pred label: tensor([0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

 75%|███████▌  | 6/8 [00:01<00:00,  5.45it/s]

tensor([0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
True state: tensor([2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

100%|██████████| 8/8 [00:01<00:00,  5.45it/s]

Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
True state: tensor([6, 6, 6, 6, 6, 6, 6, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,




In [None]:
test_iterative(
    model=model,
    data_loader=ce_test_loader,
    criterion=criterion,
    src_mask=src_causal_mask,
    )

 12%|█▎        | 1/8 [00:10<01:13, 10.45s/it]

Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True state: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

 25%|██▌       | 2/8 [00:21<01:03, 10.56s/it]

Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True state: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,

 38%|███▊      | 3/8 [00:30<00:49,  9.93s/it]

Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True state: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,

 50%|█████     | 4/8 [00:40<00:40, 10.01s/it]

Pred label: tensor([0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
True state: tensor([2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
Pred label: tensor([0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

 62%|██████▎   | 5/8 [00:50<00:30, 10.09s/it]

Pred label: tensor([0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
True state: tensor([2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

 75%|███████▌  | 6/8 [01:00<00:20, 10.02s/it]

Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
True state: tensor([2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
Pred label: tensor([0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

 88%|████████▊ | 7/8 [01:10<00:09,  9.94s/it]

Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
True state: tensor([6, 6, 6, 6, 6, 6, 6, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

100%|██████████| 8/8 [01:18<00:00,  9.82s/it]

Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
True label : tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Pred state tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
True state: tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
Pred label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0,




In [None]:
# criterion = nn.CrossEntropyLoss()

# n_epochs = 2000
# learning_rate = 1e-3

# model = TimeSeriesTransformer(input_dim=input_dim, output_dim=output_dim, num_head=5, num_layers=6, pos_encoding=True)

# train(
#     model=model,
#     train_loader=ce_train_loader,
#     val_loader=ce_test_loader, # need to generate validation data
#     n_epochs=n_epochs,
#     lr=learning_rate,
#     criterion=criterion,
#     src_mask=src_causal_mask,
#     device=device
#     )

NameError: name 'TimeSeriesTransformer' is not defined

In [None]:
test(
    model=model,
    data_loader=ce_test_loader,
    criterion=criterion,
    src_mask=src_causal_mask
    )

In [None]:
src_padding_mask = torch.ones([10, 3], dtype=torch.bool, device='cpu')
src_padding_mask[:, :0] = False
src_padding_mask

tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])