In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import pandas as pd
import numpy as np
import pickle
import torch
from torch import nn
import yaml
from pathlib import Path
from collections import OrderedDict

import proxyClassifier
import perturbation_model
from proxyClassifier import proxy_clf_network, proxy_clf

from perturbation_model import perturb_network, perturb_clf

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

try:
    file_path = os.path.realpath(__file__)
    base_path =  os.path.dirname(basepath)
except:
    base_path = './'
CONFIG_FILE = os.path.join(base_path, 'config.yaml')
ID_COL = 'PanjivaRecordID'


# Globals
DIR = None
Data_Dir = None
Anom_Dir = None
data_count = None 
clf_batch_size = None
clf_model_save_dir = None 
perturbModel_train_epochs = None
def setup_config(_dir):
    global DIR, Data_Dir, Anom_Dir
    global CONFIG_FILE
    global data_count
    global clf_batch_size
    global clf_model_save_dir
    global base_path
    global perturbModel_train_epochs
    
    with open(CONFIG_FILE,'r') as fh:
        CONFIG = yaml.safe_load(fh)
    Data_Dir = CONFIG['data_loc']
    Anom_Dir = CONFIG['anom_data_path']
    DIR = _dir
    data_count = CONFIG['data_count']
    clf_batch_size = CONFIG['clf_batch_size']
    clf_model_save_dir = os.path.join(base_path, './clf_model_save_dir', DIR)
    Path(clf_model_save_dir).mkdir(exist_ok=True, parents=True)
    Path(clf_model_save_dir).mkdir(exist_ok=True, parents=True)
    perturbModel_train_epochs = CONFIG['perturbModel_train_epochs']
    return

setup_config('us_import1')

In [2]:
def get_domain_dims():
    global DIR, Data_Dir
    with open(os.path.join(Data_Dir, DIR, 'domain_dims.pkl'), 'rb') as fh:
        domain_dims = OrderedDict(pickle.load(fh))
    return domain_dims

In [3]:
def proxy_training_data():
    global DIR, Data_Dir, Anom_Dir, data_points, ID_COL, data_count
    # normal_instances
    path_normal = os.path.join(Data_Dir, DIR, 'train_data.csv')
    df_normal = pd.read_csv(path_normal,index_col=None)
    
    path_anom = os.path.join(Anom_Dir, DIR, 'anomalies_1_1.csv')
    df_anomalies_1 = pd.read_csv(path_normal,index_col=None)
    
    path_anom = os.path.join(Anom_Dir, DIR, 'anomalies_1_2.csv')
    df_anomalies_2 = pd.read_csv(path_normal,index_col=None)
    
    df_anomalies = df_anomalies_1.append(df_anomalies_2,ignore_index=True)
    
    
    if data_count > 0:
        if len(df_anomalies) < data_count:
            df_anomalies = df_anomalies.sample(n = data_count, replace=True)
        else:
            df_anomalies = df_anomalies.sample(n = data_count)

        df_normal = df_normal.sample(n = data_count) 
        
    # create X, Y
    Y1 = np.ones([len(df_normal)])
    Y2 = np.zeros([len(df_anomalies)])
    
    try: del df_normal[ID_COL]
    except: pass
    try: del df_anomalies[ID_COL]
    except: pass
    X1 = df_normal.values
    X2 = df_anomalies.values
    
    X = np.vstack([X1,X2])
    Y = np.hstack([Y1,Y2])            
    
    return X,Y
    

In [4]:
def get_proxy_clf():
    """
    Train model  and store
    """
    global clf_batch_size, DIR
    global DEVICE
    global clf_model_save_dir
    domain_dims = get_domain_dims()
    
    network = proxy_clf_network(
        list(domain_dims.values())
    )
    clf_obj = proxy_clf(
            model = network,
            dataset = DIR, 
            save_dir = clf_model_save_dir,
            batch_size= clf_batch_size,
            device = DEVICE    
        )
    
    try:
        clf_obj.load_model()
    except:
        
        train_X, train_Y = proxy_training_data()
        clf_obj.fit(train_X, train_Y.reshape([-1,1]), num_epochs=200)
        clf_obj.save_model()
    
    return clf_obj

In [5]:
clf_obj = get_proxy_clf()

Device cuda:0


In [6]:
def train_pert_model():
    global DEVICE, DIR
    domain_dims = get_domain_dims()
    clf_obj = get_proxy_clf()
    clf_network = clf_obj.model
    
    perturb_network_obj  = perturb_network(
        list(domain_dims.values())
    )
    
    perturb_clf_obj = perturb_clf(
        perturb_model = perturb_network_obj,
        clf_model = clf_network,
        dataset = DIR,
        device = DEVICE
    )
    # Train the model 
    
    return perturb_clf_obj

In [7]:
perturb_clf_obj = train_pert_model()

Device cuda:0


In [8]:
train_X, train_Y = proxy_training_data()

train_X.shape

train_Y.shape

(20000,)

In [9]:
loss_values = perturb_clf_obj.fit(X = train_X[-1000:], Y = np.abs(1 - train_Y[-1000:]).reshape([-1,1]), num_epochs = perturbModel_train_epochs)

  1%|          | 10/1000 [00:00<00:36, 26.96it/s]

[Epoch] 0  | batch 0 | Loss 1.748657
[Epoch] 1  | batch 0 | Loss 1.781845
[Epoch] 2  | batch 0 | Loss 1.983910
[Epoch] 3  | batch 0 | Loss 1.955962
[Epoch] 4  | batch 0 | Loss 1.920985
[Epoch] 5  | batch 0 | Loss 2.030923
[Epoch] 6  | batch 0 | Loss 2.083749
[Epoch] 7  | batch 0 | Loss 2.339351
[Epoch] 8  | batch 0 | Loss 2.233174
[Epoch] 9  | batch 0 | Loss 2.075572
[Epoch] 10  | batch 0 | Loss 1.855489
[Epoch] 11  | batch 0 | Loss 2.308291
[Epoch] 12  | batch 0 | Loss 2.140818
[Epoch] 13  | batch 0 | Loss 1.858302
[Epoch] 14  | batch 0 | Loss 2.053906
[Epoch] 15  | batch 0 | Loss 2.384487
[Epoch] 16  | batch 0 | Loss 2.008908


  3%|▎         | 26/1000 [00:00<00:18, 52.49it/s]

[Epoch] 17  | batch 0 | Loss 1.966647
[Epoch] 18  | batch 0 | Loss 2.172150
[Epoch] 19  | batch 0 | Loss 2.151794
[Epoch] 20  | batch 0 | Loss 2.093338
[Epoch] 21  | batch 0 | Loss 1.764663
[Epoch] 22  | batch 0 | Loss 1.881285
[Epoch] 23  | batch 0 | Loss 1.833573
[Epoch] 24  | batch 0 | Loss 2.060652
[Epoch] 25  | batch 0 | Loss 2.118948
[Epoch] 26  | batch 0 | Loss 1.873251
[Epoch] 27  | batch 0 | Loss 1.937286
[Epoch] 28  | batch 0 | Loss 1.927108
[Epoch] 29  | batch 0 | Loss 2.081978
[Epoch] 30  | batch 0 | Loss 1.942453
[Epoch] 31  | batch 0 | Loss 1.958050
[Epoch] 32  | batch 0 | Loss 1.630450
[Epoch] 33  | batch 0 | Loss 1.924700


  4%|▍         | 44/1000 [00:00<00:14, 67.87it/s]

[Epoch] 34  | batch 0 | Loss 2.284096
[Epoch] 35  | batch 0 | Loss 1.876199
[Epoch] 36  | batch 0 | Loss 2.122646
[Epoch] 37  | batch 0 | Loss 1.954559
[Epoch] 38  | batch 0 | Loss 2.249629
[Epoch] 39  | batch 0 | Loss 2.000776
[Epoch] 40  | batch 0 | Loss 1.977582
[Epoch] 41  | batch 0 | Loss 1.947464
[Epoch] 42  | batch 0 | Loss 1.755363
[Epoch] 43  | batch 0 | Loss 2.324058
[Epoch] 44  | batch 0 | Loss 1.806571
[Epoch] 45  | batch 0 | Loss 1.740412
[Epoch] 46  | batch 0 | Loss 2.067109
[Epoch] 47  | batch 0 | Loss 2.318245
[Epoch] 48  | batch 0 | Loss 2.216286
[Epoch] 49  | batch 0 | Loss 1.805020
[Epoch] 50  | batch 0 | Loss 2.093611


  6%|▌         | 61/1000 [00:01<00:12, 74.19it/s]

[Epoch] 51  | batch 0 | Loss 1.451122
[Epoch] 52  | batch 0 | Loss 1.839589
[Epoch] 53  | batch 0 | Loss 1.877357
[Epoch] 54  | batch 0 | Loss 2.143271
[Epoch] 55  | batch 0 | Loss 2.006726
[Epoch] 56  | batch 0 | Loss 2.115196
[Epoch] 57  | batch 0 | Loss 2.121379
[Epoch] 58  | batch 0 | Loss 2.112063
[Epoch] 59  | batch 0 | Loss 1.885364
[Epoch] 60  | batch 0 | Loss 1.938189
[Epoch] 61  | batch 0 | Loss 1.711291
[Epoch] 62  | batch 0 | Loss 1.959112
[Epoch] 63  | batch 0 | Loss 1.567789
[Epoch] 64  | batch 0 | Loss 2.096354
[Epoch] 65  | batch 0 | Loss 2.272190
[Epoch] 66  | batch 0 | Loss 1.643578
[Epoch] 67  | batch 0 | Loss 2.283795


  8%|▊         | 79/1000 [00:01<00:11, 78.19it/s]

[Epoch] 68  | batch 0 | Loss 2.717146
[Epoch] 69  | batch 0 | Loss 2.062598
[Epoch] 70  | batch 0 | Loss 1.691499
[Epoch] 71  | batch 0 | Loss 2.309451
[Epoch] 72  | batch 0 | Loss 1.829885
[Epoch] 73  | batch 0 | Loss 2.269151
[Epoch] 74  | batch 0 | Loss 2.135408
[Epoch] 75  | batch 0 | Loss 1.667966
[Epoch] 76  | batch 0 | Loss 1.749386
[Epoch] 77  | batch 0 | Loss 1.937741
[Epoch] 78  | batch 0 | Loss 1.726317
[Epoch] 79  | batch 0 | Loss 2.128721
[Epoch] 80  | batch 0 | Loss 1.872778
[Epoch] 81  | batch 0 | Loss 2.230427
[Epoch] 82  | batch 0 | Loss 1.910013
[Epoch] 83  | batch 0 | Loss 1.955326
[Epoch] 84  | batch 0 | Loss 1.793359


 10%|▉         | 97/1000 [00:01<00:11, 80.07it/s]

[Epoch] 85  | batch 0 | Loss 1.899038
[Epoch] 86  | batch 0 | Loss 1.709635
[Epoch] 87  | batch 0 | Loss 2.293711
[Epoch] 88  | batch 0 | Loss 2.317099
[Epoch] 89  | batch 0 | Loss 1.870958
[Epoch] 90  | batch 0 | Loss 1.909663
[Epoch] 91  | batch 0 | Loss 1.726702
[Epoch] 92  | batch 0 | Loss 1.882707
[Epoch] 93  | batch 0 | Loss 1.849147
[Epoch] 94  | batch 0 | Loss 1.717803
[Epoch] 95  | batch 0 | Loss 2.146178
[Epoch] 96  | batch 0 | Loss 1.970565
[Epoch] 97  | batch 0 | Loss 1.983007
[Epoch] 98  | batch 0 | Loss 2.138079
[Epoch] 99  | batch 0 | Loss 2.312295
[Epoch] 100  | batch 0 | Loss 1.563949
[Epoch] 101  | batch 0 | Loss 2.324348


 12%|█▏        | 115/1000 [00:01<00:10, 80.70it/s]

[Epoch] 102  | batch 0 | Loss 2.458604
[Epoch] 103  | batch 0 | Loss 1.764577
[Epoch] 104  | batch 0 | Loss 1.952659
[Epoch] 105  | batch 0 | Loss 1.739787
[Epoch] 106  | batch 0 | Loss 1.968844
[Epoch] 107  | batch 0 | Loss 2.143780
[Epoch] 108  | batch 0 | Loss 1.786493
[Epoch] 109  | batch 0 | Loss 1.976324
[Epoch] 110  | batch 0 | Loss 2.031797
[Epoch] 111  | batch 0 | Loss 1.926511
[Epoch] 112  | batch 0 | Loss 2.347185
[Epoch] 113  | batch 0 | Loss 2.079448
[Epoch] 114  | batch 0 | Loss 1.708286
[Epoch] 115  | batch 0 | Loss 2.077670
[Epoch] 116  | batch 0 | Loss 2.396915
[Epoch] 117  | batch 0 | Loss 2.079053
[Epoch] 118  | batch 0 | Loss 1.637566


 13%|█▎        | 133/1000 [00:01<00:10, 81.09it/s]

[Epoch] 119  | batch 0 | Loss 2.270658
[Epoch] 120  | batch 0 | Loss 1.519185
[Epoch] 121  | batch 0 | Loss 1.872205
[Epoch] 122  | batch 0 | Loss 1.885167
[Epoch] 123  | batch 0 | Loss 2.056708
[Epoch] 124  | batch 0 | Loss 2.049770
[Epoch] 125  | batch 0 | Loss 1.873973
[Epoch] 126  | batch 0 | Loss 1.784019
[Epoch] 127  | batch 0 | Loss 1.911100
[Epoch] 128  | batch 0 | Loss 2.156027
[Epoch] 129  | batch 0 | Loss 1.875283
[Epoch] 130  | batch 0 | Loss 1.873952
[Epoch] 131  | batch 0 | Loss 1.831684
[Epoch] 132  | batch 0 | Loss 1.642106
[Epoch] 133  | batch 0 | Loss 1.927003
[Epoch] 134  | batch 0 | Loss 2.409115
[Epoch] 135  | batch 0 | Loss 2.153103


 15%|█▌        | 151/1000 [00:02<00:10, 82.43it/s]

[Epoch] 136  | batch 0 | Loss 2.110378
[Epoch] 137  | batch 0 | Loss 2.492218
[Epoch] 138  | batch 0 | Loss 2.181613
[Epoch] 139  | batch 0 | Loss 1.798515
[Epoch] 140  | batch 0 | Loss 1.710870
[Epoch] 141  | batch 0 | Loss 1.717960
[Epoch] 142  | batch 0 | Loss 2.282274
[Epoch] 143  | batch 0 | Loss 1.757974
[Epoch] 144  | batch 0 | Loss 1.697410
[Epoch] 145  | batch 0 | Loss 1.815692
[Epoch] 146  | batch 0 | Loss 2.194921
[Epoch] 147  | batch 0 | Loss 2.088825
[Epoch] 148  | batch 0 | Loss 1.486365
[Epoch] 149  | batch 0 | Loss 2.017501
[Epoch] 150  | batch 0 | Loss 1.797462
[Epoch] 151  | batch 0 | Loss 2.138112
[Epoch] 152  | batch 0 | Loss 2.131607


 17%|█▋        | 169/1000 [00:02<00:10, 82.86it/s]

[Epoch] 153  | batch 0 | Loss 2.243758
[Epoch] 154  | batch 0 | Loss 2.017107
[Epoch] 155  | batch 0 | Loss 1.985255
[Epoch] 156  | batch 0 | Loss 1.988942
[Epoch] 157  | batch 0 | Loss 1.872020
[Epoch] 158  | batch 0 | Loss 2.408174
[Epoch] 159  | batch 0 | Loss 1.683770
[Epoch] 160  | batch 0 | Loss 1.955487
[Epoch] 161  | batch 0 | Loss 2.063469
[Epoch] 162  | batch 0 | Loss 1.808284
[Epoch] 163  | batch 0 | Loss 2.122750
[Epoch] 164  | batch 0 | Loss 2.251623
[Epoch] 165  | batch 0 | Loss 1.806134
[Epoch] 166  | batch 0 | Loss 1.868064
[Epoch] 167  | batch 0 | Loss 1.872496
[Epoch] 168  | batch 0 | Loss 2.610860
[Epoch] 169  | batch 0 | Loss 1.946749


 19%|█▊        | 187/1000 [00:02<00:09, 82.97it/s]

[Epoch] 170  | batch 0 | Loss 1.850542
[Epoch] 171  | batch 0 | Loss 2.012417
[Epoch] 172  | batch 0 | Loss 1.911162
[Epoch] 173  | batch 0 | Loss 2.176709
[Epoch] 174  | batch 0 | Loss 1.849375
[Epoch] 175  | batch 0 | Loss 2.067003
[Epoch] 176  | batch 0 | Loss 2.102139
[Epoch] 177  | batch 0 | Loss 1.803343
[Epoch] 178  | batch 0 | Loss 2.004902
[Epoch] 179  | batch 0 | Loss 2.173371
[Epoch] 180  | batch 0 | Loss 1.590119
[Epoch] 181  | batch 0 | Loss 2.017120
[Epoch] 182  | batch 0 | Loss 1.745449
[Epoch] 183  | batch 0 | Loss 2.113071
[Epoch] 184  | batch 0 | Loss 2.112973
[Epoch] 185  | batch 0 | Loss 2.039044
[Epoch] 186  | batch 0 | Loss 1.923859


 20%|█▉        | 196/1000 [00:02<00:09, 83.03it/s]

[Epoch] 187  | batch 0 | Loss 1.671047
[Epoch] 188  | batch 0 | Loss 2.072486
[Epoch] 189  | batch 0 | Loss 1.950631
[Epoch] 190  | batch 0 | Loss 1.976294
[Epoch] 191  | batch 0 | Loss 2.228884
[Epoch] 192  | batch 0 | Loss 1.625286
[Epoch] 193  | batch 0 | Loss 2.257868
[Epoch] 194  | batch 0 | Loss 2.019538
[Epoch] 195  | batch 0 | Loss 1.796678
[Epoch] 196  | batch 0 | Loss 2.138266
[Epoch] 197  | batch 0 | Loss 2.084387
[Epoch] 198  | batch 0 | Loss 1.892818
[Epoch] 199  | batch 0 | Loss 1.887782
[Epoch] 200  | batch 0 | Loss 2.054664
[Epoch] 201  | batch 0 | Loss 1.789881
[Epoch] 202  | batch 0 | Loss 2.133473
[Epoch] 203  | batch 0 | Loss 1.977168


 21%|██▏       | 214/1000 [00:02<00:09, 83.13it/s]

[Epoch] 204  | batch 0 | Loss 1.752879
[Epoch] 205  | batch 0 | Loss 2.117376
[Epoch] 206  | batch 0 | Loss 1.935241
[Epoch] 207  | batch 0 | Loss 1.999562
[Epoch] 208  | batch 0 | Loss 2.029188
[Epoch] 209  | batch 0 | Loss 1.456415
[Epoch] 210  | batch 0 | Loss 1.720277
[Epoch] 211  | batch 0 | Loss 2.162321
[Epoch] 212  | batch 0 | Loss 2.222676
[Epoch] 213  | batch 0 | Loss 2.344067
[Epoch] 214  | batch 0 | Loss 2.103444
[Epoch] 215  | batch 0 | Loss 1.820099
[Epoch] 216  | batch 0 | Loss 2.137732
[Epoch] 217  | batch 0 | Loss 2.065006
[Epoch] 218  | batch 0 | Loss 1.697846
[Epoch] 219  | batch 0 | Loss 2.206142
[Epoch] 220  | batch 0 | Loss 1.930061


 23%|██▎       | 232/1000 [00:03<00:09, 82.78it/s]

[Epoch] 221  | batch 0 | Loss 1.974499
[Epoch] 222  | batch 0 | Loss 1.843815
[Epoch] 223  | batch 0 | Loss 2.103714
[Epoch] 224  | batch 0 | Loss 1.825418
[Epoch] 225  | batch 0 | Loss 2.051914
[Epoch] 226  | batch 0 | Loss 2.251514
[Epoch] 227  | batch 0 | Loss 2.013901
[Epoch] 228  | batch 0 | Loss 1.815034
[Epoch] 229  | batch 0 | Loss 1.641472
[Epoch] 230  | batch 0 | Loss 2.275643
[Epoch] 231  | batch 0 | Loss 1.903701
[Epoch] 232  | batch 0 | Loss 2.160292
[Epoch] 233  | batch 0 | Loss 2.182992
[Epoch] 234  | batch 0 | Loss 2.102954
[Epoch] 235  | batch 0 | Loss 2.008776
[Epoch] 236  | batch 0 | Loss 2.047871
[Epoch] 237  | batch 0 | Loss 1.814857


 25%|██▌       | 250/1000 [00:03<00:09, 83.18it/s]

[Epoch] 238  | batch 0 | Loss 2.133303
[Epoch] 239  | batch 0 | Loss 2.135068
[Epoch] 240  | batch 0 | Loss 1.856194
[Epoch] 241  | batch 0 | Loss 1.834746
[Epoch] 242  | batch 0 | Loss 2.381221
[Epoch] 243  | batch 0 | Loss 1.867001
[Epoch] 244  | batch 0 | Loss 1.933463
[Epoch] 245  | batch 0 | Loss 1.975540
[Epoch] 246  | batch 0 | Loss 2.041499
[Epoch] 247  | batch 0 | Loss 1.905690
[Epoch] 248  | batch 0 | Loss 2.530994
[Epoch] 249  | batch 0 | Loss 1.999953
[Epoch] 250  | batch 0 | Loss 2.166004
[Epoch] 251  | batch 0 | Loss 1.861163
[Epoch] 252  | batch 0 | Loss 1.836852
[Epoch] 253  | batch 0 | Loss 2.032548
[Epoch] 254  | batch 0 | Loss 2.235283


 27%|██▋       | 268/1000 [00:03<00:08, 82.86it/s]

[Epoch] 255  | batch 0 | Loss 1.960178
[Epoch] 256  | batch 0 | Loss 1.911451
[Epoch] 257  | batch 0 | Loss 1.695372
[Epoch] 258  | batch 0 | Loss 1.844858
[Epoch] 259  | batch 0 | Loss 1.951168
[Epoch] 260  | batch 0 | Loss 2.077029
[Epoch] 261  | batch 0 | Loss 2.009527
[Epoch] 262  | batch 0 | Loss 1.866753
[Epoch] 263  | batch 0 | Loss 2.332707
[Epoch] 264  | batch 0 | Loss 2.261576
[Epoch] 265  | batch 0 | Loss 1.882490
[Epoch] 266  | batch 0 | Loss 1.953040
[Epoch] 267  | batch 0 | Loss 1.986447
[Epoch] 268  | batch 0 | Loss 1.577904
[Epoch] 269  | batch 0 | Loss 2.135493
[Epoch] 270  | batch 0 | Loss 1.868956
[Epoch] 271  | batch 0 | Loss 1.769852


 29%|██▊       | 286/1000 [00:03<00:08, 82.36it/s]

[Epoch] 272  | batch 0 | Loss 2.099360
[Epoch] 273  | batch 0 | Loss 2.169394
[Epoch] 274  | batch 0 | Loss 1.712707
[Epoch] 275  | batch 0 | Loss 1.975949
[Epoch] 276  | batch 0 | Loss 1.770981
[Epoch] 277  | batch 0 | Loss 1.931496
[Epoch] 278  | batch 0 | Loss 1.918034
[Epoch] 279  | batch 0 | Loss 1.768835
[Epoch] 280  | batch 0 | Loss 2.028892
[Epoch] 281  | batch 0 | Loss 1.903781
[Epoch] 282  | batch 0 | Loss 1.804347
[Epoch] 283  | batch 0 | Loss 2.052777
[Epoch] 284  | batch 0 | Loss 2.280676
[Epoch] 285  | batch 0 | Loss 1.974956
[Epoch] 286  | batch 0 | Loss 2.017184
[Epoch] 287  | batch 0 | Loss 1.883153
[Epoch] 288  | batch 0 | Loss 1.892027


 30%|███       | 304/1000 [00:04<00:08, 82.93it/s]

[Epoch] 289  | batch 0 | Loss 2.042597
[Epoch] 290  | batch 0 | Loss 1.875266
[Epoch] 291  | batch 0 | Loss 1.854175
[Epoch] 292  | batch 0 | Loss 1.945483
[Epoch] 293  | batch 0 | Loss 1.953383
[Epoch] 294  | batch 0 | Loss 2.142420
[Epoch] 295  | batch 0 | Loss 1.988806
[Epoch] 296  | batch 0 | Loss 2.595652
[Epoch] 297  | batch 0 | Loss 1.994991
[Epoch] 298  | batch 0 | Loss 2.271371
[Epoch] 299  | batch 0 | Loss 2.274279
[Epoch] 300  | batch 0 | Loss 1.834236
[Epoch] 301  | batch 0 | Loss 1.901433
[Epoch] 302  | batch 0 | Loss 2.062754
[Epoch] 303  | batch 0 | Loss 1.720899
[Epoch] 304  | batch 0 | Loss 1.839523
[Epoch] 305  | batch 0 | Loss 1.858471


 32%|███▏      | 322/1000 [00:04<00:08, 82.67it/s]

[Epoch] 306  | batch 0 | Loss 1.937052
[Epoch] 307  | batch 0 | Loss 1.935364
[Epoch] 308  | batch 0 | Loss 1.868043
[Epoch] 309  | batch 0 | Loss 2.260275
[Epoch] 310  | batch 0 | Loss 1.704201
[Epoch] 311  | batch 0 | Loss 1.868093
[Epoch] 312  | batch 0 | Loss 1.612069
[Epoch] 313  | batch 0 | Loss 1.706673
[Epoch] 314  | batch 0 | Loss 2.282494
[Epoch] 315  | batch 0 | Loss 1.770469
[Epoch] 316  | batch 0 | Loss 1.703729
[Epoch] 317  | batch 0 | Loss 1.847435
[Epoch] 318  | batch 0 | Loss 1.782204
[Epoch] 319  | batch 0 | Loss 2.067195
[Epoch] 320  | batch 0 | Loss 2.118377
[Epoch] 321  | batch 0 | Loss 2.054613
[Epoch] 322  | batch 0 | Loss 1.857064


 34%|███▍      | 340/1000 [00:04<00:07, 82.54it/s]

[Epoch] 323  | batch 0 | Loss 2.031380
[Epoch] 324  | batch 0 | Loss 1.904765
[Epoch] 325  | batch 0 | Loss 2.074281
[Epoch] 326  | batch 0 | Loss 2.267247
[Epoch] 327  | batch 0 | Loss 1.833533
[Epoch] 328  | batch 0 | Loss 2.017368
[Epoch] 329  | batch 0 | Loss 1.811884
[Epoch] 330  | batch 0 | Loss 1.897895
[Epoch] 331  | batch 0 | Loss 2.031035
[Epoch] 332  | batch 0 | Loss 2.137798
[Epoch] 333  | batch 0 | Loss 1.986897
[Epoch] 334  | batch 0 | Loss 2.092024
[Epoch] 335  | batch 0 | Loss 2.187255
[Epoch] 336  | batch 0 | Loss 1.725961
[Epoch] 337  | batch 0 | Loss 2.148142
[Epoch] 338  | batch 0 | Loss 2.154917
[Epoch] 339  | batch 0 | Loss 2.231649


 35%|███▍      | 349/1000 [00:04<00:07, 82.61it/s]

[Epoch] 340  | batch 0 | Loss 2.016230
[Epoch] 341  | batch 0 | Loss 2.030120
[Epoch] 342  | batch 0 | Loss 2.178538
[Epoch] 343  | batch 0 | Loss 1.549946
[Epoch] 344  | batch 0 | Loss 1.955418
[Epoch] 345  | batch 0 | Loss 2.330142
[Epoch] 346  | batch 0 | Loss 1.816039
[Epoch] 347  | batch 0 | Loss 1.974024
[Epoch] 348  | batch 0 | Loss 2.151011
[Epoch] 349  | batch 0 | Loss 1.981670
[Epoch] 350  | batch 0 | Loss 1.882183
[Epoch] 351  | batch 0 | Loss 2.200510
[Epoch] 352  | batch 0 | Loss 2.090150
[Epoch] 353  | batch 0 | Loss 1.846971
[Epoch] 354  | batch 0 | Loss 2.217379
[Epoch] 355  | batch 0 | Loss 1.892644
[Epoch] 356  | batch 0 | Loss 1.999570


 37%|███▋      | 367/1000 [00:04<00:07, 82.47it/s]

[Epoch] 357  | batch 0 | Loss 1.750506
[Epoch] 358  | batch 0 | Loss 2.465342
[Epoch] 359  | batch 0 | Loss 1.738786
[Epoch] 360  | batch 0 | Loss 2.098226
[Epoch] 361  | batch 0 | Loss 1.778460
[Epoch] 362  | batch 0 | Loss 2.255298
[Epoch] 363  | batch 0 | Loss 2.104490
[Epoch] 364  | batch 0 | Loss 1.982795
[Epoch] 365  | batch 0 | Loss 2.221485
[Epoch] 366  | batch 0 | Loss 2.008557
[Epoch] 367  | batch 0 | Loss 2.068733
[Epoch] 368  | batch 0 | Loss 1.838927
[Epoch] 369  | batch 0 | Loss 2.112319
[Epoch] 370  | batch 0 | Loss 2.043877
[Epoch] 371  | batch 0 | Loss 2.072080
[Epoch] 372  | batch 0 | Loss 1.591829
[Epoch] 373  | batch 0 | Loss 2.108844


 38%|███▊      | 385/1000 [00:05<00:07, 82.88it/s]

[Epoch] 374  | batch 0 | Loss 2.220548
[Epoch] 375  | batch 0 | Loss 1.676273
[Epoch] 376  | batch 0 | Loss 2.143736
[Epoch] 377  | batch 0 | Loss 2.083928
[Epoch] 378  | batch 0 | Loss 1.787039
[Epoch] 379  | batch 0 | Loss 1.661496
[Epoch] 380  | batch 0 | Loss 1.890863
[Epoch] 381  | batch 0 | Loss 2.075548
[Epoch] 382  | batch 0 | Loss 1.930327
[Epoch] 383  | batch 0 | Loss 2.063570
[Epoch] 384  | batch 0 | Loss 1.992084
[Epoch] 385  | batch 0 | Loss 2.110180
[Epoch] 386  | batch 0 | Loss 1.855297
[Epoch] 387  | batch 0 | Loss 2.167447
[Epoch] 388  | batch 0 | Loss 1.902871
[Epoch] 389  | batch 0 | Loss 1.880473
[Epoch] 390  | batch 0 | Loss 1.947962


 40%|████      | 403/1000 [00:05<00:07, 82.07it/s]

[Epoch] 391  | batch 0 | Loss 1.926915
[Epoch] 392  | batch 0 | Loss 1.918618
[Epoch] 393  | batch 0 | Loss 1.998243
[Epoch] 394  | batch 0 | Loss 1.998692
[Epoch] 395  | batch 0 | Loss 1.755305
[Epoch] 396  | batch 0 | Loss 1.806629
[Epoch] 397  | batch 0 | Loss 1.994242
[Epoch] 398  | batch 0 | Loss 2.203270
[Epoch] 399  | batch 0 | Loss 2.086180
[Epoch] 400  | batch 0 | Loss 2.098683
[Epoch] 401  | batch 0 | Loss 2.088295
[Epoch] 402  | batch 0 | Loss 2.459814
[Epoch] 403  | batch 0 | Loss 2.090033
[Epoch] 404  | batch 0 | Loss 2.148889
[Epoch] 405  | batch 0 | Loss 1.759317
[Epoch] 406  | batch 0 | Loss 2.098374
[Epoch] 407  | batch 0 | Loss 1.606125


 42%|████▏     | 421/1000 [00:05<00:07, 82.64it/s]

[Epoch] 408  | batch 0 | Loss 2.126462
[Epoch] 409  | batch 0 | Loss 2.301436
[Epoch] 410  | batch 0 | Loss 1.962626
[Epoch] 411  | batch 0 | Loss 1.964153
[Epoch] 412  | batch 0 | Loss 1.862544
[Epoch] 413  | batch 0 | Loss 2.141797
[Epoch] 414  | batch 0 | Loss 2.382297
[Epoch] 415  | batch 0 | Loss 1.658584
[Epoch] 416  | batch 0 | Loss 2.103822
[Epoch] 417  | batch 0 | Loss 2.148164
[Epoch] 418  | batch 0 | Loss 1.794024
[Epoch] 419  | batch 0 | Loss 1.662351
[Epoch] 420  | batch 0 | Loss 2.141808
[Epoch] 421  | batch 0 | Loss 2.003173
[Epoch] 422  | batch 0 | Loss 1.884444
[Epoch] 423  | batch 0 | Loss 2.227367
[Epoch] 424  | batch 0 | Loss 2.099643


 44%|████▍     | 439/1000 [00:05<00:06, 82.92it/s]

[Epoch] 425  | batch 0 | Loss 1.999591
[Epoch] 426  | batch 0 | Loss 2.329503
[Epoch] 427  | batch 0 | Loss 2.070854
[Epoch] 428  | batch 0 | Loss 1.810045
[Epoch] 429  | batch 0 | Loss 1.912023
[Epoch] 430  | batch 0 | Loss 1.884983
[Epoch] 431  | batch 0 | Loss 1.962653
[Epoch] 432  | batch 0 | Loss 1.926894
[Epoch] 433  | batch 0 | Loss 1.720513
[Epoch] 434  | batch 0 | Loss 1.623028
[Epoch] 435  | batch 0 | Loss 1.807090
[Epoch] 436  | batch 0 | Loss 1.873561
[Epoch] 437  | batch 0 | Loss 2.272169
[Epoch] 438  | batch 0 | Loss 2.135341
[Epoch] 439  | batch 0 | Loss 2.065232
[Epoch] 440  | batch 0 | Loss 2.055897
[Epoch] 441  | batch 0 | Loss 1.838704


 46%|████▌     | 457/1000 [00:05<00:06, 83.11it/s]

[Epoch] 442  | batch 0 | Loss 2.093066
[Epoch] 443  | batch 0 | Loss 2.264799
[Epoch] 444  | batch 0 | Loss 1.994846
[Epoch] 445  | batch 0 | Loss 1.927626
[Epoch] 446  | batch 0 | Loss 2.245879
[Epoch] 447  | batch 0 | Loss 2.029412
[Epoch] 448  | batch 0 | Loss 1.947220
[Epoch] 449  | batch 0 | Loss 1.855797
[Epoch] 450  | batch 0 | Loss 2.211923
[Epoch] 451  | batch 0 | Loss 1.907754
[Epoch] 452  | batch 0 | Loss 2.031994
[Epoch] 453  | batch 0 | Loss 2.181577
[Epoch] 454  | batch 0 | Loss 1.624923
[Epoch] 455  | batch 0 | Loss 2.201126
[Epoch] 456  | batch 0 | Loss 2.174565
[Epoch] 457  | batch 0 | Loss 2.251524
[Epoch] 458  | batch 0 | Loss 2.155684


 48%|████▊     | 475/1000 [00:06<00:06, 83.33it/s]

[Epoch] 459  | batch 0 | Loss 1.855381
[Epoch] 460  | batch 0 | Loss 1.872049
[Epoch] 461  | batch 0 | Loss 1.917573
[Epoch] 462  | batch 0 | Loss 1.905231
[Epoch] 463  | batch 0 | Loss 2.032131
[Epoch] 464  | batch 0 | Loss 2.115196
[Epoch] 465  | batch 0 | Loss 1.870785
[Epoch] 466  | batch 0 | Loss 1.643039
[Epoch] 467  | batch 0 | Loss 2.141801
[Epoch] 468  | batch 0 | Loss 2.063557
[Epoch] 469  | batch 0 | Loss 1.770820
[Epoch] 470  | batch 0 | Loss 1.677899
[Epoch] 471  | batch 0 | Loss 2.035264
[Epoch] 472  | batch 0 | Loss 2.117711
[Epoch] 473  | batch 0 | Loss 2.087484
[Epoch] 474  | batch 0 | Loss 1.850622
[Epoch] 475  | batch 0 | Loss 1.886138


 49%|████▉     | 493/1000 [00:06<00:06, 83.34it/s]

[Epoch] 476  | batch 0 | Loss 1.830854
[Epoch] 477  | batch 0 | Loss 1.644964
[Epoch] 478  | batch 0 | Loss 1.790261
[Epoch] 479  | batch 0 | Loss 1.928570
[Epoch] 480  | batch 0 | Loss 2.158553
[Epoch] 481  | batch 0 | Loss 2.045571
[Epoch] 482  | batch 0 | Loss 1.737556
[Epoch] 483  | batch 0 | Loss 1.614038
[Epoch] 484  | batch 0 | Loss 1.989617
[Epoch] 485  | batch 0 | Loss 1.910609
[Epoch] 486  | batch 0 | Loss 2.140390
[Epoch] 487  | batch 0 | Loss 1.713220
[Epoch] 488  | batch 0 | Loss 2.209186
[Epoch] 489  | batch 0 | Loss 2.071695
[Epoch] 490  | batch 0 | Loss 1.947484
[Epoch] 491  | batch 0 | Loss 1.805833
[Epoch] 492  | batch 0 | Loss 1.902303


 50%|█████     | 502/1000 [00:06<00:05, 83.16it/s]

[Epoch] 493  | batch 0 | Loss 1.882492
[Epoch] 494  | batch 0 | Loss 1.443259
[Epoch] 495  | batch 0 | Loss 1.837534
[Epoch] 496  | batch 0 | Loss 1.897226
[Epoch] 497  | batch 0 | Loss 2.166750
[Epoch] 498  | batch 0 | Loss 1.693804
[Epoch] 499  | batch 0 | Loss 1.636480
[Epoch] 500  | batch 0 | Loss 1.978648
[Epoch] 501  | batch 0 | Loss 2.167118
[Epoch] 502  | batch 0 | Loss 2.204821
[Epoch] 503  | batch 0 | Loss 1.751161
[Epoch] 504  | batch 0 | Loss 2.134079
[Epoch] 505  | batch 0 | Loss 1.756830
[Epoch] 506  | batch 0 | Loss 1.946645
[Epoch] 507  | batch 0 | Loss 2.282290
[Epoch] 508  | batch 0 | Loss 1.474746
[Epoch] 509  | batch 0 | Loss 2.166022


 52%|█████▏    | 520/1000 [00:06<00:05, 83.18it/s]

[Epoch] 510  | batch 0 | Loss 2.138618
[Epoch] 511  | batch 0 | Loss 1.539756
[Epoch] 512  | batch 0 | Loss 2.130286
[Epoch] 513  | batch 0 | Loss 1.857362
[Epoch] 514  | batch 0 | Loss 2.237377
[Epoch] 515  | batch 0 | Loss 1.649559
[Epoch] 516  | batch 0 | Loss 1.830149
[Epoch] 517  | batch 0 | Loss 2.364978
[Epoch] 518  | batch 0 | Loss 2.039882
[Epoch] 519  | batch 0 | Loss 1.667134
[Epoch] 520  | batch 0 | Loss 1.623147
[Epoch] 521  | batch 0 | Loss 2.096159
[Epoch] 522  | batch 0 | Loss 2.171351
[Epoch] 523  | batch 0 | Loss 2.399137
[Epoch] 524  | batch 0 | Loss 2.203157
[Epoch] 525  | batch 0 | Loss 1.813021
[Epoch] 526  | batch 0 | Loss 1.713892


 54%|█████▍    | 538/1000 [00:06<00:05, 83.36it/s]

[Epoch] 527  | batch 0 | Loss 2.333028
[Epoch] 528  | batch 0 | Loss 1.696873
[Epoch] 529  | batch 0 | Loss 1.940562
[Epoch] 530  | batch 0 | Loss 1.986753
[Epoch] 531  | batch 0 | Loss 1.912041
[Epoch] 532  | batch 0 | Loss 2.258240
[Epoch] 533  | batch 0 | Loss 2.220645
[Epoch] 534  | batch 0 | Loss 1.958641
[Epoch] 535  | batch 0 | Loss 1.685529
[Epoch] 536  | batch 0 | Loss 1.765656
[Epoch] 537  | batch 0 | Loss 2.195030
[Epoch] 538  | batch 0 | Loss 1.554508
[Epoch] 539  | batch 0 | Loss 2.584184
[Epoch] 540  | batch 0 | Loss 1.982780
[Epoch] 541  | batch 0 | Loss 1.868569
[Epoch] 542  | batch 0 | Loss 2.265083
[Epoch] 543  | batch 0 | Loss 1.583647


 56%|█████▌    | 556/1000 [00:07<00:05, 83.36it/s]

[Epoch] 544  | batch 0 | Loss 2.058643
[Epoch] 545  | batch 0 | Loss 1.808492
[Epoch] 546  | batch 0 | Loss 1.972888
[Epoch] 547  | batch 0 | Loss 1.869487
[Epoch] 548  | batch 0 | Loss 2.021562
[Epoch] 549  | batch 0 | Loss 1.892815
[Epoch] 550  | batch 0 | Loss 1.809561
[Epoch] 551  | batch 0 | Loss 1.937239
[Epoch] 552  | batch 0 | Loss 1.416234
[Epoch] 553  | batch 0 | Loss 2.512812
[Epoch] 554  | batch 0 | Loss 1.741952
[Epoch] 555  | batch 0 | Loss 2.329866
[Epoch] 556  | batch 0 | Loss 1.882998
[Epoch] 557  | batch 0 | Loss 1.907738
[Epoch] 558  | batch 0 | Loss 1.858832
[Epoch] 559  | batch 0 | Loss 1.786804
[Epoch] 560  | batch 0 | Loss 2.151062


 57%|█████▋    | 574/1000 [00:07<00:05, 83.53it/s]

[Epoch] 561  | batch 0 | Loss 2.181426
[Epoch] 562  | batch 0 | Loss 1.777847
[Epoch] 563  | batch 0 | Loss 2.053069
[Epoch] 564  | batch 0 | Loss 1.864576
[Epoch] 565  | batch 0 | Loss 1.938997
[Epoch] 566  | batch 0 | Loss 1.909863
[Epoch] 567  | batch 0 | Loss 1.960932
[Epoch] 568  | batch 0 | Loss 1.973356
[Epoch] 569  | batch 0 | Loss 2.340777
[Epoch] 570  | batch 0 | Loss 2.101221
[Epoch] 571  | batch 0 | Loss 2.129456
[Epoch] 572  | batch 0 | Loss 1.827162
[Epoch] 573  | batch 0 | Loss 2.184477
[Epoch] 574  | batch 0 | Loss 1.807700
[Epoch] 575  | batch 0 | Loss 1.519848
[Epoch] 576  | batch 0 | Loss 1.984506
[Epoch] 577  | batch 0 | Loss 1.967035


 59%|█████▉    | 592/1000 [00:07<00:04, 83.59it/s]

[Epoch] 578  | batch 0 | Loss 2.457854
[Epoch] 579  | batch 0 | Loss 1.843225
[Epoch] 580  | batch 0 | Loss 2.203477
[Epoch] 581  | batch 0 | Loss 2.289524
[Epoch] 582  | batch 0 | Loss 1.920924
[Epoch] 583  | batch 0 | Loss 1.642537
[Epoch] 584  | batch 0 | Loss 2.380550
[Epoch] 585  | batch 0 | Loss 1.964257
[Epoch] 586  | batch 0 | Loss 2.042834
[Epoch] 587  | batch 0 | Loss 1.836321
[Epoch] 588  | batch 0 | Loss 1.816379
[Epoch] 589  | batch 0 | Loss 2.034388
[Epoch] 590  | batch 0 | Loss 2.102288
[Epoch] 591  | batch 0 | Loss 1.902202
[Epoch] 592  | batch 0 | Loss 1.823181
[Epoch] 593  | batch 0 | Loss 2.161340
[Epoch] 594  | batch 0 | Loss 2.202232


 61%|██████    | 610/1000 [00:07<00:04, 83.30it/s]

[Epoch] 595  | batch 0 | Loss 1.749441
[Epoch] 596  | batch 0 | Loss 1.961434
[Epoch] 597  | batch 0 | Loss 2.217582
[Epoch] 598  | batch 0 | Loss 1.707799
[Epoch] 599  | batch 0 | Loss 1.811428
[Epoch] 600  | batch 0 | Loss 2.113304
[Epoch] 601  | batch 0 | Loss 2.022297
[Epoch] 602  | batch 0 | Loss 1.788321
[Epoch] 603  | batch 0 | Loss 1.893065
[Epoch] 604  | batch 0 | Loss 1.954483
[Epoch] 605  | batch 0 | Loss 1.713464
[Epoch] 606  | batch 0 | Loss 1.782069
[Epoch] 607  | batch 0 | Loss 1.906011
[Epoch] 608  | batch 0 | Loss 1.830875
[Epoch] 609  | batch 0 | Loss 1.745914
[Epoch] 610  | batch 0 | Loss 2.029189
[Epoch] 611  | batch 0 | Loss 2.021531


 63%|██████▎   | 628/1000 [00:07<00:04, 83.50it/s]

[Epoch] 612  | batch 0 | Loss 1.974208
[Epoch] 613  | batch 0 | Loss 1.938184
[Epoch] 614  | batch 0 | Loss 1.932230
[Epoch] 615  | batch 0 | Loss 1.991701
[Epoch] 616  | batch 0 | Loss 1.897905
[Epoch] 617  | batch 0 | Loss 1.821106
[Epoch] 618  | batch 0 | Loss 2.169694
[Epoch] 619  | batch 0 | Loss 2.104090
[Epoch] 620  | batch 0 | Loss 2.365367
[Epoch] 621  | batch 0 | Loss 1.930265
[Epoch] 622  | batch 0 | Loss 1.937601
[Epoch] 623  | batch 0 | Loss 1.962141
[Epoch] 624  | batch 0 | Loss 2.329744
[Epoch] 625  | batch 0 | Loss 1.689151
[Epoch] 626  | batch 0 | Loss 1.846013
[Epoch] 627  | batch 0 | Loss 1.855799
[Epoch] 628  | batch 0 | Loss 2.198356


 65%|██████▍   | 646/1000 [00:08<00:04, 83.69it/s]

[Epoch] 629  | batch 0 | Loss 2.015197
[Epoch] 630  | batch 0 | Loss 1.794880
[Epoch] 631  | batch 0 | Loss 2.093280
[Epoch] 632  | batch 0 | Loss 1.903625
[Epoch] 633  | batch 0 | Loss 1.951229
[Epoch] 634  | batch 0 | Loss 2.426749
[Epoch] 635  | batch 0 | Loss 1.534991
[Epoch] 636  | batch 0 | Loss 2.047996
[Epoch] 637  | batch 0 | Loss 1.728201
[Epoch] 638  | batch 0 | Loss 2.100443
[Epoch] 639  | batch 0 | Loss 1.904452
[Epoch] 640  | batch 0 | Loss 2.087887
[Epoch] 641  | batch 0 | Loss 2.040543
[Epoch] 642  | batch 0 | Loss 2.135787
[Epoch] 643  | batch 0 | Loss 1.667621
[Epoch] 644  | batch 0 | Loss 2.156522
[Epoch] 645  | batch 0 | Loss 2.167798


 66%|██████▌   | 655/1000 [00:08<00:04, 83.58it/s]

[Epoch] 646  | batch 0 | Loss 1.792618
[Epoch] 647  | batch 0 | Loss 2.281659
[Epoch] 648  | batch 0 | Loss 2.209342
[Epoch] 649  | batch 0 | Loss 2.146027
[Epoch] 650  | batch 0 | Loss 2.142230
[Epoch] 651  | batch 0 | Loss 2.231529
[Epoch] 652  | batch 0 | Loss 1.952319
[Epoch] 653  | batch 0 | Loss 1.920014
[Epoch] 654  | batch 0 | Loss 2.046208
[Epoch] 655  | batch 0 | Loss 1.762725
[Epoch] 656  | batch 0 | Loss 1.910794
[Epoch] 657  | batch 0 | Loss 2.017651
[Epoch] 658  | batch 0 | Loss 1.934461
[Epoch] 659  | batch 0 | Loss 1.867219
[Epoch] 660  | batch 0 | Loss 1.547850
[Epoch] 661  | batch 0 | Loss 1.945541
[Epoch] 662  | batch 0 | Loss 1.826193


 67%|██████▋   | 673/1000 [00:08<00:03, 83.43it/s]

[Epoch] 663  | batch 0 | Loss 2.103679
[Epoch] 664  | batch 0 | Loss 1.825929
[Epoch] 665  | batch 0 | Loss 2.004604
[Epoch] 666  | batch 0 | Loss 1.924148
[Epoch] 667  | batch 0 | Loss 1.949780
[Epoch] 668  | batch 0 | Loss 2.023520
[Epoch] 669  | batch 0 | Loss 1.945501
[Epoch] 670  | batch 0 | Loss 1.884032
[Epoch] 671  | batch 0 | Loss 1.921133
[Epoch] 672  | batch 0 | Loss 2.000339
[Epoch] 673  | batch 0 | Loss 2.286472
[Epoch] 674  | batch 0 | Loss 2.067083
[Epoch] 675  | batch 0 | Loss 1.939597
[Epoch] 676  | batch 0 | Loss 1.940124
[Epoch] 677  | batch 0 | Loss 2.019231
[Epoch] 678  | batch 0 | Loss 1.356766
[Epoch] 679  | batch 0 | Loss 2.007332


 69%|██████▉   | 691/1000 [00:08<00:03, 83.37it/s]

[Epoch] 680  | batch 0 | Loss 1.865920
[Epoch] 681  | batch 0 | Loss 2.231432
[Epoch] 682  | batch 0 | Loss 1.989669
[Epoch] 683  | batch 0 | Loss 1.952666
[Epoch] 684  | batch 0 | Loss 1.941571
[Epoch] 685  | batch 0 | Loss 1.708428
[Epoch] 686  | batch 0 | Loss 1.835149
[Epoch] 687  | batch 0 | Loss 1.817199
[Epoch] 688  | batch 0 | Loss 2.324153
[Epoch] 689  | batch 0 | Loss 2.047696
[Epoch] 690  | batch 0 | Loss 2.060220
[Epoch] 691  | batch 0 | Loss 2.135280
[Epoch] 692  | batch 0 | Loss 2.210482
[Epoch] 693  | batch 0 | Loss 2.037083
[Epoch] 694  | batch 0 | Loss 1.768288
[Epoch] 695  | batch 0 | Loss 1.933173
[Epoch] 696  | batch 0 | Loss 1.821604


 71%|███████   | 709/1000 [00:08<00:03, 83.47it/s]

[Epoch] 697  | batch 0 | Loss 2.168729
[Epoch] 698  | batch 0 | Loss 1.587413
[Epoch] 699  | batch 0 | Loss 2.263803
[Epoch] 700  | batch 0 | Loss 1.991989
[Epoch] 701  | batch 0 | Loss 2.002326
[Epoch] 702  | batch 0 | Loss 2.221361
[Epoch] 703  | batch 0 | Loss 1.779596
[Epoch] 704  | batch 0 | Loss 1.918155
[Epoch] 705  | batch 0 | Loss 1.921922
[Epoch] 706  | batch 0 | Loss 1.906023
[Epoch] 707  | batch 0 | Loss 1.715503
[Epoch] 708  | batch 0 | Loss 2.147840
[Epoch] 709  | batch 0 | Loss 2.139563
[Epoch] 710  | batch 0 | Loss 1.974888
[Epoch] 711  | batch 0 | Loss 2.135487
[Epoch] 712  | batch 0 | Loss 2.420989
[Epoch] 713  | batch 0 | Loss 1.774047


 73%|███████▎  | 727/1000 [00:09<00:03, 83.36it/s]

[Epoch] 714  | batch 0 | Loss 1.865473
[Epoch] 715  | batch 0 | Loss 1.956617
[Epoch] 716  | batch 0 | Loss 1.754615
[Epoch] 717  | batch 0 | Loss 2.234873
[Epoch] 718  | batch 0 | Loss 1.756542
[Epoch] 719  | batch 0 | Loss 2.130489
[Epoch] 720  | batch 0 | Loss 1.912764
[Epoch] 721  | batch 0 | Loss 2.051014
[Epoch] 722  | batch 0 | Loss 1.724262
[Epoch] 723  | batch 0 | Loss 1.676129
[Epoch] 724  | batch 0 | Loss 1.581601
[Epoch] 725  | batch 0 | Loss 1.990598
[Epoch] 726  | batch 0 | Loss 2.134917
[Epoch] 727  | batch 0 | Loss 1.966460
[Epoch] 728  | batch 0 | Loss 2.666321
[Epoch] 729  | batch 0 | Loss 1.911049
[Epoch] 730  | batch 0 | Loss 2.070142


 74%|███████▍  | 745/1000 [00:09<00:03, 81.89it/s]

[Epoch] 731  | batch 0 | Loss 2.196235
[Epoch] 732  | batch 0 | Loss 2.012949
[Epoch] 733  | batch 0 | Loss 1.911001
[Epoch] 734  | batch 0 | Loss 2.000778
[Epoch] 735  | batch 0 | Loss 2.215657
[Epoch] 736  | batch 0 | Loss 1.858157
[Epoch] 737  | batch 0 | Loss 2.068646
[Epoch] 738  | batch 0 | Loss 1.950631
[Epoch] 739  | batch 0 | Loss 1.956011
[Epoch] 740  | batch 0 | Loss 2.172298
[Epoch] 741  | batch 0 | Loss 1.820642
[Epoch] 742  | batch 0 | Loss 1.850482
[Epoch] 743  | batch 0 | Loss 2.133536
[Epoch] 744  | batch 0 | Loss 1.911100
[Epoch] 745  | batch 0 | Loss 2.172663
[Epoch] 746  | batch 0 | Loss 1.728150
[Epoch] 747  | batch 0 | Loss 1.927811


 76%|███████▋  | 763/1000 [00:09<00:02, 82.59it/s]

[Epoch] 748  | batch 0 | Loss 2.110931
[Epoch] 749  | batch 0 | Loss 2.102181
[Epoch] 750  | batch 0 | Loss 1.926842
[Epoch] 751  | batch 0 | Loss 1.963519
[Epoch] 752  | batch 0 | Loss 2.064060
[Epoch] 753  | batch 0 | Loss 2.127843
[Epoch] 754  | batch 0 | Loss 2.006841
[Epoch] 755  | batch 0 | Loss 2.065466
[Epoch] 756  | batch 0 | Loss 2.088121
[Epoch] 757  | batch 0 | Loss 2.273294
[Epoch] 758  | batch 0 | Loss 2.238426
[Epoch] 759  | batch 0 | Loss 1.973683
[Epoch] 760  | batch 0 | Loss 2.172070
[Epoch] 761  | batch 0 | Loss 1.879751
[Epoch] 762  | batch 0 | Loss 2.095701
[Epoch] 763  | batch 0 | Loss 2.303127
[Epoch] 764  | batch 0 | Loss 2.133508


 78%|███████▊  | 781/1000 [00:09<00:02, 83.00it/s]

[Epoch] 765  | batch 0 | Loss 1.891112
[Epoch] 766  | batch 0 | Loss 1.824162
[Epoch] 767  | batch 0 | Loss 2.265513
[Epoch] 768  | batch 0 | Loss 2.340797
[Epoch] 769  | batch 0 | Loss 1.766320
[Epoch] 770  | batch 0 | Loss 1.945146
[Epoch] 771  | batch 0 | Loss 2.000959
[Epoch] 772  | batch 0 | Loss 1.931042
[Epoch] 773  | batch 0 | Loss 1.861336
[Epoch] 774  | batch 0 | Loss 1.876599
[Epoch] 775  | batch 0 | Loss 1.856299
[Epoch] 776  | batch 0 | Loss 1.855459
[Epoch] 777  | batch 0 | Loss 2.123008
[Epoch] 778  | batch 0 | Loss 1.837310
[Epoch] 779  | batch 0 | Loss 2.133786
[Epoch] 780  | batch 0 | Loss 1.681384
[Epoch] 781  | batch 0 | Loss 1.975272


 80%|███████▉  | 799/1000 [00:10<00:02, 82.93it/s]

[Epoch] 782  | batch 0 | Loss 2.092283
[Epoch] 783  | batch 0 | Loss 1.626835
[Epoch] 784  | batch 0 | Loss 1.997230
[Epoch] 785  | batch 0 | Loss 2.124466
[Epoch] 786  | batch 0 | Loss 1.854637
[Epoch] 787  | batch 0 | Loss 2.174959
[Epoch] 788  | batch 0 | Loss 1.662592
[Epoch] 789  | batch 0 | Loss 2.140189
[Epoch] 790  | batch 0 | Loss 1.682601
[Epoch] 791  | batch 0 | Loss 2.046948
[Epoch] 792  | batch 0 | Loss 2.037759
[Epoch] 793  | batch 0 | Loss 2.136212
[Epoch] 794  | batch 0 | Loss 1.930922
[Epoch] 795  | batch 0 | Loss 1.795611
[Epoch] 796  | batch 0 | Loss 1.677368
[Epoch] 797  | batch 0 | Loss 1.956648
[Epoch] 798  | batch 0 | Loss 2.039328


 81%|████████  | 808/1000 [00:10<00:02, 83.05it/s]

[Epoch] 799  | batch 0 | Loss 2.472255
[Epoch] 800  | batch 0 | Loss 2.210989
[Epoch] 801  | batch 0 | Loss 2.044737
[Epoch] 802  | batch 0 | Loss 1.844622
[Epoch] 803  | batch 0 | Loss 2.190614
[Epoch] 804  | batch 0 | Loss 1.782872
[Epoch] 805  | batch 0 | Loss 2.186255
[Epoch] 806  | batch 0 | Loss 2.067484
[Epoch] 807  | batch 0 | Loss 1.969438
[Epoch] 808  | batch 0 | Loss 1.928686
[Epoch] 809  | batch 0 | Loss 1.799802
[Epoch] 810  | batch 0 | Loss 2.245427
[Epoch] 811  | batch 0 | Loss 1.893561
[Epoch] 812  | batch 0 | Loss 2.261734
[Epoch] 813  | batch 0 | Loss 2.124953
[Epoch] 814  | batch 0 | Loss 1.665737
[Epoch] 815  | batch 0 | Loss 1.983112


 83%|████████▎ | 826/1000 [00:10<00:02, 83.26it/s]

[Epoch] 816  | batch 0 | Loss 2.153615
[Epoch] 817  | batch 0 | Loss 1.788841
[Epoch] 818  | batch 0 | Loss 2.073813
[Epoch] 819  | batch 0 | Loss 1.775188
[Epoch] 820  | batch 0 | Loss 1.721065
[Epoch] 821  | batch 0 | Loss 1.686099
[Epoch] 822  | batch 0 | Loss 1.917596
[Epoch] 823  | batch 0 | Loss 2.607499
[Epoch] 824  | batch 0 | Loss 2.257554
[Epoch] 825  | batch 0 | Loss 2.336727
[Epoch] 826  | batch 0 | Loss 1.683770
[Epoch] 827  | batch 0 | Loss 2.313606
[Epoch] 828  | batch 0 | Loss 1.885352
[Epoch] 829  | batch 0 | Loss 2.315949
[Epoch] 830  | batch 0 | Loss 2.001852
[Epoch] 831  | batch 0 | Loss 1.852312
[Epoch] 832  | batch 0 | Loss 1.797482


 84%|████████▍ | 844/1000 [00:10<00:01, 83.16it/s]

[Epoch] 833  | batch 0 | Loss 2.548412
[Epoch] 834  | batch 0 | Loss 1.805941
[Epoch] 835  | batch 0 | Loss 2.112507
[Epoch] 836  | batch 0 | Loss 2.323193
[Epoch] 837  | batch 0 | Loss 1.930389
[Epoch] 838  | batch 0 | Loss 2.136830
[Epoch] 839  | batch 0 | Loss 1.844801
[Epoch] 840  | batch 0 | Loss 1.628254
[Epoch] 841  | batch 0 | Loss 1.811540
[Epoch] 842  | batch 0 | Loss 2.160929
[Epoch] 843  | batch 0 | Loss 2.049681
[Epoch] 844  | batch 0 | Loss 2.128045
[Epoch] 845  | batch 0 | Loss 2.285294
[Epoch] 846  | batch 0 | Loss 2.625828
[Epoch] 847  | batch 0 | Loss 1.996703
[Epoch] 848  | batch 0 | Loss 1.762694
[Epoch] 849  | batch 0 | Loss 1.582795


 86%|████████▌ | 862/1000 [00:10<00:01, 83.25it/s]

[Epoch] 850  | batch 0 | Loss 1.736845
[Epoch] 851  | batch 0 | Loss 1.902606
[Epoch] 852  | batch 0 | Loss 2.008889
[Epoch] 853  | batch 0 | Loss 2.144031
[Epoch] 854  | batch 0 | Loss 1.990549
[Epoch] 855  | batch 0 | Loss 1.887630
[Epoch] 856  | batch 0 | Loss 2.105638
[Epoch] 857  | batch 0 | Loss 1.774415
[Epoch] 858  | batch 0 | Loss 2.239830
[Epoch] 859  | batch 0 | Loss 1.582149
[Epoch] 860  | batch 0 | Loss 1.839932
[Epoch] 861  | batch 0 | Loss 1.770580
[Epoch] 862  | batch 0 | Loss 2.011213
[Epoch] 863  | batch 0 | Loss 1.727744
[Epoch] 864  | batch 0 | Loss 2.037736
[Epoch] 865  | batch 0 | Loss 1.806271
[Epoch] 866  | batch 0 | Loss 2.486876


 88%|████████▊ | 880/1000 [00:10<00:01, 83.38it/s]

[Epoch] 867  | batch 0 | Loss 1.725221
[Epoch] 868  | batch 0 | Loss 1.953711
[Epoch] 869  | batch 0 | Loss 1.806835
[Epoch] 870  | batch 0 | Loss 2.100109
[Epoch] 871  | batch 0 | Loss 1.862106
[Epoch] 872  | batch 0 | Loss 2.168667
[Epoch] 873  | batch 0 | Loss 1.916929
[Epoch] 874  | batch 0 | Loss 2.143170
[Epoch] 875  | batch 0 | Loss 1.601714
[Epoch] 876  | batch 0 | Loss 2.118413
[Epoch] 877  | batch 0 | Loss 2.111158
[Epoch] 878  | batch 0 | Loss 2.039806
[Epoch] 879  | batch 0 | Loss 2.111898
[Epoch] 880  | batch 0 | Loss 2.112863
[Epoch] 881  | batch 0 | Loss 2.041208
[Epoch] 882  | batch 0 | Loss 1.733431
[Epoch] 883  | batch 0 | Loss 1.737093


 90%|████████▉ | 898/1000 [00:11<00:01, 83.56it/s]

[Epoch] 884  | batch 0 | Loss 2.162158
[Epoch] 885  | batch 0 | Loss 2.112472
[Epoch] 886  | batch 0 | Loss 1.851798
[Epoch] 887  | batch 0 | Loss 1.754578
[Epoch] 888  | batch 0 | Loss 2.066627
[Epoch] 889  | batch 0 | Loss 2.454998
[Epoch] 890  | batch 0 | Loss 2.057125
[Epoch] 891  | batch 0 | Loss 1.667840
[Epoch] 892  | batch 0 | Loss 2.008941
[Epoch] 893  | batch 0 | Loss 1.764520
[Epoch] 894  | batch 0 | Loss 1.678779
[Epoch] 895  | batch 0 | Loss 2.023775
[Epoch] 896  | batch 0 | Loss 1.859396
[Epoch] 897  | batch 0 | Loss 2.122180
[Epoch] 898  | batch 0 | Loss 1.938975
[Epoch] 899  | batch 0 | Loss 1.911370
[Epoch] 900  | batch 0 | Loss 2.054240


 92%|█████████▏| 916/1000 [00:11<00:01, 83.49it/s]

[Epoch] 901  | batch 0 | Loss 1.574341
[Epoch] 902  | batch 0 | Loss 1.894788
[Epoch] 903  | batch 0 | Loss 2.023317
[Epoch] 904  | batch 0 | Loss 1.751349
[Epoch] 905  | batch 0 | Loss 1.885630
[Epoch] 906  | batch 0 | Loss 2.068394
[Epoch] 907  | batch 0 | Loss 2.228554
[Epoch] 908  | batch 0 | Loss 1.748150
[Epoch] 909  | batch 0 | Loss 1.915088
[Epoch] 910  | batch 0 | Loss 1.907520
[Epoch] 911  | batch 0 | Loss 1.873987
[Epoch] 912  | batch 0 | Loss 1.990652
[Epoch] 913  | batch 0 | Loss 2.277930
[Epoch] 914  | batch 0 | Loss 1.992876
[Epoch] 915  | batch 0 | Loss 2.084332
[Epoch] 916  | batch 0 | Loss 1.717236
[Epoch] 917  | batch 0 | Loss 1.655604


 93%|█████████▎| 934/1000 [00:11<00:00, 83.47it/s]

[Epoch] 918  | batch 0 | Loss 1.783804
[Epoch] 919  | batch 0 | Loss 1.712767
[Epoch] 920  | batch 0 | Loss 1.791541
[Epoch] 921  | batch 0 | Loss 1.976798
[Epoch] 922  | batch 0 | Loss 1.723701
[Epoch] 923  | batch 0 | Loss 1.979066
[Epoch] 924  | batch 0 | Loss 1.802494
[Epoch] 925  | batch 0 | Loss 1.936019
[Epoch] 926  | batch 0 | Loss 1.964295
[Epoch] 927  | batch 0 | Loss 1.953289
[Epoch] 928  | batch 0 | Loss 1.948270
[Epoch] 929  | batch 0 | Loss 2.223064
[Epoch] 930  | batch 0 | Loss 1.749786
[Epoch] 931  | batch 0 | Loss 2.161515
[Epoch] 932  | batch 0 | Loss 1.816214
[Epoch] 933  | batch 0 | Loss 1.800693
[Epoch] 934  | batch 0 | Loss 1.766381


 95%|█████████▌| 952/1000 [00:11<00:00, 83.41it/s]

[Epoch] 935  | batch 0 | Loss 2.391405
[Epoch] 936  | batch 0 | Loss 2.248093
[Epoch] 937  | batch 0 | Loss 1.989581
[Epoch] 938  | batch 0 | Loss 2.389271
[Epoch] 939  | batch 0 | Loss 1.876164
[Epoch] 940  | batch 0 | Loss 1.939524
[Epoch] 941  | batch 0 | Loss 1.626540
[Epoch] 942  | batch 0 | Loss 1.651448
[Epoch] 943  | batch 0 | Loss 2.087681
[Epoch] 944  | batch 0 | Loss 2.170242
[Epoch] 945  | batch 0 | Loss 2.016423
[Epoch] 946  | batch 0 | Loss 1.986102
[Epoch] 947  | batch 0 | Loss 1.884727
[Epoch] 948  | batch 0 | Loss 2.008531
[Epoch] 949  | batch 0 | Loss 1.863124
[Epoch] 950  | batch 0 | Loss 2.018269
[Epoch] 951  | batch 0 | Loss 1.828248


 96%|█████████▌| 961/1000 [00:11<00:00, 83.44it/s]

[Epoch] 952  | batch 0 | Loss 1.828860
[Epoch] 953  | batch 0 | Loss 2.024080
[Epoch] 954  | batch 0 | Loss 1.723251
[Epoch] 955  | batch 0 | Loss 1.944839
[Epoch] 956  | batch 0 | Loss 1.974316
[Epoch] 957  | batch 0 | Loss 2.199961
[Epoch] 958  | batch 0 | Loss 1.718175
[Epoch] 959  | batch 0 | Loss 1.775737
[Epoch] 960  | batch 0 | Loss 1.982655
[Epoch] 961  | batch 0 | Loss 2.073966
[Epoch] 962  | batch 0 | Loss 2.032851
[Epoch] 963  | batch 0 | Loss 1.784951
[Epoch] 964  | batch 0 | Loss 1.760723
[Epoch] 965  | batch 0 | Loss 2.356742
[Epoch] 966  | batch 0 | Loss 1.839314
[Epoch] 967  | batch 0 | Loss 1.957399
[Epoch] 968  | batch 0 | Loss 1.747418


 98%|█████████▊| 979/1000 [00:12<00:00, 82.61it/s]

[Epoch] 969  | batch 0 | Loss 1.843012
[Epoch] 970  | batch 0 | Loss 1.915320
[Epoch] 971  | batch 0 | Loss 2.357218
[Epoch] 972  | batch 0 | Loss 2.026498
[Epoch] 973  | batch 0 | Loss 1.895410
[Epoch] 974  | batch 0 | Loss 2.021171
[Epoch] 975  | batch 0 | Loss 1.839972
[Epoch] 976  | batch 0 | Loss 1.886445
[Epoch] 977  | batch 0 | Loss 1.843869
[Epoch] 978  | batch 0 | Loss 2.313621
[Epoch] 979  | batch 0 | Loss 1.896933
[Epoch] 980  | batch 0 | Loss 1.876756
[Epoch] 981  | batch 0 | Loss 1.911339
[Epoch] 982  | batch 0 | Loss 2.164487
[Epoch] 983  | batch 0 | Loss 1.815938
[Epoch] 984  | batch 0 | Loss 2.153142
[Epoch] 985  | batch 0 | Loss 1.819600


100%|██████████| 1000/1000 [00:12<00:00, 80.49it/s]

[Epoch] 986  | batch 0 | Loss 1.637998
[Epoch] 987  | batch 0 | Loss 1.958771
[Epoch] 988  | batch 0 | Loss 1.965782
[Epoch] 989  | batch 0 | Loss 2.002327
[Epoch] 990  | batch 0 | Loss 2.120966
[Epoch] 991  | batch 0 | Loss 2.193309
[Epoch] 992  | batch 0 | Loss 1.825473
[Epoch] 993  | batch 0 | Loss 1.676007
[Epoch] 994  | batch 0 | Loss 1.581259
[Epoch] 995  | batch 0 | Loss 2.015757
[Epoch] 996  | batch 0 | Loss 2.100486
[Epoch] 997  | batch 0 | Loss 2.099599
[Epoch] 998  | batch 0 | Loss 2.374251
[Epoch] 999  | batch 0 | Loss 1.763730





In [11]:
perturb_clf_obj.predict(train_X[-2:])

[array([ 188, 1856,  608,  196,   29,   35,   40,  956]),
 array([ 152, 2281,  301,  167,   22,   81,   24, 1040])]