In [1]:
%matplotlib inline

In [1]:
cd ..

C:\Projects\python\recommender


In [2]:
from functools import partial
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import torch as T
import torch.optim as optim

from utils import get_log_dir

In [3]:
from datasets import SeqKaggle
from models import FMLearner, TorchFM, TorchHrmFM, TorchPrmeFM, TorchTransFM, TransProbFM
from models.fm_learner import simple_loss, trans_loss, simple_weight_loss, trans_weight_loss

In [4]:
DEVICE = T.cuda.current_device()
BATCH = 2000
SHUFFLE = True
WORKERS = 0
NEG_SAMPLE = 5
item_path = Path("./inputs/kaggle/item.csv")

In [5]:
db = SeqKaggle(data_path=item_path, user_min=4)
db

Raw dataframe shape (476244, 8)
After drop nan shape: (429988, 8)
Original comptition size: 292
Original competitor size: 140065
Filtered competiter size: 27449
Filtered dataframe shape: (284806, 8)


<datasets.torch_kaggle.SeqKaggle at 0x2975e3eb288>

In [6]:
db.config_db(batch_size=BATCH,
             shuffle=SHUFFLE,
             num_workers=WORKERS,
             device=DEVICE,
             neg_sample=NEG_SAMPLE)

In [7]:
feat_dim = db.feat_dim
NUM_DIM = 124
INIT_MEAN = 0.1

## Create Criterion

In [8]:
# regst setting
LINEAR_REG = 1
EMB_REG = 1
TRANS_REG = 1

In [9]:
simple_loss_callback = partial(simple_loss, LINEAR_REG, EMB_REG)
simple_loss_callback

functools.partial(<function simple_loss at 0x000002975E391C18>, 1, 1)

In [10]:
trans_loss_callback = partial(trans_loss, LINEAR_REG, EMB_REG, TRANS_REG)
trans_loss_callback

functools.partial(<function trans_loss at 0x000002975E3E3318>, 1, 1, 1)

In [11]:
simple_weight_loss_callback = partial(simple_weight_loss, LINEAR_REG, EMB_REG)
simple_weight_loss_callback

functools.partial(<function simple_weight_loss at 0x000002975E3E33A8>, 1, 1)

In [12]:
trans_weight_loss_callback = partial(trans_weight_loss, LINEAR_REG, EMB_REG, TRANS_REG)
trans_weight_loss_callback

functools.partial(<function trans_weight_loss at 0x000002975E3E3438>, 1, 1, 1)

## Train Model

### Hyper-parameter

In [27]:
feat_dim = db.feat_dim
NUM_DIM = 124
INIT_MEAN = 0.1
INIT_SCALE = 0.001

### Train FM Model

In [14]:
LEARNING_RATE = 0.001
DECAY_FREQ = 1000
DECAY_GAMMA = 1

In [15]:
fm_model = TorchFM(feature_dim=feat_dim, num_dim=NUM_DIM, init_mean=INIT_MEAN)
fm_model

TorchFM()

In [16]:
adam_opt = optim.Adam(fm_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt, step_size=DECAY_FREQ, gamma=DECAY_GAMMA)

In [17]:
fm_learner = FMLearner(fm_model, adam_opt, schedular, db)
fm_learner

<models.fm_learner.FMLearner at 0x22f8b470588>

In [18]:
fm_learner.compile(train_col='base',
                   valid_col='base',
                   test_col='base',
                   loss_callback=simple_loss_callback)

In [19]:
fm_learner.compile(train_col='seq',
                   valid_col='seq',
                   test_col='seq',
                   loss_callback=simple_weight_loss_callback)

In [25]:
fm_learner.compile(train_col='seq',
                   valid_col='seq',
                   test_col='seq',
                   loss_callback=simple_loss_callback)

In [19]:
fm_learner.fit(epoch=2,
               log_dir=get_log_dir(ds_type='simple_kaggle', model_type='fm'))

  0%|                                                                                                                                                           | 0/2 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 11965.26566334206
Epoch 0 step 1: training accuarcy: 0.8755000000000001
Epoch 0 step 1: training loss: 11600.11047899487
Epoch 0 step 2: training accuarcy: 0.8915000000000001
Epoch 0 step 2: training loss: 11264.478091427358
Epoch 0 step 3: training accuarcy: 0.8985
Epoch 0 step 3: training loss: 10915.713287832556
Epoch 0 step 4: training accuarcy: 0.9075
Epoch 0 step 4: training loss: 10571.656815091415
Epoch 0 step 5: training accuarcy: 0.923
Epoch 0 step 5: training loss: 10256.229579476176
Epoch 0 step 6: training accuarcy: 0.926
Epoch 0 step 6: training loss: 9944.374015567542
Epoch 0 step 7: training accuarcy: 0.926
Epoch 0 step 7: training loss: 9649.853580218305
Epoch 0 step 8: training accuarcy: 0.929
Epoch 0 step 8: training loss: 9349.899384050676
Epoch 0 step 9: training accuarcy: 0.9390000000000001
Epoch 0 step 9: training loss: 9048.757016595198
Epoch 0 step 10: training accuarcy: 0.9460000000000001
Epoch 0 step 10: training loss: 

KeyboardInterrupt: 

In [20]:
del fm_model
T.cuda.empty_cache()

### Train HRM FM Model

In [21]:
hrm_model = TorchHrmFM(feature_dim=feat_dim, num_dim=NUM_DIM, init_mean=INIT_MEAN)
hrm_model

TorchHrmFM()

In [22]:
adam_opt = optim.Adam(hrm_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt,
                                      step_size=DECAY_FREQ,
                                      gamma=DECAY_GAMMA)

In [23]:
hrm_learner = FMLearner(hrm_model, adam_opt, schedular, db)
hrm_learner

<models.fm_learner.FMLearner at 0x22f83fd1588>

In [24]:
hrm_learner.compile(train_col='base',
                    valid_col='base',
                    test_col='base',
                    loss_callback=simple_loss_callback)

In [19]:
hrm_learner.compile(train_col='seq',
                    valid_col='seq',
                    test_col='seq',
                    loss_callback=simple_loss_callback)

In [25]:
hrm_learner.compile(train_col='seq',
                    valid_col='seq',
                    test_col='seq',
                    loss_callback=simple_weight_loss_callback)

In [25]:
hrm_learner.fit(epoch=4,
                log_dir=get_log_dir('kaggle', 'hrm'))


  0%|                                                                                                                                                           | 0/4 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 50276.187795471065
Epoch 0 step 1: training accuarcy: 0.07150000000000001
Epoch 0 step 1: training loss: 49092.85216770522
Epoch 0 step 2: training accuarcy: 0.095
Epoch 0 step 2: training loss: 48660.836132789926
Epoch 0 step 3: training accuarcy: 0.106
Epoch 0 step 3: training loss: 48686.83333892614
Epoch 0 step 4: training accuarcy: 0.12
Epoch 0 step 4: training loss: 47837.53223607579
Epoch 0 step 5: training accuarcy: 0.131
Epoch 0 step 5: training loss: 47555.4556557263
Epoch 0 step 6: training accuarcy: 0.1365
Epoch 0 step 6: training loss: 47841.09369064105
Epoch 0 step 7: training accuarcy: 0.133
Epoch 0 step 7: training loss: 47064.64672136519
Epoch 0 step 8: training accuarcy: 0.14250000000000002
Epoch 0 step 8: training loss: 46181.394246739845
Epoch 0 step 9: training accuarcy: 0.151
Epoch 0 step 9: training loss: 46745.285647654324
Epoch 0 step 10: training accuarcy: 0.131
Epoch 0 step 10: training loss: 46020.389417301805
Epoch 0 


 25%|████████████████████████████████████▊                                                                                                              | 1/4 [00:55<02:46, 55.63s/it]

Epoch: 1
Epoch 1 step 115: training loss: 9783.069492354141
Epoch 1 step 116: training accuarcy: 0.787
Epoch 1 step 116: training loss: 9285.822984497869
Epoch 1 step 117: training accuarcy: 0.794
Epoch 1 step 117: training loss: 9959.542070840522
Epoch 1 step 118: training accuarcy: 0.7815
Epoch 1 step 118: training loss: 9727.889155009323
Epoch 1 step 119: training accuarcy: 0.7885
Epoch 1 step 119: training loss: 9281.914538185036
Epoch 1 step 120: training accuarcy: 0.793
Epoch 1 step 120: training loss: 8656.291835320066
Epoch 1 step 121: training accuarcy: 0.811
Epoch 1 step 121: training loss: 9049.275131110622
Epoch 1 step 122: training accuarcy: 0.801
Epoch 1 step 122: training loss: 8965.234382037223
Epoch 1 step 123: training accuarcy: 0.804
Epoch 1 step 123: training loss: 10029.384218381307
Epoch 1 step 124: training accuarcy: 0.78
Epoch 1 step 124: training loss: 8784.009006360808
Epoch 1 step 125: training accuarcy: 0.8045
Epoch 1 step 125: training loss: 8570.6364062065


 50%|█████████████████████████████████████████████████████████████████████████▌                                                                         | 2/4 [01:52<01:51, 55.85s/it]

Epoch: 2
Epoch 2 step 230: training loss: 5226.899802472058
Epoch 2 step 231: training accuarcy: 0.886
Epoch 2 step 231: training loss: 5726.213888306911
Epoch 2 step 232: training accuarcy: 0.8755000000000001
Epoch 2 step 232: training loss: 6047.060764773698
Epoch 2 step 233: training accuarcy: 0.8665
Epoch 2 step 233: training loss: 5752.312034754913
Epoch 2 step 234: training accuarcy: 0.8735
Epoch 2 step 234: training loss: 5925.781889184981
Epoch 2 step 235: training accuarcy: 0.869
Epoch 2 step 235: training loss: 5619.384699222382
Epoch 2 step 236: training accuarcy: 0.8775000000000001
Epoch 2 step 236: training loss: 5708.884776051957
Epoch 2 step 237: training accuarcy: 0.8735
Epoch 2 step 237: training loss: 5360.040650722991
Epoch 2 step 238: training accuarcy: 0.8795000000000001
Epoch 2 step 238: training loss: 5896.273762892348
Epoch 2 step 239: training accuarcy: 0.868
Epoch 2 step 239: training loss: 5920.223489516687
Epoch 2 step 240: training accuarcy: 0.872
Epoch 2 s


 75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                    | 3/4 [02:46<00:55, 55.54s/it]

Epoch: 3
Epoch 3 step 345: training loss: 4209.820731556838
Epoch 3 step 346: training accuarcy: 0.9065
Epoch 3 step 346: training loss: 4103.589800701508
Epoch 3 step 347: training accuarcy: 0.91
Epoch 3 step 347: training loss: 4403.055092031274
Epoch 3 step 348: training accuarcy: 0.9015
Epoch 3 step 348: training loss: 4289.640247277833
Epoch 3 step 349: training accuarcy: 0.9005
Epoch 3 step 349: training loss: 5040.323033433302
Epoch 3 step 350: training accuarcy: 0.892
Epoch 3 step 350: training loss: 4580.412319892532
Epoch 3 step 351: training accuarcy: 0.8975
Epoch 3 step 351: training loss: 4756.74782603592
Epoch 3 step 352: training accuarcy: 0.893
Epoch 3 step 352: training loss: 3750.5678628948795
Epoch 3 step 353: training accuarcy: 0.916
Epoch 3 step 353: training loss: 4023.515744376226
Epoch 3 step 354: training accuarcy: 0.9125
Epoch 3 step 354: training loss: 4369.428928013831
Epoch 3 step 355: training accuarcy: 0.9055
Epoch 3 step 355: training loss: 4441.21792644


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:42<00:00, 55.55s/it]


In [27]:
hrm_learner.fit(epoch=2,
                log_dir=get_log_dir('weight_kaggle', 'hrm'))


  0%|                                                                                                                                                                                     | 0/2 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 476156.54998684255
Epoch 0 step 1: training accuarcy: 0.24070000000000003
Epoch 0 step 1: training loss: 475096.2517683112
Epoch 0 step 2: training accuarcy: 0.2451
Epoch 0 step 2: training loss: 476683.3893662889
Epoch 0 step 3: training accuarcy: 0.2388
Epoch 0 step 3: training loss: 476621.68926722434
Epoch 0 step 4: training accuarcy: 0.2424
Epoch 0 step 4: training loss: 474116.48489528266
Epoch 0 step 5: training accuarcy: 0.24250000000000002
Epoch 0 step 5: training loss: 472665.09536336595
Epoch 0 step 6: training accuarcy: 0.2479
Epoch 0 step 6: training loss: 478215.3165373088
Epoch 0 step 7: training accuarcy: 0.23240000000000002
Epoch 0 step 7: training loss: 476077.3696475439
Epoch 0 step 8: training accuarcy: 0.2441
Epoch 0 step 8: training loss: 473822.3694914146
Epoch 0 step 9: training accuarcy: 0.2459
Epoch 0 step 9: training loss: 476906.9368236094
Epoch 0 step 10: training accuarcy: 0.2399
Epoch 0 step 10: training loss: 47809


 50%|██████████████████████████████████████████████████████████████████████████████████████                                                                                      | 1/2 [11:03<11:03, 663.57s/it]

Epoch: 1
Epoch 1 step 115: training loss: 450378.8645166672
Epoch 1 step 116: training accuarcy: 0.29700000000000004
Epoch 1 step 116: training loss: 458080.3530401238
Epoch 1 step 117: training accuarcy: 0.2836
Epoch 1 step 117: training loss: 456139.74620575464
Epoch 1 step 118: training accuarcy: 0.2854
Epoch 1 step 118: training loss: 459225.96323987877
Epoch 1 step 119: training accuarcy: 0.2836
Epoch 1 step 119: training loss: 463845.515827573
Epoch 1 step 120: training accuarcy: 0.2772
Epoch 1 step 120: training loss: 461358.83543013706
Epoch 1 step 121: training accuarcy: 0.2771
Epoch 1 step 121: training loss: 458276.1446818058
Epoch 1 step 122: training accuarcy: 0.28040000000000004
Epoch 1 step 122: training loss: 462817.93154526915
Epoch 1 step 123: training accuarcy: 0.27940000000000004
Epoch 1 step 123: training loss: 455134.1138683906
Epoch 1 step 124: training accuarcy: 0.2898
Epoch 1 step 124: training loss: 466948.87889836164
Epoch 1 step 125: training accuarcy: 0.272


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [22:03<00:00, 662.51s/it]

In [28]:
del hrm_model
T.cuda.empty_cache()

### Train PRME FM Model

In [15]:
prme_model = TorchPrmeFM(feature_dim=feat_dim, num_dim=NUM_DIM, init_mean=INIT_MEAN)
prme_model

TorchPrmeFM()

In [16]:
adam_opt = optim.Adam(prme_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt,
                                      step_size=DECAY_FREQ,
                                      gamma=DECAY_GAMMA)

In [17]:
prme_learner = FMLearner(prme_model, adam_opt, schedular, db)
prme_learner

<models.fm_learner.FMLearner at 0x29768af3688>

In [18]:
prme_learner.compile(train_col='base',
                     valid_col='base',
                     test_col='base',
                     loss_callback=simple_loss_callback)

In [32]:
prme_learner.compile(train_col='seq',
                     valid_col='seq',
                     test_col='seq',
                     loss_callback=simple_loss_callback)

In [38]:
prme_learner.compile(train_col='seq',
                     valid_col='seq',
                     test_col='seq',
                     loss_callback=simple_weight_loss_callback)

In [19]:
prme_learner.fit(epoch=4, log_dir=get_log_dir('kaggle', 'prme'))

  0%|                                                                                                                                                           | 0/4 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 49334.19990822089
Epoch 0 step 1: training accuarcy: 0.1055
Epoch 0 step 1: training loss: 48499.95265345877
Epoch 0 step 2: training accuarcy: 0.1225
Epoch 0 step 2: training loss: 48508.58482851556
Epoch 0 step 3: training accuarcy: 0.14
Epoch 0 step 3: training loss: 48355.330318926666
Epoch 0 step 4: training accuarcy: 0.1325
Epoch 0 step 4: training loss: 48412.418119987364
Epoch 0 step 5: training accuarcy: 0.1335
Epoch 0 step 5: training loss: 48427.49356719893
Epoch 0 step 6: training accuarcy: 0.134
Epoch 0 step 6: training loss: 47088.98339314324
Epoch 0 step 7: training accuarcy: 0.1565
Epoch 0 step 7: training loss: 47545.87002208854
Epoch 0 step 8: training accuarcy: 0.15
Epoch 0 step 8: training loss: 46495.594360983
Epoch 0 step 9: training accuarcy: 0.163
Epoch 0 step 9: training loss: 46577.97881703819
Epoch 0 step 10: training accuarcy: 0.151
Epoch 0 step 10: training loss: 46024.968259917514
Epoch 0 step 11: training accuarcy: 

Epoch 0 step 86: training accuarcy: 0.185
Epoch 0 step 86: training loss: 37909.920909298824
Epoch 0 step 87: training accuarcy: 0.1805
Epoch 0 step 87: training loss: 38328.33527743073
Epoch 0 step 88: training accuarcy: 0.1705
Epoch 0 step 88: training loss: 37863.5617340094
Epoch 0 step 89: training accuarcy: 0.1825
Epoch 0 step 89: training loss: 38217.07829109385
Epoch 0 step 90: training accuarcy: 0.17500000000000002
Epoch 0 step 90: training loss: 37670.53862926792
Epoch 0 step 91: training accuarcy: 0.1805
Epoch 0 step 91: training loss: 38532.58043697983
Epoch 0 step 92: training accuarcy: 0.166
Epoch 0 step 92: training loss: 37308.27098170765
Epoch 0 step 93: training accuarcy: 0.1885
Epoch 0 step 93: training loss: 38259.95704552709
Epoch 0 step 94: training accuarcy: 0.17300000000000001
Epoch 0 step 94: training loss: 37740.50388142561
Epoch 0 step 95: training accuarcy: 0.181
Epoch 0 step 95: training loss: 37828.88628974144
Epoch 0 step 96: training accuarcy: 0.1785
Epoc

 25%|████████████████████████████████████▊                                                                                                              | 1/4 [00:53<02:41, 53.78s/it]

Epoch: 1
Epoch 1 step 115: training loss: 33691.97828820489
Epoch 1 step 116: training accuarcy: 0.254
Epoch 1 step 116: training loss: 33250.799956758834
Epoch 1 step 117: training accuarcy: 0.258
Epoch 1 step 117: training loss: 32145.618649127533
Epoch 1 step 118: training accuarcy: 0.28350000000000003
Epoch 1 step 118: training loss: 31365.668929555828
Epoch 1 step 119: training accuarcy: 0.306
Epoch 1 step 119: training loss: 29255.60404049649
Epoch 1 step 120: training accuarcy: 0.3525
Epoch 1 step 120: training loss: 29125.349789618722
Epoch 1 step 121: training accuarcy: 0.3635
Epoch 1 step 121: training loss: 29120.061800514333
Epoch 1 step 122: training accuarcy: 0.362
Epoch 1 step 122: training loss: 28000.451613361223
Epoch 1 step 123: training accuarcy: 0.387
Epoch 1 step 123: training loss: 26995.08051730424
Epoch 1 step 124: training accuarcy: 0.40850000000000003
Epoch 1 step 124: training loss: 27606.777872646675
Epoch 1 step 125: training accuarcy: 0.397
Epoch 1 step 1

Epoch 1 step 200: training accuarcy: 0.7325
Epoch 1 step 200: training loss: 12554.638665241568
Epoch 1 step 201: training accuarcy: 0.726
Epoch 1 step 201: training loss: 11417.443986023634
Epoch 1 step 202: training accuarcy: 0.7515000000000001
Epoch 1 step 202: training loss: 11838.512200656907
Epoch 1 step 203: training accuarcy: 0.739
Epoch 1 step 203: training loss: 11611.631980870925
Epoch 1 step 204: training accuarcy: 0.7455
Epoch 1 step 204: training loss: 12390.983897685875
Epoch 1 step 205: training accuarcy: 0.7305
Epoch 1 step 205: training loss: 11663.013384012615
Epoch 1 step 206: training accuarcy: 0.744
Epoch 1 step 206: training loss: 11378.500968608705
Epoch 1 step 207: training accuarcy: 0.75
Epoch 1 step 207: training loss: 12539.517852124003
Epoch 1 step 208: training accuarcy: 0.7265
Epoch 1 step 208: training loss: 12195.36109501372
Epoch 1 step 209: training accuarcy: 0.7335
Epoch 1 step 209: training loss: 11934.890524857947
Epoch 1 step 210: training accuarc

 50%|█████████████████████████████████████████████████████████████████████████▌                                                                         | 2/4 [01:46<01:46, 53.35s/it]

Epoch: 2
Epoch 2 step 230: training loss: 11461.225272801685
Epoch 2 step 231: training accuarcy: 0.751
Epoch 2 step 231: training loss: 12042.225337009419
Epoch 2 step 232: training accuarcy: 0.738
Epoch 2 step 232: training loss: 12851.898976717492
Epoch 2 step 233: training accuarcy: 0.72
Epoch 2 step 233: training loss: 11946.302059449681
Epoch 2 step 234: training accuarcy: 0.7405
Epoch 2 step 234: training loss: 12466.06910057704
Epoch 2 step 235: training accuarcy: 0.727
Epoch 2 step 235: training loss: 11405.210724082077
Epoch 2 step 236: training accuarcy: 0.7515000000000001
Epoch 2 step 236: training loss: 10762.11719960668
Epoch 2 step 237: training accuarcy: 0.764
Epoch 2 step 237: training loss: 11542.333336040258
Epoch 2 step 238: training accuarcy: 0.748
Epoch 2 step 238: training loss: 11679.189580787697
Epoch 2 step 239: training accuarcy: 0.746
Epoch 2 step 239: training loss: 11653.99132591041
Epoch 2 step 240: training accuarcy: 0.7465
Epoch 2 step 240: training los

Epoch 2 step 314: training accuarcy: 0.7625000000000001
Epoch 2 step 314: training loss: 9950.959794437247
Epoch 2 step 315: training accuarcy: 0.783
Epoch 2 step 315: training loss: 10788.78167655488
Epoch 2 step 316: training accuarcy: 0.765
Epoch 2 step 316: training loss: 10872.875124457212
Epoch 2 step 317: training accuarcy: 0.762
Epoch 2 step 317: training loss: 10753.530004744322
Epoch 2 step 318: training accuarcy: 0.765
Epoch 2 step 318: training loss: 10315.6743134588
Epoch 2 step 319: training accuarcy: 0.776
Epoch 2 step 319: training loss: 10599.851664752417
Epoch 2 step 320: training accuarcy: 0.771
Epoch 2 step 320: training loss: 10345.317259783438
Epoch 2 step 321: training accuarcy: 0.7745
Epoch 2 step 321: training loss: 10268.748956381303
Epoch 2 step 322: training accuarcy: 0.777
Epoch 2 step 322: training loss: 10865.613482324798
Epoch 2 step 323: training accuarcy: 0.762
Epoch 2 step 323: training loss: 11127.270658321255
Epoch 2 step 324: training accuarcy: 0.7

 75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                    | 3/4 [02:38<00:53, 53.12s/it]

Epoch: 3
Epoch 3 step 345: training loss: 10645.054354893087
Epoch 3 step 346: training accuarcy: 0.768
Epoch 3 step 346: training loss: 10202.75601461524
Epoch 3 step 347: training accuarcy: 0.776
Epoch 3 step 347: training loss: 10593.889347493736
Epoch 3 step 348: training accuarcy: 0.77
Epoch 3 step 348: training loss: 9738.516907726576
Epoch 3 step 349: training accuarcy: 0.7885
Epoch 3 step 349: training loss: 9771.973639609923
Epoch 3 step 350: training accuarcy: 0.7855
Epoch 3 step 350: training loss: 10567.445600058716
Epoch 3 step 351: training accuarcy: 0.771
Epoch 3 step 351: training loss: 10107.13043342131
Epoch 3 step 352: training accuarcy: 0.779
Epoch 3 step 352: training loss: 10892.240913615742
Epoch 3 step 353: training accuarcy: 0.761
Epoch 3 step 353: training loss: 9945.983062003495
Epoch 3 step 354: training accuarcy: 0.7805
Epoch 3 step 354: training loss: 10526.602242606634
Epoch 3 step 355: training accuarcy: 0.771
Epoch 3 step 355: training loss: 10063.46018

Epoch 3 step 431: training loss: 9210.401649727937
Epoch 3 step 432: training accuarcy: 0.799
Epoch 3 step 432: training loss: 9421.11035907174
Epoch 3 step 433: training accuarcy: 0.795
Epoch 3 step 433: training loss: 8405.260537505179
Epoch 3 step 434: training accuarcy: 0.8165
Epoch 3 step 434: training loss: 9395.804761671357
Epoch 3 step 435: training accuarcy: 0.796
Epoch 3 step 435: training loss: 9157.889041793349
Epoch 3 step 436: training accuarcy: 0.8
Epoch 3 step 436: training loss: 9120.387156677363
Epoch 3 step 437: training accuarcy: 0.8005
Epoch 3 step 437: training loss: 9530.100683097251
Epoch 3 step 438: training accuarcy: 0.79
Epoch 3 step 438: training loss: 9142.133654339466
Epoch 3 step 439: training accuarcy: 0.8025
Epoch 3 step 439: training loss: 10013.913937748905
Epoch 3 step 440: training accuarcy: 0.7825
Epoch 3 step 440: training loss: 10130.178425686729
Epoch 3 step 441: training accuarcy: 0.779
Epoch 3 step 441: training loss: 9976.083101387116
Epoch 3

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:30<00:00, 52.67s/it]


In [39]:
prme_learner.fit(epoch=2, log_dir=get_log_dir('weight_kaggle', 'prme'))


  0%|                                                                                                                                                                                     | 0/2 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 487422.82394555013
Epoch 0 step 1: training accuarcy: 0.12040000000000001
Epoch 0 step 1: training loss: 489450.34019803803
Epoch 0 step 2: training accuarcy: 0.1499
Epoch 0 step 2: training loss: 488016.4392187835
Epoch 0 step 3: training accuarcy: 0.1763
Epoch 0 step 3: training loss: 491882.392495647
Epoch 0 step 4: training accuarcy: 0.1903
Epoch 0 step 4: training loss: 490320.818300761
Epoch 0 step 5: training accuarcy: 0.2059
Epoch 0 step 5: training loss: 490525.54563383816
Epoch 0 step 6: training accuarcy: 0.215
Epoch 0 step 6: training loss: 485183.02029845637
Epoch 0 step 7: training accuarcy: 0.22240000000000001
Epoch 0 step 7: training loss: 487375.3059561003
Epoch 0 step 8: training accuarcy: 0.2267
Epoch 0 step 8: training loss: 489671.4944303962
Epoch 0 step 9: training accuarcy: 0.22710000000000002
Epoch 0 step 9: training loss: 492945.4432845665
Epoch 0 step 10: training accuarcy: 0.21630000000000002
Epoch 0 step 10: training l


 50%|██████████████████████████████████████████████████████████████████████████████████████                                                                                      | 1/2 [11:09<11:09, 669.07s/it]

Epoch: 1
Epoch 1 step 115: training loss: 474542.40284475335
Epoch 1 step 116: training accuarcy: 0.24000000000000002
Epoch 1 step 116: training loss: 473252.3797201521
Epoch 1 step 117: training accuarcy: 0.24650000000000002
Epoch 1 step 117: training loss: 477294.37638240063
Epoch 1 step 118: training accuarcy: 0.2396
Epoch 1 step 118: training loss: 476808.94339811546
Epoch 1 step 119: training accuarcy: 0.2421
Epoch 1 step 119: training loss: 471741.0805636938
Epoch 1 step 120: training accuarcy: 0.2466
Epoch 1 step 120: training loss: 473419.45447076997
Epoch 1 step 121: training accuarcy: 0.2457
Epoch 1 step 121: training loss: 476231.16271982976
Epoch 1 step 122: training accuarcy: 0.2431
Epoch 1 step 122: training loss: 478678.3387199232
Epoch 1 step 123: training accuarcy: 0.24000000000000002
Epoch 1 step 123: training loss: 476005.4940440219
Epoch 1 step 124: training accuarcy: 0.2411
Epoch 1 step 124: training loss: 475530.66795238934
Epoch 1 step 125: training accuarcy: 0.2


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [22:23<00:00, 670.79s/it]

In [20]:
del prme_model
T.cuda.empty_cache()

### Train Trans FM Model

In [21]:
trans_model = TorchTransFM(feature_dim=feat_dim,
                           num_dim=NUM_DIM,
                           init_mean=INIT_MEAN)
trans_model

TorchTransFM()

In [22]:
adam_opt = optim.Adam(trans_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt,
                                      step_size=DECAY_FREQ,
                                      gamma=DECAY_GAMMA)

In [23]:
trans_learner = FMLearner(trans_model, adam_opt, schedular, db)
trans_learner

<models.fm_learner.FMLearner at 0x297645f2308>

In [24]:
trans_learner.compile(train_col='base',
                      valid_col='base',
                      test_col='base',
                      loss_callback=simple_loss_callback)

In [45]:
trans_learner.compile(train_col='seq',
                      valid_col='seq',
                      test_col='seq',
                      loss_callback=simple_loss_callback)

In [51]:
trans_learner.compile(train_col='seq',
                      valid_col='seq',
                      test_col='seq',
                      loss_callback=simple_weight_loss_callback)

In [25]:
trans_learner.fit(epoch=4,
                  log_dir=get_log_dir('kaggle', 'trans'))

  0%|                                                                                                                                                           | 0/4 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 12120.69442675773
Epoch 0 step 1: training accuarcy: 0.909
Epoch 0 step 1: training loss: 11755.97217997648
Epoch 0 step 2: training accuarcy: 0.921
Epoch 0 step 2: training loss: 11363.112025436305
Epoch 0 step 3: training accuarcy: 0.9390000000000001
Epoch 0 step 3: training loss: 11020.544029804794
Epoch 0 step 4: training accuarcy: 0.9470000000000001
Epoch 0 step 4: training loss: 10638.651725567826
Epoch 0 step 5: training accuarcy: 0.964
Epoch 0 step 5: training loss: 10311.73452653396
Epoch 0 step 6: training accuarcy: 0.9645
Epoch 0 step 6: training loss: 9958.153664576823
Epoch 0 step 7: training accuarcy: 0.9715
Epoch 0 step 7: training loss: 9642.351779384368
Epoch 0 step 8: training accuarcy: 0.973
Epoch 0 step 8: training loss: 9334.707186874006
Epoch 0 step 9: training accuarcy: 0.967
Epoch 0 step 9: training loss: 9041.784337774838
Epoch 0 step 10: training accuarcy: 0.97
Epoch 0 step 10: training loss: 8742.049843769144
Epoch 0 st

Epoch 0 step 88: training loss: 444.0029686884456
Epoch 0 step 89: training accuarcy: 0.998
Epoch 0 step 89: training loss: 423.3630336746897
Epoch 0 step 90: training accuarcy: 0.9965
Epoch 0 step 90: training loss: 406.3591396754678
Epoch 0 step 91: training accuarcy: 0.9955
Epoch 0 step 91: training loss: 391.17137882501254
Epoch 0 step 92: training accuarcy: 0.9965
Epoch 0 step 92: training loss: 376.33719870501375
Epoch 0 step 93: training accuarcy: 0.996
Epoch 0 step 93: training loss: 366.65489975623143
Epoch 0 step 94: training accuarcy: 0.994
Epoch 0 step 94: training loss: 348.896664367879
Epoch 0 step 95: training accuarcy: 0.996
Epoch 0 step 95: training loss: 330.2855709816177
Epoch 0 step 96: training accuarcy: 0.9965
Epoch 0 step 96: training loss: 321.7304672555529
Epoch 0 step 97: training accuarcy: 0.9965
Epoch 0 step 97: training loss: 311.165474827428
Epoch 0 step 98: training accuarcy: 0.9975
Epoch 0 step 98: training loss: 301.134339541647
Epoch 0 step 99: trainin

 25%|████████████████████████████████████▊                                                                                                              | 1/4 [00:55<02:46, 55.62s/it]

Epoch: 1
Epoch 1 step 115: training loss: 154.57467392537114
Epoch 1 step 116: training accuarcy: 0.998
Epoch 1 step 116: training loss: 148.12490804914637
Epoch 1 step 117: training accuarcy: 0.997
Epoch 1 step 117: training loss: 144.1102572789257
Epoch 1 step 118: training accuarcy: 0.996
Epoch 1 step 118: training loss: 137.90362129630878
Epoch 1 step 119: training accuarcy: 0.9945
Epoch 1 step 119: training loss: 134.0115004666773
Epoch 1 step 120: training accuarcy: 0.9935
Epoch 1 step 120: training loss: 129.555707996564
Epoch 1 step 121: training accuarcy: 0.997
Epoch 1 step 121: training loss: 123.52978110063802
Epoch 1 step 122: training accuarcy: 0.997
Epoch 1 step 122: training loss: 123.11005586240319
Epoch 1 step 123: training accuarcy: 0.9965
Epoch 1 step 123: training loss: 114.43196328240066
Epoch 1 step 124: training accuarcy: 0.998
Epoch 1 step 124: training loss: 112.01237662976723
Epoch 1 step 125: training accuarcy: 0.9985
Epoch 1 step 125: training loss: 111.6530

Epoch 1 step 201: training loss: 34.09346659685096
Epoch 1 step 202: training accuarcy: 0.998
Epoch 1 step 202: training loss: 38.42954912794672
Epoch 1 step 203: training accuarcy: 0.997
Epoch 1 step 203: training loss: 32.612069860379975
Epoch 1 step 204: training accuarcy: 0.998
Epoch 1 step 204: training loss: 28.53018061126378
Epoch 1 step 205: training accuarcy: 0.9985
Epoch 1 step 205: training loss: 32.1554372233377
Epoch 1 step 206: training accuarcy: 0.9975
Epoch 1 step 206: training loss: 31.55647726611369
Epoch 1 step 207: training accuarcy: 0.997
Epoch 1 step 207: training loss: 30.90932976302847
Epoch 1 step 208: training accuarcy: 0.999
Epoch 1 step 208: training loss: 25.94356472119184
Epoch 1 step 209: training accuarcy: 1.0
Epoch 1 step 209: training loss: 34.92178104821672
Epoch 1 step 210: training accuarcy: 0.998
Epoch 1 step 210: training loss: 31.180555829981177
Epoch 1 step 211: training accuarcy: 0.9975
Epoch 1 step 211: training loss: 27.347888834607005
Epoch 

 50%|█████████████████████████████████████████████████████████████████████████▌                                                                         | 2/4 [01:50<01:50, 55.33s/it]

Epoch: 2
Epoch 2 step 230: training loss: 24.97570974296729
Epoch 2 step 231: training accuarcy: 0.9975
Epoch 2 step 231: training loss: 26.02795086507966
Epoch 2 step 232: training accuarcy: 0.9985
Epoch 2 step 232: training loss: 31.484587643251423
Epoch 2 step 233: training accuarcy: 0.997
Epoch 2 step 233: training loss: 27.720781637800517
Epoch 2 step 234: training accuarcy: 0.9975
Epoch 2 step 234: training loss: 23.525668269602946
Epoch 2 step 235: training accuarcy: 0.999
Epoch 2 step 235: training loss: 27.414103376361027
Epoch 2 step 236: training accuarcy: 0.9975
Epoch 2 step 236: training loss: 24.202430164707508
Epoch 2 step 237: training accuarcy: 0.9995
Epoch 2 step 237: training loss: 27.199855352953065
Epoch 2 step 238: training accuarcy: 0.9985
Epoch 2 step 238: training loss: 33.115321588662404
Epoch 2 step 239: training accuarcy: 0.9965
Epoch 2 step 239: training loss: 29.90658724576106
Epoch 2 step 240: training accuarcy: 0.9965
Epoch 2 step 240: training loss: 37.

Epoch 2 step 316: training loss: 19.91807344834281
Epoch 2 step 317: training accuarcy: 0.9985
Epoch 2 step 317: training loss: 18.603275901646608
Epoch 2 step 318: training accuarcy: 0.998
Epoch 2 step 318: training loss: 21.04759701936517
Epoch 2 step 319: training accuarcy: 0.9965
Epoch 2 step 319: training loss: 19.56508030484151
Epoch 2 step 320: training accuarcy: 0.9975
Epoch 2 step 320: training loss: 24.092901293421615
Epoch 2 step 321: training accuarcy: 0.997
Epoch 2 step 321: training loss: 17.643176617894564
Epoch 2 step 322: training accuarcy: 0.999
Epoch 2 step 322: training loss: 19.304586890110773
Epoch 2 step 323: training accuarcy: 0.998
Epoch 2 step 323: training loss: 23.027645738699157
Epoch 2 step 324: training accuarcy: 0.9965
Epoch 2 step 324: training loss: 17.31866879527546
Epoch 2 step 325: training accuarcy: 0.9995
Epoch 2 step 325: training loss: 17.49401634798469
Epoch 2 step 326: training accuarcy: 0.9985
Epoch 2 step 326: training loss: 17.1087825183284

 75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                    | 3/4 [02:45<00:55, 55.34s/it]

Epoch: 3
Epoch 3 step 345: training loss: 15.997208743232784
Epoch 3 step 346: training accuarcy: 0.9985
Epoch 3 step 346: training loss: 14.181851479803699
Epoch 3 step 347: training accuarcy: 0.9995
Epoch 3 step 347: training loss: 19.017050215980134
Epoch 3 step 348: training accuarcy: 0.9975
Epoch 3 step 348: training loss: 18.729178142864335
Epoch 3 step 349: training accuarcy: 0.999
Epoch 3 step 349: training loss: 19.098192413545533
Epoch 3 step 350: training accuarcy: 0.9975
Epoch 3 step 350: training loss: 19.938087548925996
Epoch 3 step 351: training accuarcy: 0.997
Epoch 3 step 351: training loss: 20.85954512571747
Epoch 3 step 352: training accuarcy: 0.9965
Epoch 3 step 352: training loss: 19.749529773300456
Epoch 3 step 353: training accuarcy: 0.9975
Epoch 3 step 353: training loss: 18.919647260539737
Epoch 3 step 354: training accuarcy: 0.9975
Epoch 3 step 354: training loss: 22.941329472177173
Epoch 3 step 355: training accuarcy: 0.9965
Epoch 3 step 355: training loss: 1

Epoch 3 step 431: training accuarcy: 0.998
Epoch 3 step 431: training loss: 13.417540286069217
Epoch 3 step 432: training accuarcy: 0.9985
Epoch 3 step 432: training loss: 15.194960803973448
Epoch 3 step 433: training accuarcy: 1.0
Epoch 3 step 433: training loss: 12.899694143897362
Epoch 3 step 434: training accuarcy: 0.999
Epoch 3 step 434: training loss: 16.964686648801173
Epoch 3 step 435: training accuarcy: 0.998
Epoch 3 step 435: training loss: 15.66204517716525
Epoch 3 step 436: training accuarcy: 0.997
Epoch 3 step 436: training loss: 12.9852995993046
Epoch 3 step 437: training accuarcy: 0.998
Epoch 3 step 437: training loss: 17.502824110445395
Epoch 3 step 438: training accuarcy: 0.9975
Epoch 3 step 438: training loss: 13.907046255573103
Epoch 3 step 439: training accuarcy: 0.998
Epoch 3 step 439: training loss: 15.575418875891305
Epoch 3 step 440: training accuarcy: 0.998
Epoch 3 step 440: training loss: 15.123096212713664
Epoch 3 step 441: training accuarcy: 0.998
Epoch 3 st

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:41<00:00, 55.33s/it]


In [52]:
trans_learner.fit(epoch=2, log_dir=get_log_dir('weight_kaggle', 'trans'))


  0%|                                                                                                                                                                                     | 0/2 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 17381.002641464496
Epoch 0 step 1: training accuarcy: 0.8965000000000001
Epoch 0 step 1: training loss: 16496.53774284842
Epoch 0 step 2: training accuarcy: 0.92
Epoch 0 step 2: training loss: 15744.861349360843
Epoch 0 step 3: training accuarcy: 0.9458000000000001
Epoch 0 step 3: training loss: 14956.835339405248
Epoch 0 step 4: training accuarcy: 0.9612
Epoch 0 step 4: training loss: 14224.434288367096
Epoch 0 step 5: training accuarcy: 0.9690000000000001
Epoch 0 step 5: training loss: 13756.064527740442
Epoch 0 step 6: training accuarcy: 0.9709000000000001
Epoch 0 step 6: training loss: 12986.189324055866
Epoch 0 step 7: training accuarcy: 0.9756
Epoch 0 step 7: training loss: 12541.985667730762
Epoch 0 step 8: training accuarcy: 0.9728
Epoch 0 step 8: training loss: 12022.593255255164
Epoch 0 step 9: training accuarcy: 0.9726
Epoch 0 step 9: training loss: 11626.595902241628
Epoch 0 step 10: training accuarcy: 0.9755
Epoch 0 step 10: training


 50%|██████████████████████████████████████████████████████████████████████████████████████                                                                                      | 1/2 [11:15<11:15, 675.75s/it]

Epoch: 1
Epoch 1 step 115: training loss: 664.9170099245275
Epoch 1 step 116: training accuarcy: 0.9927
Epoch 1 step 116: training loss: 599.046155383268
Epoch 1 step 117: training accuarcy: 0.9926
Epoch 1 step 117: training loss: 615.8627759908185
Epoch 1 step 118: training accuarcy: 0.9916
Epoch 1 step 118: training loss: 631.6514158098196
Epoch 1 step 119: training accuarcy: 0.9922000000000001
Epoch 1 step 119: training loss: 657.8047752044704
Epoch 1 step 120: training accuarcy: 0.9911000000000001
Epoch 1 step 120: training loss: 606.0884341488816
Epoch 1 step 121: training accuarcy: 0.9927
Epoch 1 step 121: training loss: 530.9986929363481
Epoch 1 step 122: training accuarcy: 0.9939
Epoch 1 step 122: training loss: 563.6020754520605
Epoch 1 step 123: training accuarcy: 0.9926
Epoch 1 step 123: training loss: 556.0102186648044
Epoch 1 step 124: training accuarcy: 0.9933000000000001
Epoch 1 step 124: training loss: 534.6703722445317
Epoch 1 step 125: training accuarcy: 0.9935
Epoch 


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [22:26<00:00, 674.30s/it]

In [26]:
del trans_model
T.cuda.empty_cache()

### Train Trans Probs Model

In [28]:
prob_model = TransProbFM(feature_dim=feat_dim,
                         num_dim=NUM_DIM,
                         init_mean=INIT_MEAN,
                         init_scale=INIT_SCALE)

In [30]:
adam_opt = optim.Adam(prob_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt,
                                      step_size=DECAY_FREQ,
                                      gamma=DECAY_GAMMA)

In [31]:
prob_learner = FMLearner(prob_model, adam_opt, schedular, db)
prob_learner

<models.fm_learner.FMLearner at 0x2976465dbc8>

In [32]:
prob_learner.compile(train_col='base',
                     valid_col='base',
                     test_col='base',
                     loss_callback=trans_loss_callback)

In [33]:
prob_learner.fit(epoch=4, log_dir=get_log_dir('kaggle', 'prob'))

  0%|                                                                                                                                                           | 0/4 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 12194.163369137419
Epoch 0 step 1: training accuarcy: 0.923
Epoch 0 step 1: training loss: 11788.292655826503
Epoch 0 step 2: training accuarcy: 0.936
Epoch 0 step 2: training loss: 11485.60574312673
Epoch 0 step 3: training accuarcy: 0.9440000000000001
Epoch 0 step 3: training loss: 11091.79671280051
Epoch 0 step 4: training accuarcy: 0.9580000000000001
Epoch 0 step 4: training loss: 10748.554184125536
Epoch 0 step 5: training accuarcy: 0.961
Epoch 0 step 5: training loss: 10384.2347492197
Epoch 0 step 6: training accuarcy: 0.967
Epoch 0 step 6: training loss: 10065.603850427084
Epoch 0 step 7: training accuarcy: 0.97
Epoch 0 step 7: training loss: 9748.256742602522
Epoch 0 step 8: training accuarcy: 0.9695
Epoch 0 step 8: training loss: 9433.56933070879
Epoch 0 step 9: training accuarcy: 0.9715
Epoch 0 step 9: training loss: 9133.585482925935
Epoch 0 step 10: training accuarcy: 0.969
Epoch 0 step 10: training loss: 8828.149059173327
Epoch 0 ste

Epoch 0 step 88: training loss: 420.50328778802725
Epoch 0 step 89: training accuarcy: 0.996
Epoch 0 step 89: training loss: 403.6764179081007
Epoch 0 step 90: training accuarcy: 0.9975
Epoch 0 step 90: training loss: 384.88072220290627
Epoch 0 step 91: training accuarcy: 0.996
Epoch 0 step 91: training loss: 365.3899781647972
Epoch 0 step 92: training accuarcy: 0.9975
Epoch 0 step 92: training loss: 358.5247838138228
Epoch 0 step 93: training accuarcy: 0.993
Epoch 0 step 93: training loss: 344.47469422010823
Epoch 0 step 94: training accuarcy: 0.992
Epoch 0 step 94: training loss: 325.6027075721297
Epoch 0 step 95: training accuarcy: 0.996
Epoch 0 step 95: training loss: 311.37194054525526
Epoch 0 step 96: training accuarcy: 0.995
Epoch 0 step 96: training loss: 303.2060103375603
Epoch 0 step 97: training accuarcy: 0.9945
Epoch 0 step 97: training loss: 283.6647734552223
Epoch 0 step 98: training accuarcy: 0.9965
Epoch 0 step 98: training loss: 268.14766932260716
Epoch 0 step 99: trai

 25%|████████████████████████████████████▊                                                                                                              | 1/4 [00:56<02:50, 56.88s/it]

Epoch: 1
Epoch 1 step 115: training loss: 135.19157358234324
Epoch 1 step 116: training accuarcy: 0.997
Epoch 1 step 116: training loss: 133.36063896681347
Epoch 1 step 117: training accuarcy: 0.9945
Epoch 1 step 117: training loss: 123.26658354238363
Epoch 1 step 118: training accuarcy: 0.996
Epoch 1 step 118: training loss: 113.00822771400772
Epoch 1 step 119: training accuarcy: 0.997
Epoch 1 step 119: training loss: 116.44605754694356
Epoch 1 step 120: training accuarcy: 0.9965
Epoch 1 step 120: training loss: 112.27711643454408
Epoch 1 step 121: training accuarcy: 0.9975
Epoch 1 step 121: training loss: 104.43137312498214
Epoch 1 step 122: training accuarcy: 0.9965
Epoch 1 step 122: training loss: 103.9855790741097
Epoch 1 step 123: training accuarcy: 0.9975
Epoch 1 step 123: training loss: 96.27883092856874
Epoch 1 step 124: training accuarcy: 0.996
Epoch 1 step 124: training loss: 93.44020312997716
Epoch 1 step 125: training accuarcy: 0.9975
Epoch 1 step 125: training loss: 95.97

Epoch 1 step 201: training loss: 25.899138844067775
Epoch 1 step 202: training accuarcy: 0.9965
Epoch 1 step 202: training loss: 24.238712750756356
Epoch 1 step 203: training accuarcy: 0.996
Epoch 1 step 203: training loss: 26.300794380856892
Epoch 1 step 204: training accuarcy: 0.996
Epoch 1 step 204: training loss: 23.412374490438435
Epoch 1 step 205: training accuarcy: 0.996
Epoch 1 step 205: training loss: 20.871007772024853
Epoch 1 step 206: training accuarcy: 0.997
Epoch 1 step 206: training loss: 20.571330800266637
Epoch 1 step 207: training accuarcy: 0.9975
Epoch 1 step 207: training loss: 26.21798546852835
Epoch 1 step 208: training accuarcy: 0.9975
Epoch 1 step 208: training loss: 27.96668958206214
Epoch 1 step 209: training accuarcy: 0.9955
Epoch 1 step 209: training loss: 19.984054082450307
Epoch 1 step 210: training accuarcy: 0.9975
Epoch 1 step 210: training loss: 24.10461899742993
Epoch 1 step 211: training accuarcy: 0.998
Epoch 1 step 211: training loss: 24.719385108541

 50%|█████████████████████████████████████████████████████████████████████████▌                                                                         | 2/4 [01:53<01:53, 56.89s/it]

Epoch: 2
Epoch 2 step 230: training loss: 20.38661249279255
Epoch 2 step 231: training accuarcy: 0.996
Epoch 2 step 231: training loss: 19.394001183685873
Epoch 2 step 232: training accuarcy: 0.997
Epoch 2 step 232: training loss: 20.880198994447298
Epoch 2 step 233: training accuarcy: 0.998
Epoch 2 step 233: training loss: 25.728127148123438
Epoch 2 step 234: training accuarcy: 0.995
Epoch 2 step 234: training loss: 25.38029648408752
Epoch 2 step 235: training accuarcy: 0.996
Epoch 2 step 235: training loss: 28.294512066898005
Epoch 2 step 236: training accuarcy: 0.995
Epoch 2 step 236: training loss: 23.64471241048571
Epoch 2 step 237: training accuarcy: 0.996
Epoch 2 step 237: training loss: 19.320830369661575
Epoch 2 step 238: training accuarcy: 0.996
Epoch 2 step 238: training loss: 15.985396806465594
Epoch 2 step 239: training accuarcy: 0.999
Epoch 2 step 239: training loss: 17.708924755130056
Epoch 2 step 240: training accuarcy: 0.998
Epoch 2 step 240: training loss: 23.88983730

Epoch 2 step 316: training loss: 23.98499358765241
Epoch 2 step 317: training accuarcy: 0.9945
Epoch 2 step 317: training loss: 17.918335799033272
Epoch 2 step 318: training accuarcy: 0.9955
Epoch 2 step 318: training loss: 18.455474983675664
Epoch 2 step 319: training accuarcy: 0.9965
Epoch 2 step 319: training loss: 14.831183122824172
Epoch 2 step 320: training accuarcy: 0.997
Epoch 2 step 320: training loss: 19.315194782721882
Epoch 2 step 321: training accuarcy: 0.997
Epoch 2 step 321: training loss: 16.745844303849783
Epoch 2 step 322: training accuarcy: 0.997
Epoch 2 step 322: training loss: 16.23507972737632
Epoch 2 step 323: training accuarcy: 0.997
Epoch 2 step 323: training loss: 14.056652549268978
Epoch 2 step 324: training accuarcy: 0.9975
Epoch 2 step 324: training loss: 17.162752728745755
Epoch 2 step 325: training accuarcy: 0.997
Epoch 2 step 325: training loss: 13.31282559845185
Epoch 2 step 326: training accuarcy: 0.998
Epoch 2 step 326: training loss: 11.3708734727123

 75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                    | 3/4 [02:50<00:56, 56.75s/it]

Epoch: 3
Epoch 3 step 345: training loss: 15.928547654215013
Epoch 3 step 346: training accuarcy: 0.996
Epoch 3 step 346: training loss: 16.092869669068072
Epoch 3 step 347: training accuarcy: 0.9965
Epoch 3 step 347: training loss: 19.253864257086292
Epoch 3 step 348: training accuarcy: 0.997
Epoch 3 step 348: training loss: 22.01725740779925
Epoch 3 step 349: training accuarcy: 0.9955
Epoch 3 step 349: training loss: 22.960681435858135
Epoch 3 step 350: training accuarcy: 0.995
Epoch 3 step 350: training loss: 15.803787010059628
Epoch 3 step 351: training accuarcy: 0.9965
Epoch 3 step 351: training loss: 16.349903645608904
Epoch 3 step 352: training accuarcy: 0.997
Epoch 3 step 352: training loss: 12.653509863071193
Epoch 3 step 353: training accuarcy: 0.998
Epoch 3 step 353: training loss: 14.641588745273296
Epoch 3 step 354: training accuarcy: 0.997
Epoch 3 step 354: training loss: 13.489415076805052
Epoch 3 step 355: training accuarcy: 0.9985
Epoch 3 step 355: training loss: 14.80

Epoch 3 step 431: training accuarcy: 0.998
Epoch 3 step 431: training loss: 10.795075574768571
Epoch 3 step 432: training accuarcy: 0.9985
Epoch 3 step 432: training loss: 13.228407235583443
Epoch 3 step 433: training accuarcy: 0.998
Epoch 3 step 433: training loss: 15.038142574235556
Epoch 3 step 434: training accuarcy: 0.9975
Epoch 3 step 434: training loss: 22.882799851829162
Epoch 3 step 435: training accuarcy: 0.993
Epoch 3 step 435: training loss: 22.767968858480078
Epoch 3 step 436: training accuarcy: 0.9945
Epoch 3 step 436: training loss: 14.309688019793635
Epoch 3 step 437: training accuarcy: 0.997
Epoch 3 step 437: training loss: 14.95515944708621
Epoch 3 step 438: training accuarcy: 0.9965
Epoch 3 step 438: training loss: 23.316407759464802
Epoch 3 step 439: training accuarcy: 0.993
Epoch 3 step 439: training loss: 14.7532363040482
Epoch 3 step 440: training accuarcy: 0.9975
Epoch 3 step 440: training loss: 18.645466970465797
Epoch 3 step 441: training accuarcy: 0.997
Epoch

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:46<00:00, 56.62s/it]
