In [1]:
#!/usr/bin/env python
# coding: utf-8


import numpy as np
from numpy import load
import torch
import torch.nn as nn
import pickle
import argparse

from timeit import default_timer as timer
from tensorboardX import SummaryWriter
from datetime import datetime
from source.transformer import Seq2SeqTransformer
from source.train import train_epoch, evaluate
from source.Attention_LSTM import RNNModel
from source.train import train_epoch_lstm, evaluate_lstm
from DataInfo import * 

parser = argparse.ArgumentParser(description='Train Config')

parser.add_argument('--model', type=str,default='Trans')
parser.add_argument('--epoch', type=int, default=1000)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lr_initial', type=float, default=1e-3)
parser.add_argument('--hid_dim', type=int, default=128)
parser.add_argument('--emb_dim', type=int, default=128)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--num_head', type=int, default=8)
parser.add_argument('--mbr_no', type=int, default=None)
parser.add_argument('--brn_no', type=int, default=None)
parser.add_argument('--data', type=str, default='1111Train_08')
# parser.add_argument('--data', type=str, default='1516Train')
# parser.add_argument('--data', type=str, default='_10_0823Train')
parser.add_argument('--trainname', type=str, default=None)
parser.add_argument('--device', type=str, default='cuda:0')

# args = parser.parse_args()
args = parser.parse_args(args=[])
print(args.model)
print(args.epoch)
print(args.batch_size)
print(args.lr_initial)
print(args.emb_dim)
print(args.hid_dim)
print(args.num_head)
print(args.num_layers)
print(args.mbr_no)
print(args.brn_no)

modeltype = args.model
if modeltype not in ['Trans', 'ALSTM', 'LSTM']:
    raise ValueError

datasubfix = args.data

if args.trainname:
    trainname = args.trainname
else:
    now = datetime.now()
    now.strftime("%m/%d/%Y, %H:%M:%S")
    date_time = now.strftime("%m_%d_%Y")
    trainname = date_time
device = args.device

num_epochs = args.epoch
bptt = 39
TGT_VOCAB_SIZE = 3
EMB_SIZE = args.emb_dim
NHEAD = args.num_head
FFN_HID_DIM = args.hid_dim
BATCH_SIZE = args.batch_size
lr_init = args.lr_initial
NUM_ENCODER_LAYERS = args.num_layers // 2
NUM_DECODER_LAYERS = args.num_layers // 2
NUM_LAYERS = args.num_layers

if args.mbr_no:
    mbrnlist = [(args.mbr_no, args.brn_no)]
mbrnlist=[(42,1)]
print(mbrnlist)

# load array
for mbr, brn in mbrnlist:
    DataSubfix = str(mbr) + '_' + str(brn) + datasubfix
    XDataname = 'Train_ORD' + '_' + DataSubfix + '.npy'
    YDataname = 'Train_ORD_Label_' + '_' + DataSubfix + '.npy'
    XData = load('/Data/LOBData/TrainData/' + XDataname)
    YData = load('/Data/LOBData/TrainData/' + YDataname)
    Xdata = []
    Ydata = []
    Xtrain_data = []
    Ytrain_data = []
    Xtest_data = []
    Ytest_data = []
    
    for idx in range(len(XData) // 39):
        if np.isinf(XData[39 * idx:39 * (idx + 1)][:, :].tolist()).any():
            print(np.isinf(XData[39 * idx:39 * (idx + 1)][:, :].tolist()).any())
            print(mbr, brn)
            print(XDataname)
            print(XData[39 * idx:39 * (idx + 1)][:, :].tolist())
            raise RuntimeError
            continue
        if np.isinf(YData[39 * idx:39 * (idx + 1)][:].tolist()).any():
            print(np.isinf(YData[39 * idx:39 * (idx + 1)].tolist()).any())
            raise RuntimeError
            continue
        if idx < (len(XData) // 39) * 0.9:
            Xtrain_data.append(XData[39 * idx:39 * (idx + 1)][:, :-1].tolist())
            Ytrain_data.append(YData[39 * idx:39 * (idx + 1)].tolist())
        else:
            Xtest_data.append(XData[39 * idx:39 * (idx + 1)][:, :-1].tolist())
            Ytest_data.append(YData[39 * idx:39 * (idx + 1)].tolist())
    Xtrain_data = np.vstack(Xtrain_data)
    Ytrain_data = np.vstack(Ytrain_data)
    Xtest_data = np.vstack(Xtest_data)
    Ytest_data = np.vstack(Ytest_data)

    Xtrain_data = torch.FloatTensor(Xtrain_data)
    Xtest_data = torch.FloatTensor(Xtest_data)
    Ytrain_data = torch.LongTensor(Ytrain_data)
    Ytest_data = torch.LongTensor(Ytest_data)

    Ytrain_data = Ytrain_data.view(-1)
    Ytest_data = Ytest_data.view(-1)
    
    torch.manual_seed(0)

    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SRC_VOCAB_SIZE = Xtrain_data.shape[1]

    if modeltype == 'Trans':
        model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                   NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
    if modeltype == "ALSTM":
        model = RNNModel(rnn_type='LSTM', ntoken=SRC_VOCAB_SIZE, ninp=EMB_SIZE, nhid=FFN_HID_DIM, nlayers=NUM_LAYERS,
                         proj_size=TGT_VOCAB_SIZE,
                         attention_width=39)
    if modeltype == "LSTM":
        model = RNNModel(rnn_type='LSTM', ntoken=SRC_VOCAB_SIZE, ninp=EMB_SIZE, nhid=FFN_HID_DIM, nlayers=NUM_LAYERS,
                         proj_size=TGT_VOCAB_SIZE,
                         attention=False)

    summary = SummaryWriter()
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    model = model.to(device)

    loss_fn = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_init, betas=(0.9, 0.98), eps=1e-9)
    Val_loss = []
    Train_loss = []
    Accuracy = []
    F1score = []
    NUM_EPOCHS = num_epochs
    best_val_loss = 100000000
    for epoch in range(1, NUM_EPOCHS + 1):
        start_time = timer()
        if modeltype == 'Trans':
            train_loss, _ = train_epoch(model, optimizer, Xtrain_data, Ytrain_data, loss_fn, device, BATCH_SIZE, bptt)
        else:
            train_loss, _ = train_epoch_lstm(model, optimizer, Xtrain_data, Ytrain_data, loss_fn, device, BATCH_SIZE,
                                             bptt)
        end_time = timer()
        if modeltype == 'Trans':
            val_loss, acc, prec, reca, f1sc, confusion = evaluate(model, Xtest_data, Ytest_data, loss_fn, device,
                                                                  BATCH_SIZE, bptt)
        else:
            val_loss, acc, prec, reca, f1sc, confusion = evaluate_lstm(model, Xtest_data, Ytest_data, loss_fn, device,
                                                                       BATCH_SIZE, bptt)
        print((
            f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_confusion = confusion
            best_acc = acc
            best_prec = prec
            best_reca = reca
            best_f1sc = f1sc
            best_model = model
        Val_loss.append(val_loss)
        Train_loss.append(train_loss)
        Accuracy.append(acc)
        F1score.append(f1sc)

#     PATH = 'results/best_model_' + modeltype + '_' + str(mbr) + '_' + str(brn) + trainname
# #     torch.save(best_model.state_dict(), PATH)

#     file_name = 'results/result_' + modeltype + '_' + str(mbr) + '_' + str(brn) + trainname + '.txt'
#     text_to_append = PATH + '\t' + "Acc:" + str(best_acc) + '\t' + "prec:" + str(best_prec) + '\t' + "recall:" + str(
#         best_reca) + '\t' + "f1sc:" + str(best_f1sc)
#     print(text_to_append)

#     with open(file_name, "a+") as file_object:
#         # Move read cursor to the start of file.
#         file_object.seek(0)
#         # If file is not empty then append '\n'
#         data = file_object.read(100)
#         if len(data) > 0:
#             file_object.write("\n")
#         # Append text at the end of file
#         file_object.write(text_to_append)

#     with open("results/Val_loss_" + modeltype + '_' + str(mbr) + '_' + str(brn) + trainname, "wb") as fp:  # Pickling
#         pickle.dump(Val_loss, fp)
#     with open("results/Train_loss_" + modeltype + '_' + str(mbr) + '_' + str(brn) + trainname, "wb") as fp:  # Pickling
#         pickle.dump(Train_loss, fp)
#     with open("results/Accuracy_" + modeltype + '_' + str(mbr) + '_' + str(brn) + trainname, "wb") as fp:  # Pickling
#         pickle.dump(Accuracy, fp)
#     with open("results/F1_" + modeltype + '_' + str(mbr) + '_' + str(brn) + trainname, "wb") as fp:  # Pickling
#         pickle.dump(F1score, fp)


Trans
1000
64
0.001
128
128
8
4
None
None
[(42, 1)]
4992
1270
Acc: 0.25440705128205127
Prec 0.39363479758828596
Recall 0.3785366756790882
F1 0.38593813128107
Time elapsed 0.004464149475097656
2496
130
Acc: 0.052083333333333336
Prec 0.5066852717299017
Recall 0.40529221397296783
F1 0.45035233607469427
Epoch: 1, Train loss: 0.001, Val loss: 0.007, Epoch time = 0.521s
4992
1880
Acc: 0.3766025641025641
Prec 0.6406663361518152
Recall 0.4641395889998961
F1 0.53830017232527
Time elapsed 0.0057392120361328125
2496
266
Acc: 0.10657051282051282
Prec 0.7622065593256857
Recall 0.7688531723368959
F1 0.7655154387440887
Epoch: 2, Train loss: 0.001, Val loss: 0.002, Epoch time = 0.043s
4992
2384
Acc: 0.4775641025641026
Prec 0.6677716694274537
Recall 0.6626969261469107
F1 0.6652246196108431
Time elapsed 0.007207155227661133
2496
248
Acc: 0.09935897435897435
Prec 0.734638499166801
Recall 0.7040636643006716
F1 0.7190261984636515
Epoch: 3, Train loss: 0.001, Val loss: 0.003, Epoch time = 0.051s
4992
2303
A

Time elapsed 0.006134033203125
2496
278
Acc: 0.11137820512820513
Prec 0.7919212919212919
Recall 0.7969493622691796
F1 0.7944273712898475
Epoch: 29, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.042s
4992
2607
Acc: 0.5222355769230769
Prec 0.7341967201715716
Recall 0.7220893743899984
F1 0.7280927179456281
Time elapsed 0.005861043930053711
2496
279
Acc: 0.11177884615384616
Prec 0.7946783993295621
Recall 0.79992555274537
F1 0.7972933429787893
Epoch: 30, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.047s
4992
2601
Acc: 0.5210336538461539
Prec 0.7325696426318619
Recall 0.7225895517627071
F1 0.7275453733769148
Time elapsed 0.004468441009521484
2496
281
Acc: 0.11258012820512821
Prec 0.8001031121109862
Recall 0.806137989176253
F1 0.8031092136963942
Epoch: 31, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.044s
4992
2588
Acc: 0.5184294871794872
Prec 0.726771283230962
Recall 0.7183591791052514
F1 0.7225407477399091
Time elapsed 0.006489992141723633
2496
281
Acc: 0.11258012820512821


4992
2680
Acc: 0.5368589743589743
Prec 0.7573893484591698
Recall 0.7449036769723888
F1 0.7510946280335513
Time elapsed 0.006703615188598633
2496
283
Acc: 0.11338141025641026
Prec 0.8055198055198055
Recall 0.8128705365641403
F1 0.809178477526978
Epoch: 58, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.048s
4992
2650
Acc: 0.5308493589743589
Prec 0.7464086666601606
Recall 0.735591227951878
F1 0.7409604678901863
Time elapsed 0.004557371139526367
2496
283
Acc: 0.11338141025641026
Prec 0.8088432314821953
Recall 0.8089391096244323
F1 0.8088911677121967
Epoch: 59, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.046s
4992
2649
Acc: 0.5306490384615384
Prec 0.7488139324930055
Recall 0.735231341206449
F1 0.741960480125358
Time elapsed 0.00524449348449707
2496
283
Acc: 0.11338141025641026
Prec 0.8091284105302797
Recall 0.8084189986674281
F1 0.8087735490345755
Epoch: 60, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.043s
4992
2677
Acc: 0.5362580128205128
Prec 0.7554525895277743
Recall 0

Time elapsed 0.005711078643798828
2496
283
Acc: 0.11338141025641026
Prec 0.8063961877915365
Recall 0.8113051045660983
F1 0.8088431981214989
Epoch: 86, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.043s
4992
2699
Acc: 0.5406650641025641
Prec 0.7604755791061314
Recall 0.7514219686740041
F1 0.7559216662788486
Time elapsed 0.00570225715637207
2496
280
Acc: 0.11217948717948718
Prec 0.7989860741614693
Recall 0.800800902885426
F1 0.7998924591353652
Epoch: 87, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.046s
4992
2730
Acc: 0.546875
Prec 0.7747954159037844
Recall 0.7578762874868139
F1 0.7662424667565165
Time elapsed 0.005478620529174805
2496
282
Acc: 0.11298076923076923
Prec 0.8054140078654609
Recall 0.8064932283593048
F1 0.8059532568270532
Epoch: 88, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.044s
4992
2741
Acc: 0.5490785256410257
Prec 0.774586179351331
Recall 0.7661900158184757
F1 0.7703652209457484
Time elapsed 0.00577235221862793
2496
276
Acc: 0.11057692307692307
Prec 0.

4992
2760
Acc: 0.5528846153846154
Prec 0.7826205776615872
Recall 0.7710290175776869
F1 0.7767815561237457
Time elapsed 0.005741119384765625
2496
287
Acc: 0.11498397435897435
Prec 0.8174463937621832
Recall 0.8250404530744336
F1 0.8212258678022775
Epoch: 115, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.048s
4992
2783
Acc: 0.5574919871794872
Prec 0.7885519953508217
Recall 0.7769527576013511
F1 0.7827094055697014
Time elapsed 0.004460573196411133
2496
280
Acc: 0.11217948717948718
Prec 0.7989860741614693
Recall 0.800800902885426
F1 0.7998924591353652
Epoch: 116, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.044s
4992
2754
Acc: 0.5516826923076923
Prec 0.7820316886264487
Recall 0.7656795753551623
F1 0.7737692490799035
Time elapsed 0.00582432746887207
2496
284
Acc: 0.11378205128205128
Prec 0.8089071214071214
Recall 0.8145413505207908
F1 0.8117144590790553
Epoch: 117, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.044s
4992
2776
Acc: 0.5560897435897436
Prec 0.7849204434630032
Re

Time elapsed 0.0068361759185791016
2496
271
Acc: 0.10857371794871795
Prec 0.7777744439045019
Recall 0.7735001767697369
F1 0.7756314218376571
Epoch: 143, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.045s
4992
2845
Acc: 0.5699118589743589
Prec 0.8095187390366472
Recall 0.7966914998632874
F1 0.8030539001074349
Time elapsed 0.004744768142700195
2496
278
Acc: 0.11137820512820513
Prec 0.7935800571989585
Recall 0.7956337874955808
F1 0.7946055953363877
Epoch: 144, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.046s
4992
2829
Acc: 0.5667067307692307
Prec 0.803382075375516
Recall 0.7932297693232129
F1 0.7982736448366671
Time elapsed 0.005566835403442383
2496
260
Acc: 0.10416666666666667
Prec 0.7593682258944292
Recall 0.7363156427619592
F1 0.7476642826197631
Epoch: 145, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.044s
4992
2838
Acc: 0.5685096153846154
Prec 0.8074298357188204
Recall 0.7953206599077811
F1 0.8013295039064285
Time elapsed 0.005717754364013672
2496
265
Acc: 0.10616987

4992
2865
Acc: 0.5739182692307693
Prec 0.8131997849644909
Recall 0.806168843352817
F1 0.8096690507593629
Time elapsed 0.004601716995239258
2496
259
Acc: 0.10376602564102565
Prec 0.7354540373841845
Recall 0.7469541214544071
F1 0.7411594722937772
Epoch: 171, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.045s
4992
2882
Acc: 0.577323717948718
Prec 0.8204466395293587
Recall 0.8092657654123535
F1 0.8148178484810229
Time elapsed 0.004947185516357422
2496
261
Acc: 0.1045673076923077
Prec 0.7475497475497476
Recall 0.7466141796524436
F1 0.747081670699067
Epoch: 172, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.043s
4992
2865
Acc: 0.5739182692307693
Prec 0.8142336652680995
Recall 0.8040829640438573
F1 0.8091264801146169
Time elapsed 0.0062143802642822266
2496
234
Acc: 0.09375
Prec 0.675659532126574
Recall 0.6775770988006853
F1 0.676616956846125
Epoch: 173, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.047s
4992
2877
Acc: 0.5763221153846154
Prec 0.8172697858579429
Recall 0.81031247

4992
2922
Acc: 0.5853365384615384
Prec 0.8334766605709519
Recall 0.8219710678427985
F1 0.8276838814693493
Time elapsed 0.005236387252807617
2496
216
Acc: 0.08653846153846154
Prec 0.6224867724867725
Recall 0.6268968752549564
F1 0.6246840404304211
Epoch: 200, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.044s
4992
2912
Acc: 0.5833333333333334
Prec 0.8275508169775088
Recall 0.8210772303833519
F1 0.8243013139235331
Time elapsed 0.005706787109375
2496
224
Acc: 0.08974358974358974
Prec 0.6471197844422124
Recall 0.645984607435207
F1 0.6465516976700296
Epoch: 201, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.045s
4992
2914
Acc: 0.5837339743589743
Prec 0.8300434608057324
Recall 0.8214006511274166
F1 0.825699439954747
Time elapsed 0.0056684017181396484
2496
226
Acc: 0.09054487179487179
Prec 0.6552129370925409
Recall 0.6553432052432623
F1 0.6552780646936306
Epoch: 202, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.045s
4992
2904
Acc: 0.5817307692307693
Prec 0.824993649983822
Recal

4992
2967
Acc: 0.5943509615384616
Prec 0.8497783245706109
Recall 0.8359013324124187
F1 0.8427827088274786
Time elapsed 0.0052301883697509766
2496
229
Acc: 0.09174679487179487
Prec 0.6645544931259217
Recall 0.6579692556634305
F1 0.661245479470775
Epoch: 228, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.048s
4992
2931
Acc: 0.5871394230769231
Prec 0.8330282807391011
Recall 0.8303796303818043
F1 0.8317018468327414
Time elapsed 0.0044100284576416016
2496
232
Acc: 0.09294871794871795
Prec 0.6693240169955654
Recall 0.6708292540317098
F1 0.6700757901855009
Epoch: 229, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.043s
4992
2964
Acc: 0.59375
Prec 0.8438254937791818
Recall 0.8379494995188227
F1 0.8408772315098828
Time elapsed 0.006455183029174805
2496
230
Acc: 0.0921474358974359
Prec 0.6561683634854366
Recall 0.6719493622691796
F1 0.6639651061576352
Epoch: 230, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.044s
4992
2959
Acc: 0.5927483974358975
Prec 0.8429531390312254
Recall 0.83

4992
3003
Acc: 0.6015625
Prec 0.8544023880906096
Recall 0.854510777006229
F1 0.854456579111101
Time elapsed 0.00704193115234375
2496
233
Acc: 0.09334935897435898
Prec 0.6685201686341417
Recall 0.6745958091974654
F1 0.6715442472103723
Epoch: 257, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.045s
4992
3001
Acc: 0.6011618589743589
Prec 0.8555981666712689
Recall 0.8498898353658038
F1 0.8527344480207275
Time elapsed 0.004670858383178711
2496
239
Acc: 0.09575320512820513
Prec 0.6829208833947179
Recall 0.6906070680699464
F1 0.6867424700051284
Epoch: 258, Train loss: 0.000, Val loss: 0.002, Epoch time = 0.046s
4992
3025
Acc: 0.6059695512820513
Prec 0.8659697207381339
Recall 0.8532988654523814
F1 0.8595876015617294
Time elapsed 0.005646705627441406
2496
220
Acc: 0.08814102564102565
Prec 0.6428269254697025
Recall 0.6380265698512415
F1 0.6404177523248822
Epoch: 259, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.044s
4992
3026
Acc: 0.6061698717948718
Prec 0.862378202766319
Recall 0.8599

4992
3045
Acc: 0.6099759615384616
Prec 0.8682279006973785
Recall 0.8684599511942636
F1 0.8683439104429105
Time elapsed 0.0077817440032958984
2496
220
Acc: 0.08814102564102565
Prec 0.6356379950354126
Recall 0.637491161513149
F1 0.636563229539953
Epoch: 285, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.051s
4992
3055
Acc: 0.6119791666666666
Prec 0.8747497148119301
Recall 0.8657334475922586
F1 0.8702182276080084
Time elapsed 0.004960536956787109
2496
229
Acc: 0.09174679487179487
Prec 0.6579828217759253
Recall 0.6611205161676321
F1 0.6595479372224107
Epoch: 286, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.048s
4992
3044
Acc: 0.6097756410256411
Prec 0.8694340145857437
Recall 0.8647662446988229
F1 0.8670938477623213
Time elapsed 0.004363298416137695
2496
237
Acc: 0.09495192307692307
Prec 0.6785705856967539
Recall 0.6822988904299585
F1 0.6804296309401429
Epoch: 287, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.044s
4992
3047
Acc: 0.6103766025641025
Prec 0.8686724867117652
R

4992
3052
Acc: 0.6113782051282052
Prec 0.8701538428924146
Recall 0.86755256825984
F1 0.8688512585758542
Time elapsed 0.0073359012603759766
2496
216
Acc: 0.08653846153846154
Prec 0.6201271928053099
Recall 0.6308283021946642
F1 0.6254319769899167
Epoch: 313, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.050s
4992
3086
Acc: 0.6181891025641025
Prec 0.8811310463091796
Recall 0.8778490305393625
F1 0.8794869765284443
Time elapsed 0.006742715835571289
2496
236
Acc: 0.09455128205128205
Prec 0.6743484498422241
Recall 0.6795776563052405
F1 0.6769529548120514
Epoch: 314, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.051s
4992
3087
Acc: 0.6183894230769231
Prec 0.8834285374341816
Recall 0.8750751185805722
F1 0.8792319873871128
Time elapsed 0.005303859710693359
2496
237
Acc: 0.09495192307692307
Prec 0.68582995951417
Recall 0.6757414130700824
F1 0.6807483108113384
Epoch: 315, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.050s
4992
3079
Acc: 0.6167868589743589
Prec 0.8794085787544574
Rec

4992
3077
Acc: 0.616386217948718
Prec 0.8800678689274498
Recall 0.8722112046090474
F1 0.8761219233711749
Time elapsed 0.0045850276947021484
2496
230
Acc: 0.0921474358974359
Prec 0.6570564448471425
Recall 0.6653969840363331
F1 0.6612004131019525
Epoch: 341, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.046s
4992
3115
Acc: 0.6239983974358975
Prec 0.8906316443195391
Recall 0.8873612637491775
F1 0.8889934463201469
Time elapsed 0.005088329315185547
2496
228
Acc: 0.09134615384615384
Prec 0.6573164459303072
Recall 0.6547330097087379
F1 0.6560221844158871
Epoch: 342, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.044s
4992
3107
Acc: 0.6223958333333334
Prec 0.8864387379691062
Recall 0.8864334791615917
F1 0.8864361085575495
Time elapsed 0.007441997528076172
2496
227
Acc: 0.09094551282051282
Prec 0.6552489362635435
Recall 0.6493959234179109
F1 0.6523093006894809
Epoch: 343, Train loss: 0.000, Val loss: 0.003, Epoch time = 0.051s
4992
3103
Acc: 0.6215945512820513
Prec 0.8887571967896871
R

Time elapsed 0.005867719650268555
2496
226
Acc: 0.09054487179487179
Prec 0.6501917317895974
Recall 0.6508814690924914
F1 0.650536417615994
Epoch: 369, Train loss: 0.000, Val loss: 0.004, Epoch time = 0.043s
4992
3135
Acc: 0.6280048076923077
Prec 0.8970127121660032
Recall 0.8929533699322275
F1 0.8949784380961501
Time elapsed 0.005547523498535156
2496
220
Acc: 0.08814102564102565
Prec 0.6405358298209406
Recall 0.6346101547415083
F1 0.6375592238086055
Epoch: 370, Train loss: 0.000, Val loss: 0.004, Epoch time = 0.046s
4992
3140
Acc: 0.6290064102564102
Prec 0.8958691632639404
Recall 0.8998987771816077
F1 0.8978794490961963
Time elapsed 0.007147789001464844
2496
218
Acc: 0.08733974358974358
Prec 0.6293740712345364
Recall 0.6323240461233036
F1 0.6308456100125058
Epoch: 371, Train loss: 0.000, Val loss: 0.004, Epoch time = 0.051s
4992
3129
Acc: 0.6268028846153846
Prec 0.8937359201703199
Recall 0.8916658548922692
F1 0.892699687473606
Time elapsed 0.0055484771728515625
2496
227
Acc: 0.090945512

KeyboardInterrupt: 

In [None]:
Ytrain_data[90:120]

In [None]:
Xtest_data

In [None]:
Xtrain_data.shape

In [None]:
(Ytrain_data==1).sum()

In [None]:
print(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                   NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

In [None]:
src shape torch.Size([39, 64, 10])
tgt shape torch.Size([40, 64, 1])
2 2 128 8 10 3 128