In [51]:
import os

from dataset import UserGraphDataset
from dataset_in_memory import UserGraphDatasetInMemory
from dataset_in_memory import fetch_question_ids

In [52]:
ROOT = '../data/'
MODEL = '../models/SAGE_3l_60e_64h.pt'

In [53]:
from hetero_GAT import HeteroGAT
import torch
# Load model (EXP 1)
model = HeteroGAT(out_channels=2, num_layers=2, hidden_channels=64)
model.load_state_dict(torch.load(MODEL, map_location=torch.device('cpu')), strict=False)
model.eval()

HeteroGAT(
  (convs): ModuleList(
    (0): HeteroConv(num_relations=10)
    (1): HeteroConv(num_relations=10)
  )
  (lin1): Linear(3392, 64, bias=True)
  (lin2): Linear(64, 2, bias=True)
  (softmax): Softmax(dim=-1)
)

In [54]:
question_ids = fetch_question_ids(ROOT)
train_ids = list(question_ids)[:int(len(question_ids) * 0.8)]
test_ids = [x for x in question_ids if x not in train_ids]

In [55]:
# Instantiate the dataset class for access to database
ugd = UserGraphDataset(ROOT, db_address='../stackoverflow.db', skip_processing=True)

2023-04-16 20:08:39 INFO     PostEmbedding instantiated!


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [56]:
import torch
import re
import os


def fetch_answer_files(question_id: int):
    for f in os.listdir(os.path.join(ROOT, 'processed')):
        question_id_search = re.search(r"id_(\d+)", f)
        if question_id_search:
            if int(question_id_search.group(1)) == question_id:
                yield torch.load(os.path.join(ROOT, 'processed', f))


In [57]:
from torch_geometric.loader import DataLoader
from sklearn.metrics import ndcg_score

for _ in range(1):
    NDCGS = []
    CORRECT = 0
    COUNT = 0
    for i in test_ids[:600]:
        graph_files = list(fetch_answer_files(i))
        dataloader = DataLoader(graph_files, batch_size=1, shuffle=False)

        results = []
        for answer, graph in zip(ugd.fetch_answers_for_question(i).itertuples(), dataloader):

            post_emb = torch.cat([graph.question_emb, graph.answer_emb], dim=1)

            out = model(graph.x_dict, graph.edge_index_dict, graph.batch_dict, post_emb)  # Perform a single forward pass.
            results.append((out[0][1].item(), answer.Score))
        if len(results) > 1:
            # calculate ndcg
            results.sort(key=lambda x: x[0], reverse=True)
            NDCGS.append(ndcg_score([[x[1] for x in results]], [[x[0] for x in results]], k=5))

            # calculate precision@1
            results.sort(key=lambda x: x[0], reverse=True)
            if results[0][1] >= max([x[1] for x in results[1:]]):
                #print('Correct')
                CORRECT += 1
            else:
                #print('Incorrect')
                pass
            COUNT += 1
    print(f'Precision@1: {CORRECT / COUNT}')



Precision@1: 0.36915077989601386


In [58]:
sum(NDCGS)/ len(NDCGS)

0.7361403483124718

In [59]:
CORRECT

213

In [60]:
CORRECT/ COUNT

0.36915077989601386