In [1]:
%matplotlib inline

In [2]:
cd ..

C:\Projects\python\recommender


In [3]:
from functools import partial
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import torch as T
import torch.optim as optim

from utils import get_log_dir

In [4]:
from datasets import TorchStackOverflow
from models import FMLearner, TorchFM, TorchHrmFM, TorchPrmeFM, TorchTransFM
from models.fm_learner import simple_loss, trans_loss

In [5]:
DEVICE = T.cuda.current_device()
BATCH = 3000
SHUFFLE = True
WORKERS = 0
item_path = Path("./inputs/stackoverflow/item.csv")

In [6]:
db = TorchStackOverflow(data_path=item_path, min_user=4)
db

Original user size: 75572
Filtered user size: 14795
Original item dataframe shape: (299845, 8)
Filtered item dataframe shape: (217223, 8)


<datasets.torch_stackoverflow.TorchStackOverflow at 0x1b91f6f97b8>

In [7]:
db.config_db(batch_size=BATCH,
             shuffle=SHUFFLE,
             num_workers=WORKERS,
             device=DEVICE)

## Create Criterion

In [8]:
# regst setting
LINEAR_REG = 1
EMB_REG = 1
TRANS_REG = 1

In [9]:
simple_loss_callback = partial(simple_loss, LINEAR_REG, EMB_REG)
simple_loss_callback

functools.partial(<function simple_loss at 0x000001B91F6AAA60>, 1, 1)

In [10]:
trans_loss_callback = partial(trans_loss, LINEAR_REG, EMB_REG, TRANS_REG)
trans_loss_callback

functools.partial(<function trans_loss at 0x000001B91F704B70>, 1, 1, 1)

## Train model

### Hyper-parameter

In [11]:
feat_dim = db.feat_dim
NUM_DIM = 124
INIT_MEAN = 0.1

## Train fm model

In [14]:
LEARNING_RATE = 0.001
DECAY_FREQ = 1000
DECAY_GAMMA = 1

In [13]:
fm_model = TorchFM(feature_dim=feat_dim, num_dim=NUM_DIM, init_mean=INIT_MEAN)
fm_model

TorchFM()

In [14]:
adam_opt = optim.Adam(fm_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt, step_size=DECAY_FREQ, gamma=DECAY_GAMMA)

In [15]:
fm_learner = FMLearner(fm_model, adam_opt, schedular, db)
fm_learner

<models.fm_learner.FMLearner at 0x1de266dd5c0>

In [16]:
fm_learner.fit(epoch=8, loss_callback=simple_loss_callback, log_dir=get_log_dir('fm'))

  0%|                                                                                                                                                                  | 0/8 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 157572.1662900254
Epoch 0 step 1: training accuarcy: 0.30966666666666665
Epoch 0 step 1: training loss: 152803.85492885328
Epoch 0 step 2: training accuarcy: 0.36633333333333334
Epoch 0 step 2: training loss: 148239.5385377525
Epoch 0 step 3: training accuarcy: 0.416
Epoch 0 step 3: training loss: 143949.19111709765
Epoch 0 step 4: training accuarcy: 0.43766666666666665
Epoch 0 step 4: training loss: 139690.8817278023
Epoch 0 step 5: training accuarcy: 0.48266666666666663
Epoch 0 step 5: training loss: 135535.64838294886
Epoch 0 step 6: training accuarcy: 0.5173333333333333
Epoch 0 step 6: training loss: 131506.05680230522
Epoch 0 step 7: training accuarcy: 0.5696666666666667
Epoch 0 step 7: training loss: 127644.8764366617
Epoch 0 step 8: training accuarcy: 0.5793333333333334
Epoch 0 step 8: training loss: 123925.68895784377
Epoch 0 step 9: training accuarcy: 0.6046666666666667
Epoch 0 step 9: training loss: 120230.83929059228
Epoch 0 step 10: t

 12%|███████████████████▎                                                                                                                                      | 1/8 [00:49<05:44, 49.19s/it]

Epoch: 1
Epoch 1 step 63: training loss: 20883.365795839152
Epoch 1 step 64: training accuarcy: 0.9
Epoch 1 step 64: training loss: 20124.462548850322
Epoch 1 step 65: training accuarcy: 0.9146666666666666
Epoch 1 step 65: training loss: 19443.443708299234
Epoch 1 step 66: training accuarcy: 0.9096666666666666
Epoch 1 step 66: training loss: 18777.48370399148
Epoch 1 step 67: training accuarcy: 0.9166666666666666
Epoch 1 step 67: training loss: 18151.625813766914
Epoch 1 step 68: training accuarcy: 0.9126666666666666
Epoch 1 step 68: training loss: 17540.09873886304
Epoch 1 step 69: training accuarcy: 0.9129999999999999
Epoch 1 step 69: training loss: 16938.938998389403
Epoch 1 step 70: training accuarcy: 0.9116666666666666
Epoch 1 step 70: training loss: 16354.709566784144
Epoch 1 step 71: training accuarcy: 0.9173333333333333
Epoch 1 step 71: training loss: 15805.788902313036
Epoch 1 step 72: training accuarcy: 0.9109999999999999
Epoch 1 step 72: training loss: 15276.489009289911
Epo

 25%|██████████████████████████████████████▌                                                                                                                   | 2/8 [01:38<04:55, 49.18s/it]

Epoch: 2
Epoch 2 step 126: training loss: 2476.917759485972
Epoch 2 step 127: training accuarcy: 0.9543333333333333
Epoch 2 step 127: training loss: 2392.900186983681
Epoch 2 step 128: training accuarcy: 0.961
Epoch 2 step 128: training loss: 2333.9874427028235
Epoch 2 step 129: training accuarcy: 0.9563333333333333
Epoch 2 step 129: training loss: 2291.2126131807127
Epoch 2 step 130: training accuarcy: 0.9466666666666667
Epoch 2 step 130: training loss: 2183.8065043463102
Epoch 2 step 131: training accuarcy: 0.961
Epoch 2 step 131: training loss: 2127.3960673216397
Epoch 2 step 132: training accuarcy: 0.9603333333333333
Epoch 2 step 132: training loss: 2082.945412173983
Epoch 2 step 133: training accuarcy: 0.955
Epoch 2 step 133: training loss: 2009.2174948637783
Epoch 2 step 134: training accuarcy: 0.9616666666666667
Epoch 2 step 134: training loss: 1940.6799509965444
Epoch 2 step 135: training accuarcy: 0.9636666666666667
Epoch 2 step 135: training loss: 1890.985007832837
Epoch 2 st

 38%|█████████████████████████████████████████████████████████▊                                                                                                | 3/8 [02:26<04:05, 49.00s/it]

Epoch: 3
Epoch 3 step 189: training loss: 582.4153611161112
Epoch 3 step 190: training accuarcy: 0.9813333333333333
Epoch 3 step 190: training loss: 588.6315151338849
Epoch 3 step 191: training accuarcy: 0.9786666666666667
Epoch 3 step 191: training loss: 573.8015458072387
Epoch 3 step 192: training accuarcy: 0.9793333333333333
Epoch 3 step 192: training loss: 554.811425516587
Epoch 3 step 193: training accuarcy: 0.9833333333333333
Epoch 3 step 193: training loss: 546.7452960048831
Epoch 3 step 194: training accuarcy: 0.98
Epoch 3 step 194: training loss: 529.0157869592573
Epoch 3 step 195: training accuarcy: 0.9833333333333333
Epoch 3 step 195: training loss: 553.3649481076819
Epoch 3 step 196: training accuarcy: 0.9783333333333333
Epoch 3 step 196: training loss: 533.0655274527721
Epoch 3 step 197: training accuarcy: 0.977
Epoch 3 step 197: training loss: 516.295445928737
Epoch 3 step 198: training accuarcy: 0.9856666666666666
Epoch 3 step 198: training loss: 520.2709923334427
Epoch 

 50%|█████████████████████████████████████████████████████████████████████████████                                                                             | 4/8 [03:15<03:15, 48.91s/it]

Epoch: 4
Epoch 4 step 252: training loss: 290.1296873859424
Epoch 4 step 253: training accuarcy: 0.988
Epoch 4 step 253: training loss: 285.74438817186024
Epoch 4 step 254: training accuarcy: 0.989
Epoch 4 step 254: training loss: 284.47509925132834
Epoch 4 step 255: training accuarcy: 0.9906666666666666
Epoch 4 step 255: training loss: 287.1715324922666
Epoch 4 step 256: training accuarcy: 0.9866666666666666
Epoch 4 step 256: training loss: 277.99165249140225
Epoch 4 step 257: training accuarcy: 0.9886666666666666
Epoch 4 step 257: training loss: 280.0855822760874
Epoch 4 step 258: training accuarcy: 0.987
Epoch 4 step 258: training loss: 278.4789797997187
Epoch 4 step 259: training accuarcy: 0.9866666666666666
Epoch 4 step 259: training loss: 268.62218559116803
Epoch 4 step 260: training accuarcy: 0.9896666666666666
Epoch 4 step 260: training loss: 265.0239491122217
Epoch 4 step 261: training accuarcy: 0.9886666666666666
Epoch 4 step 261: training loss: 277.2196778996121
Epoch 4 step

 62%|████████████████████████████████████████████████████████████████████████████████████████████████▎                                                         | 5/8 [04:04<02:26, 48.90s/it]

Epoch: 5
Epoch 5 step 315: training loss: 187.92163511340027
Epoch 5 step 316: training accuarcy: 0.9936666666666666
Epoch 5 step 316: training loss: 191.84299415012896
Epoch 5 step 317: training accuarcy: 0.992
Epoch 5 step 317: training loss: 189.54593085380853
Epoch 5 step 318: training accuarcy: 0.992
Epoch 5 step 318: training loss: 186.9720636270959
Epoch 5 step 319: training accuarcy: 0.992
Epoch 5 step 319: training loss: 181.9265300492384
Epoch 5 step 320: training accuarcy: 0.9943333333333333
Epoch 5 step 320: training loss: 181.1623807384513
Epoch 5 step 321: training accuarcy: 0.9933333333333333
Epoch 5 step 321: training loss: 181.4751273235009
Epoch 5 step 322: training accuarcy: 0.9926666666666666
Epoch 5 step 322: training loss: 177.17487184990443
Epoch 5 step 323: training accuarcy: 0.9953333333333333
Epoch 5 step 323: training loss: 180.72534644281333
Epoch 5 step 324: training accuarcy: 0.99
Epoch 5 step 324: training loss: 178.8352003939671
Epoch 5 step 325: trainin

 75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                      | 6/8 [04:53<01:37, 48.81s/it]

Epoch: 6
Epoch 6 step 378: training loss: 134.39942405316273
Epoch 6 step 379: training accuarcy: 0.9966666666666666
Epoch 6 step 379: training loss: 143.93534600999658
Epoch 6 step 380: training accuarcy: 0.994
Epoch 6 step 380: training loss: 142.43819564081656
Epoch 6 step 381: training accuarcy: 0.996
Epoch 6 step 381: training loss: 145.87698197705072
Epoch 6 step 382: training accuarcy: 0.993
Epoch 6 step 382: training loss: 139.60891512436964
Epoch 6 step 383: training accuarcy: 0.9963333333333333
Epoch 6 step 383: training loss: 149.25016637481065
Epoch 6 step 384: training accuarcy: 0.9926666666666666
Epoch 6 step 384: training loss: 142.18873267434103
Epoch 6 step 385: training accuarcy: 0.9946666666666666
Epoch 6 step 385: training loss: 141.72630611537724
Epoch 6 step 386: training accuarcy: 0.995
Epoch 6 step 386: training loss: 141.51432784574277
Epoch 6 step 387: training accuarcy: 0.9933333333333333
Epoch 6 step 387: training loss: 139.68929663926764
Epoch 6 step 388: t

 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                   | 7/8 [05:42<00:48, 48.84s/it]

Epoch: 7
Epoch 7 step 441: training loss: 126.93586574471033
Epoch 7 step 442: training accuarcy: 0.9946666666666666
Epoch 7 step 442: training loss: 116.24161802498695
Epoch 7 step 443: training accuarcy: 0.9966666666666666
Epoch 7 step 443: training loss: 122.3233376558658
Epoch 7 step 444: training accuarcy: 0.9963333333333333
Epoch 7 step 444: training loss: 120.98224268883112
Epoch 7 step 445: training accuarcy: 0.9956666666666666
Epoch 7 step 445: training loss: 121.67969765727265
Epoch 7 step 446: training accuarcy: 0.997
Epoch 7 step 446: training loss: 128.5335223681605
Epoch 7 step 447: training accuarcy: 0.9936666666666666
Epoch 7 step 447: training loss: 114.32640492306439
Epoch 7 step 448: training accuarcy: 0.997
Epoch 7 step 448: training loss: 118.43309591184185
Epoch 7 step 449: training accuarcy: 0.9966666666666666
Epoch 7 step 449: training loss: 116.85175543654111
Epoch 7 step 450: training accuarcy: 0.9976666666666666
Epoch 7 step 450: training loss: 118.7723864524

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [06:32<00:00, 49.34s/it]


### Train HRM FM Model

In [23]:
hrm_model = TorchHrmFM(feature_dim=feat_dim, num_dim=NUM_DIM, init_mean=INIT_MEAN)
hrm_model

TorchHrmFM()

In [24]:
adam_opt = optim.Adam(hrm_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt,
                                      step_size=DECAY_FREQ,
                                      gamma=DECAY_GAMMA)

In [25]:
hrm_learner = FMLearner(hrm_model, adam_opt, schedular, db)
hrm_learner

<models.fm_learner.FMLearner at 0x1de234217b8>

In [26]:
hrm_learner.fit(epoch=8,
                loss_callback=simple_loss_callback,
                log_dir=get_log_dir('stackoverflow', 'hrm'))


  0%|                                                                                                                                                                  | 0/8 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 161424.97674656817
Epoch 0 step 1: training accuarcy: 0.5186666666666666
Epoch 0 step 1: training loss: 157219.70639475313
Epoch 0 step 2: training accuarcy: 0.4943333333333333
Epoch 0 step 2: training loss: 152539.92346734728
Epoch 0 step 3: training accuarcy: 0.5143333333333333
Epoch 0 step 3: training loss: 148512.27948702226
Epoch 0 step 4: training accuarcy: 0.495
Epoch 0 step 4: training loss: 143848.96650937246
Epoch 0 step 5: training accuarcy: 0.5053333333333333
Epoch 0 step 5: training loss: 139727.92820349155
Epoch 0 step 6: training accuarcy: 0.48566666666666664
Epoch 0 step 6: training loss: 135480.55833961524
Epoch 0 step 7: training accuarcy: 0.496
Epoch 0 step 7: training loss: 131344.59307532327
Epoch 0 step 8: training accuarcy: 0.49866666666666665
Epoch 0 step 8: training loss: 127493.79173152546
Epoch 0 step 9: training accuarcy: 0.49466666666666664
Epoch 0 step 9: training loss: 123907.64786004399
Epoch 0 step 10: training ac


 12%|███████████████████▎                                                                                                                                      | 1/8 [00:53<06:16, 53.75s/it]

Epoch: 1
Epoch 1 step 63: training loss: 20099.349826608475
Epoch 1 step 64: training accuarcy: 0.5393333333333333
Epoch 1 step 64: training loss: 19525.53089410358
Epoch 1 step 65: training accuarcy: 0.5236666666666666
Epoch 1 step 65: training loss: 18792.944452993535
Epoch 1 step 66: training accuarcy: 0.5313333333333333
Epoch 1 step 66: training loss: 18171.69781992276
Epoch 1 step 67: training accuarcy: 0.5356666666666666
Epoch 1 step 67: training loss: 17505.967809260328
Epoch 1 step 68: training accuarcy: 0.5376666666666666
Epoch 1 step 68: training loss: 16968.548029772355
Epoch 1 step 69: training accuarcy: 0.5266666666666666
Epoch 1 step 69: training loss: 16361.836175799031
Epoch 1 step 70: training accuarcy: 0.5509999999999999
Epoch 1 step 70: training loss: 15817.677259059481
Epoch 1 step 71: training accuarcy: 0.5406666666666666
Epoch 1 step 71: training loss: 15305.414113810642
Epoch 1 step 72: training accuarcy: 0.534
Epoch 1 step 72: training loss: 14814.313165580057
E


 25%|██████████████████████████████████████▌                                                                                                                   | 2/8 [01:46<05:21, 53.52s/it]

Epoch: 2
Epoch 2 step 126: training loss: 3373.650143752887
Epoch 2 step 127: training accuarcy: 0.591
Epoch 2 step 127: training loss: 3369.37813073328
Epoch 2 step 128: training accuarcy: 0.565
Epoch 2 step 128: training loss: 3267.9593970594433
Epoch 2 step 129: training accuarcy: 0.5843333333333333
Epoch 2 step 129: training loss: 3254.730579994797
Epoch 2 step 130: training accuarcy: 0.5619999999999999
Epoch 2 step 130: training loss: 3196.435402231258
Epoch 2 step 131: training accuarcy: 0.575
Epoch 2 step 131: training loss: 3124.948425416194
Epoch 2 step 132: training accuarcy: 0.5733333333333334
Epoch 2 step 132: training loss: 3094.0870894866985
Epoch 2 step 133: training accuarcy: 0.572
Epoch 2 step 133: training loss: 3052.9544829416677
Epoch 2 step 134: training accuarcy: 0.5803333333333334
Epoch 2 step 134: training loss: 3020.0289250244755
Epoch 2 step 135: training accuarcy: 0.58
Epoch 2 step 135: training loss: 2961.8340104790213
Epoch 2 step 136: training accuarcy: 0.


 38%|█████████████████████████████████████████████████████████▊                                                                                                | 3/8 [02:38<04:25, 53.13s/it]

Epoch: 3
Epoch 3 step 189: training loss: 2121.8349365938698
Epoch 3 step 190: training accuarcy: 0.5893333333333333
Epoch 3 step 190: training loss: 2145.52977590754
Epoch 3 step 191: training accuarcy: 0.579
Epoch 3 step 191: training loss: 2122.828718472727
Epoch 3 step 192: training accuarcy: 0.604
Epoch 3 step 192: training loss: 2131.5219390721227
Epoch 3 step 193: training accuarcy: 0.6
Epoch 3 step 193: training loss: 2142.333434279857
Epoch 3 step 194: training accuarcy: 0.5816666666666667
Epoch 3 step 194: training loss: 2117.216293413301
Epoch 3 step 195: training accuarcy: 0.5886666666666667
Epoch 3 step 195: training loss: 2122.320449036417
Epoch 3 step 196: training accuarcy: 0.5866666666666667
Epoch 3 step 196: training loss: 2101.039516466553
Epoch 3 step 197: training accuarcy: 0.6036666666666667
Epoch 3 step 197: training loss: 2127.127066332417
Epoch 3 step 198: training accuarcy: 0.5976666666666667
Epoch 3 step 198: training loss: 2100.5493767997177
Epoch 3 step 199


 50%|█████████████████████████████████████████████████████████████████████████████                                                                             | 4/8 [03:33<03:33, 53.41s/it]

Epoch: 4
Epoch 4 step 252: training loss: 2016.6664329372459
Epoch 4 step 253: training accuarcy: 0.6016666666666667
Epoch 4 step 253: training loss: 2020.384434311287
Epoch 4 step 254: training accuarcy: 0.5963333333333333
Epoch 4 step 254: training loss: 2042.1156725311998
Epoch 4 step 255: training accuarcy: 0.5863333333333333
Epoch 4 step 255: training loss: 2023.1849259850521
Epoch 4 step 256: training accuarcy: 0.6
Epoch 4 step 256: training loss: 2024.5826529372928
Epoch 4 step 257: training accuarcy: 0.6043333333333333
Epoch 4 step 257: training loss: 2029.9795135063266
Epoch 4 step 258: training accuarcy: 0.592
Epoch 4 step 258: training loss: 2063.82668703091
Epoch 4 step 259: training accuarcy: 0.5866666666666667
Epoch 4 step 259: training loss: 2019.3598024541072
Epoch 4 step 260: training accuarcy: 0.597
Epoch 4 step 260: training loss: 2037.3944275301772
Epoch 4 step 261: training accuarcy: 0.5873333333333333
Epoch 4 step 261: training loss: 2023.1076321915436
Epoch 4 ste


 62%|████████████████████████████████████████████████████████████████████████████████████████████████▎                                                         | 5/8 [04:27<02:41, 53.71s/it]

Epoch: 5
Epoch 5 step 315: training loss: 1983.6355215660808
Epoch 5 step 316: training accuarcy: 0.61
Epoch 5 step 316: training loss: 2023.0480678469223
Epoch 5 step 317: training accuarcy: 0.59
Epoch 5 step 317: training loss: 2015.4908571412334
Epoch 5 step 318: training accuarcy: 0.5906666666666667
Epoch 5 step 318: training loss: 2029.086197218268
Epoch 5 step 319: training accuarcy: 0.5796666666666667
Epoch 5 step 319: training loss: 2029.1023183934656
Epoch 5 step 320: training accuarcy: 0.5903333333333333
Epoch 5 step 320: training loss: 2020.6801811379019
Epoch 5 step 321: training accuarcy: 0.5913333333333333
Epoch 5 step 321: training loss: 1993.0565412942449
Epoch 5 step 322: training accuarcy: 0.5993333333333333
Epoch 5 step 322: training loss: 1985.5952423123474
Epoch 5 step 323: training accuarcy: 0.5993333333333333
Epoch 5 step 323: training loss: 2010.7098290802792
Epoch 5 step 324: training accuarcy: 0.5953333333333333
Epoch 5 step 324: training loss: 2004.5888035529


 75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                      | 6/8 [05:20<01:47, 53.59s/it]

Epoch: 6
Epoch 6 step 378: training loss: 2011.4703182433225
Epoch 6 step 379: training accuarcy: 0.5976666666666667
Epoch 6 step 379: training loss: 1992.1870746287473
Epoch 6 step 380: training accuarcy: 0.606
Epoch 6 step 380: training loss: 1996.6260246726808
Epoch 6 step 381: training accuarcy: 0.6013333333333333
Epoch 6 step 381: training loss: 1997.9918530451678
Epoch 6 step 382: training accuarcy: 0.609
Epoch 6 step 382: training loss: 2007.524881895957
Epoch 6 step 383: training accuarcy: 0.5936666666666667
Epoch 6 step 383: training loss: 1992.2203747422484
Epoch 6 step 384: training accuarcy: 0.599
Epoch 6 step 384: training loss: 1989.8164323455285
Epoch 6 step 385: training accuarcy: 0.6016666666666667
Epoch 6 step 385: training loss: 2007.1022148668287
Epoch 6 step 386: training accuarcy: 0.5833333333333333
Epoch 6 step 386: training loss: 1983.1376794180271
Epoch 6 step 387: training accuarcy: 0.6053333333333333
Epoch 6 step 387: training loss: 2000.5422643004083
Epoch 6


 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                   | 7/8 [06:13<00:53, 53.44s/it]

Epoch: 7
Epoch 7 step 441: training loss: 1999.778280009805
Epoch 7 step 442: training accuarcy: 0.5953333333333333
Epoch 7 step 442: training loss: 1994.0888515732663
Epoch 7 step 443: training accuarcy: 0.5926666666666667
Epoch 7 step 443: training loss: 2002.648652882715
Epoch 7 step 444: training accuarcy: 0.5906666666666667
Epoch 7 step 444: training loss: 1988.454796283186
Epoch 7 step 445: training accuarcy: 0.607
Epoch 7 step 445: training loss: 1993.1667836226102
Epoch 7 step 446: training accuarcy: 0.5913333333333333
Epoch 7 step 446: training loss: 1995.998275970917
Epoch 7 step 447: training accuarcy: 0.5953333333333333
Epoch 7 step 447: training loss: 1996.9595535129556
Epoch 7 step 448: training accuarcy: 0.6053333333333333
Epoch 7 step 448: training loss: 2012.9855808796176
Epoch 7 step 449: training accuarcy: 0.5893333333333333
Epoch 7 step 449: training loss: 2026.399698873912
Epoch 7 step 450: training accuarcy: 0.588
Epoch 7 step 450: training loss: 2002.681819573651


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [07:06<00:00, 53.09s/it]

### Train PRME FM Model

In [27]:
prme_model = TorchPrmeFM(feature_dim=feat_dim, num_dim=NUM_DIM, init_mean=INIT_MEAN)
prme_model

TorchPrmeFM()

In [28]:
adam_opt = optim.Adam(prme_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt,
                                      step_size=DECAY_FREQ,
                                      gamma=DECAY_GAMMA)

In [30]:
prme_learner = FMLearner(prme_model, adam_opt, schedular, db)
prme_learner

<models.fm_learner.FMLearner at 0x1de1f4afb00>

In [31]:
prme_learner.fit(epoch=8, loss_callback=simple_loss_callback, log_dir=get_log_dir('prme'))


  0%|                                                                                                                                                                  | 0/8 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 168569.2936286707
Epoch 0 step 1: training accuarcy: 0.4993333333333333
Epoch 0 step 1: training loss: 163485.7659552902
Epoch 0 step 2: training accuarcy: 0.4983333333333333
Epoch 0 step 2: training loss: 159378.10168334848
Epoch 0 step 3: training accuarcy: 0.5073333333333333
Epoch 0 step 3: training loss: 154501.4194061317
Epoch 0 step 4: training accuarcy: 0.5146666666666666
Epoch 0 step 4: training loss: 149650.16897265305
Epoch 0 step 5: training accuarcy: 0.5113333333333333
Epoch 0 step 5: training loss: 145888.4337116899
Epoch 0 step 6: training accuarcy: 0.5026666666666666
Epoch 0 step 6: training loss: 141368.1092794052
Epoch 0 step 7: training accuarcy: 0.5093333333333333
Epoch 0 step 7: training loss: 136958.64898813626
Epoch 0 step 8: training accuarcy: 0.5126666666666666
Epoch 0 step 8: training loss: 133736.63994669836
Epoch 0 step 9: training accuarcy: 0.5073333333333333
Epoch 0 step 9: training loss: 129658.23196107848
Epoch 0 st


 12%|███████████████████▎                                                                                                                                      | 1/8 [00:54<06:20, 54.32s/it]

Epoch: 1
Epoch 1 step 63: training loss: 25264.932058651277
Epoch 1 step 64: training accuarcy: 0.5313333333333333
Epoch 1 step 64: training loss: 24563.6291795637
Epoch 1 step 65: training accuarcy: 0.5233333333333333
Epoch 1 step 65: training loss: 23837.740697772657
Epoch 1 step 66: training accuarcy: 0.5263333333333333
Epoch 1 step 66: training loss: 23137.5456343006
Epoch 1 step 67: training accuarcy: 0.519
Epoch 1 step 67: training loss: 22305.188506839324
Epoch 1 step 68: training accuarcy: 0.5383333333333333
Epoch 1 step 68: training loss: 21865.87371189461
Epoch 1 step 69: training accuarcy: 0.5186666666666666
Epoch 1 step 69: training loss: 21126.914611093707
Epoch 1 step 70: training accuarcy: 0.5413333333333333
Epoch 1 step 70: training loss: 20507.082900245856
Epoch 1 step 71: training accuarcy: 0.5283333333333333
Epoch 1 step 71: training loss: 19900.68308730873
Epoch 1 step 72: training accuarcy: 0.5353333333333333
Epoch 1 step 72: training loss: 19344.148694673775
Epoch


 25%|██████████████████████████████████████▌                                                                                                                   | 2/8 [01:47<05:23, 53.96s/it]

Epoch: 2
Epoch 2 step 126: training loss: 5184.593375826083
Epoch 2 step 127: training accuarcy: 0.5393333333333333
Epoch 2 step 127: training loss: 5008.555125339645
Epoch 2 step 128: training accuarcy: 0.568
Epoch 2 step 128: training loss: 4956.185804356794
Epoch 2 step 129: training accuarcy: 0.5549999999999999
Epoch 2 step 129: training loss: 4921.405573011365
Epoch 2 step 130: training accuarcy: 0.5243333333333333
Epoch 2 step 130: training loss: 4805.803256062902
Epoch 2 step 131: training accuarcy: 0.5593333333333333
Epoch 2 step 131: training loss: 4734.823993431066
Epoch 2 step 132: training accuarcy: 0.5496666666666666
Epoch 2 step 132: training loss: 4652.349268151687
Epoch 2 step 133: training accuarcy: 0.5559999999999999
Epoch 2 step 133: training loss: 4600.53122568235
Epoch 2 step 134: training accuarcy: 0.5443333333333333
Epoch 2 step 134: training loss: 4482.449525876415
Epoch 2 step 135: training accuarcy: 0.5603333333333333
Epoch 2 step 135: training loss: 4417.0145


 38%|█████████████████████████████████████████████████████████▊                                                                                                | 3/8 [02:39<04:27, 53.45s/it]

Epoch: 3
Epoch 3 step 189: training loss: 2598.2127461305076
Epoch 3 step 190: training accuarcy: 0.5813333333333334
Epoch 3 step 190: training loss: 2581.033499826791
Epoch 3 step 191: training accuarcy: 0.5619999999999999
Epoch 3 step 191: training loss: 2566.510985635582
Epoch 3 step 192: training accuarcy: 0.5856666666666667
Epoch 3 step 192: training loss: 2531.016741107579
Epoch 3 step 193: training accuarcy: 0.592
Epoch 3 step 193: training loss: 2541.562224420147
Epoch 3 step 194: training accuarcy: 0.5803333333333334
Epoch 3 step 194: training loss: 2530.739546260867
Epoch 3 step 195: training accuarcy: 0.5863333333333333
Epoch 3 step 195: training loss: 2499.440603219685
Epoch 3 step 196: training accuarcy: 0.5966666666666667
Epoch 3 step 196: training loss: 2511.6160976224837
Epoch 3 step 197: training accuarcy: 0.5626666666666666
Epoch 3 step 197: training loss: 2467.7861841889912
Epoch 3 step 198: training accuarcy: 0.5913333333333333
Epoch 3 step 198: training loss: 2487.


 50%|█████████████████████████████████████████████████████████████████████████████                                                                             | 4/8 [03:32<03:33, 53.33s/it]

Epoch: 4
Epoch 4 step 252: training loss: 2136.8685482048395
Epoch 4 step 253: training accuarcy: 0.6026666666666667
Epoch 4 step 253: training loss: 2145.6143147362563
Epoch 4 step 254: training accuarcy: 0.5953333333333333
Epoch 4 step 254: training loss: 2162.6533408941864
Epoch 4 step 255: training accuarcy: 0.5783333333333334
Epoch 4 step 255: training loss: 2142.784330519935
Epoch 4 step 256: training accuarcy: 0.593
Epoch 4 step 256: training loss: 2120.76231227398
Epoch 4 step 257: training accuarcy: 0.6026666666666667
Epoch 4 step 257: training loss: 2150.1939646126716
Epoch 4 step 258: training accuarcy: 0.5946666666666667
Epoch 4 step 258: training loss: 2148.654510259628
Epoch 4 step 259: training accuarcy: 0.5943333333333333
Epoch 4 step 259: training loss: 2139.855315930714
Epoch 4 step 260: training accuarcy: 0.5916666666666667
Epoch 4 step 260: training loss: 2160.752932330425
Epoch 4 step 261: training accuarcy: 0.578
Epoch 4 step 261: training loss: 2129.6064059452274


 62%|████████████████████████████████████████████████████████████████████████████████████████████████▎                                                         | 5/8 [04:27<02:41, 53.68s/it]

Epoch: 5
Epoch 5 step 315: training loss: 2054.7393343226613
Epoch 5 step 316: training accuarcy: 0.591
Epoch 5 step 316: training loss: 2043.4029467919302
Epoch 5 step 317: training accuarcy: 0.5976666666666667
Epoch 5 step 317: training loss: 2047.5563513050395
Epoch 5 step 318: training accuarcy: 0.598
Epoch 5 step 318: training loss: 2065.2450572445873
Epoch 5 step 319: training accuarcy: 0.583
Epoch 5 step 319: training loss: 2033.6200819081823
Epoch 5 step 320: training accuarcy: 0.597
Epoch 5 step 320: training loss: 2048.1190641615126
Epoch 5 step 321: training accuarcy: 0.6046666666666667
Epoch 5 step 321: training loss: 2045.3066730305654
Epoch 5 step 322: training accuarcy: 0.5993333333333333
Epoch 5 step 322: training loss: 2036.8918369335117
Epoch 5 step 323: training accuarcy: 0.5976666666666667
Epoch 5 step 323: training loss: 2042.5031042932249
Epoch 5 step 324: training accuarcy: 0.6
Epoch 5 step 324: training loss: 2045.3018582227803
Epoch 5 step 325: training accuarc


 75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                      | 6/8 [05:19<01:46, 53.35s/it]

Epoch: 6
Epoch 6 step 378: training loss: 2007.3039918269963
Epoch 6 step 379: training accuarcy: 0.6063333333333333
Epoch 6 step 379: training loss: 2014.3587262542972
Epoch 6 step 380: training accuarcy: 0.5936666666666667
Epoch 6 step 380: training loss: 2032.9607512739262
Epoch 6 step 381: training accuarcy: 0.5926666666666667
Epoch 6 step 381: training loss: 2023.6147372454552
Epoch 6 step 382: training accuarcy: 0.5916666666666667
Epoch 6 step 382: training loss: 2001.440187290323
Epoch 6 step 383: training accuarcy: 0.6113333333333333
Epoch 6 step 383: training loss: 2038.911427772406
Epoch 6 step 384: training accuarcy: 0.577
Epoch 6 step 384: training loss: 2042.6050643427182
Epoch 6 step 385: training accuarcy: 0.5856666666666667
Epoch 6 step 385: training loss: 2019.788289053153
Epoch 6 step 386: training accuarcy: 0.6026666666666667
Epoch 6 step 386: training loss: 2033.5953070702801
Epoch 6 step 387: training accuarcy: 0.588
Epoch 6 step 387: training loss: 2018.3627146785


 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                   | 7/8 [06:12<00:53, 53.12s/it]

Epoch: 7
Epoch 7 step 441: training loss: 2020.356170581338
Epoch 7 step 442: training accuarcy: 0.5926666666666667
Epoch 7 step 442: training loss: 2017.4887357087612
Epoch 7 step 443: training accuarcy: 0.5933333333333333
Epoch 7 step 443: training loss: 1996.4614705658942
Epoch 7 step 444: training accuarcy: 0.6096666666666667
Epoch 7 step 444: training loss: 2010.033932801882
Epoch 7 step 445: training accuarcy: 0.599
Epoch 7 step 445: training loss: 2016.8442833552137
Epoch 7 step 446: training accuarcy: 0.5956666666666667
Epoch 7 step 446: training loss: 2015.6366536854407
Epoch 7 step 447: training accuarcy: 0.5863333333333333
Epoch 7 step 447: training loss: 2008.5217373952091
Epoch 7 step 448: training accuarcy: 0.5986666666666667
Epoch 7 step 448: training loss: 2006.8401407849979
Epoch 7 step 449: training accuarcy: 0.5993333333333333
Epoch 7 step 449: training loss: 2019.0411705627255
Epoch 7 step 450: training accuarcy: 0.5976666666666667
Epoch 7 step 450: training loss: 1


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [07:04<00:00, 52.95s/it]

### Train Trans FM Model

In [15]:
trans_model = TorchTransFM(feature_dim=feat_dim, num_dim=NUM_DIM, init_mean=INIT_MEAN)
trans_model

TorchTransFM()

In [16]:
adam_opt = optim.Adam(trans_model.parameters(), lr=LEARNING_RATE)
schedular = optim.lr_scheduler.StepLR(adam_opt,
                                      step_size=DECAY_FREQ,
                                      gamma=DECAY_GAMMA)

In [17]:
trans_learner = FMLearner(trans_model, adam_opt, schedular, db)
trans_learner

<models.fm_learner.FMLearner at 0x1b929682160>

In [18]:
trans_learner.fit(epoch=8, loss_callback=trans_loss_callback, log_dir=get_log_dir('trans'))

  0%|                                                                                                                                                                  | 0/8 [00:00<?, ?it/s]

Epoch: 0
Epoch 0 step 0: training loss: 171935.10777800047
Epoch 0 step 1: training accuarcy: 0.4983333333333333
Epoch 0 step 1: training loss: 167236.04864530085
Epoch 0 step 2: training accuarcy: 0.5033333333333333
Epoch 0 step 2: training loss: 162522.13603857072
Epoch 0 step 3: training accuarcy: 0.506
Epoch 0 step 3: training loss: 158037.64368668137
Epoch 0 step 4: training accuarcy: 0.4943333333333333
Epoch 0 step 4: training loss: 153135.47970628904
Epoch 0 step 5: training accuarcy: 0.511
Epoch 0 step 5: training loss: 149167.705830932
Epoch 0 step 6: training accuarcy: 0.5096666666666666
Epoch 0 step 6: training loss: 144981.4878073405
Epoch 0 step 7: training accuarcy: 0.4993333333333333
Epoch 0 step 7: training loss: 140899.25312771543
Epoch 0 step 8: training accuarcy: 0.5026666666666666
Epoch 0 step 8: training loss: 136904.69567565294
Epoch 0 step 9: training accuarcy: 0.5086666666666666
Epoch 0 step 9: training loss: 132839.0323278288
Epoch 0 step 10: training accuarcy:

 12%|███████████████████▎                                                                                                                                      | 1/8 [00:57<06:39, 57.02s/it]

Epoch: 1
Epoch 1 step 63: training loss: 24765.28876517985
Epoch 1 step 64: training accuarcy: 0.5439999999999999
Epoch 1 step 64: training loss: 24343.362802849264
Epoch 1 step 65: training accuarcy: 0.533
Epoch 1 step 65: training loss: 23330.051831719284
Epoch 1 step 66: training accuarcy: 0.5363333333333333
Epoch 1 step 66: training loss: 22676.773948953753
Epoch 1 step 67: training accuarcy: 0.5393333333333333
Epoch 1 step 67: training loss: 21755.683702378883
Epoch 1 step 68: training accuarcy: 0.5483333333333333
Epoch 1 step 68: training loss: 21054.956681211257
Epoch 1 step 69: training accuarcy: 0.5553333333333333
Epoch 1 step 69: training loss: 20870.62700801407
Epoch 1 step 70: training accuarcy: 0.5376666666666666
Epoch 1 step 70: training loss: 19935.317096798874
Epoch 1 step 71: training accuarcy: 0.5493333333333333
Epoch 1 step 71: training loss: 19713.50639924015
Epoch 1 step 72: training accuarcy: 0.5473333333333333
Epoch 1 step 72: training loss: 18888.050174725067
Ep

 25%|██████████████████████████████████████▌                                                                                                                   | 2/8 [01:54<05:42, 57.03s/it]

Epoch: 2
Epoch 2 step 126: training loss: 5341.7469564397115
Epoch 2 step 127: training accuarcy: 0.6033333333333333
Epoch 2 step 127: training loss: 5206.677189602455
Epoch 2 step 128: training accuarcy: 0.6073333333333333
Epoch 2 step 128: training loss: 5040.0380092549985
Epoch 2 step 129: training accuarcy: 0.6306666666666666
Epoch 2 step 129: training loss: 4982.010231246841
Epoch 2 step 130: training accuarcy: 0.6176666666666667
Epoch 2 step 130: training loss: 5006.833278831855
Epoch 2 step 131: training accuarcy: 0.61
Epoch 2 step 131: training loss: 4932.096732335662
Epoch 2 step 132: training accuarcy: 0.607
Epoch 2 step 132: training loss: 4862.2319357641345
Epoch 2 step 133: training accuarcy: 0.6176666666666667
Epoch 2 step 133: training loss: 4932.976106528808
Epoch 2 step 134: training accuarcy: 0.599
Epoch 2 step 134: training loss: 4666.74127676386
Epoch 2 step 135: training accuarcy: 0.6236666666666666
Epoch 2 step 135: training loss: 4711.189458596309
Epoch 2 step 13

 38%|█████████████████████████████████████████████████████████▊                                                                                                | 3/8 [02:51<04:45, 57.13s/it]

Epoch: 3
Epoch 3 step 189: training loss: 3100.6241967996957
Epoch 3 step 190: training accuarcy: 0.6476666666666666
Epoch 3 step 190: training loss: 3065.905978395724
Epoch 3 step 191: training accuarcy: 0.6506666666666666
Epoch 3 step 191: training loss: 3051.055718787956
Epoch 3 step 192: training accuarcy: 0.6586666666666666
Epoch 3 step 192: training loss: 2995.7153299278502
Epoch 3 step 193: training accuarcy: 0.6573333333333333
Epoch 3 step 193: training loss: 3002.588309603074
Epoch 3 step 194: training accuarcy: 0.6646666666666666
Epoch 3 step 194: training loss: 3020.4793526827852
Epoch 3 step 195: training accuarcy: 0.656
Epoch 3 step 195: training loss: 3092.5020863427226
Epoch 3 step 196: training accuarcy: 0.6366666666666666
Epoch 3 step 196: training loss: 3080.4067489374947
Epoch 3 step 197: training accuarcy: 0.6296666666666666
Epoch 3 step 197: training loss: 3061.1763300382936
Epoch 3 step 198: training accuarcy: 0.6443333333333333
Epoch 3 step 198: training loss: 30

 50%|█████████████████████████████████████████████████████████████████████████████                                                                             | 4/8 [03:49<03:49, 57.40s/it]

Epoch: 4
Epoch 4 step 252: training loss: 2735.7327254946854
Epoch 4 step 253: training accuarcy: 0.6443333333333333
Epoch 4 step 253: training loss: 2612.4413435358856
Epoch 4 step 254: training accuarcy: 0.6693333333333333
Epoch 4 step 254: training loss: 2602.902898685944
Epoch 4 step 255: training accuarcy: 0.6689999999999999
Epoch 4 step 255: training loss: 2509.7742310637746
Epoch 4 step 256: training accuarcy: 0.6813333333333333
Epoch 4 step 256: training loss: 2602.060543348648
Epoch 4 step 257: training accuarcy: 0.6663333333333333
Epoch 4 step 257: training loss: 2600.6538045840157
Epoch 4 step 258: training accuarcy: 0.6766666666666666
Epoch 4 step 258: training loss: 2609.456664274499
Epoch 4 step 259: training accuarcy: 0.6719999999999999
Epoch 4 step 259: training loss: 2560.6986074286774
Epoch 4 step 260: training accuarcy: 0.6706666666666666
Epoch 4 step 260: training loss: 2660.5885387430635
Epoch 4 step 261: training accuarcy: 0.657
Epoch 4 step 261: training loss: 25

 62%|████████████████████████████████████████████████████████████████████████████████████████████████▎                                                         | 5/8 [04:47<02:53, 57.69s/it]

Epoch: 5
Epoch 5 step 315: training loss: 2412.2438616002646
Epoch 5 step 316: training accuarcy: 0.6846666666666666
Epoch 5 step 316: training loss: 2336.35504225671
Epoch 5 step 317: training accuarcy: 0.6956666666666667
Epoch 5 step 317: training loss: 2441.457771936631
Epoch 5 step 318: training accuarcy: 0.6689999999999999
Epoch 5 step 318: training loss: 2406.507540120618
Epoch 5 step 319: training accuarcy: 0.6816666666666666
Epoch 5 step 319: training loss: 2433.25940751863
Epoch 5 step 320: training accuarcy: 0.6833333333333333
Epoch 5 step 320: training loss: 2366.3120855902403
Epoch 5 step 321: training accuarcy: 0.6789999999999999
Epoch 5 step 321: training loss: 2360.9960344062156
Epoch 5 step 322: training accuarcy: 0.6873333333333334
Epoch 5 step 322: training loss: 2402.5369292450096
Epoch 5 step 323: training accuarcy: 0.6816666666666666
Epoch 5 step 323: training loss: 2404.133115327325
Epoch 5 step 324: training accuarcy: 0.6853333333333333
Epoch 5 step 324: training

 75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                      | 6/8 [05:45<01:55, 57.58s/it]

Epoch: 6
Epoch 6 step 378: training loss: 2263.5935061250066
Epoch 6 step 379: training accuarcy: 0.6876666666666666
Epoch 6 step 379: training loss: 2287.9425524954586
Epoch 6 step 380: training accuarcy: 0.6923333333333334
Epoch 6 step 380: training loss: 2306.9840913737758
Epoch 6 step 381: training accuarcy: 0.695
Epoch 6 step 381: training loss: 2234.2039193461146
Epoch 6 step 382: training accuarcy: 0.7003333333333334
Epoch 6 step 382: training loss: 2269.0065718211335
Epoch 6 step 383: training accuarcy: 0.713
Epoch 6 step 383: training loss: 2270.095949674769
Epoch 6 step 384: training accuarcy: 0.6976666666666667
Epoch 6 step 384: training loss: 2308.040924046886
Epoch 6 step 385: training accuarcy: 0.6976666666666667
Epoch 6 step 385: training loss: 2332.956090578001
Epoch 6 step 386: training accuarcy: 0.6769999999999999
Epoch 6 step 386: training loss: 2311.3791655096375
Epoch 6 step 387: training accuarcy: 0.6913333333333334
Epoch 6 step 387: training loss: 2260.6101782034

 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                   | 7/8 [06:43<00:57, 57.68s/it]

Epoch: 7
Epoch 7 step 441: training loss: 2165.3529943064427
Epoch 7 step 442: training accuarcy: 0.7133333333333333
Epoch 7 step 442: training loss: 2116.9963831491973
Epoch 7 step 443: training accuarcy: 0.719
Epoch 7 step 443: training loss: 2151.0668768045743
Epoch 7 step 444: training accuarcy: 0.7093333333333333
Epoch 7 step 444: training loss: 2132.4021095970916
Epoch 7 step 445: training accuarcy: 0.7253333333333333
Epoch 7 step 445: training loss: 2104.2357112608065
Epoch 7 step 446: training accuarcy: 0.7156666666666667
Epoch 7 step 446: training loss: 2126.812333236503
Epoch 7 step 447: training accuarcy: 0.721
Epoch 7 step 447: training loss: 2167.1015551817395
Epoch 7 step 448: training accuarcy: 0.7073333333333333
Epoch 7 step 448: training loss: 2180.3855493158503
Epoch 7 step 449: training accuarcy: 0.6996666666666667
Epoch 7 step 449: training loss: 2203.771525998578
Epoch 7 step 450: training accuarcy: 0.7096666666666667
Epoch 7 step 450: training loss: 2124.226691504

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [07:39<00:00, 57.32s/it]
