In [None]:
import pandas as pd
from pathlib import Path
import sys
import os
from ast import literal_eval
from matplotlib import pyplot as plt
import numpy as np

# Add parent directory to sys.path
sys.path.append(os.path.abspath(".."))
import logger
from stark_qa.load_qa import load_qa

In [None]:
root_path = ""
datasets = ["prime", "mag", "amazon"]
llm = "openai/gpt-oss-120b"  #"GPT OSS 120B"
col_names_metrics = ["recall@20", "hit@20", "hit@5", "hit@1", "num_samples"]
qa_path = ""

In [None]:
def merge_dicts(dicts):
    merged = {}
    for d in dicts:
        for k, v in d.items():
            if k not in merged:
                merged[k] = []
            merged[k].append(v)
    return dict(merged)

In [None]:
cypher_strand_file_name = "unnamed_experiment/step6_vss.csv"
vss_strand_file_name = "unnamed_experiment/step7_vss.csv"
datasplit = "test" # val

datarows_r20s = {}
datarows_h20s = {}

for dataset in datasets:
    ds_s6 = pd.read_csv(Path(root_path) / dataset / datasplit / llm / cypher_strand_file_name, quotechar="|")
    ds_s7 = pd.read_csv(Path(root_path) / dataset / datasplit / llm / vss_strand_file_name, quotechar="|")
    cypher_answers = []
    vss_answers = []

    assert len(cypher_answers) == len(vss_answers)
    qa = load_qa(dataset, qa_path, human_generated_eval=False)
    gt = []
    for i in range(len(ds_s6)):
        query, qid, ground_truths, _ = qa[ds_s6["q_id"][i]]
        gt.append(ground_truths)
        cypher_answers.append(literal_eval(ds_s6[ds_s6["q_id"] == qid]["answers_vss"].iat[0]))
        vss_answers.append(literal_eval(ds_s7[ds_s7["q_id"] == qid]["answers_vss"].iat[0]))

    hit20s = []
    hit5s = []
    r20s = []
    hit1s = []
    mrrs = []
    for alpha in range(21):
        dicts = []
        answers = []
        for i in range(len(ds_s6)):
            answers_i = cypher_answers[i][:alpha]
            for a in vss_answers[i]:
                if len(answers_i) >= 20:
                    break
                if a not in answers_i:
                    answers_i.append(a)
            answers.append(answers_i)

        for i in range(len(ds_s6)):
            metrics = logger.calculate_metrics(answers[i], gt[i])
            dicts.append(metrics)

        merged_dicts = merge_dicts(dicts)
        hit20s.append(np.average(merged_dicts["hit_20"]))
        hit5s.append(np.average(merged_dicts["hit_5"]))
        r20s.append(np.average(merged_dicts["recall_20"]))
        hit1s.append(np.average(merged_dicts["hit_1"]))
        mrrs.append(np.average(merged_dicts["reciprocal_rank_20"]))
    datarows_r20s[dataset] = r20s
    datarows_h20s[dataset] = hit20s

In [None]:
plt.style.use("default")
# Example data
names = ["PRIME", "MAG", "AMAZON"]

# X positions: integers 0–20
x = list(range(len(names)))

# Plot

plt.plot(datarows_h20s[datasets[0]], marker="o", label="PRIME")
plt.plot(datarows_h20s[datasets[1]], marker="x", label="MAG")
plt.plot(datarows_h20s[datasets[2]], marker="*", label="AMAZON")
#plt.plot(hit5s, marker="o", label="hit@5 (for validation)")

# Add vertical line at x=13
plt.axvline(x=13, color='red', linestyle='--', label=r'$\lfloor \alpha/k \rfloor = \lfloor 2/3 \rfloor =13$')

# Full‑width horizontal lines at the two values
y_prime_at_13 = datarows_h20s[datasets[0]][13]
y_mag_at_13 = datarows_h20s[datasets[1]][13]
y_amazon_at_13 = datarows_h20s[datasets[2]][13]
print(y_prime_at_13, y_mag_at_13, y_amazon_at_13)

# Draw the lines
plt.axhline(y=y_prime_at_13,
            color='tab:blue', linestyle=':', linewidth=1,
            label=r'PRIME ($\alpha=13$) = ' + f'{y_prime_at_13*100:.1f}%')
plt.axhline(y=y_mag_at_13,
            color='tab:orange', linestyle=':', linewidth=1,
            label=r'MAG ($\alpha=13$) = ' + f'{y_mag_at_13*100:.1f}%')
plt.axhline(y=y_amazon_at_13,
            color='tab:green', linestyle=':', linewidth=1,
            label=r'AMAZON ($\alpha=13$) = ' + f'{y_amazon_at_13*100:.1f}%')


plt.xticks([0,2.5,5,6.666,10,13.333,15,17.5,20], labels=["0", "1/8", "1/4", "1/3", "1/2", "2/3", "3/4", "7/8", "1"])

yticks = np.array(range(3,11)) / 10
plt.yticks(yticks, labels=[f"{y*100:.0f}%" for y in yticks])

# Labels
plt.xlabel(r"$\alpha / k$")
plt.ylabel("average hit@20")
#plt.title(r"hit@20 in dependence of $\alpha$ on validation set")
plt.legend(loc='lower center', ncol=2)
plt.grid(True, linestyle="--", alpha=0.6)
plt.savefig("alpha_h20_test.png")
plt.savefig("alpha_h20_test.svg")

#plt.tight_layout()
plt.show()


In [None]:
# Example data
names = ["PRIME", "MAG", "AMAZON"]

# X positions: integers 0–20
x = list(range(len(names)))

# Plot

plt.plot(datarows_r20s[datasets[0]], marker="o", label="PRIME")
plt.plot(datarows_r20s[datasets[1]], marker="x", label="MAG")
plt.plot(datarows_r20s[datasets[2]], marker="*", label="AMAZON")
#plt.plot(hit5s, marker="o", label="hit@5 (for validation)")

# Add vertical line at x=13
plt.axvline(x=13, color='red', linestyle='--', label=r'$\lfloor \alpha/k \rfloor = \lfloor 2/3 \rfloor =13$')

# Full‑width horizontal lines at the two values
y_prime_at_13 = datarows_r20s[datasets[0]][13]
y_mag_at_13 = datarows_r20s[datasets[1]][13]
y_amazon_at_13 = datarows_r20s[datasets[2]][13]
print(y_prime_at_13, y_mag_at_13, y_amazon_at_13)

# Draw the lines
plt.axhline(y=y_prime_at_13,
            color='tab:blue', linestyle=':', linewidth=1,
            label=r'PRIME ($\alpha=13$) = ' + f'{y_prime_at_13*100:.1f}%')
plt.axhline(y=y_mag_at_13,
            color='tab:orange', linestyle=':', linewidth=1,
            label=r'MAG ($\alpha=13$) = ' + f'{y_mag_at_13*100:.1f}%')
plt.axhline(y=y_amazon_at_13,
            color='tab:green', linestyle=':', linewidth=1,
            label=r'AMAZON ($\alpha=13$) = ' + f'{y_amazon_at_13*100:.1f}%')


plt.xticks([0,2.5,5,6.666,10,13.333,15,17.5,20], labels=["0", "1/8", "1/4", "1/3", "1/2", "2/3", "3/4", "7/8", "1"])

yticks = np.array(range(2,8)) / 10
plt.yticks(yticks, labels=[f"{y*100:.0f}%" for y in yticks])

# Labels
plt.xlabel(r"$\alpha / k$")
plt.ylabel("average hit@20")
#plt.title(r"hit@20 in dependence of $\alpha$ on validation set")
#.legend()
plt.grid(True, linestyle="--", alpha=0.6)
plt.savefig("alpha_r20_test.png")
plt.savefig("alpha_r20_test.svg")

#plt.tight_layout()
plt.show()