In [1]:
import torch
import sys
import json
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

from src.dataset import ECommerceDS
from src.metrics import compute_recall_at_k, compute_mrr
from src.metric_handler import MetricHandler

sys.path.append("src/")
from models.sr_gnn import SRGNN

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
product2token_fp = "data/product2token.json"
weight_fp = "results/srgnn_baseline/best_model.pth"
device = "mps"
test_ds_fp = "data/splits/test.jsonl"

In [5]:
with open(product2token_fp, mode="r") as f:
    product2token = json.load(f)

In [7]:
model = SRGNN(hidden_size=128, n_node=len(product2token), num_layers=2)
model.load_state_dict(torch.load(weight_fp))
model.to(device)
model.eval()
print("Loaded model weights")

Loaded model weights


  model.load_state_dict(torch.load(weight_fp))


In [8]:
test_ds = ECommerceDS(test_ds_fp, max_len=50, product2token=product2token, mask="last")
test_dl = DataLoader(test_ds, batch_size=256, shuffle=False)

In [15]:
metric_handler = MetricHandler("results/srgnn_baseline")

all_logits = list()
all_labels = list()
all_times = list()

with torch.no_grad():  # Disable gradient computation for evaluation
    for test_batch in tqdm(test_dl):
        x_val = test_batch["graph"].to(device)
        labels_val = test_batch["products"].to(device)
        cloze_mask_val = test_batch["cloze_mask"].to(device)

        logits_flat_val = model(x_val)
        labels_flat_val = labels_val[:, -1]
        cloze_mask_flat_val = cloze_mask_val[:, -1]

        # Select only the masked positions
        valid_indices_val = cloze_mask_flat_val == 1
        logits_masked_val = logits_flat_val[valid_indices_val]
        labels_masked_val = labels_flat_val[valid_indices_val]

        # Calculate validation loss and metrics
        test_loss = F.cross_entropy(logits_masked_val, labels_masked_val)
        batch_metrics = {
            "test_loss": test_loss.item(),
            "test_recall@1": compute_recall_at_k(logits_masked_val, labels_masked_val, k=1),
            "test_recall@5": compute_recall_at_k(logits_masked_val, labels_masked_val, k=5),
            "test_recall@10": compute_recall_at_k(logits_masked_val, labels_masked_val, k=10),
            "test_recall@20": compute_recall_at_k(logits_masked_val, labels_masked_val, k=20),
            "test_mrr": compute_mrr(logits_masked_val, labels_masked_val),
        }
        metric_handler.batch_update(batch_metrics)

metric_handler.all_update_save_clear(save_name="test_results.csv")

  0%|          | 0/700 [00:00<?, ?it/s]

7
[('2020-01-05 03:49:57 UTC', '2020-02-03 15:52:45 UTC', '2020-03-02 15:48:51 UTC', '2019-10-11 04:22:31 UTC', '2020-02-29 10:32:58 UTC', '2020-04-23 09:28:15 UTC', '2019-10-19 03:52:41 UTC', '2020-02-04 09:18:15 UTC', '2019-11-03 06:08:02 UTC', '2020-01-24 23:13:16 UTC', '2019-10-03 06:09:01 UTC', '2019-10-04 16:57:38 UTC', '2019-10-08 14:51:19 UTC', '2020-04-18 07:31:53 UTC', '2019-10-01 12:23:00 UTC', '2019-11-17 12:39:39 UTC', '2019-10-15 11:42:57 UTC', '2019-10-13 19:25:39 UTC', '2019-10-13 18:07:28 UTC', '2020-01-07 04:21:17 UTC', '2020-02-08 18:55:13 UTC', '2019-10-08 18:29:09 UTC', '2020-02-14 11:03:02 UTC', '2019-11-04 09:52:34 UTC', '2019-12-15 22:34:32 UTC', '2019-12-12 11:38:57 UTC', '2020-03-19 17:56:51 UTC', '2019-10-04 13:40:37 UTC', '2019-12-16 10:01:31 UTC', '2020-04-09 21:50:50 UTC', '2020-02-15 17:29:40 UTC', '2019-12-01 05:12:38 UTC', '2019-10-21 09:12:27 UTC', '2019-10-04 10:52:39 UTC', '2019-10-20 05:17:24 UTC', '2019-12-19 11:59:33 UTC', '2019-12-08 09:17:30 UTC




### Next steps
- performance vs time variation
- by category

In [None]:
test_ds[0]

{'graph': Data(x=[4, 1], edge_index=[2, 12], edge_weights=[12]),
 'masked_products': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 3, 2, 4, 5,
         5, 1]),
 'products': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 3, 2, 4, 5,
         5, 6]),
 'attention_mask': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
         1, 1]),
 'cloze_mask': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1])}