In [1]:
import os
import sys
import time
import math

import dill
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from src import utils
from src import model2 as m2
import src.dataset as dset
import src.pytorch_utils as ptu
import src.chu_liu_edmonds as chu

import warnings
warnings.filterwarnings('ignore')

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
versions_dir = 'models'

cuda


In [2]:
# train_dataset = dset.DataSet('data/train.labeled', tqdm_bar=True, use_glove=True)
# test_dataset = dset.DataSet('data/test.labeled', train_dataset=train_dataset, tqdm_bar=True, use_glove=True)
# comp_dataset = dset.DataSet('data/comp.unlabeled', train_dataset=train_dataset, tagged=False, use_glove=True, tqdm_bar=True)

# with open(os.path.join('data', 'train_dataset.pth'), 'wb') as f:
#     dill.dump(train_dataset, f)
# with open(os.path.join('data', 'test_dataset.pth'), 'wb') as f:
#     dill.dump(test_dataset, f)

In [3]:
with open(os.path.join('data', 'train_dataset.pth'), "rb") as f:
    train_dataset = dill.load(f)
with open(os.path.join('data', 'test_dataset.pth'), "rb") as f:
    test_dataset = dill.load(f)

In [4]:
version = 'V2_1.10'
save = True

In [5]:
model = m2.Model2(train_dataset=train_dataset,
                  word_embed_dim=300,
                  tag_embed_dim=32,
                  hidden_dim=200,
                  num_layers=4,
                  bias=True,
                  attention_dim=200,
                  lstm_activation=None,
#                   attn_activation=nn.Tanh(),
                  p_dropout=0.4,  # 0.5
                  attention=utils.MultiplicativeAttention,
#                   softmax=nn.Softmax(dim=2),
                  softmax=nn.LogSoftmax(dim=2),
                  glove=True,
                  freeze=True)

checkpoint = ptu.Checkpoint(versions_dir=versions_dir,
                            version=version,
                            model=model,
                            score=lambda y_true, y_pred: (np.array(y_true) == np.array(y_pred)).mean(),
                            loss_decision_func=utils.loss_decision_func,
#                             out_decision_func=lambda y_pred, flat_y_pred, mask, padding: flat_y_pred.argmax(axis=1),
                            out_decision_func=chu.test_chu_liu_edmonds,
                            seed=42,
#                             optimizer=QHAdam,
                            optimizer=torch.optim.Adam,
#                             optimizer=torch.optim.AdamW,
                            criterion=nn.NLLLoss,
#                             criterion=nn.MSELoss,
                            save=save,
                            prints=True)

model version: V2_1.10
Number of parameters 8155736 trainable 3905936


In [6]:
# checkpoint = ptu.load_model(version=version, versions_dir=versions_dir, epoch=40, seed=42)
# display(checkpoint.log)

In [7]:
word_dropout_alpha = 0.25
hyperparam_list = [
    {'train_epochs': 20, 'batch_size': 8, 'optimizer_params': {'lr': 2e-3, 'weight_decay': 5e-7}},
    {'train_epochs': 20, 'batch_size': 8, 'optimizer_params': {'lr': 2e-3, 'weight_decay': 1e-6}, 'lr_decay': 0.2},
#     {'train_epochs': 15, 'batch_size': 16, 'optimizer_params': {'lr': 2e-3, 'weight_decay': 0.0}},  # 5e-7
    {'train_epochs': 20, 'batch_size': 8, 'optimizer_params': {'lr': 1e-3, 'weight_decay': 0.0}, 'lr_decay': 0.2},  # 1e-6
#     {'train_epochs': 30, 'batch_size': 32, 'optimizer_params': {}},  # 1e-6
#     {'train_epochs': 20, 'batch_size': 16, 'optimizer_params': {'lr': 2e-3, 'weight_decay': 1e-5}, 'lr_decay': 0.2},
#     {'train_epochs': 20, 'batch_size': 32, 'optimizer_params': {'lr': 2e-3, 'weight_decay': 1.5e-6}, 'lr_decay': 0.2},
#     {'train_epochs': 20, 'batch_size': 64, 'optimizer_params': {'lr': 2e-3, 'weight_decay': 1.5e-6}, 'lr_decay': 0.2},
]

for session in hyperparam_list:
    checkpoint.train(device=device,
                     train_dataset=train_dataset.dataset(word_dropout_alpha, train=True),
                     val_dataset=test_dataset.dataset(train=False),
                     prints=True,
                     epochs_save=5,
                     save=save,
#                      early_stop=5,
                     **session)

epoch   1/ 20 | train_loss 0.45094 | val_loss 0.48463 | train_score 0.86644 | val_score 0.85714 | train_time   1.02 min *
epoch   2/ 20 | train_loss 0.32350 | val_loss 0.40231 | train_score 0.89766 | val_score 0.87692 | train_time   2.60 min *
epoch   3/ 20 | train_loss 0.25138 | val_loss 0.37215 | train_score 0.91963 | val_score 0.88576 | train_time   3.90 min *
epoch   4/ 20 | train_loss 0.20477 | val_loss 0.37632 | train_score 0.93485 | val_score 0.88785 | train_time   5.22 min *
epoch   5/ 20 | train_loss 0.16689 | val_loss 0.37997 | train_score 0.94613 | val_score 0.89159 | train_time   6.43 min *
epoch   6/ 20 | train_loss 0.14055 | val_loss 0.40346 | train_score 0.95381 | val_score 0.89118 | train_time   7.90 min
epoch   7/ 20 | train_loss 0.11738 | val_loss 0.40196 | train_score 0.96157 | val_score 0.90195 | train_time   9.00 min *
epoch   8/ 20 | train_loss 0.10202 | val_loss 0.40818 | train_score 0.96722 | val_score 0.89517 | train_time  10.39 min
epoch   9/ 20 | train_loss 0

In [None]:
scores_numpy = 
estimated_tree_heads = np.zeros((batch_size, max_sentence_len))
for b in range(batch_size):
    # sentence len is the unpadded sentence length - important when we add batches
    estimated_tree_heads[b, :], _ = self.decoder(scores_numpy[b, :, :], unsorted_sentences_lengths[b], has_labels=False)
            

In [None]:
def unlabeled_attachment_score(scores, heads, lengths):
    uas = 0
    for batch, length in enumerate(lengths):
        length = length.item()
        parse_tree, _ = decode_mst(scores[batch, :, :].detach().cpu().numpy(), length, has_labels=False)
        uas += sum(parse_tree[i] == heads[batch, i].item() for i in range(length)) / length
    return uas


In [None]:
def run_chu_li(y_pred, flat_y_pred, mask, padding):
    lengths = mask.sum(dim=1)


In [4]:
version = 'V2_hpo_1.0'

In [5]:
attentions = {
    'Additive': utils.AdditiveAttention,
    'Multiplicative': utils.MultiplicativeAttention,
}

softmaxs = {
    'LogSoftmax': nn.LogSoftmax(dim=2),
#     'Softmax': nn.Softmax(dim=2),
}

activations = dict(sorted(list({
    'tanh': nn.Tanh(),
    'hard_tanh': nn.Hardtanh(),
#     'relu': nn.ReLU(),
#     'elu': nn.ELU(),
#     'leaky_relu': nn.LeakyReLU(),
    'p_relu': nn.PReLU(),
#     'relu6': nn.ReLU6(),
#     'gelu': nn.GELU(),
#     'sigmoid': nn.Sigmoid(),
}.items()), key=lambda x: x[0]))

In [6]:
import hyperopt as hpo

In [7]:
init_space = dict(sorted(list({
#     'train_epochs': 50,
    'batch_size': 16, #hpo.hp.quniform('batch_size', low=4, high=5, q=1),  # 16-32-64
    'optimizer__lr': hpo.hp.uniform('optimizer__lr', low=8e-4, high=2e-3),
    'optimizer__wd': hpo.hp.uniform('optimizer__wd', low=5e-7, high=5e-6),# 0.0
#     'early_stop': 5,
    
    'word_embed_dim': 300,  # 300
    'tag_embed_dim': 32, #hpo.hp.quniform('tag_embed_dim', low=30, high=50, q=4), #25
    'hidden_dim': hpo.hp.quniform('hidden_dim', low=200, high=300, q=50), #125,  # 
    'num_layers': hpo.hp.quniform('num_layers', low=3, high=4, q=1),#2,  # 
    'bias': True, #hpo.hp.choice('bias', [True, False]),
    'attention_dim': hpo.hp.quniform('attention_dim', low=200, high=300, q=50),#100,  # 
    'attention': hpo.hp.choice('attention', list(attentions.keys())),
    'activation': hpo.hp.choice('activation', list(activations.keys())),
    'softmax': hpo.hp.choice('softmax', list(softmaxs.keys())),
    'p_dropout': hpo.hp.uniform('p_dropout', low=0.3, high=0.6),#0.1,  # 
    'lr_decay': hpo.hp.uniform('lr_decay', low=0.15, high=0.25),#0.1,  # 
    'freeze': True, #hpo.hp.choice('freeze', [True, False]),
}.items()), key=lambda x: x[0]))

def init_objective(space, save=False):
    display(space)
    last_score = init_log['test_score'].max() if len(init_log) > 0 else 0.0
    batch_size = int(2 ** space['batch_size'])
#     attention = utils.MultiplicativeAttention if space['attention'] == 'Multiplicative' else utils.AdditiveAttention
#     activation = space['attention'] if space['attention'] != 'Multiplicative' else 'tanh'
#     activation = activations[activation]
    
    model = m2.Model2(train_dataset=train_dataset,
                      word_embed_dim=space['word_embed_dim'],  # 300
                      tag_embed_dim=space['tag_embed_dim'],  # 32
                      hidden_dim=int(space['hidden_dim']),  # 125
                      num_layers=int(space['num_layers']),  # 2
                      bias=space['bias'],  # True
                      attention_dim=int(space['attention_dim']),  # 10
                      activation=activations[space['activation']],
                      p_dropout=space['p_dropout'],  # 0.5
                      attention=attentions[space['attention']],
                      softmax=softmaxs[space['softmax']],
                      glove=True,
                      freeze=space['freeze'])

    init_checkpoint = ptu.Checkpoint(versions_dir=versions_dir,
                                     version=version,
                                     model=model,
                                     score=lambda y_true, y_pred: (np.array(y_true) == np.array(y_pred)).mean(),
                                     loss_decision_func=utils.loss_decision_func,
                                     out_decision_func=lambda y_pred, flat_y_pred, mask, padding: flat_y_pred.argmax(axis=1),
                                     seed=42,
                                     optimizer=torch.optim.AdamW,
                                     criterion=nn.NLLLoss,
                                     save=False,
                                     prints=True)
    
    word_dropout_alpha = 0.25
    hyperparam_list = [
#         {'train_epochs': 1, 'batch_size': 16, 'optimizer_params': {'lr': space['optimizer__lr'], 'weight_decay': 5e-7}},
        {'train_epochs': 20, 'batch_size': 16, 'optimizer_params': {'lr': space['optimizer__lr'], 'weight_decay': 5e-7}},
        {'train_epochs': 20, 'batch_size': 16, 'optimizer_params': {'lr': space['optimizer__lr'], 'weight_decay': space['optimizer__wd']}, 'lr_decay': space['lr_decay']},
#         {'train_epochs': 20, 'batch_size': 16, 'optimizer_params': {'lr': space['optimizer__lr'], 'weight_decay': space['optimizer__wd']}, 'lr_decay': space['lr_decay']},
    #     {'train_epochs': 20, 'batch_size': 32, 'optimizer_params': {'lr': 2e-3, 'weight_decay': 1.5e-6}, 'lr_decay': 0.2},
    #     {'train_epochs': 20, 'batch_size': 64, 'optimizer_params': {'lr': 2e-3, 'weight_decay': 1.5e-6}, 'lr_decay': 0.2},
    ]

    for session in hyperparam_list:
        init_checkpoint.train(device=device,
                              train_dataset=train_dataset.dataset(word_dropout_alpha, train=True),
                              val_dataset=test_dataset.dataset(train=False),
                              prints=True,
                              epochs_save=5,
                              save=save,
        #                       early_stop=5,
                              **session)    
    
    train_score = init_checkpoint.get_log(col='train_score', epoch=-1)
    test_score = init_checkpoint.get_log(col='val_score', epoch=-1)
#     print('test_score', test_score)
    ###############################################################
    if test_score > last_score:
        init_checkpoint.save(epoch=True)
    init_log.loc[init_log.index.max() + 1 if len(init_log) > 0 else 0] = [time.strftime('%d-%m-%Y %H:%M:%S'),
#                                                                           train_score,
                                                                          test_score,
                                                                          space] + list(space.values())
    
    with open(os.path.join(versions_dir, version, 'trials.pth'), 'wb') as f:
        dill.dump(init_trials, f)
    init_log.to_csv(os.path.join(versions_dir, version, 'trials_log.csv'), index=False)

    return -test_score

# session_space = dict(sorted(list({
#     'train_epochs': 5,
#     'batch_size_mult': min(len(X_train), int(2**hpo.hp.quniform('batch_size_mult', low=5, high=9, q=1))),
#     'optimizer__lr_mult': hpo.hp.uniform('optimizer__lr_mult', low=1e-5, high=1e-3),
#     'optimizer__weight_decay': hpo.hp.uniform('optimizer__weight_decay', low=1e-5, high=1e-3),
#     'p_dropout': max(0.0, min(0.9, hpo.hp.normal('p_dropout', mu=0.5, sigma=0.15))),
# }.items()), key=lambda x: x[0]))

In [8]:
# init_trials = hpo.Trials()
# init_log = pd.DataFrame(columns=['timestamp',
#                                  # 'train_score',
#                                  'test_score',
#                                  'space'] + list(init_space.keys()))

# with open(os.path.join(versions_dir, version, 'trials.pth'), 'wb') as f:
#     dill.dump(init_trials, f)
# init_log.to_csv(os.path.join(versions_dir, version, 'trials_log.csv'), index=False)

with open(os.path.join(versions_dir, version, 'trials.pth'), "rb") as f:
    init_trials = dill.load(f)
init_log = pd.read_csv(os.path.join(versions_dir, version, 'trials_log.csv'))
display(init_log)

Unnamed: 0,timestamp,test_score,space,activation,attention,attention_dim,batch_size,bias,freeze,hidden_dim,lr_decay,num_layers,optimizer__lr,optimizer__wd,p_dropout,softmax,tag_embed_dim,word_embed_dim
0,27-06-2020 00:29:23,0.907996,"{'activation': 'tanh', 'attention': 'Multiplic...",tanh,Multiplicative,250.0,16,True,False,250.0,0.182835,4.0,0.001348,4.488547e-06,0.327025,LogSoftmax,32,300
1,27-06-2020 01:23:47,0.905858,"{'activation': 'hard_tanh', 'attention': 'Addi...",hard_tanh,Additive,250.0,16,True,True,200.0,0.160718,3.0,0.001237,6.287681e-07,0.400194,LogSoftmax,32,300
2,27-06-2020 01:38:03,0.905612,"{'activation': 'tanh', 'attention': 'Multiplic...",tanh,Multiplicative,250.0,16,True,False,200.0,0.232751,4.0,0.001468,2.992325e-06,0.431738,LogSoftmax,32,300
3,27-06-2020 01:49:03,0.911531,"{'activation': 'p_relu', 'attention': 'Multipl...",p_relu,Multiplicative,250.0,16,True,True,200.0,0.165322,3.0,0.001895,3.360282e-06,0.515076,LogSoftmax,32,300
4,27-06-2020 02:45:22,0.895005,"{'activation': 'tanh', 'attention': 'Additive'...",tanh,Additive,250.0,16,True,False,250.0,0.216024,3.0,0.001702,4.9402e-06,0.475179,LogSoftmax,32,300
5,27-06-2020 03:33:44,0.897842,"{'activation': 'tanh', 'attention': 'Additive'...",tanh,Additive,200.0,16,True,True,200.0,0.170843,4.0,0.001632,2.365344e-06,0.426961,LogSoftmax,32,300


In [9]:
iters = 500

_ = hpo.fmin(init_objective,
             init_space,
             algo=hpo.tpe.suggest,
             trials=init_trials,
             max_queue_len=1,
             max_evals=iters)

  1%|          | 5/500 [00:00<?, ?trial/s, best loss=?]

{'activation': 'hard_tanh',
 'attention': 'Additive',
 'attention_dim': 250.0,
 'batch_size': 16,
 'bias': True,
 'freeze': False,
 'hidden_dim': 300.0,
 'lr_decay': 0.18403316927816465,
 'num_layers': 3.0,
 'optimizer__lr': 0.0010580590955336128,
 'optimizer__wd': 3.0539062944273563e-06,
 'p_dropout': 0.542777185183592,
 'softmax': 'LogSoftmax',
 'tag_embed_dim': 32,
 'word_embed_dim': 300}

model version:                                         
V2_hpo_1.0                                                     
Number of parameters 10403287 trainable 10403287               
epoch   1/ 20 | train_loss 0.64693 | val_loss 0.67034 | train_score 0.81441 | val_score 0.81052 | train_time   1.46 min *
epoch   2/ 20 | train_loss 0.41799 | val_loss 0.49972 | train_score 0.87624 | val_score 0.85427 | train_time   2.91 min *
epoch   3/ 20 | train_loss 0.30594 | val_loss 0.44829 | train_score 0.90902 | val_score 0.86643 | train_time   4.37 min *
epoch   4/ 20 | train_loss 0.23014 | val_loss 0.43653 | train_score 0.93058 | val_score 0.87404 | train_time   5.82 min *
epoch   5/ 20 | train_loss 0.18785 | val_loss 0.45082 | train_score 0.94163 | val_score 0.87963 | train_time   7.28 min *
epoch   6/ 20 | train_loss 0.15673 | val_loss 0.48232 | train_score 0.95123 | val_score 0.87437 | train_time   8.73 min
epoch   7/ 20 | train_loss 0.12603 | val_loss 0.49341 | train_score 0.96168 | val_scor

{'activation': 'tanh',
 'attention': 'Multiplicative',
 'attention_dim': 300.0,
 'batch_size': 16,
 'bias': True,
 'freeze': False,
 'hidden_dim': 250.0,
 'lr_decay': 0.20709048937348762,
 'num_layers': 3.0,
 'optimizer__lr': 0.0009120247387124067,
 'optimizer__wd': 2.718449748748443e-06,
 'p_dropout': 0.45152728232954664,
 'softmax': 'LogSoftmax',
 'tag_embed_dim': 32,
 'word_embed_dim': 300}

model version:                                                                          
V2_hpo_1.0                                                                              
Number of parameters 8677836 trainable 8677836                                          
epoch   1/ 20 | train_loss 0.56802 | val_loss 0.60253 | train_score 0.83325 | val_score 0.82565 | train_time   0.32 min *
epoch   2/ 20 | train_loss 0.34190 | val_loss 0.45287 | train_score 0.89566 | val_score 0.86557 | train_time   0.65 min *
epoch   3/ 20 | train_loss 0.23812 | val_loss 0.42762 | train_score 0.92393 | val_score 0.87572 | train_time   0.97 min *
epoch   4/ 20 | train_loss 0.17725 | val_loss 0.43899 | train_score 0.94189 | val_score 0.87836 | train_time   1.30 min *
epoch   5/ 20 | train_loss 0.14049 | val_loss 0.45358 | train_score 0.95399 | val_score 0.88185 | train_time   1.62 min *
epoch   6/ 20 | train_loss 0.11512 | val_loss 0.51391 | train_score 0.96068 | val_score 0.88004 | train_time   1.95 min
epo

{'activation': 'p_relu',
 'attention': 'Multiplicative',
 'attention_dim': 200.0,
 'batch_size': 16,
 'bias': True,
 'freeze': False,
 'hidden_dim': 250.0,
 'lr_decay': 0.15360017588134772,
 'num_layers': 4.0,
 'optimizer__lr': 0.0009106221871915555,
 'optimizer__wd': 9.831547469304713e-07,
 'p_dropout': 0.338099438264459,
 'softmax': 'LogSoftmax',
 'tag_embed_dim': 32,
 'word_embed_dim': 300}

model version:                                                                            
V2_hpo_1.0                                                                                
Number of parameters 10181836 trainable 10181836                                          
epoch   1/ 20 | train_loss 0.58668 | val_loss 0.61920 | train_score 0.82878 | val_score 0.82121 | train_time   0.42 min *
epoch   2/ 20 | train_loss 0.36370 | val_loss 0.46898 | train_score 0.88808 | val_score 0.85829 | train_time   0.84 min *
epoch   3/ 20 | train_loss 0.25032 | val_loss 0.44591 | train_score 0.91955 | val_score 0.86754 | train_time   1.26 min *
epoch   4/ 20 | train_loss 0.19306 | val_loss 0.44727 | train_score 0.93754 | val_score 0.87503 | train_time   1.68 min *
epoch   5/ 20 | train_loss 0.13996 | val_loss 0.46890 | train_score 0.95357 | val_score 0.88103 | train_time   2.10 min *
epoch   6/ 20 | train_loss 0.12647 | val_loss 0.49298 | train_score 0.95868 | val_score 0.88070 | train_time   2.52 m

{'activation': 'p_relu',
 'attention': 'Additive',
 'attention_dim': 300.0,
 'batch_size': 16,
 'bias': True,
 'freeze': False,
 'hidden_dim': 250.0,
 'lr_decay': 0.24299504686523796,
 'num_layers': 3.0,
 'optimizer__lr': 0.0016816180726068677,
 'optimizer__wd': 1.4125481325876369e-06,
 'p_dropout': 0.45658280816360297,
 'softmax': 'LogSoftmax',
 'tag_embed_dim': 32,
 'word_embed_dim': 300}

model version:                                                                            
V2_hpo_1.0                                                                                
Number of parameters 8728238 trainable 8728238                                            
epoch   1/ 20 | train_loss 0.96454 | val_loss 0.99111 | train_score 0.71662 | val_score 0.71293 | train_time   2.00 min *
epoch   2/ 20 | train_loss 0.45434 | val_loss 0.53772 | train_score 0.86607 | val_score 0.84423 | train_time   3.99 min *
epoch   3/ 20 | train_loss 0.29577 | val_loss 0.47016 | train_score 0.91009 | val_score 0.86610 | train_time   5.98 min *
epoch   4/ 20 | train_loss 0.22262 | val_loss 0.46843 | train_score 0.92974 | val_score 0.87342 | train_time   7.97 min *
epoch   5/ 20 | train_loss 0.16670 | val_loss 0.47030 | train_score 0.94697 | val_score 0.88008 | train_time   9.97 min *
epoch   6/ 20 | train_loss 0.13503 | val_loss 0.49350 | train_score 0.95749 | val_score 0.88296 | train_time  11.96 m

{'activation': 'p_relu',
 'attention': 'Additive',
 'attention_dim': 300.0,
 'batch_size': 16,
 'bias': True,
 'freeze': False,
 'hidden_dim': 250.0,
 'lr_decay': 0.1847636931038067,
 'num_layers': 4.0,
 'optimizer__lr': 0.0017188952953206265,
 'optimizer__wd': 4.995869558577694e-06,
 'p_dropout': 0.5588284918374005,
 'softmax': 'LogSoftmax',
 'tag_embed_dim': 32,
 'word_embed_dim': 300}

model version:                                                                            
V2_hpo_1.0                                                                                
Number of parameters 10232238 trainable 10232238                                          
epoch   1/ 20 | train_loss 1.84529 | val_loss 1.85188 | train_score 0.40756 | val_score 0.40835 | train_time   2.09 min *
epoch   2/ 20 | train_loss 0.62279 | val_loss 0.67177 | train_score 0.82357 | val_score 0.81135 | train_time   4.18 min *
  2%|▏         | 9/500 [2:52:29<156:50:16, 1149.93s/trial, best loss: -0.9115313463514902]


KeyboardInterrupt: 

In [33]:
with open(os.path.join(versions_dir, version, 'trials.pth'), 'wb') as f:
    dill.dump(init_trials, f)
init_log.to_csv(os.path.join(versions_dir, version, 'trials_log.csv'), index=False)

In [None]:
# checkpoint = ptu.load_model(version=version, versions_dir=versions_dir, epoch='best', seed=42)
# loss, score = checkpoint.predict(test_dataset.dataset,
#                                  batch_size=32,
#                                  device=device,
#                                  results=False,
#                                  decision_func=chu.test_chu_liu_edmonds)
# print(f'chu_liu_edmonds_UAS: {score}')

In [None]:
# %%time
# checkpoint.model = checkpoint.model.to(device)
# checkpoint.model.train()
# batch_size = 32

# loader = torch.utils.data.DataLoader(dataset=train_dataset.dataset, batch_size=batch_size, shuffle=True)
# for batch in loader:
#     loss, flat_y, flat_out, mask, out, y = utils.loss_decision_func(checkpoint, device, batch, prints=True)
#     break
# torch.cuda.empty_cache()