In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from model.pytorchtools import EarlyStopping
import model.net as models 
from model.dataset import SurfaceComplexationDataset
from tqdm import tqdm
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import numpy as np

In [2]:
def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9)
    elif optimizer == "adam":
        optimizer = optim.Adam(network.parameters(),
                               lr=learning_rate)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    return optimizer, scheduler

In [3]:
def load_data(data_dir): 
    train_set = SurfaceComplexationDataset(root_dir=data_dir)
    test_set = SurfaceComplexationDataset(root_dir=data_dir, split='test')
    val_set = SurfaceComplexationDataset(root_dir=data_dir, split='val')

    return train_set, test_set, val_set

In [4]:
def train_epoch(train_loader, model, optimizer, device, epoch):
    """ Train the model on num_steps batches 
    Args: 
        train_loader: a torch.utils.data.DataLoader object that fetches the data
        model: the neural network 
        optimizer: adams 
    """
    model.train()
    running_loss = 0.0
    num_batch = len(train_loader)

    for i, (inputs, targets) in enumerate(train_loader): 
        inputs, targets = inputs.to(device), targets.to(device)
        # zero the paramter gradients 
        optimizer.zero_grad()

        # forward + backward + optimize 
        pred = model(inputs)
        loss = F.mse_loss(pred, targets)
        loss.backward()
        optimizer.step()

        # print statistics 
        running_loss += loss.item()
        # if i % 300 == 0: 
        #     print('[%d: %d/%d] train loss: %f ' % (epoch, i, num_batch, loss.item()))
        # if i % 300 == 0: 
        #     print('[%d: %d/%d] train loss: %f lr = %f' % (epoch, i, num_batch, loss.item(), optimizer.param_groups[0]["lr"]))

    return running_loss / num_batch 

In [5]:
def validate(val_dataloader, model, device): 
    model.eval()
    val_running_loss = 0.0 

    with torch.no_grad(): 
        for inputs, targets in val_dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = F.mse_loss(outputs, targets)

            val_running_loss += loss.item() * inputs.size(0)

    return val_running_loss / len(val_dataloader.dataset)

In [6]:
def plot_pramas(test_y, test_pred, foldername, outfile): 
    print("R2 of test is: ", r2_score(test_y, test_pred))

    np.savetxt(f'{foldername}/test_pred_{outfile}.txt', test_pred)
    np.savetxt(f'{foldername}/test_y_{outfile}.txt', test_y)

    test_mse = mean_squared_error(test_y, test_pred)
    test_mae = mean_absolute_error(test_y, test_pred)

    print('Test set results for %i samples:' % test_pred.shape[0])
    print('MSE:', test_mse)
    print('MAE:', test_mae)

In [7]:
def test_accuracy(net, testloader, foldername, outfile, device): 
    test_pred = []
    test_y = []
    running_loss = 0 

    with torch.no_grad():
        for data in testloader:
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = F.mse_loss(outputs, targets)

            running_loss += loss.item() 

            pred_val_numpy = outputs.data.cpu().numpy()
            target_val_numpy = targets.data.cpu().numpy()

            test_pred.append(pred_val_numpy)
            test_y.append(target_val_numpy)

    test_pred = np.concatenate(test_pred, axis=0)
    test_y = np.concatenate(test_y, axis=0)

    plot_pramas(test_y, test_pred, foldername, outfile)

    print('MSE loss on test set is:', running_loss / len(testloader.dataset)) 

In [8]:
def train_model(model, device, train_loader, val_loader, test_loader, optimizer, lr_scheduler, isSch, res_dir, name, patience = 20, n_epochs = 100): 
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 

    blue = lambda x: '\033[94m' + x + '\033[0m'
    
    checkpoint_dir = os.path.join(res_dir, 'checkpoints')
    try:
        os.makedirs(res_dir)
        os.makedirs(checkpoint_dir)
    except OSError:
        pass

    checkpoint_path = os.path.join(checkpoint_dir, f'{name}.pt')
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True, path = checkpoint_path)

    for epoch in tqdm(range(1, n_epochs + 1)):
        ###################
        # train the model #
        ###################
        train_epoch_loss = train_epoch(train_loader, model, optimizer, device, epoch)
        val_epoch_loss = validate(val_loader, model, device)

        # print loss every epoch 
        print('[%d] train loss: %f ' % (epoch, train_epoch_loss))
        print('[%d] %s loss: %f' % (epoch, blue('validate'), val_epoch_loss))

        avg_train_losses.append(train_epoch_loss)
        avg_valid_losses.append(val_epoch_loss)
        
        if isSch: 
            lr_scheduler.step(val_epoch_loss) 
        
        # add early stopping 
        # early_stopping(val_epoch_loss, model)
        early_stopping(train_epoch_loss, model)
        if early_stopping.early_stop: 
            print("Early stopping")
            break 

    np.savetxt(os.path.join(res_dir, f'train_loss_{name}.csv'), avg_train_losses)
    np.savetxt(os.path.join(res_dir, f'val_loss_{name}.csv'), avg_valid_losses) 

    # load the last checkpoint with the best model
    model.load_state_dict(torch.load(checkpoint_path)) 

    # test on test set 
    test_accuracy(model, test_loader, res_dir, name, device)
    # print(optimizer.state_dict())

In [9]:
def train_main(config): 
    data_dir = 'datasets/train/'

    # get dataset 
    train_set, val_set, test_set = load_data(data_dir)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=4, 
        pin_memory=False)

    val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=int(config["batch_size"]),
            shuffle=True,
            num_workers=4, 
            pin_memory=False)

    test_loader = torch.utils.data.DataLoader(
            test_set, 
            batch_size=int(config["batch_size"]), 
            shuffle=True,
            num_workers=4, 
            pin_memory=False)
            
    print("Creating model")
    Model = getattr(models, config['model'])
    print('created model is: ', Model)
    
    if config['model'] == 'DeepNet4LayerTune': 
        model = Model(config['batch_norm'], config['layer_norm'], config['constraint'],
                      config["l1"], config["l2"], config["l3"],)
    elif config['model'] == 'DeepNet5LayerTune': 
        model = Model(config['batch_norm'], config['layer_norm'], config['constraint'],
                     config["l1"], config["l2"], config["l3"], config["l4"])
    else: 
        model = Model(config['batch_norm'], config['layer_norm'], config['constraint'],
                     config["l1"], config["l2"], config["l3"], config["l4"], config["l5"])
        
    name = f"{config['model']}_171inputs_{config['l1']}{config['l2']}{config['l3']}{config['l4']}{config['l5']}lr{config['lr']}BS{config['batch_size']}isB{config['batch_norm']}ln{config['layer_norm']}cons{config['constraint']}Opt{config['optimizer']}sch{config['lr_scheduler']}"
    

    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cuda")
    # if torch.cuda.device_count() > 1:
    #     model = nn.DataParallel(model)
    model.to(device)

    optimizer, lr_scheduler = build_optimizer(model, config['optimizer'], config['lr'])
    res_dir = 'simpleDNN/res'
    
    train_model(model, device, train_loader, val_loader, test_loader, optimizer, lr_scheduler, config['lr_scheduler'], res_dir, name, 40, 5000)

# train DNN 

In [10]:
config = {'l1': 256, 'l2': 256, 'l3': 256, 'l4': 64, 'l5': 16, 
          'lr': 0.001, 'batch_size': 128, 'model': 'DeepNet5LayerTune', 'batch_norm': False, 
          'layer_norm': True, 'lr_scheduler': True, 'constraint': False, 'optimizer': 'adam'}

In [11]:
import time 
start_time = time.perf_counter()
train_main(config) 
end_time = time.perf_counter() 
print('time used to train model with 40/1000 patience is: ', (end_time - start_time)/60, 'mins')

Creating model
created model is:  <class 'model.net.DeepNet5LayerTune'>


  0%|                                                                                                                                  | 1/5000 [00:03<5:32:28,  3.99s/it]

[1] train loss: 0.029184 
[1] [94mvalidate[0m loss: 0.025162
Validation loss decreased (inf --> 0.029184).  Saving model ...


  0%|                                                                                                                                  | 2/5000 [00:07<5:06:54,  3.68s/it]

[2] train loss: 0.020917 
[2] [94mvalidate[0m loss: 0.017738
Validation loss decreased (0.029184 --> 0.020917).  Saving model ...


  0%|                                                                                                                                  | 3/5000 [00:11<5:02:18,  3.63s/it]

[3] train loss: 0.017796 
[3] [94mvalidate[0m loss: 0.014830
Validation loss decreased (0.020917 --> 0.017796).  Saving model ...


  0%|                                                                                                                                  | 4/5000 [00:14<4:58:05,  3.58s/it]

[4] train loss: 0.016424 
[4] [94mvalidate[0m loss: 0.017269
Validation loss decreased (0.017796 --> 0.016424).  Saving model ...


  0%|▏                                                                                                                                 | 5/5000 [00:18<4:55:26,  3.55s/it]

[5] train loss: 0.015590 
[5] [94mvalidate[0m loss: 0.017308
Validation loss decreased (0.016424 --> 0.015590).  Saving model ...


  0%|▏                                                                                                                                 | 6/5000 [00:21<4:54:29,  3.54s/it]

[6] train loss: 0.014566 
[6] [94mvalidate[0m loss: 0.013246
Validation loss decreased (0.015590 --> 0.014566).  Saving model ...


  0%|▏                                                                                                                                 | 7/5000 [00:25<4:56:47,  3.57s/it]

[7] train loss: 0.014194 
[7] [94mvalidate[0m loss: 0.014126
Validation loss decreased (0.014566 --> 0.014194).  Saving model ...


  0%|▏                                                                                                                                 | 8/5000 [00:28<4:53:01,  3.52s/it]

[8] train loss: 0.013604 
[8] [94mvalidate[0m loss: 0.015745
Validation loss decreased (0.014194 --> 0.013604).  Saving model ...


  0%|▏                                                                                                                                 | 9/5000 [00:31<4:48:15,  3.47s/it]

[9] train loss: 0.013263 
[9] [94mvalidate[0m loss: 0.013065
Validation loss decreased (0.013604 --> 0.013263).  Saving model ...


  0%|▎                                                                                                                                | 10/5000 [00:35<4:51:02,  3.50s/it]

[10] train loss: 0.013174 
[10] [94mvalidate[0m loss: 0.011421
Validation loss decreased (0.013263 --> 0.013174).  Saving model ...


  0%|▎                                                                                                                                | 11/5000 [00:38<4:49:32,  3.48s/it]

[11] train loss: 0.012975 
[11] [94mvalidate[0m loss: 0.012769
Validation loss decreased (0.013174 --> 0.012975).  Saving model ...


  0%|▎                                                                                                                                | 12/5000 [00:42<4:52:27,  3.52s/it]

[12] train loss: 0.012989 
[12] [94mvalidate[0m loss: 0.012607
EarlyStopping counter: 1 out of 40


  0%|▎                                                                                                                                | 13/5000 [00:46<4:54:19,  3.54s/it]

[13] train loss: 0.012699 
[13] [94mvalidate[0m loss: 0.012033
Validation loss decreased (0.012975 --> 0.012699).  Saving model ...


  0%|▎                                                                                                                                | 14/5000 [00:49<4:54:41,  3.55s/it]

[14] train loss: 0.012295 
[14] [94mvalidate[0m loss: 0.011724
Validation loss decreased (0.012699 --> 0.012295).  Saving model ...


  0%|▍                                                                                                                                | 15/5000 [00:53<5:03:11,  3.65s/it]

[15] train loss: 0.012121 
[15] [94mvalidate[0m loss: 0.013016
Validation loss decreased (0.012295 --> 0.012121).  Saving model ...


  0%|▍                                                                                                                                | 16/5000 [00:57<5:06:05,  3.68s/it]

[16] train loss: 0.012505 
[16] [94mvalidate[0m loss: 0.011536
EarlyStopping counter: 1 out of 40


  0%|▍                                                                                                                                | 17/5000 [01:01<5:05:28,  3.68s/it]

[17] train loss: 0.012003 
[17] [94mvalidate[0m loss: 0.012286
Validation loss decreased (0.012121 --> 0.012003).  Saving model ...


  0%|▍                                                                                                                                | 18/5000 [01:04<5:02:11,  3.64s/it]

[18] train loss: 0.011953 
[18] [94mvalidate[0m loss: 0.011108
Validation loss decreased (0.012003 --> 0.011953).  Saving model ...


  0%|▍                                                                                                                                | 19/5000 [01:07<4:37:53,  3.35s/it]

[19] train loss: 0.011696 
[19] [94mvalidate[0m loss: 0.011242
Validation loss decreased (0.011953 --> 0.011696).  Saving model ...


  0%|▌                                                                                                                                | 20/5000 [01:10<4:40:36,  3.38s/it]

[20] train loss: 0.011756 
[20] [94mvalidate[0m loss: 0.011447
EarlyStopping counter: 1 out of 40


  0%|▌                                                                                                                                | 21/5000 [01:14<4:49:27,  3.49s/it]

[21] train loss: 0.011275 
[21] [94mvalidate[0m loss: 0.011416
Validation loss decreased (0.011696 --> 0.011275).  Saving model ...


  0%|▌                                                                                                                                | 22/5000 [01:17<4:41:36,  3.39s/it]

[22] train loss: 0.011607 
[22] [94mvalidate[0m loss: 0.010765
EarlyStopping counter: 1 out of 40


  0%|▌                                                                                                                                | 23/5000 [01:21<4:45:24,  3.44s/it]

[23] train loss: 0.011407 
[23] [94mvalidate[0m loss: 0.010660
EarlyStopping counter: 2 out of 40


  0%|▌                                                                                                                                | 24/5000 [01:24<4:50:48,  3.51s/it]

[24] train loss: 0.011264 
[24] [94mvalidate[0m loss: 0.010598
Validation loss decreased (0.011275 --> 0.011264).  Saving model ...


  0%|▋                                                                                                                                | 25/5000 [01:27<4:42:03,  3.40s/it]

[25] train loss: 0.011158 
[25] [94mvalidate[0m loss: 0.011327
Validation loss decreased (0.011264 --> 0.011158).  Saving model ...


  1%|▋                                                                                                                                | 26/5000 [01:31<4:43:39,  3.42s/it]

[26] train loss: 0.010987 
[26] [94mvalidate[0m loss: 0.012251
Validation loss decreased (0.011158 --> 0.010987).  Saving model ...


  1%|▋                                                                                                                                | 27/5000 [01:34<4:42:32,  3.41s/it]

[27] train loss: 0.010899 
[27] [94mvalidate[0m loss: 0.012013
Validation loss decreased (0.010987 --> 0.010899).  Saving model ...


  1%|▋                                                                                                                                | 28/5000 [01:38<4:45:53,  3.45s/it]

[28] train loss: 0.010439 
[28] [94mvalidate[0m loss: 0.010996
Validation loss decreased (0.010899 --> 0.010439).  Saving model ...


  1%|▋                                                                                                                                | 29/5000 [01:42<4:53:34,  3.54s/it]

[29] train loss: 0.010590 
[29] [94mvalidate[0m loss: 0.011578
EarlyStopping counter: 1 out of 40


  1%|▊                                                                                                                                | 30/5000 [01:45<4:56:49,  3.58s/it]

[30] train loss: 0.010421 
[30] [94mvalidate[0m loss: 0.010490
Validation loss decreased (0.010439 --> 0.010421).  Saving model ...


  1%|▊                                                                                                                                | 31/5000 [01:49<4:58:17,  3.60s/it]

[31] train loss: 0.010538 
[31] [94mvalidate[0m loss: 0.015333
EarlyStopping counter: 1 out of 40


  1%|▊                                                                                                                                | 32/5000 [01:52<4:53:44,  3.55s/it]

[32] train loss: 0.010214 
[32] [94mvalidate[0m loss: 0.009273
Validation loss decreased (0.010421 --> 0.010214).  Saving model ...


  1%|▊                                                                                                                                | 33/5000 [01:55<4:19:23,  3.13s/it]

[33] train loss: 0.010292 
[33] [94mvalidate[0m loss: 0.009534
EarlyStopping counter: 1 out of 40


  1%|▉                                                                                                                                | 34/5000 [01:58<4:26:44,  3.22s/it]

[34] train loss: 0.010139 
[34] [94mvalidate[0m loss: 0.009203
Validation loss decreased (0.010214 --> 0.010139).  Saving model ...


  1%|▉                                                                                                                                | 35/5000 [02:01<4:21:07,  3.16s/it]

[35] train loss: 0.010281 
[35] [94mvalidate[0m loss: 0.008462
EarlyStopping counter: 1 out of 40


  1%|▉                                                                                                                                | 36/5000 [02:05<4:31:23,  3.28s/it]

[36] train loss: 0.009952 
[36] [94mvalidate[0m loss: 0.009962
Validation loss decreased (0.010139 --> 0.009952).  Saving model ...


  1%|▉                                                                                                                                | 37/5000 [02:08<4:38:37,  3.37s/it]

[37] train loss: 0.009745 
[37] [94mvalidate[0m loss: 0.009413
Validation loss decreased (0.009952 --> 0.009745).  Saving model ...


  1%|▉                                                                                                                                | 38/5000 [02:12<4:45:00,  3.45s/it]

[38] train loss: 0.010275 
[38] [94mvalidate[0m loss: 0.014283
EarlyStopping counter: 1 out of 40


  1%|█                                                                                                                                | 39/5000 [02:15<4:46:31,  3.47s/it]

[39] train loss: 0.009774 
[39] [94mvalidate[0m loss: 0.010344
EarlyStopping counter: 2 out of 40


  1%|█                                                                                                                                | 40/5000 [02:19<4:46:34,  3.47s/it]

[40] train loss: 0.009946 
[40] [94mvalidate[0m loss: 0.009729
EarlyStopping counter: 3 out of 40


  1%|█                                                                                                                                | 41/5000 [02:22<4:49:08,  3.50s/it]

[41] train loss: 0.009827 
[41] [94mvalidate[0m loss: 0.012733
EarlyStopping counter: 4 out of 40


  1%|█                                                                                                                                | 42/5000 [02:26<4:53:29,  3.55s/it]

[42] train loss: 0.009638 
[42] [94mvalidate[0m loss: 0.009160
Validation loss decreased (0.009745 --> 0.009638).  Saving model ...


  1%|█                                                                                                                                | 43/5000 [02:29<4:52:40,  3.54s/it]

[43] train loss: 0.009811 
[43] [94mvalidate[0m loss: 0.012106
EarlyStopping counter: 1 out of 40


  1%|█▏                                                                                                                               | 44/5000 [02:33<4:52:39,  3.54s/it]

[44] train loss: 0.009570 
[44] [94mvalidate[0m loss: 0.008995
Validation loss decreased (0.009638 --> 0.009570).  Saving model ...


  1%|█▏                                                                                                                               | 45/5000 [02:37<4:52:25,  3.54s/it]

[45] train loss: 0.009551 
[45] [94mvalidate[0m loss: 0.010555
Validation loss decreased (0.009570 --> 0.009551).  Saving model ...


  1%|█▏                                                                                                                               | 46/5000 [02:40<4:55:26,  3.58s/it]

[46] train loss: 0.009662 
[46] [94mvalidate[0m loss: 0.011046
EarlyStopping counter: 1 out of 40


  1%|█▏                                                                                                                               | 47/5000 [02:44<4:52:30,  3.54s/it]

[47] train loss: 0.006926 
[47] [94mvalidate[0m loss: 0.007094
Validation loss decreased (0.009551 --> 0.006926).  Saving model ...


  1%|█▏                                                                                                                               | 48/5000 [02:47<4:58:05,  3.61s/it]

[48] train loss: 0.006643 
[48] [94mvalidate[0m loss: 0.006852
Validation loss decreased (0.006926 --> 0.006643).  Saving model ...


  1%|█▎                                                                                                                               | 49/5000 [02:51<4:54:09,  3.56s/it]

[49] train loss: 0.006531 
[49] [94mvalidate[0m loss: 0.007345
Validation loss decreased (0.006643 --> 0.006531).  Saving model ...


  1%|█▎                                                                                                                               | 50/5000 [02:54<4:38:32,  3.38s/it]

[50] train loss: 0.006435 
[50] [94mvalidate[0m loss: 0.006907
Validation loss decreased (0.006531 --> 0.006435).  Saving model ...


  1%|█▎                                                                                                                               | 51/5000 [02:57<4:32:08,  3.30s/it]

[51] train loss: 0.006379 
[51] [94mvalidate[0m loss: 0.007258
Validation loss decreased (0.006435 --> 0.006379).  Saving model ...


  1%|█▎                                                                                                                               | 52/5000 [03:01<4:45:49,  3.47s/it]

[52] train loss: 0.006330 
[52] [94mvalidate[0m loss: 0.006447
Validation loss decreased (0.006379 --> 0.006330).  Saving model ...


  1%|█▎                                                                                                                               | 53/5000 [03:04<4:48:45,  3.50s/it]

[53] train loss: 0.006291 
[53] [94mvalidate[0m loss: 0.006314
Validation loss decreased (0.006330 --> 0.006291).  Saving model ...


  1%|█▍                                                                                                                               | 54/5000 [03:08<4:52:05,  3.54s/it]

[54] train loss: 0.006228 
[54] [94mvalidate[0m loss: 0.006211
Validation loss decreased (0.006291 --> 0.006228).  Saving model ...


  1%|█▍                                                                                                                               | 55/5000 [03:12<4:55:46,  3.59s/it]

[55] train loss: 0.006291 
[55] [94mvalidate[0m loss: 0.006267
EarlyStopping counter: 1 out of 40


  1%|█▍                                                                                                                               | 56/5000 [03:15<4:56:11,  3.59s/it]

[56] train loss: 0.006148 
[56] [94mvalidate[0m loss: 0.006572
Validation loss decreased (0.006228 --> 0.006148).  Saving model ...


  1%|█▍                                                                                                                               | 57/5000 [03:19<4:56:31,  3.60s/it]

[57] train loss: 0.006160 
[57] [94mvalidate[0m loss: 0.006346
EarlyStopping counter: 1 out of 40


  1%|█▍                                                                                                                               | 58/5000 [03:22<4:50:03,  3.52s/it]

[58] train loss: 0.006155 
[58] [94mvalidate[0m loss: 0.006554
EarlyStopping counter: 2 out of 40


  1%|█▌                                                                                                                               | 59/5000 [03:26<4:51:50,  3.54s/it]

[59] train loss: 0.006118 
[59] [94mvalidate[0m loss: 0.006503
Validation loss decreased (0.006148 --> 0.006118).  Saving model ...


  1%|█▌                                                                                                                               | 60/5000 [03:29<4:52:29,  3.55s/it]

[60] train loss: 0.006062 
[60] [94mvalidate[0m loss: 0.007225
Validation loss decreased (0.006118 --> 0.006062).  Saving model ...


  1%|█▌                                                                                                                               | 61/5000 [03:33<4:51:35,  3.54s/it]

[61] train loss: 0.006034 
[61] [94mvalidate[0m loss: 0.006690
Validation loss decreased (0.006062 --> 0.006034).  Saving model ...


  1%|█▌                                                                                                                               | 62/5000 [03:37<4:51:52,  3.55s/it]

[62] train loss: 0.006065 
[62] [94mvalidate[0m loss: 0.006603
EarlyStopping counter: 1 out of 40


  1%|█▋                                                                                                                               | 63/5000 [03:40<4:51:29,  3.54s/it]

[63] train loss: 0.006046 
[63] [94mvalidate[0m loss: 0.006630
EarlyStopping counter: 2 out of 40


  1%|█▋                                                                                                                               | 64/5000 [03:44<4:57:31,  3.62s/it]

[64] train loss: 0.005942 
[64] [94mvalidate[0m loss: 0.006386
Validation loss decreased (0.006034 --> 0.005942).  Saving model ...


  1%|█▋                                                                                                                               | 65/5000 [03:47<4:55:40,  3.59s/it]

[65] train loss: 0.005908 
[65] [94mvalidate[0m loss: 0.006173
Validation loss decreased (0.005942 --> 0.005908).  Saving model ...


  1%|█▋                                                                                                                               | 66/5000 [03:51<4:54:44,  3.58s/it]

[66] train loss: 0.005931 
[66] [94mvalidate[0m loss: 0.006247
EarlyStopping counter: 1 out of 40


  1%|█▋                                                                                                                               | 67/5000 [03:54<4:45:15,  3.47s/it]

[67] train loss: 0.005873 
[67] [94mvalidate[0m loss: 0.006313
Validation loss decreased (0.005908 --> 0.005873).  Saving model ...


  1%|█▊                                                                                                                               | 68/5000 [03:58<4:47:54,  3.50s/it]

[68] train loss: 0.005849 
[68] [94mvalidate[0m loss: 0.006345
Validation loss decreased (0.005873 --> 0.005849).  Saving model ...


  1%|█▊                                                                                                                               | 69/5000 [04:01<4:50:41,  3.54s/it]

[69] train loss: 0.005855 
[69] [94mvalidate[0m loss: 0.006046
EarlyStopping counter: 1 out of 40


  1%|█▊                                                                                                                               | 70/5000 [04:04<4:24:05,  3.21s/it]

[70] train loss: 0.005954 
[70] [94mvalidate[0m loss: 0.006278
EarlyStopping counter: 2 out of 40


  1%|█▊                                                                                                                               | 71/5000 [04:07<4:32:14,  3.31s/it]

[71] train loss: 0.005844 
[71] [94mvalidate[0m loss: 0.005865
Validation loss decreased (0.005849 --> 0.005844).  Saving model ...


  1%|█▊                                                                                                                               | 72/5000 [04:11<4:40:04,  3.41s/it]

[72] train loss: 0.005899 
[72] [94mvalidate[0m loss: 0.006093
EarlyStopping counter: 1 out of 40


  1%|█▉                                                                                                                               | 73/5000 [04:15<4:42:27,  3.44s/it]

[73] train loss: 0.005852 
[73] [94mvalidate[0m loss: 0.005956
EarlyStopping counter: 2 out of 40


  1%|█▉                                                                                                                               | 74/5000 [04:18<4:47:39,  3.50s/it]

[74] train loss: 0.005792 
[74] [94mvalidate[0m loss: 0.006339
Validation loss decreased (0.005844 --> 0.005792).  Saving model ...


  2%|█▉                                                                                                                               | 75/5000 [04:22<4:46:54,  3.50s/it]

[75] train loss: 0.005802 
[75] [94mvalidate[0m loss: 0.006091
EarlyStopping counter: 1 out of 40


  2%|█▉                                                                                                                               | 76/5000 [04:25<4:50:28,  3.54s/it]

[76] train loss: 0.005735 
[76] [94mvalidate[0m loss: 0.005866
Validation loss decreased (0.005792 --> 0.005735).  Saving model ...


  2%|█▉                                                                                                                               | 77/5000 [04:29<4:47:38,  3.51s/it]

[77] train loss: 0.005751 
[77] [94mvalidate[0m loss: 0.005932
EarlyStopping counter: 1 out of 40


  2%|██                                                                                                                               | 78/5000 [04:32<4:48:21,  3.52s/it]

[78] train loss: 0.005715 
[78] [94mvalidate[0m loss: 0.005985
Validation loss decreased (0.005735 --> 0.005715).  Saving model ...


  2%|██                                                                                                                               | 79/5000 [04:36<4:49:34,  3.53s/it]

[79] train loss: 0.005776 
[79] [94mvalidate[0m loss: 0.006146
EarlyStopping counter: 1 out of 40


  2%|██                                                                                                                               | 80/5000 [04:39<4:48:54,  3.52s/it]

[80] train loss: 0.005710 
[80] [94mvalidate[0m loss: 0.005791
Validation loss decreased (0.005715 --> 0.005710).  Saving model ...


  2%|██                                                                                                                               | 81/5000 [04:43<4:52:15,  3.56s/it]

[81] train loss: 0.005843 
[81] [94mvalidate[0m loss: 0.005951
EarlyStopping counter: 1 out of 40


  2%|██                                                                                                                               | 82/5000 [04:47<4:52:56,  3.57s/it]

[82] train loss: 0.005702 
[82] [94mvalidate[0m loss: 0.005837
Validation loss decreased (0.005710 --> 0.005702).  Saving model ...


  2%|██▏                                                                                                                              | 83/5000 [04:49<4:23:21,  3.21s/it]

[83] train loss: 0.005693 
[83] [94mvalidate[0m loss: 0.005853
Validation loss decreased (0.005702 --> 0.005693).  Saving model ...


  2%|██▏                                                                                                                              | 84/5000 [04:53<4:32:20,  3.32s/it]

[84] train loss: 0.005722 
[84] [94mvalidate[0m loss: 0.005988
EarlyStopping counter: 1 out of 40


  2%|██▏                                                                                                                              | 85/5000 [04:56<4:39:20,  3.41s/it]

[85] train loss: 0.005697 
[85] [94mvalidate[0m loss: 0.006332
EarlyStopping counter: 2 out of 40


  2%|██▏                                                                                                                              | 86/5000 [05:00<4:47:29,  3.51s/it]

[86] train loss: 0.005658 
[86] [94mvalidate[0m loss: 0.006075
Validation loss decreased (0.005693 --> 0.005658).  Saving model ...


  2%|██▏                                                                                                                              | 87/5000 [05:04<4:50:16,  3.55s/it]

[87] train loss: 0.005629 
[87] [94mvalidate[0m loss: 0.006102
Validation loss decreased (0.005658 --> 0.005629).  Saving model ...


  2%|██▎                                                                                                                              | 88/5000 [05:07<4:53:18,  3.58s/it]

[88] train loss: 0.005661 
[88] [94mvalidate[0m loss: 0.006013
EarlyStopping counter: 1 out of 40


  2%|██▎                                                                                                                              | 89/5000 [05:11<4:53:24,  3.58s/it]

[89] train loss: 0.005653 
[89] [94mvalidate[0m loss: 0.006057
EarlyStopping counter: 2 out of 40


  2%|██▎                                                                                                                              | 90/5000 [05:14<4:31:41,  3.32s/it]

[90] train loss: 0.005689 
[90] [94mvalidate[0m loss: 0.007360
EarlyStopping counter: 3 out of 40


  2%|██▎                                                                                                                              | 91/5000 [05:17<4:37:53,  3.40s/it]

[91] train loss: 0.005634 
[91] [94mvalidate[0m loss: 0.005983
EarlyStopping counter: 4 out of 40


  2%|██▎                                                                                                                              | 92/5000 [05:20<4:37:00,  3.39s/it]

[92] train loss: 0.004985 
[92] [94mvalidate[0m loss: 0.005610
Validation loss decreased (0.005629 --> 0.004985).  Saving model ...


  2%|██▍                                                                                                                              | 93/5000 [05:24<4:39:27,  3.42s/it]

[93] train loss: 0.004930 
[93] [94mvalidate[0m loss: 0.005899
Validation loss decreased (0.004985 --> 0.004930).  Saving model ...


  2%|██▍                                                                                                                              | 94/5000 [05:28<4:46:08,  3.50s/it]

[94] train loss: 0.004901 
[94] [94mvalidate[0m loss: 0.005526
Validation loss decreased (0.004930 --> 0.004901).  Saving model ...


  2%|██▍                                                                                                                              | 95/5000 [05:31<4:43:48,  3.47s/it]

[95] train loss: 0.004889 
[95] [94mvalidate[0m loss: 0.005519
Validation loss decreased (0.004901 --> 0.004889).  Saving model ...


  2%|██▍                                                                                                                              | 96/5000 [05:35<4:49:50,  3.55s/it]

[96] train loss: 0.004884 
[96] [94mvalidate[0m loss: 0.005609
Validation loss decreased (0.004889 --> 0.004884).  Saving model ...


  2%|██▌                                                                                                                              | 97/5000 [05:38<4:54:42,  3.61s/it]

[97] train loss: 0.004865 
[97] [94mvalidate[0m loss: 0.005541
Validation loss decreased (0.004884 --> 0.004865).  Saving model ...


  2%|██▌                                                                                                                              | 98/5000 [05:41<4:21:24,  3.20s/it]

[98] train loss: 0.004854 
[98] [94mvalidate[0m loss: 0.005567
Validation loss decreased (0.004865 --> 0.004854).  Saving model ...


  2%|██▌                                                                                                                              | 99/5000 [05:44<4:30:56,  3.32s/it]

[99] train loss: 0.004855 
[99] [94mvalidate[0m loss: 0.005501
EarlyStopping counter: 1 out of 40


  2%|██▌                                                                                                                             | 100/5000 [05:47<4:05:15,  3.00s/it]

[100] train loss: 0.004840 
[100] [94mvalidate[0m loss: 0.005429
Validation loss decreased (0.004854 --> 0.004840).  Saving model ...


  2%|██▌                                                                                                                             | 101/5000 [05:50<4:19:19,  3.18s/it]

[101] train loss: 0.004839 
[101] [94mvalidate[0m loss: 0.005526
Validation loss decreased (0.004840 --> 0.004839).  Saving model ...


  2%|██▌                                                                                                                             | 102/5000 [05:54<4:23:35,  3.23s/it]

[102] train loss: 0.004840 
[102] [94mvalidate[0m loss: 0.005463
EarlyStopping counter: 1 out of 40


  2%|██▋                                                                                                                             | 103/5000 [05:57<4:34:28,  3.36s/it]

[103] train loss: 0.004821 
[103] [94mvalidate[0m loss: 0.005470
Validation loss decreased (0.004839 --> 0.004821).  Saving model ...


  2%|██▋                                                                                                                             | 104/5000 [06:01<4:40:20,  3.44s/it]

[104] train loss: 0.004823 
[104] [94mvalidate[0m loss: 0.005427
EarlyStopping counter: 1 out of 40


  2%|██▋                                                                                                                             | 105/5000 [06:04<4:44:40,  3.49s/it]

[105] train loss: 0.004812 
[105] [94mvalidate[0m loss: 0.005568
Validation loss decreased (0.004821 --> 0.004812).  Saving model ...


  2%|██▋                                                                                                                             | 106/5000 [06:08<4:49:24,  3.55s/it]

[106] train loss: 0.004807 
[106] [94mvalidate[0m loss: 0.005417
Validation loss decreased (0.004812 --> 0.004807).  Saving model ...


  2%|██▋                                                                                                                             | 107/5000 [06:12<4:50:12,  3.56s/it]

[107] train loss: 0.004790 
[107] [94mvalidate[0m loss: 0.005582
Validation loss decreased (0.004807 --> 0.004790).  Saving model ...


  2%|██▊                                                                                                                             | 108/5000 [06:15<4:48:14,  3.54s/it]

[108] train loss: 0.004790 
[108] [94mvalidate[0m loss: 0.005515
Validation loss decreased (0.004790 --> 0.004790).  Saving model ...


  2%|██▊                                                                                                                             | 109/5000 [06:19<4:48:38,  3.54s/it]

[109] train loss: 0.004786 
[109] [94mvalidate[0m loss: 0.005550
Validation loss decreased (0.004790 --> 0.004786).  Saving model ...


  2%|██▊                                                                                                                             | 110/5000 [06:22<4:46:55,  3.52s/it]

[110] train loss: 0.004797 
[110] [94mvalidate[0m loss: 0.005430
EarlyStopping counter: 1 out of 40


  2%|██▊                                                                                                                             | 111/5000 [06:25<4:29:08,  3.30s/it]

[111] train loss: 0.004774 
[111] [94mvalidate[0m loss: 0.005484
Validation loss decreased (0.004786 --> 0.004774).  Saving model ...


  2%|██▊                                                                                                                             | 112/5000 [06:27<4:03:58,  2.99s/it]

[112] train loss: 0.004775 
[112] [94mvalidate[0m loss: 0.005425
EarlyStopping counter: 1 out of 40


  2%|██▉                                                                                                                             | 113/5000 [06:31<4:22:10,  3.22s/it]

[113] train loss: 0.004786 
[113] [94mvalidate[0m loss: 0.005546
EarlyStopping counter: 2 out of 40


  2%|██▉                                                                                                                             | 114/5000 [06:35<4:28:48,  3.30s/it]

[114] train loss: 0.004767 
[114] [94mvalidate[0m loss: 0.005359
Validation loss decreased (0.004774 --> 0.004767).  Saving model ...


  2%|██▉                                                                                                                             | 115/5000 [06:38<4:32:00,  3.34s/it]

[115] train loss: 0.004772 
[115] [94mvalidate[0m loss: 0.005622
EarlyStopping counter: 1 out of 40


  2%|██▉                                                                                                                             | 116/5000 [06:41<4:35:03,  3.38s/it]

[116] train loss: 0.004774 
[116] [94mvalidate[0m loss: 0.005380
EarlyStopping counter: 2 out of 40


  2%|██▉                                                                                                                             | 117/5000 [06:45<4:36:04,  3.39s/it]

[117] train loss: 0.004752 
[117] [94mvalidate[0m loss: 0.005447
Validation loss decreased (0.004767 --> 0.004752).  Saving model ...


  2%|███                                                                                                                             | 118/5000 [06:48<4:39:48,  3.44s/it]

[118] train loss: 0.004757 
[118] [94mvalidate[0m loss: 0.005483
EarlyStopping counter: 1 out of 40


  2%|███                                                                                                                             | 119/5000 [06:52<4:40:16,  3.45s/it]

[119] train loss: 0.004767 
[119] [94mvalidate[0m loss: 0.005422
EarlyStopping counter: 2 out of 40


  2%|███                                                                                                                             | 120/5000 [06:55<4:43:46,  3.49s/it]

[120] train loss: 0.004749 
[120] [94mvalidate[0m loss: 0.005412
Validation loss decreased (0.004752 --> 0.004749).  Saving model ...


  2%|███                                                                                                                             | 121/5000 [06:59<4:43:46,  3.49s/it]

[121] train loss: 0.004733 
[121] [94mvalidate[0m loss: 0.005377
Validation loss decreased (0.004749 --> 0.004733).  Saving model ...


  2%|███                                                                                                                             | 122/5000 [07:03<4:48:17,  3.55s/it]

[122] train loss: 0.004747 
[122] [94mvalidate[0m loss: 0.005430
EarlyStopping counter: 1 out of 40


  2%|███▏                                                                                                                            | 123/5000 [07:06<4:49:51,  3.57s/it]

[123] train loss: 0.004735 
[123] [94mvalidate[0m loss: 0.005590
EarlyStopping counter: 2 out of 40


  2%|███▏                                                                                                                            | 124/5000 [07:10<4:47:12,  3.53s/it]

[124] train loss: 0.004730 
[124] [94mvalidate[0m loss: 0.005382
Validation loss decreased (0.004733 --> 0.004730).  Saving model ...


  2%|███▏                                                                                                                            | 125/5000 [07:13<4:51:03,  3.58s/it]

[125] train loss: 0.004746 
[125] [94mvalidate[0m loss: 0.005341
EarlyStopping counter: 1 out of 40


  3%|███▏                                                                                                                            | 126/5000 [07:17<4:52:36,  3.60s/it]

[126] train loss: 0.004728 
[126] [94mvalidate[0m loss: 0.005438
Validation loss decreased (0.004730 --> 0.004728).  Saving model ...


  3%|███▎                                                                                                                            | 127/5000 [07:21<4:50:56,  3.58s/it]

[127] train loss: 0.004728 
[127] [94mvalidate[0m loss: 0.005390
EarlyStopping counter: 1 out of 40


  3%|███▎                                                                                                                            | 128/5000 [07:24<4:58:55,  3.68s/it]

[128] train loss: 0.004704 
[128] [94mvalidate[0m loss: 0.005345
Validation loss decreased (0.004728 --> 0.004704).  Saving model ...


  3%|███▎                                                                                                                            | 129/5000 [07:28<4:59:15,  3.69s/it]

[129] train loss: 0.004711 
[129] [94mvalidate[0m loss: 0.005404
EarlyStopping counter: 1 out of 40


  3%|███▎                                                                                                                            | 130/5000 [07:31<4:27:45,  3.30s/it]

[130] train loss: 0.004724 
[130] [94mvalidate[0m loss: 0.005441
EarlyStopping counter: 2 out of 40


  3%|███▎                                                                                                                            | 131/5000 [07:33<4:15:58,  3.15s/it]

[131] train loss: 0.004723 
[131] [94mvalidate[0m loss: 0.005335
EarlyStopping counter: 3 out of 40


  3%|███▍                                                                                                                            | 132/5000 [07:37<4:26:02,  3.28s/it]

[132] train loss: 0.004711 
[132] [94mvalidate[0m loss: 0.005466
EarlyStopping counter: 4 out of 40


  3%|███▍                                                                                                                            | 133/5000 [07:40<4:28:57,  3.32s/it]

[133] train loss: 0.004705 
[133] [94mvalidate[0m loss: 0.005407
EarlyStopping counter: 5 out of 40


  3%|███▍                                                                                                                            | 134/5000 [07:44<4:28:31,  3.31s/it]

[134] train loss: 0.004713 
[134] [94mvalidate[0m loss: 0.005350
EarlyStopping counter: 6 out of 40


  3%|███▍                                                                                                                            | 135/5000 [07:47<4:35:57,  3.40s/it]

[135] train loss: 0.004684 
[135] [94mvalidate[0m loss: 0.005467
Validation loss decreased (0.004704 --> 0.004684).  Saving model ...


  3%|███▍                                                                                                                            | 136/5000 [07:51<4:41:16,  3.47s/it]

[136] train loss: 0.004705 
[136] [94mvalidate[0m loss: 0.005547
EarlyStopping counter: 1 out of 40


  3%|███▌                                                                                                                            | 137/5000 [07:54<4:33:38,  3.38s/it]

[137] train loss: 0.004678 
[137] [94mvalidate[0m loss: 0.005457
Validation loss decreased (0.004684 --> 0.004678).  Saving model ...


  3%|███▌                                                                                                                            | 138/5000 [07:58<4:39:09,  3.44s/it]

[138] train loss: 0.004687 
[138] [94mvalidate[0m loss: 0.005389
EarlyStopping counter: 1 out of 40


  3%|███▌                                                                                                                            | 139/5000 [08:01<4:42:08,  3.48s/it]

[139] train loss: 0.004689 
[139] [94mvalidate[0m loss: 0.005435
EarlyStopping counter: 2 out of 40


  3%|███▌                                                                                                                            | 140/5000 [08:05<4:52:38,  3.61s/it]

[140] train loss: 0.004700 
[140] [94mvalidate[0m loss: 0.005610
EarlyStopping counter: 3 out of 40


  3%|███▌                                                                                                                            | 141/5000 [08:09<4:51:32,  3.60s/it]

[141] train loss: 0.004667 
[141] [94mvalidate[0m loss: 0.005300
Validation loss decreased (0.004678 --> 0.004667).  Saving model ...


  3%|███▋                                                                                                                            | 142/5000 [08:12<4:49:43,  3.58s/it]

[142] train loss: 0.004678 
[142] [94mvalidate[0m loss: 0.005443
EarlyStopping counter: 1 out of 40


  3%|███▋                                                                                                                            | 143/5000 [08:16<4:49:33,  3.58s/it]

[143] train loss: 0.004662 
[143] [94mvalidate[0m loss: 0.005435
Validation loss decreased (0.004667 --> 0.004662).  Saving model ...


  3%|███▋                                                                                                                            | 144/5000 [08:19<4:49:35,  3.58s/it]

[144] train loss: 0.004665 
[144] [94mvalidate[0m loss: 0.005515
EarlyStopping counter: 1 out of 40


  3%|███▋                                                                                                                            | 145/5000 [08:23<4:48:35,  3.57s/it]

[145] train loss: 0.004687 
[145] [94mvalidate[0m loss: 0.005441
EarlyStopping counter: 2 out of 40


  3%|███▋                                                                                                                            | 146/5000 [08:27<4:52:08,  3.61s/it]

[146] train loss: 0.004655 
[146] [94mvalidate[0m loss: 0.005360
Validation loss decreased (0.004662 --> 0.004655).  Saving model ...


  3%|███▊                                                                                                                            | 147/5000 [08:30<4:47:35,  3.56s/it]

[147] train loss: 0.004656 
[147] [94mvalidate[0m loss: 0.005351
EarlyStopping counter: 1 out of 40


  3%|███▊                                                                                                                            | 148/5000 [08:34<4:52:23,  3.62s/it]

[148] train loss: 0.004665 
[148] [94mvalidate[0m loss: 0.005517
EarlyStopping counter: 2 out of 40


  3%|███▊                                                                                                                            | 149/5000 [08:37<4:47:07,  3.55s/it]

[149] train loss: 0.004667 
[149] [94mvalidate[0m loss: 0.005322
EarlyStopping counter: 3 out of 40


  3%|███▊                                                                                                                            | 150/5000 [08:41<4:43:33,  3.51s/it]

[150] train loss: 0.004651 
[150] [94mvalidate[0m loss: 0.005354
Validation loss decreased (0.004655 --> 0.004651).  Saving model ...


  3%|███▊                                                                                                                            | 151/5000 [08:44<4:44:41,  3.52s/it]

[151] train loss: 0.004636 
[151] [94mvalidate[0m loss: 0.005389
Validation loss decreased (0.004651 --> 0.004636).  Saving model ...


  3%|███▉                                                                                                                            | 152/5000 [08:46<4:12:51,  3.13s/it]

[152] train loss: 0.004651 
[152] [94mvalidate[0m loss: 0.005376
EarlyStopping counter: 1 out of 40


  3%|███▉                                                                                                                            | 153/5000 [08:50<4:18:37,  3.20s/it]

[153] train loss: 0.004537 
[153] [94mvalidate[0m loss: 0.005295
Validation loss decreased (0.004636 --> 0.004537).  Saving model ...


  3%|███▉                                                                                                                            | 154/5000 [08:53<4:28:36,  3.33s/it]

[154] train loss: 0.004530 
[154] [94mvalidate[0m loss: 0.005278
Validation loss decreased (0.004537 --> 0.004530).  Saving model ...


  3%|███▉                                                                                                                            | 155/5000 [08:57<4:36:47,  3.43s/it]

[155] train loss: 0.004531 
[155] [94mvalidate[0m loss: 0.005280
EarlyStopping counter: 1 out of 40


  3%|███▉                                                                                                                            | 156/5000 [09:01<4:50:24,  3.60s/it]

[156] train loss: 0.004526 
[156] [94mvalidate[0m loss: 0.005274
Validation loss decreased (0.004530 --> 0.004526).  Saving model ...


  3%|████                                                                                                                            | 157/5000 [09:05<4:48:09,  3.57s/it]

[157] train loss: 0.004530 
[157] [94mvalidate[0m loss: 0.005286
EarlyStopping counter: 1 out of 40


  3%|████                                                                                                                            | 158/5000 [09:08<4:39:02,  3.46s/it]

[158] train loss: 0.004528 
[158] [94mvalidate[0m loss: 0.005283
EarlyStopping counter: 2 out of 40


  3%|████                                                                                                                            | 159/5000 [09:11<4:40:27,  3.48s/it]

[159] train loss: 0.004528 
[159] [94mvalidate[0m loss: 0.005293
EarlyStopping counter: 3 out of 40


  3%|████                                                                                                                            | 160/5000 [09:15<4:41:44,  3.49s/it]

[160] train loss: 0.004527 
[160] [94mvalidate[0m loss: 0.005276
EarlyStopping counter: 4 out of 40


  3%|████                                                                                                                            | 161/5000 [09:18<4:40:19,  3.48s/it]

[161] train loss: 0.004525 
[161] [94mvalidate[0m loss: 0.005274
Validation loss decreased (0.004526 --> 0.004525).  Saving model ...


  3%|████▏                                                                                                                           | 162/5000 [09:22<4:43:39,  3.52s/it]

[162] train loss: 0.004525 
[162] [94mvalidate[0m loss: 0.005289
EarlyStopping counter: 1 out of 40


  3%|████▏                                                                                                                           | 163/5000 [09:25<4:41:58,  3.50s/it]

[163] train loss: 0.004524 
[163] [94mvalidate[0m loss: 0.005273
Validation loss decreased (0.004525 --> 0.004524).  Saving model ...


  3%|████▏                                                                                                                           | 164/5000 [09:29<4:47:01,  3.56s/it]

[164] train loss: 0.004525 
[164] [94mvalidate[0m loss: 0.005275
EarlyStopping counter: 1 out of 40


  3%|████▏                                                                                                                           | 165/5000 [09:32<4:39:35,  3.47s/it]

[165] train loss: 0.004525 
[165] [94mvalidate[0m loss: 0.005299
EarlyStopping counter: 2 out of 40


  3%|████▏                                                                                                                           | 166/5000 [09:35<4:30:23,  3.36s/it]

[166] train loss: 0.004526 
[166] [94mvalidate[0m loss: 0.005287
EarlyStopping counter: 3 out of 40


  3%|████▎                                                                                                                           | 167/5000 [09:39<4:36:36,  3.43s/it]

[167] train loss: 0.004524 
[167] [94mvalidate[0m loss: 0.005274
EarlyStopping counter: 4 out of 40


  3%|████▎                                                                                                                           | 168/5000 [09:43<4:39:41,  3.47s/it]

[168] train loss: 0.004520 
[168] [94mvalidate[0m loss: 0.005300
Validation loss decreased (0.004524 --> 0.004520).  Saving model ...


  3%|████▎                                                                                                                           | 169/5000 [09:46<4:45:00,  3.54s/it]

[169] train loss: 0.004524 
[169] [94mvalidate[0m loss: 0.005272
EarlyStopping counter: 1 out of 40


  3%|████▎                                                                                                                           | 170/5000 [09:50<4:48:06,  3.58s/it]

[170] train loss: 0.004519 
[170] [94mvalidate[0m loss: 0.005294
Validation loss decreased (0.004520 --> 0.004519).  Saving model ...


  3%|████▍                                                                                                                           | 171/5000 [09:53<4:36:01,  3.43s/it]

[171] train loss: 0.004520 
[171] [94mvalidate[0m loss: 0.005270
EarlyStopping counter: 1 out of 40


  3%|████▍                                                                                                                           | 172/5000 [09:56<4:37:45,  3.45s/it]

[172] train loss: 0.004523 
[172] [94mvalidate[0m loss: 0.005301
EarlyStopping counter: 2 out of 40


  3%|████▍                                                                                                                           | 173/5000 [10:00<4:39:32,  3.47s/it]

[173] train loss: 0.004520 
[173] [94mvalidate[0m loss: 0.005276
EarlyStopping counter: 3 out of 40


  3%|████▍                                                                                                                           | 174/5000 [10:04<4:39:54,  3.48s/it]

[174] train loss: 0.004521 
[174] [94mvalidate[0m loss: 0.005274
EarlyStopping counter: 4 out of 40


  4%|████▍                                                                                                                           | 175/5000 [10:07<4:38:33,  3.46s/it]

[175] train loss: 0.004521 
[175] [94mvalidate[0m loss: 0.005279
EarlyStopping counter: 5 out of 40


  4%|████▌                                                                                                                           | 176/5000 [10:10<4:36:05,  3.43s/it]

[176] train loss: 0.004521 
[176] [94mvalidate[0m loss: 0.005269
EarlyStopping counter: 6 out of 40


  4%|████▌                                                                                                                           | 177/5000 [10:14<4:33:26,  3.40s/it]

[177] train loss: 0.004520 
[177] [94mvalidate[0m loss: 0.005276
EarlyStopping counter: 7 out of 40


  4%|████▌                                                                                                                           | 178/5000 [10:17<4:43:21,  3.53s/it]

[178] train loss: 0.004520 
[178] [94mvalidate[0m loss: 0.005298
EarlyStopping counter: 8 out of 40


  4%|████▌                                                                                                                           | 179/5000 [10:21<4:40:26,  3.49s/it]

[179] train loss: 0.004516 
[179] [94mvalidate[0m loss: 0.005278
Validation loss decreased (0.004519 --> 0.004516).  Saving model ...


  4%|████▌                                                                                                                           | 180/5000 [10:24<4:38:39,  3.47s/it]

[180] train loss: 0.004515 
[180] [94mvalidate[0m loss: 0.005269
Validation loss decreased (0.004516 --> 0.004515).  Saving model ...


  4%|████▋                                                                                                                           | 181/5000 [10:28<4:36:15,  3.44s/it]

[181] train loss: 0.004516 
[181] [94mvalidate[0m loss: 0.005272
EarlyStopping counter: 1 out of 40


  4%|████▋                                                                                                                           | 182/5000 [10:32<4:48:52,  3.60s/it]

[182] train loss: 0.004517 
[182] [94mvalidate[0m loss: 0.005279
EarlyStopping counter: 2 out of 40


  4%|████▋                                                                                                                           | 183/5000 [10:35<4:46:21,  3.57s/it]

[183] train loss: 0.004516 
[183] [94mvalidate[0m loss: 0.005271
EarlyStopping counter: 3 out of 40


  4%|████▋                                                                                                                           | 184/5000 [10:39<4:49:45,  3.61s/it]

[184] train loss: 0.004514 
[184] [94mvalidate[0m loss: 0.005285
Validation loss decreased (0.004515 --> 0.004514).  Saving model ...


  4%|████▋                                                                                                                           | 185/5000 [10:42<4:50:15,  3.62s/it]

[185] train loss: 0.004514 
[185] [94mvalidate[0m loss: 0.005281
Validation loss decreased (0.004514 --> 0.004514).  Saving model ...


  4%|████▊                                                                                                                           | 186/5000 [10:46<4:45:34,  3.56s/it]

[186] train loss: 0.004512 
[186] [94mvalidate[0m loss: 0.005269
Validation loss decreased (0.004514 --> 0.004512).  Saving model ...


  4%|████▊                                                                                                                           | 187/5000 [10:49<4:45:39,  3.56s/it]

[187] train loss: 0.004510 
[187] [94mvalidate[0m loss: 0.005267
Validation loss decreased (0.004512 --> 0.004510).  Saving model ...


  4%|████▊                                                                                                                           | 188/5000 [10:53<4:40:06,  3.49s/it]

[188] train loss: 0.004512 
[188] [94mvalidate[0m loss: 0.005276
EarlyStopping counter: 1 out of 40


  4%|████▊                                                                                                                           | 189/5000 [10:56<4:41:51,  3.52s/it]

[189] train loss: 0.004512 
[189] [94mvalidate[0m loss: 0.005304
EarlyStopping counter: 2 out of 40


  4%|████▊                                                                                                                           | 190/5000 [10:59<4:21:15,  3.26s/it]

[190] train loss: 0.004513 
[190] [94mvalidate[0m loss: 0.005284
EarlyStopping counter: 3 out of 40


  4%|████▉                                                                                                                           | 191/5000 [11:03<4:31:21,  3.39s/it]

[191] train loss: 0.004510 
[191] [94mvalidate[0m loss: 0.005273
Validation loss decreased (0.004510 --> 0.004510).  Saving model ...


  4%|████▉                                                                                                                           | 192/5000 [11:06<4:35:51,  3.44s/it]

[192] train loss: 0.004510 
[192] [94mvalidate[0m loss: 0.005280
Validation loss decreased (0.004510 --> 0.004510).  Saving model ...


  4%|████▉                                                                                                                           | 193/5000 [11:10<4:51:02,  3.63s/it]

[193] train loss: 0.004512 
[193] [94mvalidate[0m loss: 0.005275
EarlyStopping counter: 1 out of 40


  4%|████▉                                                                                                                           | 194/5000 [11:14<4:46:25,  3.58s/it]

[194] train loss: 0.004510 
[194] [94mvalidate[0m loss: 0.005287
Validation loss decreased (0.004510 --> 0.004510).  Saving model ...


  4%|████▉                                                                                                                           | 195/5000 [11:17<4:47:53,  3.59s/it]

[195] train loss: 0.004509 
[195] [94mvalidate[0m loss: 0.005275
Validation loss decreased (0.004510 --> 0.004509).  Saving model ...


  4%|█████                                                                                                                           | 196/5000 [11:20<4:35:08,  3.44s/it]

[196] train loss: 0.004510 
[196] [94mvalidate[0m loss: 0.005296
EarlyStopping counter: 1 out of 40


  4%|█████                                                                                                                           | 197/5000 [11:24<4:38:10,  3.47s/it]

[197] train loss: 0.004513 
[197] [94mvalidate[0m loss: 0.005271
EarlyStopping counter: 2 out of 40


  4%|█████                                                                                                                           | 198/5000 [11:28<4:46:38,  3.58s/it]

[198] train loss: 0.004507 
[198] [94mvalidate[0m loss: 0.005271
Validation loss decreased (0.004509 --> 0.004507).  Saving model ...


  4%|█████                                                                                                                           | 199/5000 [11:31<4:47:07,  3.59s/it]

[199] train loss: 0.004494 
[199] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004507 --> 0.004494).  Saving model ...


  4%|█████                                                                                                                           | 200/5000 [11:35<4:53:10,  3.66s/it]

[200] train loss: 0.004495 
[200] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  4%|█████▏                                                                                                                          | 201/5000 [11:39<4:51:48,  3.65s/it]

[201] train loss: 0.004494 
[201] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004494 --> 0.004494).  Saving model ...


  4%|█████▏                                                                                                                          | 202/5000 [11:43<4:50:25,  3.63s/it]

[202] train loss: 0.004494 
[202] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  4%|█████▏                                                                                                                          | 203/5000 [11:46<4:49:07,  3.62s/it]

[203] train loss: 0.004493 
[203] [94mvalidate[0m loss: 0.005267
Validation loss decreased (0.004494 --> 0.004493).  Saving model ...


  4%|█████▏                                                                                                                          | 204/5000 [11:50<4:47:00,  3.59s/it]

[204] train loss: 0.004494 
[204] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  4%|█████▏                                                                                                                          | 205/5000 [11:53<4:43:24,  3.55s/it]

[205] train loss: 0.004494 
[205] [94mvalidate[0m loss: 0.005264
EarlyStopping counter: 2 out of 40


  4%|█████▎                                                                                                                          | 206/5000 [11:57<4:41:46,  3.53s/it]

[206] train loss: 0.004493 
[206] [94mvalidate[0m loss: 0.005264
EarlyStopping counter: 3 out of 40


  4%|█████▎                                                                                                                          | 207/5000 [12:00<4:48:20,  3.61s/it]

[207] train loss: 0.004494 
[207] [94mvalidate[0m loss: 0.005268
EarlyStopping counter: 4 out of 40


  4%|█████▎                                                                                                                          | 208/5000 [12:04<4:47:49,  3.60s/it]

[208] train loss: 0.004494 
[208] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 5 out of 40


  4%|█████▎                                                                                                                          | 209/5000 [12:07<4:45:50,  3.58s/it]

[209] train loss: 0.004494 
[209] [94mvalidate[0m loss: 0.005267
EarlyStopping counter: 6 out of 40


  4%|█████▍                                                                                                                          | 210/5000 [12:11<4:36:32,  3.46s/it]

[210] train loss: 0.004494 
[210] [94mvalidate[0m loss: 0.005269
EarlyStopping counter: 7 out of 40


  4%|█████▍                                                                                                                          | 211/5000 [12:14<4:40:01,  3.51s/it]

[211] train loss: 0.004494 
[211] [94mvalidate[0m loss: 0.005267
EarlyStopping counter: 8 out of 40


  4%|█████▍                                                                                                                          | 212/5000 [12:18<4:43:22,  3.55s/it]

[212] train loss: 0.004492 
[212] [94mvalidate[0m loss: 0.005266
Validation loss decreased (0.004493 --> 0.004492).  Saving model ...


  4%|█████▍                                                                                                                          | 213/5000 [12:22<4:44:02,  3.56s/it]

[213] train loss: 0.004492 
[213] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004492 --> 0.004492).  Saving model ...


  4%|█████▍                                                                                                                          | 214/5000 [12:25<4:39:29,  3.50s/it]

[214] train loss: 0.004491 
[214] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004492 --> 0.004491).  Saving model ...


  4%|█████▌                                                                                                                          | 215/5000 [12:28<4:28:51,  3.37s/it]

[215] train loss: 0.004491 
[215] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  4%|█████▌                                                                                                                          | 216/5000 [12:32<4:33:23,  3.43s/it]

[216] train loss: 0.004492 
[216] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 2 out of 40


  4%|█████▌                                                                                                                          | 217/5000 [12:35<4:34:34,  3.44s/it]

[217] train loss: 0.004492 
[217] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 3 out of 40


  4%|█████▌                                                                                                                          | 218/5000 [12:38<4:33:40,  3.43s/it]

[218] train loss: 0.004491 
[218] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004491 --> 0.004491).  Saving model ...


  4%|█████▌                                                                                                                          | 219/5000 [12:42<4:37:43,  3.49s/it]

[219] train loss: 0.004492 
[219] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  4%|█████▋                                                                                                                          | 220/5000 [12:46<4:42:44,  3.55s/it]

[220] train loss: 0.004491 
[220] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 2 out of 40


  4%|█████▋                                                                                                                          | 221/5000 [12:49<4:43:14,  3.56s/it]

[221] train loss: 0.004492 
[221] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 3 out of 40


  4%|█████▋                                                                                                                          | 222/5000 [12:53<4:44:15,  3.57s/it]

[222] train loss: 0.004491 
[222] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 4 out of 40


  4%|█████▋                                                                                                                          | 223/5000 [12:56<4:41:08,  3.53s/it]

[223] train loss: 0.004491 
[223] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 5 out of 40


  4%|█████▋                                                                                                                          | 224/5000 [13:00<4:42:01,  3.54s/it]

[224] train loss: 0.004491 
[224] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 6 out of 40


  4%|█████▊                                                                                                                          | 225/5000 [13:03<4:29:20,  3.38s/it]

[225] train loss: 0.004491 
[225] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 7 out of 40


  5%|█████▊                                                                                                                          | 226/5000 [13:06<4:28:42,  3.38s/it]

[226] train loss: 0.004491 
[226] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 8 out of 40


  5%|█████▊                                                                                                                          | 227/5000 [13:10<4:32:24,  3.42s/it]

[227] train loss: 0.004492 
[227] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 9 out of 40


  5%|█████▊                                                                                                                          | 228/5000 [13:14<4:47:39,  3.62s/it]

[228] train loss: 0.004492 
[228] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 10 out of 40


  5%|█████▊                                                                                                                          | 229/5000 [13:17<4:45:18,  3.59s/it]

[229] train loss: 0.004491 
[229] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 11 out of 40


  5%|█████▉                                                                                                                          | 230/5000 [13:21<4:44:25,  3.58s/it]

[230] train loss: 0.004492 
[230] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 12 out of 40


  5%|█████▉                                                                                                                          | 231/5000 [13:24<4:42:18,  3.55s/it]

[231] train loss: 0.004492 
[231] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 13 out of 40


  5%|█████▉                                                                                                                          | 232/5000 [13:28<4:44:05,  3.57s/it]

[232] train loss: 0.004491 
[232] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 14 out of 40


  5%|█████▉                                                                                                                          | 233/5000 [13:31<4:40:34,  3.53s/it]

[233] train loss: 0.004492 
[233] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 15 out of 40


  5%|█████▉                                                                                                                          | 234/5000 [13:35<4:45:11,  3.59s/it]

[234] train loss: 0.004491 
[234] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 16 out of 40


  5%|██████                                                                                                                          | 235/5000 [13:39<4:43:10,  3.57s/it]

[235] train loss: 0.004491 
[235] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 17 out of 40


  5%|██████                                                                                                                          | 236/5000 [13:42<4:41:01,  3.54s/it]

[236] train loss: 0.004492 
[236] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 18 out of 40


  5%|██████                                                                                                                          | 237/5000 [13:46<4:40:28,  3.53s/it]

[237] train loss: 0.004491 
[237] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 19 out of 40


  5%|██████                                                                                                                          | 238/5000 [13:49<4:36:18,  3.48s/it]

[238] train loss: 0.004492 
[238] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 20 out of 40


  5%|██████                                                                                                                          | 239/5000 [13:52<4:33:57,  3.45s/it]

[239] train loss: 0.004492 
[239] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 21 out of 40


  5%|██████▏                                                                                                                         | 240/5000 [13:56<4:38:14,  3.51s/it]

[240] train loss: 0.004492 
[240] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 22 out of 40


  5%|██████▏                                                                                                                         | 241/5000 [14:00<4:36:16,  3.48s/it]

[241] train loss: 0.004492 
[241] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 23 out of 40


  5%|██████▏                                                                                                                         | 242/5000 [14:03<4:39:23,  3.52s/it]

[242] train loss: 0.004492 
[242] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 24 out of 40


  5%|██████▏                                                                                                                         | 243/5000 [14:07<4:39:05,  3.52s/it]

[243] train loss: 0.004492 
[243] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 25 out of 40


  5%|██████▏                                                                                                                         | 244/5000 [14:10<4:36:42,  3.49s/it]

[244] train loss: 0.004491 
[244] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 26 out of 40


  5%|██████▎                                                                                                                         | 245/5000 [14:14<4:39:19,  3.52s/it]

[245] train loss: 0.004492 
[245] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 27 out of 40


  5%|██████▎                                                                                                                         | 246/5000 [14:17<4:43:59,  3.58s/it]

[246] train loss: 0.004492 
[246] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 28 out of 40


  5%|██████▎                                                                                                                         | 247/5000 [14:21<4:41:03,  3.55s/it]

[247] train loss: 0.004492 
[247] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 29 out of 40


  5%|██████▎                                                                                                                         | 248/5000 [14:24<4:39:58,  3.54s/it]

[248] train loss: 0.004491 
[248] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 30 out of 40


  5%|██████▎                                                                                                                         | 249/5000 [14:28<4:44:13,  3.59s/it]

[249] train loss: 0.004491 
[249] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004491 --> 0.004491).  Saving model ...


  5%|██████▍                                                                                                                         | 250/5000 [14:32<4:48:21,  3.64s/it]

[250] train loss: 0.004492 
[250] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  5%|██████▍                                                                                                                         | 251/5000 [14:35<4:47:31,  3.63s/it]

[251] train loss: 0.004491 
[251] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 2 out of 40


  5%|██████▍                                                                                                                         | 252/5000 [14:39<4:44:03,  3.59s/it]

[252] train loss: 0.004492 
[252] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 3 out of 40


  5%|██████▍                                                                                                                         | 253/5000 [14:43<4:42:58,  3.58s/it]

[253] train loss: 0.004491 
[253] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 4 out of 40


  5%|██████▌                                                                                                                         | 254/5000 [14:46<4:45:43,  3.61s/it]

[254] train loss: 0.004491 
[254] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 5 out of 40


  5%|██████▌                                                                                                                         | 255/5000 [14:50<4:40:07,  3.54s/it]

[255] train loss: 0.004491 
[255] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 6 out of 40


  5%|██████▌                                                                                                                         | 256/5000 [14:53<4:38:17,  3.52s/it]

[256] train loss: 0.004492 
[256] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 7 out of 40


  5%|██████▌                                                                                                                         | 257/5000 [14:57<4:44:47,  3.60s/it]

[257] train loss: 0.004491 
[257] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 8 out of 40


  5%|██████▌                                                                                                                         | 258/5000 [15:01<4:52:27,  3.70s/it]

[258] train loss: 0.004491 
[258] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 9 out of 40


  5%|██████▋                                                                                                                         | 259/5000 [15:05<4:59:33,  3.79s/it]

[259] train loss: 0.004492 
[259] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 10 out of 40


  5%|██████▋                                                                                                                         | 260/5000 [15:09<5:04:55,  3.86s/it]

[260] train loss: 0.004491 
[260] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 11 out of 40


  5%|██████▋                                                                                                                         | 261/5000 [15:13<5:06:14,  3.88s/it]

[261] train loss: 0.004491 
[261] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 12 out of 40


  5%|██████▋                                                                                                                         | 262/5000 [15:17<5:07:13,  3.89s/it]

[262] train loss: 0.004491 
[262] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 13 out of 40


  5%|██████▋                                                                                                                         | 263/5000 [15:21<5:07:58,  3.90s/it]

[263] train loss: 0.004492 
[263] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 14 out of 40


  5%|██████▊                                                                                                                         | 264/5000 [15:24<5:04:43,  3.86s/it]

[264] train loss: 0.004491 
[264] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004491 --> 0.004491).  Saving model ...


  5%|██████▊                                                                                                                         | 265/5000 [15:28<5:10:09,  3.93s/it]

[265] train loss: 0.004491 
[265] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  5%|██████▊                                                                                                                         | 266/5000 [15:32<4:59:46,  3.80s/it]

[266] train loss: 0.004491 
[266] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 2 out of 40


  5%|██████▊                                                                                                                         | 267/5000 [15:36<5:02:05,  3.83s/it]

[267] train loss: 0.004491 
[267] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 3 out of 40


  5%|██████▊                                                                                                                         | 268/5000 [15:39<4:57:01,  3.77s/it]

[268] train loss: 0.004492 
[268] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 4 out of 40


  5%|██████▉                                                                                                                         | 269/5000 [15:43<4:58:52,  3.79s/it]

[269] train loss: 0.004491 
[269] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 5 out of 40


  5%|██████▉                                                                                                                         | 270/5000 [15:47<5:02:32,  3.84s/it]

[270] train loss: 0.004491 
[270] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 6 out of 40


  5%|██████▉                                                                                                                         | 271/5000 [15:51<5:06:11,  3.88s/it]

[271] train loss: 0.004491 
[271] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 7 out of 40


  5%|██████▉                                                                                                                         | 272/5000 [15:55<5:14:06,  3.99s/it]

[272] train loss: 0.004491 
[272] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 8 out of 40


  5%|██████▉                                                                                                                         | 273/5000 [15:59<5:09:00,  3.92s/it]

[273] train loss: 0.004492 
[273] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 9 out of 40


  5%|███████                                                                                                                         | 274/5000 [16:03<5:07:10,  3.90s/it]

[274] train loss: 0.004491 
[274] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 10 out of 40


  6%|███████                                                                                                                         | 275/5000 [16:07<5:10:54,  3.95s/it]

[275] train loss: 0.004492 
[275] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 11 out of 40


  6%|███████                                                                                                                         | 276/5000 [16:11<5:11:19,  3.95s/it]

[276] train loss: 0.004491 
[276] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004491 --> 0.004491).  Saving model ...


  6%|███████                                                                                                                         | 277/5000 [16:15<5:10:18,  3.94s/it]

[277] train loss: 0.004492 
[277] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  6%|███████                                                                                                                         | 278/5000 [16:19<5:04:21,  3.87s/it]

[278] train loss: 0.004491 
[278] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 2 out of 40


  6%|███████▏                                                                                                                        | 279/5000 [16:22<4:59:47,  3.81s/it]

[279] train loss: 0.004492 
[279] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 3 out of 40


  6%|███████▏                                                                                                                        | 280/5000 [16:26<4:57:17,  3.78s/it]

[280] train loss: 0.004491 
[280] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 4 out of 40


  6%|███████▏                                                                                                                        | 281/5000 [16:30<5:03:01,  3.85s/it]

[281] train loss: 0.004491 
[281] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 5 out of 40


  6%|███████▏                                                                                                                        | 282/5000 [16:34<4:58:24,  3.79s/it]

[282] train loss: 0.004492 
[282] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 6 out of 40


  6%|███████▏                                                                                                                        | 283/5000 [16:37<4:53:16,  3.73s/it]

[283] train loss: 0.004491 
[283] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 7 out of 40


  6%|███████▎                                                                                                                        | 284/5000 [16:42<5:04:05,  3.87s/it]

[284] train loss: 0.004491 
[284] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 8 out of 40


  6%|███████▎                                                                                                                        | 285/5000 [16:45<5:04:51,  3.88s/it]

[285] train loss: 0.004492 
[285] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 9 out of 40


  6%|███████▎                                                                                                                        | 286/5000 [16:49<5:05:20,  3.89s/it]

[286] train loss: 0.004492 
[286] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 10 out of 40


  6%|███████▎                                                                                                                        | 287/5000 [16:53<5:06:40,  3.90s/it]

[287] train loss: 0.004491 
[287] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 11 out of 40


  6%|███████▎                                                                                                                        | 288/5000 [16:57<5:02:29,  3.85s/it]

[288] train loss: 0.004491 
[288] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004491 --> 0.004491).  Saving model ...


  6%|███████▍                                                                                                                        | 289/5000 [17:01<4:57:33,  3.79s/it]

[289] train loss: 0.004491 
[289] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  6%|███████▍                                                                                                                        | 290/5000 [17:05<4:59:10,  3.81s/it]

[290] train loss: 0.004491 
[290] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 2 out of 40


  6%|███████▍                                                                                                                        | 291/5000 [17:08<5:00:57,  3.83s/it]

[291] train loss: 0.004491 
[291] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 3 out of 40


  6%|███████▍                                                                                                                        | 292/5000 [17:12<5:04:26,  3.88s/it]

[292] train loss: 0.004491 
[292] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 4 out of 40


  6%|███████▌                                                                                                                        | 293/5000 [17:16<5:05:14,  3.89s/it]

[293] train loss: 0.004492 
[293] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 5 out of 40


  6%|███████▌                                                                                                                        | 294/5000 [17:20<5:02:41,  3.86s/it]

[294] train loss: 0.004491 
[294] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 6 out of 40


  6%|███████▌                                                                                                                        | 295/5000 [17:24<5:01:11,  3.84s/it]

[295] train loss: 0.004491 
[295] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 7 out of 40


  6%|███████▌                                                                                                                        | 296/5000 [17:28<5:00:11,  3.83s/it]

[296] train loss: 0.004492 
[296] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 8 out of 40


  6%|███████▌                                                                                                                        | 297/5000 [17:32<5:15:10,  4.02s/it]

[297] train loss: 0.004491 
[297] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 9 out of 40


  6%|███████▋                                                                                                                        | 298/5000 [17:36<5:15:22,  4.02s/it]

[298] train loss: 0.004491 
[298] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 10 out of 40


  6%|███████▋                                                                                                                        | 299/5000 [17:40<5:15:22,  4.03s/it]

[299] train loss: 0.004491 
[299] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 11 out of 40


  6%|███████▋                                                                                                                        | 300/5000 [17:44<5:08:14,  3.93s/it]

[300] train loss: 0.004492 
[300] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 12 out of 40


  6%|███████▋                                                                                                                        | 301/5000 [17:46<4:34:39,  3.51s/it]

[301] train loss: 0.004492 
[301] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 13 out of 40


  6%|███████▋                                                                                                                        | 302/5000 [17:50<4:45:19,  3.64s/it]

[302] train loss: 0.004491 
[302] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 14 out of 40


  6%|███████▊                                                                                                                        | 303/5000 [17:54<4:46:23,  3.66s/it]

[303] train loss: 0.004491 
[303] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 15 out of 40


  6%|███████▊                                                                                                                        | 304/5000 [17:57<4:23:18,  3.36s/it]

[304] train loss: 0.004491 
[304] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 16 out of 40


  6%|███████▊                                                                                                                        | 305/5000 [18:01<4:46:35,  3.66s/it]

[305] train loss: 0.004491 
[305] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 17 out of 40


  6%|███████▊                                                                                                                        | 306/5000 [18:05<4:50:37,  3.71s/it]

[306] train loss: 0.004491 
[306] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 18 out of 40


  6%|███████▊                                                                                                                        | 307/5000 [18:09<4:56:48,  3.79s/it]

[307] train loss: 0.004491 
[307] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 19 out of 40


  6%|███████▉                                                                                                                        | 308/5000 [18:13<5:00:55,  3.85s/it]

[308] train loss: 0.004491 
[308] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 20 out of 40


  6%|███████▉                                                                                                                        | 309/5000 [18:17<5:01:42,  3.86s/it]

[309] train loss: 0.004492 
[309] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 21 out of 40


  6%|███████▉                                                                                                                        | 310/5000 [18:21<5:03:12,  3.88s/it]

[310] train loss: 0.004491 
[310] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 22 out of 40


  6%|███████▉                                                                                                                        | 311/5000 [18:24<4:57:28,  3.81s/it]

[311] train loss: 0.004491 
[311] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 23 out of 40


  6%|███████▉                                                                                                                        | 312/5000 [18:28<4:54:59,  3.78s/it]

[312] train loss: 0.004491 
[312] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 24 out of 40


  6%|████████                                                                                                                        | 313/5000 [18:32<4:54:32,  3.77s/it]

[313] train loss: 0.004491 
[313] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 25 out of 40


  6%|████████                                                                                                                        | 314/5000 [18:36<4:57:36,  3.81s/it]

[314] train loss: 0.004491 
[314] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 26 out of 40


  6%|████████                                                                                                                        | 315/5000 [18:40<5:03:49,  3.89s/it]

[315] train loss: 0.004491 
[315] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 27 out of 40


  6%|████████                                                                                                                        | 316/5000 [18:44<4:59:55,  3.84s/it]

[316] train loss: 0.004491 
[316] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004491 --> 0.004491).  Saving model ...


  6%|████████                                                                                                                        | 317/5000 [18:47<4:56:24,  3.80s/it]

[317] train loss: 0.004491 
[317] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  6%|████████▏                                                                                                                       | 318/5000 [18:49<4:15:15,  3.27s/it]

[318] train loss: 0.004492 
[318] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 2 out of 40


  6%|████████▏                                                                                                                       | 319/5000 [18:51<3:46:37,  2.90s/it]

[319] train loss: 0.004491 
[319] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 3 out of 40


  6%|████████▏                                                                                                                       | 320/5000 [18:54<3:39:16,  2.81s/it]

[320] train loss: 0.004492 
[320] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 4 out of 40


  6%|████████▏                                                                                                                       | 321/5000 [18:58<4:07:16,  3.17s/it]

[321] train loss: 0.004491 
[321] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 5 out of 40


  6%|████████▏                                                                                                                       | 322/5000 [19:02<4:28:21,  3.44s/it]

[322] train loss: 0.004491 
[322] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 6 out of 40


  6%|████████▎                                                                                                                       | 323/5000 [19:06<4:39:22,  3.58s/it]

[323] train loss: 0.004491 
[323] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 7 out of 40


  6%|████████▎                                                                                                                       | 324/5000 [19:10<4:48:17,  3.70s/it]

[324] train loss: 0.004491 
[324] [94mvalidate[0m loss: 0.005265
Validation loss decreased (0.004491 --> 0.004491).  Saving model ...


  6%|████████▎                                                                                                                       | 325/5000 [19:14<4:50:44,  3.73s/it]

[325] train loss: 0.004491 
[325] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 1 out of 40


  7%|████████▎                                                                                                                       | 326/5000 [19:18<4:59:45,  3.85s/it]

[326] train loss: 0.004491 
[326] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 2 out of 40


  7%|████████▎                                                                                                                       | 327/5000 [19:22<4:56:57,  3.81s/it]

[327] train loss: 0.004491 
[327] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 3 out of 40


  7%|████████▍                                                                                                                       | 328/5000 [19:26<5:00:21,  3.86s/it]

[328] train loss: 0.004491 
[328] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 4 out of 40


  7%|████████▍                                                                                                                       | 329/5000 [19:30<5:13:01,  4.02s/it]

[329] train loss: 0.004491 
[329] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 5 out of 40


  7%|████████▍                                                                                                                       | 330/5000 [19:34<5:12:24,  4.01s/it]

[330] train loss: 0.004491 
[330] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 6 out of 40


  7%|████████▍                                                                                                                       | 331/5000 [19:38<5:10:28,  3.99s/it]

[331] train loss: 0.004491 
[331] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 7 out of 40


  7%|████████▍                                                                                                                       | 332/5000 [19:42<5:11:32,  4.00s/it]

[332] train loss: 0.004491 
[332] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 8 out of 40


  7%|████████▌                                                                                                                       | 333/5000 [19:46<5:10:24,  3.99s/it]

[333] train loss: 0.004491 
[333] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 9 out of 40


  7%|████████▌                                                                                                                       | 334/5000 [19:50<5:12:36,  4.02s/it]

[334] train loss: 0.004491 
[334] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 10 out of 40


  7%|████████▌                                                                                                                       | 335/5000 [19:54<5:05:59,  3.94s/it]

[335] train loss: 0.004491 
[335] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 11 out of 40


  7%|████████▌                                                                                                                       | 336/5000 [19:58<5:05:29,  3.93s/it]

[336] train loss: 0.004491 
[336] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 12 out of 40


  7%|████████▋                                                                                                                       | 337/5000 [20:02<5:16:37,  4.07s/it]

[337] train loss: 0.004491 
[337] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 13 out of 40


  7%|████████▋                                                                                                                       | 338/5000 [20:06<5:13:37,  4.04s/it]

[338] train loss: 0.004491 
[338] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 14 out of 40


  7%|████████▋                                                                                                                       | 339/5000 [20:10<5:13:59,  4.04s/it]

[339] train loss: 0.004491 
[339] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 15 out of 40


  7%|████████▋                                                                                                                       | 340/5000 [20:14<5:11:28,  4.01s/it]

[340] train loss: 0.004491 
[340] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 16 out of 40


  7%|████████▋                                                                                                                       | 341/5000 [20:18<5:12:04,  4.02s/it]

[341] train loss: 0.004491 
[341] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 17 out of 40


  7%|████████▊                                                                                                                       | 342/5000 [20:22<5:06:00,  3.94s/it]

[342] train loss: 0.004491 
[342] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 18 out of 40


  7%|████████▊                                                                                                                       | 343/5000 [20:26<5:02:14,  3.89s/it]

[343] train loss: 0.004491 
[343] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 19 out of 40


  7%|████████▊                                                                                                                       | 344/5000 [20:30<5:04:39,  3.93s/it]

[344] train loss: 0.004491 
[344] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 20 out of 40


  7%|████████▊                                                                                                                       | 345/5000 [20:34<5:05:47,  3.94s/it]

[345] train loss: 0.004491 
[345] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 21 out of 40


  7%|████████▊                                                                                                                       | 346/5000 [20:38<5:13:07,  4.04s/it]

[346] train loss: 0.004491 
[346] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 22 out of 40


  7%|████████▉                                                                                                                       | 347/5000 [20:42<5:12:33,  4.03s/it]

[347] train loss: 0.004491 
[347] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 23 out of 40


  7%|████████▉                                                                                                                       | 348/5000 [20:46<5:09:33,  3.99s/it]

[348] train loss: 0.004491 
[348] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 24 out of 40


  7%|████████▉                                                                                                                       | 349/5000 [20:50<5:09:53,  4.00s/it]

[349] train loss: 0.004491 
[349] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 25 out of 40


  7%|████████▉                                                                                                                       | 350/5000 [20:54<5:08:26,  3.98s/it]

[350] train loss: 0.004491 
[350] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 26 out of 40


  7%|████████▉                                                                                                                       | 351/5000 [20:58<5:09:54,  4.00s/it]

[351] train loss: 0.004491 
[351] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 27 out of 40


  7%|█████████                                                                                                                       | 352/5000 [21:02<5:08:41,  3.98s/it]

[352] train loss: 0.004491 
[352] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 28 out of 40


  7%|█████████                                                                                                                       | 353/5000 [21:06<5:06:15,  3.95s/it]

[353] train loss: 0.004491 
[353] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 29 out of 40


  7%|█████████                                                                                                                       | 354/5000 [21:09<5:01:10,  3.89s/it]

[354] train loss: 0.004491 
[354] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 30 out of 40


  7%|█████████                                                                                                                       | 355/5000 [21:13<5:01:48,  3.90s/it]

[355] train loss: 0.004491 
[355] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 31 out of 40


  7%|█████████                                                                                                                       | 356/5000 [21:17<4:55:55,  3.82s/it]

[356] train loss: 0.004491 
[356] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 32 out of 40


  7%|█████████▏                                                                                                                      | 357/5000 [21:21<4:57:27,  3.84s/it]

[357] train loss: 0.004491 
[357] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 33 out of 40


  7%|█████████▏                                                                                                                      | 358/5000 [21:25<4:57:43,  3.85s/it]

[358] train loss: 0.004491 
[358] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 34 out of 40


  7%|█████████▏                                                                                                                      | 359/5000 [21:28<4:58:31,  3.86s/it]

[359] train loss: 0.004491 
[359] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 35 out of 40


  7%|█████████▏                                                                                                                      | 360/5000 [21:33<5:05:58,  3.96s/it]

[360] train loss: 0.004491 
[360] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 36 out of 40


  7%|█████████▏                                                                                                                      | 361/5000 [21:37<5:03:27,  3.92s/it]

[361] train loss: 0.004491 
[361] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 37 out of 40


  7%|█████████▎                                                                                                                      | 362/5000 [21:40<4:59:41,  3.88s/it]

[362] train loss: 0.004491 
[362] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 38 out of 40


  7%|█████████▎                                                                                                                      | 363/5000 [21:44<5:02:10,  3.91s/it]

[363] train loss: 0.004491 
[363] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 39 out of 40


  7%|█████████▎                                                                                                                      | 363/5000 [21:48<4:38:35,  3.60s/it]

[364] train loss: 0.004491 
[364] [94mvalidate[0m loss: 0.005265
EarlyStopping counter: 40 out of 40
Early stopping





R2 of test is:  0.9247267271910061
Test set results for 11309 samples:
MSE: 0.0049529984
MAE: 0.042119555
MSE loss on test set is: 3.893242470174962e-05
time used to train model with 40/1000 patience is:  21.9355857372672 mins


# get prediction on training/experiment data

In [12]:
def predict_on_test(config): 
    
    data_dir = 'datasets/train/'
    exp_cscl_dir = 'datasets/exp/rutitle_cscl'
    exp_licl_dir = 'datasets/exp/rutitle_LiCl'

    # get training dataset 
    train_set, _, _ = load_data(exp_cscl_dir)
    # train_set, _, _ = load_data(exp_licl_dir)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=int(config["batch_size"]),
        shuffle=True)
            
    print("Creating model")
    Model = getattr(models, config['model'])
    print('created model is: ', Model)
    
    if config['model'] == 'DeepNet4LayerTune': 
        model = Model(config['batch_norm'], config['layer_norm'], config['constraint'],
                      config["l1"], config["l2"], config["l3"],)
    elif config['model'] == 'DeepNet5LayerTune': 
        model = Model(config['batch_norm'], config['layer_norm'], config['constraint'],
                     config["l1"], config["l2"], config["l3"], config["l4"])
    else: 
        model = Model(config['batch_norm'], config['layer_norm'], config['constraint'],
                     config["l1"], config["l2"], config["l3"], config["l4"], config["l5"])
        
    name = f"{config['model']}_171inputs_{config['l1']}{config['l2']}{config['l3']}{config['l4']}{config['l5']}lr{config['lr']}BS{config['batch_size']}isB{config['batch_norm']}ln{config['layer_norm']}cons{config['constraint']}Opt{config['optimizer']}sch{config['lr_scheduler']}"
    name1 = f"cscl_{config['model']}_171inputs_{config['l1']}{config['l2']}{config['l3']}{config['l4']}{config['l5']}lr{config['lr']}BS{config['batch_size']}isB{config['batch_norm']}ln{config['layer_norm']}cons{config['constraint']}Opt{config['optimizer']}sch{config['lr_scheduler']}"
   
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cuda")
    # if torch.cuda.device_count() > 1:
    #     model = nn.DataParallel(model)
    model.to(device)

    optimizer, lr_scheduler = build_optimizer(model, config['optimizer'], config['lr'])
    model.to(device)

    res_dir = 'simpleDNN/res/'
    
    checkpoint_dir = os.path.join(res_dir, 'checkpoints')
    checkpoint_path = os.path.join(checkpoint_dir, f'{name}.pt')

    model.load_state_dict(torch.load(checkpoint_path))

    # test on test set 
    test_accuracy(model, train_loader, res_dir, name1, device)   

In [13]:
import time  
start_time = time.perf_counter()
predict_on_test(config)
end_time = time.perf_counter() 
print('time used to train model with 20/250 patience is: ', (end_time - start_time), 'secs')

Creating model
created model is:  <class 'model.net.DeepNet5LayerTune'>
R2 of test is:  nan
Test set results for 1 samples:
MSE: 0.11755475
MAE: 0.29246107
MSE loss on test set is: 0.11755474656820297
time used to train model with 20/250 patience is:  0.028642237186431885 secs


