In [8]:
import numpy as np
import random
import torch
from src.dataloaders import Dataset100k
from src.models import GMFBCEModel
from src.trainer import Trainer
from src.metrics import hitratio, ndcg

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
print(f"{device=}")

device='cuda'


In [9]:
class config:
    data_dir = "ml-100k"
    epochs = 50
    batch_size = 2048
    dim = 16
    lr = 0.001


dataset = Dataset100k(config.data_dir)
dataset.gen_adjacency()
dataset.make_train_test()
print(f"{dataset.train_size=}, {dataset.test_size=}")

metrics = {
    "HR@1": (hitratio, {"top_n": 1}),
    "HR@5": (hitratio, {"top_n": 5}),
    "HR@10": (hitratio, {"top_n": 10}),
    "NDCG@1": (ndcg, {"top_n": 1}),
    "NDCG@5": (ndcg, {"top_n": 5}),
    "NDCG@10": (ndcg, {"top_n": 10}),
}

dataset.train_size=198114, dataset.test_size=943


In [10]:
model = GMFBCEModel(dataset.user_count, dataset.item_count, embed_size=config.dim)

optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

trainer = Trainer(
    dataset,
    model,
    optimizer,
    metrics,
    epochs=config.epochs,
    batch_size=config.batch_size,
    device=device,
)

In [11]:
trainer.train(evaluate=True, verbose=True, progressbar=True)
# trainer.test(verbose=False, pbar=False)

                                                                   

Epoch 0: Avg Loss/Batch 0.693109            


                                                   

HR@1: 0.14422057264050903
HR@5: 0.3616118769883351
HR@10: 0.5344644750795334
NDCG@1: 0.14422057264050903
NDCG@5: 0.25501661702803474
NDCG@10: 0.310695052412515


                                                                   

Epoch 1: Avg Loss/Batch 0.682623            


                                                   

HR@1: 0.1420996818663839
HR@5: 0.34676564156945916
HR@10: 0.5217391304347826
NDCG@1: 0.1420996818663839
NDCG@5: 0.2471404020658947
NDCG@10: 0.3036570128965668


                                                                   

Epoch 2: Avg Loss/Batch 0.610052            


                                                   

HR@1: 0.14316012725344646
HR@5: 0.35100742311770944
HR@10: 0.5227995758218452
NDCG@1: 0.14316012725344646
NDCG@5: 0.24890456367868136
NDCG@10: 0.304358040166351


                                                                   

Epoch 3: Avg Loss/Batch 0.516167            


                                                   

HR@1: 0.14952279957582185
HR@5: 0.3616118769883351
HR@10: 0.5270413573700954
NDCG@1: 0.14952279957582185
NDCG@5: 0.2556252778109433
NDCG@10: 0.30885253238340366


                                                                   

Epoch 4: Avg Loss/Batch 0.480456            


                                                   

HR@1: 0.14846235418875928
HR@5: 0.36585365853658536
HR@10: 0.5270413573700954
NDCG@1: 0.14846235418875928
NDCG@5: 0.256941322659347
NDCG@10: 0.3089040957582971


                                                                   

Epoch 5: Avg Loss/Batch 0.471106            


                                                   

HR@1: 0.1474019088016967
HR@5: 0.37009544008483564
HR@10: 0.5312831389183457
NDCG@1: 0.1474019088016967
NDCG@5: 0.25826497687290084
NDCG@10: 0.3099889389155207


                                                                   

Epoch 6: Avg Loss/Batch 0.467216            


                                                   

HR@1: 0.14528101802757157
HR@5: 0.37645811240721105
HR@10: 0.5323435843054083
NDCG@1: 0.14528101802757157
NDCG@5: 0.26070945381779653
NDCG@10: 0.31067984389349007


                                                                   

Epoch 7: Avg Loss/Batch 0.464806            


                                                   

HR@1: 0.14316012725344646
HR@5: 0.3753976670201485
HR@10: 0.5355249204665959
NDCG@1: 0.14316012725344646
NDCG@5: 0.26011639195537417
NDCG@10: 0.31153398471763893


                                                                   

Epoch 8: Avg Loss/Batch 0.462948            


                                                   

HR@1: 0.14316012725344646
HR@5: 0.3796394485683987
HR@10: 0.5344644750795334
NDCG@1: 0.14316012725344646
NDCG@5: 0.26210911271566345
NDCG@10: 0.31189753876551624


                                                                   

Epoch 9: Avg Loss/Batch 0.461275            


                                                   

HR@1: 0.14316012725344646
HR@5: 0.3806998939554613
HR@10: 0.5365853658536586
NDCG@1: 0.14316012725344646
NDCG@5: 0.26306405056401555
NDCG@10: 0.3132913180165578


                                                                    

Epoch 10: Avg Loss/Batch 0.459571            


                                                   

HR@1: 0.14634146341463414
HR@5: 0.38494167550371156
HR@10: 0.5440084835630965
NDCG@1: 0.14634146341463414
NDCG@5: 0.2666914218551599
NDCG@10: 0.3180060101619444


                                                                    

Epoch 11: Avg Loss/Batch 0.457664            


                                                   

HR@1: 0.1474019088016967
HR@5: 0.3860021208907741
HR@10: 0.5524920466595971
NDCG@1: 0.1474019088016967
NDCG@5: 0.267724826212437
NDCG@10: 0.3215390100568849


                                                                    

Epoch 12: Avg Loss/Batch 0.455391            


                                                   

HR@1: 0.15058324496288442
HR@5: 0.39660657476139977
HR@10: 0.5609756097560976
NDCG@1: 0.15058324496288442
NDCG@5: 0.2740535879902603
NDCG@10: 0.3270559962175553


                                                                    

Epoch 13: Avg Loss/Batch 0.452587            


                                                   

HR@1: 0.15906680805938495
HR@5: 0.40402969247083775
HR@10: 0.5705196182396607
NDCG@1: 0.15906680805938495
NDCG@5: 0.28260877650046706
NDCG@10: 0.33645833661489


                                                                    

Epoch 14: Avg Loss/Batch 0.449097            


                                                   

HR@1: 0.16542948038176034
HR@5: 0.4135737009544008
HR@10: 0.5811240721102863
NDCG@1: 0.16542948038176034
NDCG@5: 0.2902114897028803
NDCG@10: 0.3442653942476162


                                                                    

Epoch 15: Avg Loss/Batch 0.444815            


                                                   

HR@1: 0.18027571580063625
HR@5: 0.42735949098621423
HR@10: 0.5938494167550371
NDCG@1: 0.18027571580063625
NDCG@5: 0.3035022809502095
NDCG@10: 0.35706960120235637


                                                                    

Epoch 16: Avg Loss/Batch 0.439749            


                                                   

HR@1: 0.1855779427359491
HR@5: 0.44220572640509015
HR@10: 0.616118769883351
NDCG@1: 0.1855779427359491
NDCG@5: 0.31505483157396863
NDCG@10: 0.3707867506755771


                                                                    

Epoch 17: Avg Loss/Batch 0.434054            


                                                   

HR@1: 0.18981972428419935
HR@5: 0.45493107104984093
HR@10: 0.6288441145281018
NDCG@1: 0.18981972428419935
NDCG@5: 0.32587854095631913
NDCG@10: 0.3820901829367425


                                                                    

Epoch 18: Avg Loss/Batch 0.428003            


                                                   

HR@1: 0.19406150583244963
HR@5: 0.47720042417815484
HR@10: 0.6489925768822906
NDCG@1: 0.19406150583244963
NDCG@5: 0.34095678736071444
NDCG@10: 0.39615397466132024


                                                                    

Epoch 19: Avg Loss/Batch 0.421883            


                                                   

HR@1: 0.19936373276776245
HR@5: 0.48250265111346763
HR@10: 0.662778366914104
NDCG@1: 0.19936373276776245
NDCG@5: 0.34707187819663216
NDCG@10: 0.4053031135665134


                                                                    

Epoch 20: Avg Loss/Batch 0.415888            


                                                   

HR@1: 0.2067868504772004
HR@5: 0.503711558854719
HR@10: 0.6733828207847296
NDCG@1: 0.2067868504772004
NDCG@5: 0.359359369874741
NDCG@10: 0.4140890767180342


                                                                    

Epoch 21: Avg Loss/Batch 0.410071            


                                                   

HR@1: 0.21527041357370094
HR@5: 0.5185577942735949
HR@10: 0.6850477200424178
NDCG@1: 0.21527041357370094
NDCG@5: 0.36962195365413075
NDCG@10: 0.4229981348913664


                                                                    

Epoch 22: Avg Loss/Batch 0.404387            


                                                   

HR@1: 0.21951219512195122
HR@5: 0.5270413573700954
HR@10: 0.6998939554612937
NDCG@1: 0.21951219512195122
NDCG@5: 0.3768346305162065
NDCG@10: 0.432444315907094


                                                                    

Epoch 23: Avg Loss/Batch 0.398784            


                                                   

HR@1: 0.2258748674443266
HR@5: 0.5408271474019088
HR@10: 0.7020148462354189
NDCG@1: 0.2258748674443266
NDCG@5: 0.38542484889558737
NDCG@10: 0.4374505883995888


                                                                    

Epoch 24: Avg Loss/Batch 0.393279            


                                                   

HR@1: 0.22905620360551432
HR@5: 0.542948038176034
HR@10: 0.7051961823966065
NDCG@1: 0.22905620360551432
NDCG@5: 0.38910060961120085
NDCG@10: 0.4416004268649068


                                                                    

Epoch 25: Avg Loss/Batch 0.387989            


                                                   

HR@1: 0.23753976670201485
HR@5: 0.5524920466595971
HR@10: 0.7094379639448568
NDCG@1: 0.23753976670201485
NDCG@5: 0.39779272524456344
NDCG@10: 0.44870295310652447


                                                                    

Epoch 26: Avg Loss/Batch 0.383080            


                                                   

HR@1: 0.24284199363732767
HR@5: 0.5652173913043478
HR@10: 0.7126193001060446
NDCG@1: 0.24284199363732767
NDCG@5: 0.40675185474496595
NDCG@10: 0.45454240999310846


                                                                    

Epoch 27: Avg Loss/Batch 0.378682            


                                                   

HR@1: 0.24602332979851538
HR@5: 0.5737009544008483
HR@10: 0.7147401908801697
NDCG@1: 0.24602332979851538
NDCG@5: 0.4123060424373434
NDCG@10: 0.4579867567471363


                                                                    

Epoch 28: Avg Loss/Batch 0.374825            


                                                   

HR@1: 0.2492046659597031
HR@5: 0.574761399787911
HR@10: 0.7179215270413574
NDCG@1: 0.2492046659597031
NDCG@5: 0.41521524632644735
NDCG@10: 0.4615052962204848


                                                                    

Epoch 29: Avg Loss/Batch 0.371436            


                                                   

HR@1: 0.2513255567338282
HR@5: 0.574761399787911
HR@10: 0.7168610816542949
NDCG@1: 0.2513255567338282
NDCG@5: 0.41700628406464135
NDCG@10: 0.46303326789315896


                                                                    

Epoch 30: Avg Loss/Batch 0.368382            


                                                   

HR@1: 0.25344644750795337
HR@5: 0.5832449628844114
HR@10: 0.721102863202545
NDCG@1: 0.25344644750795337
NDCG@5: 0.4213768091408391
NDCG@10: 0.46601871554182916


                                                                    

Epoch 31: Avg Loss/Batch 0.365506            


                                                   

HR@1: 0.25556733828207845
HR@5: 0.5896076352067868
HR@10: 0.7264050901378579
NDCG@1: 0.25556733828207845
NDCG@5: 0.4250733148855828
NDCG@10: 0.469122381947934


                                                                    

Epoch 32: Avg Loss/Batch 0.362669            


                                                   

HR@1: 0.25344644750795337
HR@5: 0.5896076352067868
HR@10: 0.735949098621421
NDCG@1: 0.25344644750795337
NDCG@5: 0.4254906147151265
NDCG@10: 0.4723463172096978


                                                                    

Epoch 33: Avg Loss/Batch 0.359752            


                                                   

HR@1: 0.256627783669141
HR@5: 0.5896076352067868
HR@10: 0.7401908801696713
NDCG@1: 0.256627783669141
NDCG@5: 0.42716298011513537
NDCG@10: 0.4753915198954211


                                                                    

Epoch 34: Avg Loss/Batch 0.356670            


                                                   

HR@1: 0.25344644750795337
HR@5: 0.5896076352067868
HR@10: 0.7391304347826086
NDCG@1: 0.25344644750795337
NDCG@5: 0.4266718147907475
NDCG@10: 0.47471710201935957


                                                                    

Epoch 35: Avg Loss/Batch 0.353367            


                                                   

HR@1: 0.25344644750795337
HR@5: 0.5938494167550371
HR@10: 0.7391304347826086
NDCG@1: 0.25344644750795337
NDCG@5: 0.4287380492752432
NDCG@10: 0.47558875362681


                                                                    

Epoch 36: Avg Loss/Batch 0.349818            


                                                   

HR@1: 0.25874867444326616
HR@5: 0.5980911983032874
HR@10: 0.7380699893955461
NDCG@1: 0.25874867444326616
NDCG@5: 0.4332411236358836
NDCG@10: 0.47844473487748024


                                                                    

Epoch 37: Avg Loss/Batch 0.346027            


                                                   

HR@1: 0.2598091198303287
HR@5: 0.5959703075291622
HR@10: 0.7423117709437964
NDCG@1: 0.2598091198303287
NDCG@5: 0.43373803843446884
NDCG@10: 0.48121917082544907


                                                                    

Epoch 38: Avg Loss/Batch 0.342028            


                                                   

HR@1: 0.25874867444326616
HR@5: 0.5959703075291622
HR@10: 0.7465535524920467
NDCG@1: 0.25874867444326616
NDCG@5: 0.4335213023023982
NDCG@10: 0.48240269304589295


                                                                    

Epoch 39: Avg Loss/Batch 0.337877            


                                                   

HR@1: 0.264050901378579
HR@5: 0.6002120890774125
HR@10: 0.7465535524920467
NDCG@1: 0.264050901378579
NDCG@5: 0.4369614403158695
NDCG@10: 0.4845571466899414


                                                                    

Epoch 40: Avg Loss/Batch 0.333648            


                                                   

HR@1: 0.264050901378579
HR@5: 0.5980911983032874
HR@10: 0.7465535524920467
NDCG@1: 0.264050901378579
NDCG@5: 0.43732695500607577
NDCG@10: 0.4858448942230303


                                                                    

Epoch 41: Avg Loss/Batch 0.329418            


                                                   

HR@1: 0.2651113467656416
HR@5: 0.6002120890774125
HR@10: 0.7476139978791092
NDCG@1: 0.2651113467656416
NDCG@5: 0.43882410346776957
NDCG@10: 0.48704945413887846


                                                                    

Epoch 42: Avg Loss/Batch 0.325251            


                                                   

HR@1: 0.2651113467656416
HR@5: 0.6023329798515377
HR@10: 0.7486744432661718
NDCG@1: 0.2651113467656416
NDCG@5: 0.4397997870611991
NDCG@10: 0.48780603238551773


                                                                    

Epoch 43: Avg Loss/Batch 0.321188            


                                                   

HR@1: 0.2682926829268293
HR@5: 0.6044538706256628
HR@10: 0.7539766702014846
NDCG@1: 0.2682926829268293
NDCG@5: 0.4423278496360104
NDCG@10: 0.4912491140627189


                                                                    

Epoch 44: Avg Loss/Batch 0.317245            


                                                   

HR@1: 0.26935312831389185
HR@5: 0.6044538706256628
HR@10: 0.7592788971367974
NDCG@1: 0.26935312831389185
NDCG@5: 0.4430026095337948
NDCG@10: 0.4934538723127174


                                                                    

Epoch 45: Avg Loss/Batch 0.313420            


                                                   

HR@1: 0.2672322375397667
HR@5: 0.6118769883351007
HR@10: 0.7635206786850477
NDCG@1: 0.2672322375397667
NDCG@5: 0.4460639869799561
NDCG@10: 0.4951391345112799


                                                                    

Epoch 46: Avg Loss/Batch 0.309704            


                                                   

HR@1: 0.2704135737009544
HR@5: 0.616118769883351
HR@10: 0.7667020148462355
NDCG@1: 0.2704135737009544
NDCG@5: 0.44873965051729153
NDCG@10: 0.4975641822489779


                                                                    

Epoch 47: Avg Loss/Batch 0.306086            


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6139978791092259
HR@10: 0.767762460233298
NDCG@1: 0.27465535524920465
NDCG@5: 0.4498930412931957
NDCG@10: 0.49975542877040346


                                                                    

Epoch 48: Avg Loss/Batch 0.302558            


                                                   

HR@1: 0.26617179215270415
HR@5: 0.6086956521739131
HR@10: 0.7709437963944857
NDCG@1: 0.26617179215270415
NDCG@5: 0.44456131150348543
NDCG@10: 0.497216247037698


                                                                    

Epoch 49: Avg Loss/Batch 0.299113            


                                                   

HR@1: 0.264050901378579
HR@5: 0.6076352067868505
HR@10: 0.7730646871686108
NDCG@1: 0.264050901378579
NDCG@5: 0.4433412765181415
NDCG@10: 0.49699318713790863




In [12]:
best_epoch = np.argmax([r["NDCG@10"] for r in trainer.test_log])
print(f"{best_epoch}: {trainer.test_log[best_epoch]}")

47: {'HR@1': 0.27465535524920465, 'HR@5': 0.6139978791092259, 'HR@10': 0.767762460233298, 'NDCG@1': 0.27465535524920465, 'NDCG@5': 0.4498930412931957, 'NDCG@10': 0.49975542877040346}


In [13]:
torch.save(trainer.model.state_dict(), "saved_models/gmfbce.pt")
# trainer.model.load_state_dict(torch.load("saved_models/gmfbce.pt"))