#### Copyright (c) 2025 Graphcore Ltd. All rights reserved.

## GTSQA

Post-processing used to produce the final version of GTSQA, from the data generated with `synth_kgqa`.

In [1]:
import json
import os.path as osp
import pickle

import numpy as np
import pandas as pd
from tqdm import tqdm

## Load data

ogbl-wikig2 (preprocessed as in `notebooks/preprocess_wikikg2.ipynb`)

In [3]:
wikikg2_path = "../data/ogbl_wikikg2/"

node_qids = np.load(
    wikikg2_path + "node_ids.npy",
    allow_pickle=True,
)
relation_pids = np.load(
    wikikg2_path + "relation_ids.npy",
    allow_pickle=True,
)

node_labels = np.load(
    wikikg2_path + "node_labels.npy",
    allow_pickle=True,
)
relation_labels = np.load(
    wikikg2_path + "relation_labels.npy",
    allow_pickle=True,
)

Train and test datasets, in the structure of the outputs of `synth-kgqa/process_qa.py`

In [None]:
train_ds = json.load(
    open(
        "../data/train.jsonl",
        "rb",
    )
)
test_ds = json.load(
    open(
        "../data/test.jsonl",
        "rb",
    )
)

len(train_ds), len(test_ds)

(30477, 1622)

## Process datasets

In [None]:
def process_dataset(ds, split):
    ds_processed = []
    for dp in tqdm(ds):
        minimal_s_and_q = dict()
        seeds = [node_qids[i] for i in dp["seed_nodes_id"]]
        minimal_len = 99999
        for s, queries in dp["minimal_query_per_seed"].items():
            for q in queries:
                q_seeds = tuple(set([a for a in seeds if "wd:" + a in q]))
                if len(q_seeds) < minimal_len:
                    minimal_len = len(q_seeds)
                    minimal_s_and_q = dict()
                if len(q_seeds) <= minimal_len:
                    minimal_s_and_q[q_seeds] = q

        ans_subgraph = []
        for h, r, t in dp["answer_subgraph"]:
            ans_subgraph.append(
                [
                    node_labels[h] + f" ({node_qids[h]})",
                    relation_labels[r] + f" ({relation_pids[r]})",
                    node_labels[t] + f" ({node_qids[t]})",
                ]
            )

        dp_clean = {
            "id": dp["id"],
            "question": dp["question"],
            "paraphrased_question": (
                dp["paraphrased_question"] if split == "train" else None
            ),
            "seed_entities": dp["seed_nodes"],
            "answer_node": dp["answer_node"],
            "answer_subgraph": ans_subgraph,
            "sparql_query": dp["sparql_query"],
            "all_answers_wikidata": dp["all_answers"],
            "full_answer_subgraph_wikidata": dp["full_subgraph"],
            "all_answers_wikikg2": dp["all_answers_wikikg2"],
            "full_answer_subgraph_wikikg2": dp["full_subgraph_wikikg2"],
            "n_hops": dp["n_hops"],
            "graph_isomorphism": dp["graph_template"],
            "redundant": dp["redundant"],
            "minimal_graph_isomorphism": dp["minimal_graph_templates"],
            "minimal_seeds_and_queries": {
                "-".join(k): v for k, v in minimal_s_and_q.items()
            },
            "test_type": dp["test_type"] if split == "test" else ["training"],
        }

        ds_processed.append(dp_clean)
    return ds_processed

### Train set
Processed in chunks

In [None]:
train_clean = process_dataset(train_ds, split="train")
train_df = pd.DataFrame(train_clean)
train_df["minimal_seeds_and_queries"] = train_df["minimal_seeds_and_queries"].astype(
    str
)
train_df.to_parquet("../HF_GTSQA/gtsqa/train.parquet")


N_BLOCKS = 22
block_size = int(np.ceil(88 / N_BLOCKS))

k = 0
block_start = 0
for b in range(N_BLOCKS):
    block_tot = 0
    for i in range(block_size):
        partial_res = []
        file_name = f"../data/train.jsonl_scores_{b*block_size + i}.pkl"  # outputs of synth-kgqa/compute_neighs_and_sp.py (chunked)
        if not osp.exists(file_name):
            continue
        print(f"Processing {file_name}")
        for sample_scores in pickle.load(open(file_name, "rb")):
            partial_res.append(sample_scores)
        for sample_scores in tqdm(partial_res):
            assert train_clean[k]["id"] == sample_scores["id"]
            text_list = []
            for h, r, t in zip(
                sample_scores["h_id_list"],
                sample_scores["r_id_list"],
                sample_scores["t_id_list"],
            ):
                text_list.append(
                    [
                        node_labels[h] + f" ({node_qids[h]})",
                        relation_labels[r] + f" ({relation_pids[r]})",
                        node_labels[t] + f" ({node_qids[t]})",
                    ]
                )
            train_clean[k].update({"graph": text_list})
            k += 1
        block_tot += len(partial_res)
    block_clean = train_clean[block_start : block_start + block_tot]
    save_path = f"../HF_GTSQA/gtsqa-with-graphs/train-{'{:05d}'.format(b)}-of-{'{:05d}'.format(N_BLOCKS)}.parquet"
    print(f"Saving span {block_start}--{block_start+block_tot} to {save_path}")
    block_df = pd.DataFrame(block_clean)
    block_df["minimal_seeds_and_queries"] = block_df[
        "minimal_seeds_and_queries"
    ].astype(str)
    block_df.to_parquet(save_path)
    block_start = block_start + block_tot

100%|██████████| 30477/30477 [00:01<00:00, 15704.42it/s]


## Test set

In [None]:
test_clean = process_dataset(test_ds, split="test")

test_ds_scores = pickle.load(
    open("../data/test_scores.pkl", "rb")
)  # output of synth-kgqa/compute_neighs_and_sp.py

test_df = pd.DataFrame(test_clean)
test_df["minimal_seeds_and_queries"] = train_df["minimal_seeds_and_queries"].astype(str)
test_df.to_parquet("../HF_GTSQA/gtsqa/test.parquet")

for sample, sample_scores in tqdm(
    zip(test_clean, test_ds_scores), total=len(test_clean)
):
    assert sample["id"] == sample_scores["id"]
    text_list = []
    for h, r, t in zip(
        sample_scores["h_id_list"],
        sample_scores["r_id_list"],
        sample_scores["t_id_list"],
    ):
        text_list.append(
            [
                node_labels[h] + f" ({node_qids[h]})",
                relation_labels[r] + f" ({relation_pids[r]})",
                node_labels[t] + f" ({node_qids[t]})",
            ]
        )
    sample.update({"graph": text_list})

test_df = pd.DataFrame(test_clean)
test_df["minimal_seeds_and_queries"] = test_df["minimal_seeds_and_queries"].astype(str)
test_df.to_parquet("../HF_GTSQA/gtsqa-with-graphs/test-00000-of-00001.parquet")

100%|██████████| 1622/1622 [00:00<00:00, 66666.94it/s]
