In [21]:
import sys
import os
import shutil
import math
import numpy as np
import pandas as pd
import scipy.sparse as ss
from sklearn.preprocessing import StandardScaler
from datetime import datetime
import time
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchsummary import summary
import Metrics
from STGCN import *
from Param import *

def getXSYS_single(data, mode):
    TRAIN_NUM = int(data.shape[0] * TRAINRATIO)
    print('data.shape:',data.shape)
    XS, YS = [], []
    if mode == 'TRAIN':    
        for i in range(TRAIN_NUM - TIMESTEP_OUT - TIMESTEP_IN + 1):
            x = data[i:i+TIMESTEP_IN, :]
            y = data[i+TIMESTEP_IN:i+TIMESTEP_IN+1, :]
            XS.append(x), YS.append(y)
    elif mode == 'TEST':
        for i in range(TRAIN_NUM - TIMESTEP_IN,  data.shape[0] - TIMESTEP_OUT - TIMESTEP_IN + 1):
            x = data[i:i+TIMESTEP_IN, :]
            y = data[i+TIMESTEP_IN:i+TIMESTEP_IN+1, :]
            XS.append(x), YS.append(y)
    XS, YS = np.array(XS), np.array(YS)
    XS, YS = XS[:, np.newaxis, :, :], YS[:, np.newaxis, :]
    return XS, YS

def getXSYS(data, mode):
    TRAIN_NUM = int(data.shape[0] * TRAINRATIO)
    XS, YS = [], []
    if mode == 'TRAIN':    
        for i in range(TRAIN_NUM - TIMESTEP_OUT - TIMESTEP_IN + 1):
            x = data[i:i+TIMESTEP_IN, :]
            y = data[i+TIMESTEP_IN:i+TIMESTEP_IN+TIMESTEP_OUT, :]
            XS.append(x), YS.append(y)
    elif mode == 'TEST':
        for i in range(TRAIN_NUM - TIMESTEP_IN,  data.shape[0] - TIMESTEP_OUT - TIMESTEP_IN + 1):
            x = data[i:i+TIMESTEP_IN, :]
            y = data[i+TIMESTEP_IN:i+TIMESTEP_IN+TIMESTEP_OUT, :]
            XS.append(x), YS.append(y)
    XS, YS = np.array(XS), np.array(YS)
    XS, YS = XS[:, np.newaxis, :, :], YS[:, np.newaxis, :]
    return XS, YS

def getModel(name):
    if name == 'STGCN':
        ks, kt, bs, T, n, p = 3, 3, [[CHANNEL, 32, 64], [64, 32, 128]], TIMESTEP_IN, N_NODE, 0
        A = load_matrix(ADJPATH)
        W = weight_matrix(A)
        L = scaled_laplacian(W)
        Lk = cheb_poly(L, ks)
        Lk = torch.Tensor(Lk.astype(np.float32)).to(device)
        model = STGCN(ks, kt, bs, T, n, Lk, p).to(device)
        return model
    else:
        return None
    
def evaluateModel(model, criterion, data_iter):
    model.eval()
    l_sum, n = 0.0, 0
    with torch.no_grad():
        for x, y in data_iter:
            y_pred = model(x)
            l = criterion(y_pred, y)
            l_sum += l.item() * y.shape[0]
            n += y.shape[0]
        return l_sum / n

def predictModel(model, data_iter):
    YS_pred = []
    model.eval()
    with torch.no_grad():
        for x, y in data_iter:
            YS_pred_batch = model(x)
            YS_pred_batch = YS_pred_batch.cpu().numpy()
            YS_pred.append(YS_pred_batch)
        YS_pred = np.vstack(YS_pred)
    return YS_pred

def predictModel_multi(model, data_iter):
    YS_pred_multi = []
    model.eval()
    with torch.no_grad():
        for x, y in data_iter:
            print('8'*40)
            print('x.shape',x.shape)
            print('y.shape',y.shape)
            XS_pred_multi_batch, YS_pred_multi_batch = [x], []
            for i in range(TIMESTEP_OUT):
                tmp_torch = torch.cat(XS_pred_multi_batch, axis=2)[:, :, i:, :]
                print('tmp_torch.shape',tmp_torch.shape)
                yhat = model(tmp_torch)
                print('yhat.shape',yhat.shape)
                XS_pred_multi_batch.append(yhat)
                YS_pred_multi_batch.append(yhat)
            YS_pred_multi_batch = torch.cat(YS_pred_multi_batch, axis=2).cpu().numpy()
            YS_pred_multi.append(YS_pred_multi_batch)
        YS_pred_multi = np.vstack(YS_pred_multi)
    return YS_pred_multi

def trainModel(name, mode, XS, YS):
    print('Model Training Started ...', time.ctime())
    print('TIMESTEP_IN, TIMESTEP_OUT', TIMESTEP_IN, TIMESTEP_OUT)
    model = getModel(name)
#     summary(model, (CHANNEL, TIMESTEP_IN, N_NODE), device="cuda:{}".format(GPU))
    XS_torch, YS_torch = torch.Tensor(XS).to(device), torch.Tensor(YS).to(device)
    trainval_data = torch.utils.data.TensorDataset(XS_torch, YS_torch)
    trainval_size = len(trainval_data)
    train_size = int(trainval_size * (1-TRAINVALSPLIT))
    train_data = torch.utils.data.Subset(trainval_data, list(range(0, train_size)))
    val_data = torch.utils.data.Subset(trainval_data, list(range(train_size, trainval_size)))
    train_iter = torch.utils.data.DataLoader(train_data, BATCHSIZE, shuffle=True)
    val_iter = torch.utils.data.DataLoader(val_data, BATCHSIZE, shuffle=True)
    
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARN)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)
    
    min_val_loss = np.inf
    wait = 0
    for epoch in range(EPOCH):
        starttime = datetime.now()     
        loss_sum, n = 0.0, 0
        model.train()
        for x, y in train_iter:
            optimizer.zero_grad()
            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item() * y.shape[0]
            n += y.shape[0]
        # scheduler.step()
        train_loss = loss_sum / n
        val_loss = evaluateModel(model, criterion, val_iter)
        if val_loss < min_val_loss:
            wait = 0
            min_val_loss = val_loss
            torch.save(model.state_dict(), PATH + '/' + name + '.pt')
        else:
            wait += 1
            if wait == PATIENCE:
                print('Early stopping at epoch: %d' % epoch)
                break
        endtime = datetime.now()
        epoch_time = (endtime - starttime).seconds
        print("epoch", epoch, "time used:", epoch_time," seconds ", "train loss:", train_loss, ", validation loss:", val_loss)
    
    torch_score = evaluateModel(model, criterion, train_iter)
    YS_pred = predictModel(model, torch.utils.data.DataLoader(trainval_data, BATCHSIZE, shuffle=False))
    print('YS.shape, YS_pred.shape,', YS.shape, YS_pred.shape)
    YS, YS_pred = scaler.inverse_transform(np.squeeze(YS)), scaler.inverse_transform(np.squeeze(YS_pred))
    print('YS.shape, YS_pred.shape,', YS.shape, YS_pred.shape)
    MSE, RMSE, MAE, MAPE = Metrics.evaluate(YS, YS_pred)
    f = open(PATH + '/' + name + '_prediction_scores.txt', 'a')
    f.write("%s, %s, Torch MSE, %.10e, %.10f\n" % (name, mode, torch_score, torch_score))
    f.write("%s, %s, MSE, RMSE, MAE, MAPE, %.10f, %.10f, %.10f, %.10f\n" % (name, mode, MSE, RMSE, MAE, MAPE))
    f.close()
    print('*' * 40)
    print("%s, %s, Torch MSE, %.10e, %.10f\n" % (name, mode, torch_score, torch_score))
    print("%s, %s, MSE, RMSE, MAE, MAPE, %.10f, %.10f, %.10f, %.10f\n" % (name, mode, MSE, RMSE, MAE, MAPE))
    print('Model Training Ended ...', time.ctime())
        
def testModel(name, mode, XS, YS, YS_multi):
    print('Model Testing Started ...', time.ctime())
    print('TIMESTEP_IN, TIMESTEP_OUT', TIMESTEP_IN, TIMESTEP_OUT)
    XS_torch, YS_torch = torch.Tensor(XS).to(device), torch.Tensor(YS).to(device)
    test_data = torch.utils.data.TensorDataset(XS_torch, YS_torch)
    test_iter = torch.utils.data.DataLoader(test_data, BATCHSIZE, shuffle=False)
    model = getModel(name)
    model.load_state_dict(torch.load('STGCN.pt'))
    criterion = nn.MSELoss()
    torch_score = evaluateModel(model, criterion, test_iter)
    YS_pred_multi = predictModel_multi(model, test_iter)
    print('YS_multi.shape, YS_pred_multi.shape,', YS_multi.shape, YS_pred_multi.shape)
    YS_multi, YS_pred_multi = np.squeeze(YS_multi), np.squeeze(YS_pred_multi)
    for i in range(YS_multi.shape[1]):
        YS_multi[:, i, :] = scaler.inverse_transform(YS_multi[:, i, :])
        YS_pred_multi[:, i, :] = scaler.inverse_transform(YS_pred_multi[:, i, :])
    print('YS_multi.shape, YS_pred_multi.shape,', YS_multi.shape, YS_pred_multi.shape)
#     np.save(PATH + '/' + MODELNAME + '_prediction.npy', YS_pred_multi)
#     np.save(PATH + '/' + MODELNAME + '_groundtruth.npy', YS_multi)
    MSE, RMSE, MAE, MAPE = Metrics.evaluate(YS_multi, YS_pred_multi)
    print('*' * 40)
    print("%s, %s, Torch MSE, %.10e, %.10f\n" % (name, mode, torch_score, torch_score))
#     f = open(PATH + '/' + name + '_prediction_scores.txt', 'a')
#     f.write("%s, %s, Torch MSE, %.10e, %.10f\n" % (name, mode, torch_score, torch_score))
#     print("all pred steps, %s, %s, MSE, RMSE, MAE, MAPE, %.10f, %.10f, %.10f, %.10f\n" % (name, mode, MSE, RMSE, MAE, MAPE))
#     f.write("all pred steps, %s, %s, MSE, RMSE, MAE, MAPE, %.10f, %.10f, %.10f, %.10f\n" % (name, mode, MSE, RMSE, MAE, MAPE))
#     for i in [2, 5, 11]:
#         MSE, RMSE, MAE, MAPE = Metrics.evaluate(YS_multi[:, i, :], YS_pred_multi[:, i, :])
#         print("%d step, %s, %s, MSE, RMSE, MAE, MAPE, %.10f, %.10f, %.10f, %.10f\n" % (i, name, mode, MSE, RMSE, MAE, MAPE))
#         f.write("%d step, %s, %s, MSE, RMSE, MAE, MAPE, %.10f, %.10f, %.10f, %.10f\n" % (i, name, mode, MSE, RMSE, MAE, MAPE))
#     f.close()
    print('Model Testing Ended ...', time.ctime())
        
################# Parameter Setting #######################
MODELNAME = 'STGCN'
KEYWORD = 'pred_' + DATANAME + '_' + MODELNAME + '_' + datetime.now().strftime("%y%m%d%H%M")
PATH = '../' + KEYWORD
torch.manual_seed(100)
torch.cuda.manual_seed(100)
np.random.seed(100)
torch.backends.cudnn.deterministic = True
###########################################################
param = sys.argv
if len(param) == 2:
    GPU = param[-1]
else:
    GPU = '3'
device = torch.device("cuda:{}".format(GPU)) if torch.cuda.is_available() else torch.device("cpu")
###########################################################

data = pd.read_hdf(FLOWPATH).values
scaler = StandardScaler()
data = scaler.fit_transform(data)
print('data.shape', data.shape)


data.shape (34272, 207)


In [3]:
if not os.path.exists(PATH):
    os.makedirs(PATH)
currentPython = sys.argv[0]
shutil.copy2(currentPython, PATH)
shutil.copy2('Param.py', PATH)

print(KEYWORD, 'training started', time.ctime())
trainXS, trainYS = getXSYS_single(data, 'TRAIN')
print('TRAIN XS.shape YS,shape', trainXS.shape, trainYS.shape)
trainModel(MODELNAME, 'train', trainXS, trainYS)

print(KEYWORD, 'testing started', time.ctime())
testXS, testYS = getXSYS_single(data, 'TEST')
testXS_multi, testYS_multi = getXSYS(data, 'TEST')
print('TEST XS.shape, YS.shape, XS_multi.shape, YS_multi.shape', testXS.shape, testYS.shape, testXS_multi.shape, testYS_multi.shape)
testModel(MODELNAME, 'test', testXS, testYS, testYS_multi)

pred_METR-LA_STGCN_2105300113 training started Sun May 30 01:13:56 2021
data.shape: (34272, 207)
TRAIN XS.shape YS,shape (27394, 1, 12, 207) (27394, 1, 1, 207)
Model Training Started ... Sun May 30 01:13:57 2021
TIMESTEP_IN, TIMESTEP_OUT 12 12
epoch 0 time used: 10  seconds  train loss: 0.13095902523298308 , validation loss: 0.09504216520725192
YS.shape, YS_pred.shape, (27394, 1, 1, 207) (27394, 1, 1, 207)
YS.shape, YS_pred.shape, (27394, 207) (27394, 207)
MSE: 39.1961344396
RMSE: 6.2606816274
MAE: 3.1770077839
0.0 0.0
MAPE: 6.9630661724 %
PCC: 0.9478494283
0.0 0.0
****************************************
STGCN, train, Torch MSE, 1.1002553005e-01, 0.1100255300

STGCN, train, MSE, RMSE, MAE, MAPE, 39.1869610845, 6.2599489682, 3.1741758748, 0.0696306617

Model Training Ended ... Sun May 30 01:14:20 2021
pred_METR-LA_STGCN_2105300113 testing started Sun May 30 01:14:20 2021
data.shape: (34272, 207)
TEST XS.shape, YS.shape, XS_multi.shape, YS_multi.shape (6844, 1, 12, 207) (6844, 1, 1, 207

In [24]:
x = torch.randn(64, 2,12,1)

for i in range(12):
    tmp_torch = torch.cat([x], axis=2)[:, :, i:, :]
    print(i)
    print(tmp_torch.shape)

0
torch.Size([64, 2, 12, 1])
1
torch.Size([64, 2, 11, 1])
2
torch.Size([64, 2, 10, 1])
3
torch.Size([64, 2, 9, 1])
4
torch.Size([64, 2, 8, 1])
5
torch.Size([64, 2, 7, 1])
6
torch.Size([64, 2, 6, 1])
7
torch.Size([64, 2, 5, 1])
8
torch.Size([64, 2, 4, 1])
9
torch.Size([64, 2, 3, 1])
10
torch.Size([64, 2, 2, 1])
11
torch.Size([64, 2, 1, 1])


In [22]:
testModel('STGCN', 'test', testXS, testYS, testYS_multi)

Model Testing Started ... Sun May 30 01:27:35 2021
TIMESTEP_IN, TIMESTEP_OUT 12 12
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 2

x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 

yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])


yhat.shape torch.Size([64, 1, 1, 207])
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])


x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 

yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])


yhat.shape torch.Size([64, 1, 1, 207])
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])


yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])


yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])


x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 

yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])


yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
x.shape torch.Size([64, 1, 12, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
tmp_torch.shape torch.Size([64, 1, 12, 207])
yhat.shape torch.Size([64, 1, 1, 207])
