In [1]:
import sys, os, re, json, time

import pandas as pd
import pickle
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import plotting
from PIL import Image
from tqdm import tqdm
from utils import imread, img_data_2_mini_batch, imgs2batch

from sklearn import metrics

from naive import Enc, Dec, EncDec
from data_loader import VQADataSet

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import transforms

%matplotlib inline
%reload_ext autoreload
%autoreload 2

ModuleNotFoundError: No module named 'bokeh'

In [2]:
N = 2000
dataset_filename = "./data/data_{}.pkl".format(N)
dataset = None
print(dataset_filename)
if (os.path.exists(dataset_filename)):
    with open(dataset_filename, 'rb') as handle:
        print("reading from " + dataset_filename)
        dataset = pickle.load(handle)
else:
    dataset = VQADataSet(Q=N)
    with open(dataset_filename, 'wb') as handle:
        print("writing to " + dataset_filename)
        pickle.dump(dataset, handle)

assert(dataset is not None)
def debug(v,q,a):
    print('\nV: {}\nQ: {}\nA: {}'.format(v.shape, q.shape, a.shape))


./data/data_2000.pkl


100%|██████████| 2000/2000 [00:05<00:00, 375.14it/s]
100%|██████████| 1979/1979 [00:00<00:00, 691941.28it/s]
100%|██████████| 1979/1979 [00:00<00:00, 7857.37it/s]
100%|██████████| 1979/1979 [00:00<00:00, 244860.55it/s]
100%|██████████| 1979/1979 [00:00<00:00, 321489.12it/s]
100%|██████████| 509/509 [00:00<00:00, 53819.22it/s]


VQADataSet init time: 11.126109838485718
writing to ./data/data_2000.pkl


In [3]:
embed_size        = 128
hidden_size       = 128
batch_size        = 32
ques_vocab_size   = len(dataset.vocab['question'])
ans_vocab_size    = len(dataset.vocab['answer'])
rnn_layers        = 1
n_epochs          = 25
learning_rate     = 0.01
momentum          = 0.98

print(ques_vocab_size, ans_vocab_size)

1469 509


In [4]:
def eval_model(data_loader, model, criterion, optimizer, batch_size, training=False,
              total_loss_over_epochs=[], scores_over_epochs=[]):
    running_loss = 0.
    final_labels, final_preds = [], []
    if data_loader is None:
        return
    
    if training:
        model.train()
    else:
        model.eval()
    
    for i, minibatch in enumerate(data_loader):
        # extract minibatch
        t0 = time.time()
        idxs, v, q, a, q_len = minibatch
        
        # convert torch's DataLoader output to proper format.
        # torch gives a List[Tensor_1, ... ] where tensor has been transposed. 
        # batchify transposes back.`
        v = v.to(device)
        q = VQADataSet.batchify_questions(q).to(device)
        a = a.to(device)

        logits = model(v, q, q_len)
        preds = torch.argmax(logits, dim=1)

#         loss = criterion(logits, a)
        loss = F.nll_loss(logits, a)
        running_loss += loss.item()
        
        score = metrics.precision_recall_fscore_support(preds.tolist(),
                                                        a.tolist(),
                                                        average='weighted')
        
        total_loss_over_epochs['train_loss'].append(loss)
        scores_over_epochs['train_scores'].append(score)
        
        if training and optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        final_labels += a.tolist()
        final_preds  += preds.tolist()
        if True:#(i%20==0):
#             plotting.plot_score_over_n_epochs(scores_over_epochs, score_type='precision', fig_size=(7,3))
#             plotting.plot_loss_over_n_epochs(total_loss_over_epochs, fig_size=(7, 3), title="Loss")
            print("Loss: {} - score: {} - t: {}".format(loss, score, time.time()-t0))
            
    return running_loss, final_labels, final_preds

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = EncDec(embed_size, hidden_size, ques_vocab_size, ans_vocab_size, rnn_layers).to(device)

criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.get_parameters(), lr=learning_rate, momentum=momentum)
optimizer = torch.optim.Adam(model.get_parameters(), lr=learning_rate)


In [6]:

train_loader = dataset.build_data_loader(train=True, args={'batch_size': batch_size})
test_loader  = dataset.build_data_loader(test=True, args={'batch_size': batch_size})

best_score = 0

train_all_loss, train_all_labels, train_all_preds = [], [], []

total_loss_over_epochs, scores_over_epochs = plotting.get_empty_stat_over_n_epoch_dictionaries()

for epoch in tqdm(range(n_epochs)):
    t0= time.time()
    tr_loss, tr_labels, tr_preds = eval_model(data_loader = train_loader,
                                     model       = model,
                                     criterion   = criterion,
                                     optimizer   = optimizer,
                                     batch_size  = batch_size,
                                     training    = True,
                                     total_loss_over_epochs = total_loss_over_epochs,
                                     scores_over_epochs     = scores_over_epochs)
    
#     train_scores = metrics.precision_recall_fscore_support(tr_labels,
#                                                            tr_preds,
#                                                            average='weighted')
    
#     total_loss_over_epochs['train_loss'].append(tr_loss)
#     scores_over_epochs['train_scores'].append(train_scores)
    
#     if True:# or epoch%1 == 0:
#         print("#==#"*5 + "epoch: {}".format(epoch) + "#==#"*5)
#         print("time: {}".format(time.time()-t0))
#         print(train_scores)
#     plotting.plot_score_over_n_epochs(scores_over_epochs, score_type='precision', fig_size=(8,5))
#     plotting.plot_loss_over_n_epochs(total_loss_over_epochs, fig_size=(8, 5), title="Loss")
    
    
    
    

  0%|          | 0/25 [00:00<?, ?it/s]

batch_size: 32 shuffle: True
batch_size: 32 shuffle: False


  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


Loss: 6.2394843101501465 - score: (0.0, 0.0, 0.0, None) - t: 0.28528499603271484
Loss: 6.18806266784668 - score: (0.0, 0.0, 0.0, None) - t: 0.2519233226776123
Loss: 6.1274237632751465 - score: (0.125, 0.0625, 0.08333333333333333, None) - t: 0.24086761474609375
Loss: 5.885769844055176 - score: (0.4618055555555556, 0.21875, 0.296875, None) - t: 0.24924707412719727
Loss: 5.541611194610596 - score: (0.8616071428571429, 0.25, 0.3211245519713262, None) - t: 0.24949860572814941
Loss: 5.2758073806762695 - score: (0.875, 0.09375, 0.16935483870967738, None) - t: 0.24796700477600098
Loss: 4.5459160804748535 - score: (0.37734375, 0.15625, 0.21389751552795033, None) - t: 0.24123477935791016
Loss: 5.480690956115723 - score: (1.0, 0.09375, 0.17142857142857143, None) - t: 0.2491469383239746
Loss: 5.495992660522461 - score: (1.0, 0.15625, 0.2702702702702703, None) - t: 0.2487189769744873
Loss: 4.906850337982178 - score: (0.90625, 0.1875, 0.3107142857142857, None) - t: 0.2502307891845703
Loss: 4.6163821

  4%|▍         | 1/25 [00:46<18:27, 46.13s/it]

Loss: 5.092282295227051 - score: (0.6896551724137931, 0.20689655172413793, 0.3022546419098143, None) - t: 0.2135171890258789
Loss: 3.427579402923584 - score: (0.65625, 0.25, 0.3618551587301587, None) - t: 0.2486882209777832
Loss: 4.327885150909424 - score: (0.47395833333333337, 0.21875, 0.2855392156862745, None) - t: 0.2441718578338623
Loss: 3.7157018184661865 - score: (0.5625, 0.15625, 0.2269345238095238, None) - t: 0.25115060806274414
Loss: 3.429935932159424 - score: (0.625, 0.3125, 0.4063988095238095, None) - t: 0.24688172340393066
Loss: 2.9368233680725098 - score: (0.459375, 0.28125, 0.2927389705882353, None) - t: 0.25553154945373535
Loss: 3.8653032779693604 - score: (0.4166666666666667, 0.21875, 0.2765625, None) - t: 0.2543144226074219
Loss: 3.403931140899658 - score: (0.5625, 0.1875, 0.28125, None) - t: 0.24820590019226074
Loss: 3.388094186782837 - score: (0.484375, 0.3125, 0.3791666666666667, None) - t: 0.24448823928833008
Loss: 4.332608699798584 - score: (0.5625, 0.15625, 0.232

  8%|▊         | 2/25 [01:24<16:49, 43.87s/it]

Loss: 3.5074210166931152 - score: (0.2798029556650246, 0.2413793103448276, 0.25914315569487983, None) - t: 0.2216484546661377
Loss: 2.831212282180786 - score: (0.5880681818181819, 0.5, 0.5394345238095237, None) - t: 0.24596023559570312
Loss: 3.374818801879883 - score: (0.375, 0.3125, 0.328125, None) - t: 0.24685430526733398
Loss: 3.249868392944336 - score: (0.4375, 0.25, 0.3035714285714286, None) - t: 0.2457714080810547
Loss: 3.4846560955047607 - score: (0.5, 0.21875, 0.2991071428571429, None) - t: 0.24392938613891602
Loss: 2.5046236515045166 - score: (0.6875, 0.21875, 0.32028508771929826, None) - t: 0.24571633338928223
Loss: 3.4667458534240723 - score: (0.59375, 0.3125, 0.39943181818181817, None) - t: 0.24668121337890625
Loss: 2.3499085903167725 - score: (0.5738636363636364, 0.21875, 0.2525641025641026, None) - t: 0.24506139755249023
Loss: 2.952432155609131 - score: (0.3645833333333333, 0.28125, 0.296875, None) - t: 0.24825382232666016
Loss: 3.2663486003875732 - score: (0.34375, 0.156

 12%|█▏        | 3/25 [02:00<15:13, 41.53s/it]

Loss: 2.9095189571380615 - score: (0.3793103448275862, 0.2413793103448276, 0.24674329501915707, None) - t: 0.22798585891723633
Loss: 2.234834671020508 - score: (0.6875, 0.59375, 0.625, None) - t: 0.25049781799316406
Loss: 1.7666211128234863 - score: (0.78515625, 0.625, 0.6534801136363637, None) - t: 0.24587297439575195
Loss: 1.8658267259597778 - score: (0.734375, 0.59375, 0.6056547619047619, None) - t: 0.2461566925048828
Loss: 2.3991878032684326 - score: (0.5859375, 0.5625, 0.5636904761904762, None) - t: 0.24617218971252441
Loss: 1.7497574090957642 - score: (0.7535714285714286, 0.5, 0.5605696386946387, None) - t: 0.24708318710327148
Loss: 1.9351184368133545 - score: (0.703125, 0.59375, 0.638507326007326, None) - t: 0.24680399894714355
Loss: 1.6882929801940918 - score: (0.734375, 0.6875, 0.6979166666666666, None) - t: 0.24686884880065918
Loss: 1.8675575256347656 - score: (0.6597222222222222, 0.625, 0.6351338612368024, None) - t: 0.2461528778076172
Loss: 2.186174154281616 - score: (0.417

 16%|█▌        | 4/25 [02:36<13:55, 39.79s/it]

Loss: 1.712647557258606 - score: (0.5004926108374385, 0.4827586206896552, 0.4827586206896552, None) - t: 0.21894049644470215
Loss: 1.364761471748352 - score: (0.6822916666666666, 0.6875, 0.6749999999999999, None) - t: 0.24702119827270508
Loss: 1.5607997179031372 - score: (0.6197916666666666, 0.5625, 0.5625, None) - t: 0.24505376815795898
Loss: 1.4622180461883545 - score: (0.78125, 0.71875, 0.7317708333333333, None) - t: 0.2466127872467041
Loss: 1.080470085144043 - score: (0.9032738095238095, 0.71875, 0.7530381944444444, None) - t: 0.24831461906433105
Loss: 1.1821550130844116 - score: (0.75, 0.65625, 0.646780303030303, None) - t: 0.2473735809326172
Loss: 1.3590643405914307 - score: (0.6875, 0.59375, 0.6041666666666666, None) - t: 0.2474079132080078
Loss: 0.8535659909248352 - score: (0.875, 0.875, 0.875, None) - t: 0.24779987335205078
Loss: 1.1160372495651245 - score: (0.640625, 0.6875, 0.65625, None) - t: 0.24410557746887207
Loss: 1.030616044998169 - score: (0.75, 0.71875, 0.71875, None

 20%|██        | 5/25 [03:14<13:04, 39.25s/it]

Loss: 1.5568526983261108 - score: (0.5885057471264369, 0.5172413793103449, 0.541871921182266, None) - t: 0.21711516380310059
Loss: 0.5805186033248901 - score: (0.8308035714285714, 0.8125, 0.8118444055944056, None) - t: 0.24634790420532227
Loss: 0.6255047917366028 - score: (0.8958333333333333, 0.8125, 0.8291666666666666, None) - t: 0.24834609031677246
Loss: 0.4547421336174011 - score: (0.8576388888888888, 0.875, 0.8618607954545454, None) - t: 0.24651741981506348
Loss: 0.790121853351593 - score: (0.8125, 0.6875, 0.7277777777777777, None) - t: 0.24593591690063477
Loss: 0.3718580901622772 - score: (1.0, 0.9375, 0.9583333333333334, None) - t: 0.24796199798583984
Loss: 0.5318841934204102 - score: (0.93125, 0.875, 0.8794642857142857, None) - t: 0.24822211265563965
Loss: 0.5514596700668335 - score: (0.875, 0.875, 0.875, None) - t: 0.24714946746826172
Loss: 0.5096904039382935 - score: (0.90625, 0.875, 0.8854166666666666, None) - t: 0.24696135520935059
Loss: 0.557156503200531 - score: (0.8802083

 24%|██▍       | 6/25 [04:03<13:19, 42.08s/it]

Loss: 0.9551457166671753 - score: (0.7807881773399015, 0.7241379310344828, 0.735042735042735, None) - t: 0.21750783920288086
Loss: 0.37941503524780273 - score: (0.9114583333333334, 0.90625, 0.9024621212121211, None) - t: 0.24802446365356445
Loss: 0.438060462474823 - score: (0.8697916666666666, 0.875, 0.8662878787878787, None) - t: 0.24921870231628418
Loss: 0.31053751707077026 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24663162231445312
Loss: 0.2854231297969818 - score: (0.953125, 0.9375, 0.9395833333333334, None) - t: 0.24818038940429688
Loss: 0.2703598141670227 - score: (0.9090909090909091, 0.90625, 0.905952380952381, None) - t: 0.2446274757385254
Loss: 0.3579314351081848 - score: (0.90625, 0.90625, 0.90625, None) - t: 0.24680089950561523
Loss: 0.2849922776222229 - score: (0.9375, 0.9375, 0.9375, None) - t: 0.24619436264038086
Loss: 0.4448101818561554 - score: (0.90625, 0.875, 0.8854166666666666, None) - t: 0.2489621639251709
Loss: 0.4259703457355499 - score: (0.9125, 0.90625, 

 28%|██▊       | 7/25 [04:50<13:06, 43.72s/it]

Loss: 0.462769091129303 - score: (0.8706896551724138, 0.8620689655172413, 0.8633825944170772, None) - t: 0.221480131149292
Loss: 0.13564753532409668 - score: (0.9722222222222222, 0.96875, 0.9686274509803922, None) - t: 0.24392366409301758
Loss: 0.21275025606155396 - score: (0.94375, 0.90625, 0.9173611111111111, None) - t: 0.24715089797973633
Loss: 0.17208480834960938 - score: (0.9375, 0.90625, 0.9166666666666666, None) - t: 0.24441123008728027
Loss: 0.2037065178155899 - score: (0.9375, 0.9375, 0.9375, None) - t: 0.24598240852355957
Loss: 0.14573323726654053 - score: (0.9427083333333334, 0.96875, 0.9545454545454546, None) - t: 0.24740171432495117
Loss: 0.22354696691036224 - score: (0.9419642857142857, 0.9375, 0.9372814685314685, None) - t: 0.24541997909545898
Loss: 0.23931409418582916 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.2532525062561035
Loss: 0.25269219279289246 - score: (0.90625, 0.90625, 0.90625, None) - t: 0.24529409408569336
Loss: 0.17829956114292145 - score: (0.96875,

 32%|███▏      | 8/25 [05:26<11:44, 41.44s/it]

Loss: 0.43434619903564453 - score: (0.8620689655172413, 0.8620689655172413, 0.8620689655172413, None) - t: 0.22272133827209473
Loss: 0.12587955594062805 - score: (0.9791666666666666, 0.96875, 0.96875, None) - t: 0.2488570213317871
Loss: 0.15642181038856506 - score: (0.96875, 0.9375, 0.95, None) - t: 0.2492365837097168
Loss: 0.20659469068050385 - score: (1.0, 0.96875, 0.9791666666666666, None) - t: 0.24801206588745117
Loss: 0.10934224724769592 - score: (1.0, 1.0, 1.0, None) - t: 0.24552106857299805
Loss: 0.21280711889266968 - score: (0.9114583333333334, 0.90625, 0.9067513368983957, None) - t: 0.2455294132232666
Loss: 0.2701762318611145 - score: (0.890625, 0.90625, 0.8958333333333333, None) - t: 0.24550127983093262
Loss: 0.1075071394443512 - score: (0.9765625, 0.96875, 0.9657738095238094, None) - t: 0.2496049404144287
Loss: 0.14857545495033264 - score: (1.0, 0.96875, 0.9791666666666666, None) - t: 0.24596333503723145
Loss: 0.054061055183410645 - score: (1.0, 1.0, 1.0, None) - t: 0.245852

 36%|███▌      | 9/25 [06:03<10:39, 39.95s/it]

Loss: 0.06699226796627045 - score: (1.0, 1.0, 1.0, None) - t: 0.21597599983215332
Loss: 0.1255986988544464 - score: (0.96875, 0.9375, 0.9479166666666666, None) - t: 0.24900197982788086
Loss: 0.10613685846328735 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.2492363452911377
Loss: 0.05228722095489502 - score: (1.0, 1.0, 1.0, None) - t: 0.2484602928161621
Loss: 0.04145243763923645 - score: (1.0, 1.0, 1.0, None) - t: 0.2479393482208252
Loss: 0.12369778752326965 - score: (1.0, 1.0, 1.0, None) - t: 0.24870538711547852
Loss: 0.17049437761306763 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.2509286403656006
Loss: 0.13100576400756836 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24902939796447754
Loss: 0.0510006844997406 - score: (1.0, 1.0, 1.0, None) - t: 0.25282812118530273
Loss: 0.05005127191543579 - score: (1.0, 1.0, 1.0, None) - t: 0.25153636932373047
Loss: 0.2562367022037506 - score: (0.875, 0.875, 0.875, None) - t: 0.24989604949951172
Loss: 0.15731438994407654 - score: (0.984

 40%|████      | 10/25 [06:41<09:49, 39.33s/it]

Loss: 0.08198501169681549 - score: (1.0, 1.0, 1.0, None) - t: 0.21700000762939453
Loss: 0.11997118592262268 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24536657333374023
Loss: 0.03383883833885193 - score: (1.0, 1.0, 1.0, None) - t: 0.24517226219177246
Loss: 0.1517198085784912 - score: (0.9140625, 0.9375, 0.9241071428571428, None) - t: 0.251049280166626
Loss: 0.064372718334198 - score: (1.0, 1.0, 1.0, None) - t: 0.25250887870788574
Loss: 0.10813972353935242 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.2534759044647217
Loss: 0.04991424083709717 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24783754348754883
Loss: 0.08174008131027222 - score: (1.0, 1.0, 1.0, None) - t: 0.24945425987243652
Loss: 0.06963041424751282 - score: (0.975, 0.96875, 0.9690656565656566, None) - t: 0.24779033660888672
Loss: 0.0288560688495636 - score: (1.0, 1.0, 1.0, None) - t: 0.2524435520172119
Loss: 0.15585142374038696 - score: (0.9375, 0.9375, 0.9375, None) - t: 0.255598783493042
Loss: 0.1724381595

 44%|████▍     | 11/25 [07:19<09:05, 38.94s/it]

Loss: 0.18772754073143005 - score: (0.9655172413793104, 0.9655172413793104, 0.9655172413793104, None) - t: 0.21920514106750488
Loss: 0.06538054347038269 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.2435166835784912
Loss: 0.0352228581905365 - score: (1.0, 1.0, 1.0, None) - t: 0.24964332580566406
Loss: 0.1934482455253601 - score: (0.9375, 0.9375, 0.9375, None) - t: 0.24395346641540527
Loss: 0.04766830801963806 - score: (1.0, 1.0, 1.0, None) - t: 0.24714279174804688
Loss: 0.12364715337753296 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.2480325698852539
Loss: 0.04656180739402771 - score: (1.0, 1.0, 1.0, None) - t: 0.25034666061401367
Loss: 0.04773804545402527 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24703264236450195
Loss: 0.07963374257087708 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.2487959861755371
Loss: 0.06457695364952087 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24413824081420898
Loss: 0.10430148243904114 - score: (1.0, 0.96875, 0.9791666666666666, N

 48%|████▊     | 12/25 [08:02<08:42, 40.18s/it]

Loss: 0.12307015806436539 - score: (0.9482758620689655, 0.9655172413793104, 0.9540229885057471, None) - t: 0.21663522720336914
Loss: 0.03420951962471008 - score: (1.0, 1.0, 1.0, None) - t: 0.24928808212280273
Loss: 0.09547868371009827 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24723267555236816
Loss: 0.061547279357910156 - score: (1.0, 0.96875, 0.9791666666666666, None) - t: 0.24667692184448242
Loss: 0.07579755783081055 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24834275245666504
Loss: 0.0706055760383606 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24505925178527832
Loss: 0.09219726920127869 - score: (0.953125, 0.96875, 0.9583333333333333, None) - t: 0.24703240394592285
Loss: 0.015090852975845337 - score: (1.0, 1.0, 1.0, None) - t: 0.24846935272216797
Loss: 0.050667911767959595 - score: (0.953125, 0.96875, 0.9583333333333333, None) - t: 0.2546577453613281
Loss: 0.02308604121208191 - score: (1.0, 1.0, 1.0, None) - t: 0.25429677963256836
Loss: 0.05750301480293274 - sco

 52%|█████▏    | 13/25 [08:41<07:58, 39.85s/it]

Loss: 0.1816045492887497 - score: (0.9359605911330049, 0.9310344827586207, 0.930793344586448, None) - t: 0.21824860572814941
Loss: 0.12877006828784943 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.24834036827087402
Loss: 0.07056084275245667 - score: (1.0, 0.96875, 0.98125, None) - t: 0.254910945892334
Loss: 0.061870574951171875 - score: (0.9732142857142857, 0.96875, 0.9689102564102564, None) - t: 0.25023460388183594
Loss: 0.12833629548549652 - score: (1.0, 0.96875, 0.9791666666666666, None) - t: 0.25290727615356445
Loss: 0.03703814744949341 - score: (1.0, 1.0, 1.0, None) - t: 0.2543973922729492
Loss: 0.11107541620731354 - score: (0.9375, 0.9375, 0.9375, None) - t: 0.25595808029174805
Loss: 0.026928424835205078 - score: (1.0, 1.0, 1.0, None) - t: 0.2524862289428711
Loss: 0.033330559730529785 - score: (1.0, 1.0, 1.0, None) - t: 0.2510077953338623
Loss: 0.050582826137542725 - score: (1.0, 1.0, 1.0, None) - t: 0.24727201461791992
Loss: 0.06021009385585785 - score: (1.0, 1.0, 1.0, None)

 56%|█████▌    | 14/25 [09:27<07:38, 41.72s/it]

Loss: 0.09191510081291199 - score: (1.0, 0.9655172413793104, 0.981609195402299, None) - t: 0.21768689155578613
Loss: 0.03647094964981079 - score: (1.0, 1.0, 1.0, None) - t: 0.25031423568725586
Loss: 0.03871944546699524 - score: (1.0, 1.0, 1.0, None) - t: 0.2462151050567627
Loss: 0.04218283295631409 - score: (1.0, 1.0, 1.0, None) - t: 0.24758625030517578
Loss: 0.042114078998565674 - score: (1.0, 1.0, 1.0, None) - t: 0.2484893798828125
Loss: 0.038497358560562134 - score: (1.0, 1.0, 1.0, None) - t: 0.24848079681396484
Loss: 0.022426187992095947 - score: (1.0, 1.0, 1.0, None) - t: 0.2500476837158203
Loss: 0.10148191452026367 - score: (0.9419642857142857, 0.9375, 0.9379578754578755, None) - t: 0.24740123748779297
Loss: 0.08760307729244232 - score: (0.96875, 0.96875, 0.96875, None) - t: 0.25130605697631836
Loss: 0.022877246141433716 - score: (1.0, 1.0, 1.0, None) - t: 0.24778532981872559
Loss: 0.07037971913814545 - score: (0.984375, 0.96875, 0.96875, None) - t: 0.2473127841949463
Loss: 0.118

KeyboardInterrupt: 

In [None]:
print(type(tr_labels))
print(type(tr_preds))


In [None]:
print(tr_labels[0])
print(tr_preds[0])