In [None]:
import numpy as np
import random
import torch
from src.dataloaders import PairwiseDataset
from src.models import MatrixFactorizationBPRModel
from src.trainer import Trainer
from src.metrics import hitratio, ndcg
from joblib import Parallel, delayed

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
print(f"{device=}")

In [None]:
class config:
    data_dir = "ml-100k"
    epochs = 50
    batch_size = 2048
    dim = 40


dataset = PairwiseDataset(config.data_dir)
dataset.gen_adjacency()
dataset.make_train_test()
print(f"{dataset.train_size=}, {dataset.test_size=}")

metrics = {
    # "HR@1": (hitratio, {"top_n": 1}),
    # "HR@5": (hitratio, {"top_n": 5}),
    "HR@10": (hitratio, {"top_n": 10}),
    # "NDCG@1": (ndcg, {"top_n": 1}),
    # "NDCG@5": (ndcg, {"top_n": 5}),
    "NDCG@10": (ndcg, {"top_n": 10}),
}

In [None]:
# took 1 hour 15 mins to complete

grid_params = {
    "lr": [0.001, 0.01, 0.1, 1],
    "momentum": list(np.arange(0.1, 1.1, 0.1)),
    "weight_decay": [0.0001, 0.001, 0.01, 0.1, 1],
}


def search(lr, mom, wd):
    model = MatrixFactorizationBPRModel(
        dataset.user_count, dataset.item_count, config.dim
    )

    optimizer = torch.optim.SGD(
        model.parameters(), lr=lr, momentum=mom, nesterov=True, weight_decay=wd
    )

    trainer = Trainer(
        dataset,
        model,
        optimizer,
        metrics,
        epochs=config.epochs,
        batch_size=config.batch_size,
        device=device,
    )

    trainer.train(evaluate=True, verbose=False, progressbar=False)
    best_ndcg, best_ndcg_epoch = float("-inf"), 0
    for i, s in enumerate(trainer.test_log):
        if s["NDCG@10"] > best_ndcg:
            best_ndcg = s["NDCG@10"]
            best_ndcg_epoch = i

    return (best_ndcg, best_ndcg_epoch, (lr, mom, wd))


output = Parallel(n_jobs=4)(
    delayed(search)(lr, mom, wd)
    for lr in grid_params["lr"]
    for mom in grid_params["momentum"]
    for wd in grid_params["weight_decay"]
)

In [None]:
sorted_log = sorted(output)
sorted_log.reverse()

with open("gridsearch_logs/mfbpr.txt", "w") as f:
    for r in sorted_log:
        f.write(f"{r}\n")