In [2]:
import os

from ACL2024.modules.dataset.user_graph_dataset import UserGraphDataset
from ACL2024.modules.dataset.compile_dataset import UserGraphDatasetInMemory
from ACL2024.modules.dataset.compile_dataset import fetch_question_ids
from ACL2024.modules.util.get_root_dir import get_project_root

In [3]:
ROOT = os.path.join(get_project_root(), 'data')
MODEL = os.path.join(get_project_root(), 'modules', 'models', 'out', 'MODEL_OUT.pt')
DB_ADDRESS = os.path.join(get_project_root(), 'data', 'raw', 'g4so.db')

In [4]:
import sqlite3

db = sqlite3.connect(DB_ADDRESS)

In [5]:
from ACL2024.modules.models.GNNs.hetero_GAT import HeteroGAT
import torch
# Load model (EXP 1)
model = HeteroGAT(out_channels=2, num_layers=2, hidden_channels=64, dropout=0.3, vertex_types=['question', 'answer', 'comment', 'tag', 'module'], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model.load_state_dict(torch.load(MODEL, map_location=torch.device('cpu')), strict=False)
model.eval()

2023-08-27 20:31:32 INFO     MODEL: GAT


HeteroGAT(
  (convs): ModuleList(
    (0): HeteroConv(num_relations=10)
    (1): HeteroConv(num_relations=10)
  )
  (lin1): Linear(4992, 64, bias=True)
  (lin2): Linear(64, 2, bias=True)
  (softmax): Softmax(dim=-1)
)

In [6]:
question_ids = fetch_question_ids(ROOT)
train_ids = list(question_ids)[:int(len(question_ids) * 0.8)]
test_ids = [x for x in question_ids if x not in train_ids]
len(train_ids)

5956

In [7]:
import torch
import re
import os


def fetch_answer_files(question_id: int):
    for f in os.listdir(os.path.join(ROOT, 'processed')):
        question_id_search = re.search(r"id_(\d+)", f)
        if question_id_search:
            if int(question_id_search.group(1)) == question_id:
                yield torch.load(os.path.join(ROOT, 'processed', f))


In [30]:
from ACL2024.modules.util.db_query import fetch_answers_for_question
from torch_geometric.loader import DataLoader
from sklearn.metrics import ndcg_score


for _ in range(1):
    results = []

    for i in test_ids[:200]:
        graph_files = list(fetch_answer_files(i))
        dataloader = DataLoader(graph_files, batch_size=1, shuffle=False)
        print(f'Question ID: {i}')
        for i, data in enumerate(dataloader):
            # For each answer, predict whether it is the accepted answer
            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1)

            out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
            results.append((out[0][1].item(), data.score.item()))


    CORRECT = 0
    COUNT = 0

    if len(results) > 1:

        # calculate precision@1
        results.sort(key=lambda x: x[0], reverse=True)

        if results[0][1] >= max([x[1] for x in results[1:]]):
            CORRECT += 1

        COUNT += 1

    # Precicison@1
    print(f'Precision@1: {CORRECT / COUNT}')

# TODO: Because accepted answer indicator is wrong, we can see which answer ids are accepted and which are not from the database


Question ID: 2319838
Question ID: 2090464
Question ID: 648675
Question ID: 1893867
Question ID: 2090479
Question ID: 2483696
Question ID: 255476
Question ID: 189943
Question ID: 1697273
Question ID: 1828345
Question ID: 321024
Question ID: 714242
Question ID: 419334
Question ID: 452104
Question ID: 976395
Question ID: 1762831
Question ID: 1271320
Question ID: 157211
Question ID: 1828379
Question ID: 222752
Question ID: 1271337
Question ID: 550446
Question ID: 1500718
Question ID: 583216
Question ID: 2319928
Question ID: 1566266
Question ID: 190010
Question ID: 878143
Question ID: 2090564
Question ID: 550474
Question ID: 910930
Question ID: 1271378
Question ID: 2090582
Question ID: 1369697
Question ID: 1205863
Question ID: 2123369
Question ID: 1631855
Question ID: 2582138
Question ID: 222877
Question ID: 91810
Question ID: 157359
Question ID: 452283
Question ID: 976577
Question ID: 386753
Question ID: 452300
Question ID: 452305
Question ID: 2483924
Question ID: 779989
Question ID: 25570

In [9]:
"""
DUMPING GROUND
"""
results = []
for answer, graph in zip(fetch_answers_for_question(i, db).itertuples(), dataloader):

    post_emb = torch.cat([graph.question_emb, graph.answer_emb], dim=1)

    out = model(graph.x_dict, graph.edge_index_dict, graph.batch_dict, post_emb)  # Perform a single forward pass.
    results.append((out[0][1].item(), answer.Score))
if len(results) > 1:
    # calculate ndcg
    results.sort(key=lambda x: x[0], reverse=True)
    NDCGS.append(ndcg_score([[x[1] for x in results]], [[x[0] for x in results]], k=5))

    # calculate precision@1
    results.sort(key=lambda x: x[0], reverse=True)
    if results[0][1] >= max([x[1] for x in results[1:]]):
        #print('Correct')
        CORRECT += 1
    else:
        #print('Incorrect')
        pass
    COUNT += 1

'\nDUMPING GROUND\n'