In [3]:
import os

from ACL2024.modules.dataset.user_graph_dataset import UserGraphDataset
from ACL2024.modules.dataset.compile_dataset import UserGraphDatasetInMemory
from ACL2024.modules.dataset.compile_dataset import fetch_question_ids
from ACL2024.modules.util.get_root_dir import get_project_root

In [8]:
ROOT = os.path.join(get_project_root(), 'data_tmp')
MODEL = os.path.join(get_project_root(), 'modules', 'models', 'out', 'ALPHA.pt')
DB_ADDRESS = os.path.join(get_project_root(), 'data', 'raw', 'g4so.db')

In [9]:
import sqlite3

db = sqlite3.connect(DB_ADDRESS)

In [10]:
from ACL2024.modules.models.GNNs.hetero_GAT import HeteroGAT
import torch
# Load model (EXP 1)
model = HeteroGAT(out_channels=2, num_layers=3, hidden_channels=64, dropout=0.3, vertex_types=['question', 'answer', 'comment', 'tag', 'module'], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model.load_state_dict(torch.load(MODEL, map_location=torch.device('cpu')), strict=False)
model.eval()

2023-09-03 20:48:48 INFO     MODEL: GAT


HeteroGAT(
  (convs): ModuleList(
    (0): HeteroConv(num_relations=10)
    (1): HeteroConv(num_relations=10)
    (2): HeteroConv(num_relations=10)
  )
  (lin1): Linear(3392, 64, bias=True)
  (lin2): Linear(64, 2, bias=True)
  (softmax): Softmax(dim=-1)
)

In [11]:
question_ids = fetch_question_ids(ROOT)
train_ids = list(question_ids)[:int(len(question_ids) * 0.8)]
test_ids = [x for x in question_ids if x not in train_ids]
len(train_ids)

2156

In [12]:
import torch
import re
import os


def fetch_answer_files(question_id: int):
    for f in os.listdir(os.path.join(ROOT, 'processed')):
        question_id_search = re.search(r"id_(\d+)", f)
        if question_id_search:
            if int(question_id_search.group(1)) == question_id:
                yield torch.load(os.path.join(ROOT, 'processed', f))


In [19]:
from ACL2024.modules.util.db_query import fetch_answers_for_question
from torch_geometric.loader import DataLoader
from sklearn.metrics import ndcg_score


for _ in range(1):
    CORRECT = 0
    COUNT = 0

    for i in test_ids[:200]:
        results = []
        accepted_answer = None

        graph_files = list(fetch_answer_files(i))
        dataloader = DataLoader(graph_files, batch_size=1, shuffle=False)
        print(f'Question ID: {i}')
        for i, data in enumerate(dataloader):
            if data.accepted:
                accepted_answer = i

            # For each answer, predict whether it is the accepted answer
            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1)

            out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb, None, None, None)  # Perform a single forward pass.
            results.append((i, out[0][1].item()))


        if len(results) > 1 and accepted_answer is not None:

            # calculate precision@1
            results.sort(key=lambda x: x[1], reverse=True)

            pred_index, pred_confidence = results[0]
            if pred_index == accepted_answer:
                CORRECT += 1

            print(results, accepted_answer)

            COUNT += 1

    # Precicison@1
    print(f'Precision@1: {CORRECT / COUNT}')

# TODO: Because accepted answer indicator is wrong, we can see which answer ids are accepted and which are not from the database


Question ID: 285061
[(4, 0.9979262351989746), (3, 0.9922215342521667), (0, 0.9731692671775818), (1, 0.9618261456489563), (2, 0.21306012570858002)] 0
Question ID: 473498
[(2, 0.9921720027923584), (1, 0.9882916808128357), (3, 0.18909239768981934), (0, 0.024139879271388054), (4, 0.0017281874315813184)] 1
Question ID: 620954
[(2, 0.9984024167060852), (3, 0.997868537902832), (1, 0.951692521572113), (0, 0.2257421463727951)] 1
Question ID: 481692
[(1, 0.9989075660705566), (0, 0.9987720847129822), (3, 0.9956104159355164), (7, 0.993486225605011), (10, 0.9881060123443604), (6, 0.9614726901054382), (11, 0.9610804319381714), (5, 0.912511944770813), (4, 0.8150637149810791), (12, 0.6079466938972473), (14, 0.4361797869205475), (9, 0.30980515480041504), (8, 0.2565757930278778), (13, 0.1840936243534088), (2, 0.09556128084659576)] 0
Question ID: 219547
[(2, 0.9688692092895508), (3, 0.9649320840835571), (1, 0.958297848701477), (0, 0.5794697403907776), (4, 0.0038595835212618113)] 1
Question ID: 342434
[(7

In [9]:
"""
DUMPING GROUND
"""
results = []
for answer, graph in zip(fetch_answers_for_question(i, db).itertuples(), dataloader):

    post_emb = torch.cat([graph.question_emb, graph.answer_emb], dim=1)

    out = model(graph.x_dict, graph.edge_index_dict, graph.batch_dict, post_emb)  # Perform a single forward pass.
    results.append((out[0][1].item(), answer.Score))
if len(results) > 1:
    # calculate ndcg
    results.sort(key=lambda x: x[0], reverse=True)
    NDCGS.append(ndcg_score([[x[1] for x in results]], [[x[0] for x in results]], k=5))

    # calculate precision@1
    results.sort(key=lambda x: x[0], reverse=True)
    if results[0][1] >= max([x[1] for x in results[1:]]):
        #print('Correct')
        CORRECT += 1
    else:
        #print('Incorrect')
        pass
    COUNT += 1

'\nDUMPING GROUND\n'