In [9]:
import torch
import sys
import json
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

from src.dataset import ECommerceDS
from src.metrics import compute_recall_at_k, compute_mrr
from src.metric_handler import MetricHandler

sys.path.append("src/")
from models.sr_gnn import SRGNN

In [10]:
product2token_fp = "data/product2token.json"
weight_fp = "results/srgnn_baseline/best_model.pth"
device = "mps"
test_ds_fp = "data/splits/test.jsonl"

In [11]:
with open(product2token_fp, mode="r") as f:
    product2token = json.load(f)

In [12]:
model = SRGNN(hidden_size=128, n_node=len(product2token), num_layers=2)
model.load_state_dict(torch.load(weight_fp))
model.to(device)
model.eval()
print("Loaded model weights")

Loaded model weights


  model.load_state_dict(torch.load(weight_fp))


In [13]:
test_ds = ECommerceDS(test_ds_fp, max_len=50, product2token=product2token, mask="last")
test_dl = DataLoader(test_ds, batch_size=256, shuffle=False)

In [14]:
metric_handler = MetricHandler("results/srgnn_baseline")

all_logits = list()
all_labels = list()
all_times = list()

with torch.no_grad():  # Disable gradient computation for evaluation
    for test_batch in tqdm(test_dl):
        x_val = test_batch["graph"].to(device)
        labels_val = test_batch["products"].to(device)
        cloze_mask_val = test_batch["cloze_mask"].to(device)

        logits_flat_val = model(x_val)
        labels_flat_val = labels_val[:, -1]
        cloze_mask_flat_val = cloze_mask_val[:, -1]

        # Select only the masked positions
        valid_indices_val = cloze_mask_flat_val == 1
        logits_masked_val = logits_flat_val[valid_indices_val]
        labels_masked_val = labels_flat_val[valid_indices_val]

        # Calculate validation loss and metrics
        test_loss = F.cross_entropy(logits_masked_val, labels_masked_val)
        batch_metrics = {
            "test_loss": test_loss.item(),
            "test_recall@1": compute_recall_at_k(logits_masked_val, labels_masked_val, k=1),
            "test_recall@5": compute_recall_at_k(logits_masked_val, labels_masked_val, k=5),
            "test_recall@10": compute_recall_at_k(logits_masked_val, labels_masked_val, k=10),
            "test_recall@20": compute_recall_at_k(logits_masked_val, labels_masked_val, k=20),
            "test_mrr": compute_mrr(logits_masked_val, labels_masked_val),
        }
        metric_handler.batch_update(batch_metrics)

metric_handler.all_update_save_clear(save_name="test_results.csv")

  logits_masked_val = logits_flat_val[valid_indices_val]
100%|██████████| 700/700 [02:59<00:00,  3.91it/s]


### Next steps
- performance vs time variation
- by category

In [15]:
test_ds[0]

{'graph': Data(x=[4, 1], edge_index=[2, 12], edge_weights=[12]),
 'masked_products': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 3, 2, 4, 5,
         5, 1]),
 'products': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 3, 2, 4, 5,
         5, 6]),
 'times': ['2020-01-05 03:49:57 UTC',
  '2020-01-21 06:20:35 UTC',
  '2020-01-21 06:24:57 UTC',
  '2020-01-23 04:00:00 UTC',
  '2020-01-28 08:08:00 UTC',
  '2020-01-31 09:46:24 UTC',
  '2020-02-03 04:27:15 UTC',
  '2020-02-05 09:12:57 UTC'],
 'attention_mask': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
         1, 1]),
 'cloze_mask': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 