**Imports**

In [None]:
import pandas as pd
import networkx as nx
import json
import os
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

Load KG

In [None]:
KG_path = r"../data/arch_kg.csv"

df = pd.read_csv(KG_path)
for _,row in df.head(5).iterrows():
    print(f'{row["Entity1"]} ({row["Label1"]})->[{row["Relationship"]}]->{row["Entity2"]} ({row["Label2"]})')

Create DiGraph (Knowledge Graph)

In [None]:
G = nx.DiGraph()

for _, row in df.iterrows():
    G.add_node(row["Entity1"], label=row["Label1"])
    G.add_node(row["Entity2"], label=row["Label2"])

    G.add_edge(
        row["Entity1"],
        row["Entity2"],
        relationship=row["Relationship"],
        source=row["SourceFilename"],
        chunk=row["ChunkID"]
    )

print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} egdes.")

Create MMLU Dataframe

In [None]:
import json
benchmark_path = r"../data/benchmark.json"

with open(benchmark_path, "r") as f:
    data = json.load(f)

mmlu_data = data["mmlu"]

flat_data = []
for qid, qdata in mmlu_data.items():
    flat_data.append({
        "id": qid,
        "question": qdata["question"],
        "options": qdata["options"],
        "answer": qdata["answer"]
    })

mmlu_df = pd.DataFrame(flat_data)
print(mmlu_df.head())

Create MedMCQA Dataframe

In [None]:
import json
benchmark_path = r"../data/benchmark.json"

with open(benchmark_path, "r") as f:
    data = json.load(f)

medmcqa_data = data["medmcqa"]

flat_data = []
for qid, qdata in medmcqa_data.items():
    flat_data.append({
        "id": qid,
        "question": qdata["question"],
        "options": qdata["options"],
        "answer": qdata["answer"]
    })

medmcqa_df = pd.DataFrame(flat_data)
print(medmcqa_df.head())

Pydantic Related Methods

In [None]:
from langchain_ollama import OllamaLLM
from langchain.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field, field_validator
from typing import List, Tuple

class EntityPair(BaseModel):
    entity1: str = Field(...)
    entity2: str = Field(...)

    @field_validator("entity1", "entity2")
    def not_empty(cls, v):
        if not v.strip():
            raise ValueError("Entity cannot be empty")
        return v.strip()

def preprocess_line(line: str) -> str:
    line = line.strip()

    if line.startswith("*"):
        line = line.lstrip("*").strip()

    if line.startswith("(") and line.endswith(")"):
        line = line[1:-1].strip()

    if (line.startswith("'") and line.endswith("'")) or (line.startswith('"') and line.endswith('"')):
        line = line[1:-1].strip()

    return line

def clean_entity(entity: str) -> str:
    entity = entity.strip()
    while entity.startswith(("'", '"')) or entity.endswith(("'", '"')):
        if entity.startswith(("'", '"')):
            entity = entity[1:]
        if entity.endswith(("'", '"')):
            entity = entity[:-1]
        entity = entity.strip()
    return entity

def parse_pairs(raw_text: str) -> List[EntityPair]:
    lines = raw_text.strip().splitlines()
    pairs = []

    for line in lines:
        line = preprocess_line(line)

        if not line or "," not in line:
            continue

        parts = line.split(",", 1)
        if len(parts) != 2:
            raise ValueError(f"Invalid pair format: {line}")
        
        entity1 = clean_entity(parts[0])
        entity2 = clean_entity(parts[1])

        if not entity1 or not entity2:
            print(f"Skipping invalid pair with empty entity: '{entity1}', '{entity2}'")
            continue

        pair = EntityPair(entity1=entity1, entity2=entity2)
        pairs.append(pair)
    return pairs

Helper Functions

In [None]:
def remove_duplicates_except_underscores(lst):
    seen = set()
    result = []
    for item in lst:
        if item == "__":
            result.append(item)
        elif item not in seen:
            seen.add(item)
            result.append(item)
    return result

def format_edge_path(path):
    formatted = ""
    for i in range(0, len(path) - 2, 2):
        source = path[i]
        relation = path[i + 1]
        target = path[i + 2]
        formatted += f"{source}->[{relation}]->"
    formatted += path[-1]  # Add last node
    return formatted

def remove_direct_neighbors_only(data):
    indices_to_remove = set()
    for i, val in enumerate(data):
        if val == '__':
            if i > 0 and data[i - 1] != '__':
                indices_to_remove.add(i - 1)
            if i < len(data) - 1 and data[i + 1] != '__':
                indices_to_remove.add(i + 1)
    return [val for i, val in enumerate(data) if i not in indices_to_remove]

RAG and Perturbation Methods Definition

In [None]:
model = OllamaLLM(model="llama3.2:3b-instruct-fp16", temperature=0)

prompt = ChatPromptTemplate.from_template('''
    You are given the following question in a medical context:
                                          
        {question}
                                          
    Your task is to extract medically relevant entity pairs from the text.
    Return the output strictly in the format:
        
        <entity1, entity2>
                                          
    Make sure each pair is unique and avoid repeating the same entity in both positions.
                                          
    Do not explain or add anything else - just return the entity pairs.
''')

prompt_pseudo_paragraph = ChatPromptTemplate.from_template("""
You are a medical instructor helping to generate educational materials.
                                                    
Given a relationship chain connecting medical concepts, write a short, clear and medically accurate paragraph that explains the connections in a way understandable to students.
                                                           
Input chain:
{relation_chain}
                                                           
Output:
A single, coherent paragraph explaining the relationships. Do not include any further comments, only your generated answer.
""")

RAG_prompt = ChatPromptTemplate.from_template("""
You are a knowledgeable medical assistant.

Use the following medical paragraph to answer the multiple-choice question below.
Choose the best answer from the provided options based solely on the paragraph.
                                              
Medical Paragraph:
{paragraph}
                                              
Question:
{question}
                                              
Options:
{options}
                                              
Answer:
Provide only the letter (A, B, C or D) corresponding to the corrected choice.
Do not provide any additional explanation or commentary.
""")

chain = prompt | model
chain_par = prompt_pseudo_paragraph | model
rag_chain = RAG_prompt | model 

def extract_entities(question):
    response = chain.invoke({"question": question})
    pairs = parse_pairs(response)

    return pairs

def find_shortest_path(entity1, entity2):
    path = ""
    node_list = []
    edge_list = []
    subpath_list = []

    try:
        shortest_path = nx.shortest_path(G, entity1, entity2)
        # print(f"Shortest path: ", shortest_path)

        for i in range(len(shortest_path) - 1):
            u = shortest_path[i]
            v = shortest_path[i+1]
            edge_data = G.get_edge_data(u, v)
            relationship = edge_data.get("relationship","unknown") if edge_data else "no_relationship"
            # print(f"{u}->[{relationship}]->{v}")

            path += f"{u}->[{relationship}]->"

            node_list.append(u)
            edge_list.append(relationship)
            subpath_list.append((u, relationship, v))
            
        node_list.append(shortest_path[-1])
        path += f"{shortest_path[-1]}"
    except nx.NetworkXNoPath:
        # print(f"No path found between {entity1} and {entity2}!")
        return
    except nx.NodeNotFound as e:
        # print(f"Node not found error: {e}")
        return

    # print(f"Shortest path: {path}")
    return [node_list, edge_list, subpath_list, path]

def generate_pseudo_paragraph(path):
    paragraph = chain_par.invoke({"relation_chain": path})
    # print(f"Paragraph: {paragraph}")

    return paragraph

def run_rag(paragraph, question, options):
    response = rag_chain.invoke({"paragraph": paragraph, "question": question, "options": options})
    return response[0]

def generate_perturbations(node_list, edge_list, subpath_list):
    node_perturbations = []
    edge_perturbations = []
    subpath_perturbations = []

    for i in range(len(node_list)):
        perturbation = node_list[:i]+["__"]+node_list[i+1:]

        temp = ""
        for j in range(len(node_list)-1):
            temp += f"{perturbation[j]}->[{edge_list[j]}]->"
        
        temp += perturbation[-1]
        node_perturbations.append((temp, node_list[i], G.nodes[node_list[i]].get("label")))

    for i in range(len(edge_list)):
        perturbation = edge_list[:i]+["__"]+edge_list[i+1:]

        temp = ""
        for j in range(len(node_list)-1):
            temp += f"{node_list[j]}->[{perturbation[j]}]->"

        temp += node_list[-1]
        edge_perturbations.append(temp)

    for i in range(len(subpath_list)):
        perturbation = subpath_list[:i]+[("__","__","__")]+subpath_list[i+1:]

        elements = []
        for p in perturbation:
            elements.append(p[0])
            elements.append(p[1])
            elements.append(p[2])

        elements = remove_duplicates_except_underscores(elements)
        elements = remove_direct_neighbors_only(elements)
        temp = format_edge_path(elements)

        subpath_perturbations.append(temp)

    return {
        "node_perturbations": node_perturbations,
        "edge_perturbations": edge_perturbations,
        "subpath_perturbations": subpath_perturbations
        }

def produce_default_answer(question, options):
    
    entity_pairs = extract_entities(question)
    entity_pairs = [(pair.entity1, pair.entity2) for pair in entity_pairs]

    paths = []

    for pair in entity_pairs:
        shortest_path = find_shortest_path(pair[0], pair[1])
        if shortest_path:
            paths.append(shortest_path[3])
    
    # print(f"Default paths:\n")
    paragraphs = []
    for p in paths:
        # print(f"\t{p}")
        par = generate_pseudo_paragraph(p)
        paragraphs.append(par)

    paragraphs = "\n".join(paragraphs)
    # print(f"\nDefault paragraphs: {paragraphs}\n")

    response = run_rag(paragraphs, question, options)
    # print(f"Default Response: {response}")
    # print("\n")

    # print("#"*40)
    # print("\n")

    return response, len(paragraphs.split())


Application of Perturbations

In [None]:
responses = {}

for test_idx,example in tqdm(list(medmcqa_df.iterrows()), desc="Apply Perturbations", total=len(list(medmcqa_df.iterrows()))):

    node_counter = 0
    edge_counter = 0
    subpath_counter = 0

    label_list = []

    node_positions = []
    edge_positions = []
    subpath_positions = []

    question = example["question"]
    options = example["options"]
    answer = example["answer"]

    default_response, _ = produce_default_answer(question, options)
    # print(default_response)

    path_lists = []
    paths = []

    entity_pairs = extract_entities(question)
    entity_pairs = [(pair.entity1, pair.entity2) for pair in entity_pairs]
    # print(f"Entity Pairs: {entity_pairs}\n")

    for pair in entity_pairs:
        shortest_path = find_shortest_path(pair[0], pair[1])
        if shortest_path:
            path_lists.append(shortest_path[:3])
            paths.append(shortest_path[3])

    # print()
    # print(f"Paths:\n")
    # for path in paths:
    #     print(f"\t{path}")

    # print("\n")

    if len(paths) == 0:
        continue

    responses[f"test_{test_idx}"] = {}
    responses[f"test_{test_idx}"]["node"] = {}
    responses[f"test_{test_idx}"]["edge"] = {}
    responses[f"test_{test_idx}"]["subpath"] = {}

    all_perturbations = []

    for i in range(len(paths)):
        perturbations = generate_perturbations(path_lists[i][0], path_lists[i][1], path_lists[i][2])
        all_perturbations.append(perturbations)


    for index, perturbation in enumerate(all_perturbations):
        node_perturbations = perturbation["node_perturbations"]
        edge_perturbations = perturbation["edge_perturbations"]
        subpath_perturbations = perturbation["subpath_perturbations"]

        # print(f"Node perturbations")

        for i, perturbation in enumerate(node_perturbations):
            temp_paths = paths[:]
            temp_paths.pop(index)
            
            temp_paths.insert(index, perturbation[0])

            # print(temp_paths)

            paragraphs = []
            for p in temp_paths:
                par = generate_pseudo_paragraph(p)
                paragraphs.append(par)

            # print(paragraphs)

            paragraphs = "\n".join(paragraphs)
            response = run_rag(paragraphs, question, options)

            if response != default_response:
                responses[f"test_{test_idx}"]["node"][f"{index}_{i}_node"] = {"perturbation": perturbation, "paragraphs": paragraphs, "answer": response, "changed": 1}
                node_counter += 1

                label_list.append(perturbation[-1])
                node_positions.append(i/len(node_perturbations))
                
            else:

                responses[f"test_{test_idx}"]["node"][f"{index}_{i}_node"] = {"perturbation": perturbation, "paragraphs": paragraphs, "answer": response, "changed": 0}


            # print(perturbation)
            # print(f"Response: {response}")
            # print(f"Original Answer: {answer}")
            # print("*"*30)
            # print()

        responses[f"test_{test_idx}"]["node"]["perturbation_amount"] = len(node_perturbations)

        # print("#"*40)

        # print(f"Edge perturbations")

        for i, perturbation in enumerate(edge_perturbations):
            temp_paths = paths[:]
            temp_paths.pop(index)
            
            temp_paths.insert(index, perturbation)

            # print(temp_paths)

            paragraphs = []
            for p in temp_paths:
                par = generate_pseudo_paragraph(p)
                paragraphs.append(par)

            # print(paragraphs)

            paragraphs = "\n".join(paragraphs)
            response = run_rag(paragraphs, question, options)

            if response != default_response:
                responses[f"test_{test_idx}"]["edge"][f"{index}_{i}_edge"] = {"perturbation": perturbation, "paragraphs": paragraphs, "answer": response, "changed": 1}
                edge_counter += 1

                edge_positions.append(i/len(edge_perturbations))
            
            else:

                responses[f"test_{test_idx}"]["edge"][f"{index}_{i}_edge"] = {"perturbation": perturbation, "paragraphs": paragraphs, "answer": response, "changed": 0}

            
            # print(perturbation)
            # print(f"Response: {response}")
            # print(f"Original Answer: {answer}")
            # print("*"*30)
            # print()

        responses[f"test_{test_idx}"]["edge"]["perturbation_amount"] = len(edge_perturbations)

        # print("#"*40)

        # print(f"Sub-path perturbations")

        for i, perturbation in enumerate(subpath_perturbations):
            temp_paths = paths[:]
            temp_paths.pop(index)
            
            temp_paths.insert(index, perturbation)

            # print(temp_paths)

            paragraphs = []
            for p in temp_paths:
                par = generate_pseudo_paragraph(p)
                paragraphs.append(par)

            # print(paragraphs)

            paragraphs = "\n".join(paragraphs)
            response = run_rag(paragraphs, question, options)

            if response != default_response:
                responses[f"test_{test_idx}"]["subpath"][f"{index}_{i}_subpath"] = {"perturbation": perturbation, "paragraphs": paragraphs, "answer": response, "changed": 1}
                subpath_counter += 1

                subpath_positions.append(i/len(subpath_perturbations))

            else:
                responses[f"test_{test_idx}"]["subpath"][f"{index}_{i}_subpath"] = {"perturbation": perturbation, "paragraphs": paragraphs, "answer": response, "changed": 0}


            
            # print(perturbation)
            # print(f"Response: {response}")
            # print(f"Original Answer: {answer}")
            # print("*"*30)
            # print()

        responses[f"test_{test_idx}"]["subpath"]["perturbation_amount"] = len(subpath_perturbations)

    responses[f"test_{test_idx}"]["node_counter"] = node_counter
    responses[f"test_{test_idx}"]["edge_counter"] = edge_counter
    responses[f"test_{test_idx}"]["subpath_counter"] = subpath_counter
    
    responses[f"test_{test_idx}"]["labels"] = label_list
    responses[f"test_{test_idx}"]["relative_node_positions"] = node_positions
    responses[f"test_{test_idx}"]["relative_edge_positions"] = edge_positions
    responses[f"test_{test_idx}"]["relative_subpath_positions"] = subpath_positions

    if node_counter == 0 and edge_counter == 0 and subpath_counter == 0:
        continue
    else:
        temp_path = "../data/temp_results1.json"
        final_path = "../results/medmcqa.json"

        with open(temp_path, "w", encoding="utf-8") as f:
            json.dump(responses, f, indent=4, ensure_ascii=False)

        os.replace(temp_path, final_path)

Procude LLM call counts and token amount

In [None]:
responses = {}

for test_idx,example in tqdm(list(medmcqa_df.iterrows())[:1000], desc="Apply Perturbations", total=len(list(medmcqa_df.iterrows())[:1000])):
    question = example["question"]
    options = example["options"]

    default_response, paragraph_length = produce_default_answer(question, options)

    path_lists = []
    paths = []

    entity_pairs = extract_entities(question)
    entity_pairs = [(pair.entity1, pair.entity2) for pair in entity_pairs]

    for pair in entity_pairs:
        shortest_path = find_shortest_path(pair[0], pair[1])
        if shortest_path:
            path_lists.append(shortest_path[:3])
            paths.append(shortest_path[3])

    if len(paths) == 0:
        continue

    responses[f"test_{test_idx}"] = {}

    all_perturbations_counter = 0

    for i in range(len(paths)):
        perturbations = generate_perturbations(path_lists[i][0], path_lists[i][1], path_lists[i][2])
        all_perturbations_counter += len(perturbations["node_perturbations"]) + len(perturbations["edge_perturbations"]) + len(perturbations["subpath_perturbations"])

    responses[f"test_{test_idx}"] = {}
    responses[f"test_{test_idx}"]["llm_calls"] = all_perturbations_counter
    responses[f"test_{test_idx}"]["total_tokens"] = all_perturbations_counter * paragraph_length

with open(r"../results/medmcqa_calls_amount.json", "w", encoding="utf-8") as f:
    json.dump(responses, f, indent=2, ensure_ascii=False)


Clean/remove examples that did not trigger ANY change

In [None]:
with open(r"../results/medmcqa.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

rows_to_remove = []

for row in data:
    n_counter = data[row]["node_counter"]
    e_counter = data[row]["edge_counter"]
    s_counter = data[row]["subpath_counter"]

    if n_counter == 0 and e_counter == 0 and s_counter == 0:
        rows_to_remove.append(row)

for row in rows_to_remove:
    data.pop(row)

with open(r"../cleaned_results/cleaned_results_medmcqa.json", "w", encoding="utf-8") as json_file:
    json.dump(data, json_file, indent=4)

Number of Perturbation Examples

In [None]:
with open(r"../cleaned_results/cleaned_results_mmlu.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

print(f"Number of Perturbation Examples: {len(data)}")

Calculate **Impact**

In [None]:
with open(r"../cleaned_results/cleaned_results_mmlu.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

node_impactfulness = 0
edge_impactfulness = 0
subpath_impactfulness = 0

for row in data:

    n_counter = 0
    e_counter = 0
    s_counter = 0

    for perturbation in data[row]["node"]:
        try:
            # print(data[row]["node"][perturbation]["changed"])
            if data[row]["node"][perturbation]["changed"] == 1:
                n_counter += 1
        except:
            continue
    
    for perturbation in data[row]["edge"]:
        try:
            if data[row]["edge"][perturbation]["changed"] == 1:
                e_counter += 1
        except:
            continue

    for perturbation in data[row]["subpath"]:
        try:
            if data[row]["subpath"][perturbation]["changed"] == 1:
                s_counter += 1
        except:
            continue

    n_counter = n_counter/data[row]["node"]["perturbation_amount"]
    e_counter = e_counter/data[row]["edge"]["perturbation_amount"]
    s_counter = s_counter/data[row]["subpath"]["perturbation_amount"]

    # print(n_counter)
    # print(e_counter)
    # print(s_counter)
    # print()

    if n_counter >= e_counter and n_counter >= s_counter:
        node_impactfulness += 1
    elif e_counter > n_counter and e_counter >= s_counter:
        edge_impactfulness += 1
    elif s_counter > n_counter and s_counter > e_counter:
        subpath_impactfulness += 1

print(f"Node Impactfulness: {node_impactfulness}")
print(f"Edge Impactfulness: {edge_impactfulness}")
print(f"Sub-path Impactfulness: {subpath_impactfulness}")

Generate Position Distribution Graphs

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import json

with open(r"../cleaned_results/cleaned_results_medmcqa.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

all_node_positions = []
all_edge_positions = []
all_subpath_positions = []

for row in data:
    r_node_positions = data[row]["relative_node_positions"]
    r_edge_positions = data[row]["relative_edge_positions"]
    r_subpath_positions = data[row]["relative_subpath_positions"]

    all_node_positions.extend(r_node_positions)
    all_edge_positions.extend(r_edge_positions)
    all_subpath_positions.extend(r_subpath_positions)

print(f"Node poisitions: {all_node_positions}")
print(f"Edge poisitions: {all_edge_positions}")
print(f"Subpath poisitions: {all_subpath_positions}")

plt.figure(figsize=(6, 5))
sns.histplot(all_node_positions, kde=True, bins=20, color='skyblue')
# plt.title('Node Positions Distribution')
plt.xlabel("Relative Position", fontsize=20, fontweight='bold')
plt.xticks(fontsize=15, fontweight='bold')

plt.ylabel("Critical Change Count", fontsize=20, fontweight='bold')
plt.yticks(fontsize=15, fontweight='bold')

plt.tight_layout()
plt.savefig(r"../medmcqa_plots_pdf/node_positions.pdf")
# plt.show()
plt.close()

plt.figure(figsize=(6, 5))

sns.histplot(all_edge_positions, kde=True, bins=20, color='salmon')
# plt.title('Edge Positions Distribution')
plt.xlabel("Relative Position", fontsize=20, fontweight='bold')
plt.xticks(fontsize=15, fontweight='bold')

# plt.ylabel("Critical Change Count", fontsize=14, fontweight='bold')
plt.ylabel("")
plt.yticks(fontsize=15, fontweight='bold')

plt.tight_layout()
plt.savefig(r"../medmcqa_plots_pdf/edge_positions.pdf")
plt.close()

plt.figure(figsize=(6, 5))

sns.histplot(all_subpath_positions, kde=True, bins=20, color='lightgreen')
# plt.title('Subpath Positions Distribution')
plt.xlabel("Relative Position", fontsize=20, fontweight='bold')
plt.xticks(fontsize=15, fontweight='bold')

# plt.ylabel("Critical Change Count", fontsize=14, fontweight='bold')
plt.ylabel("")
plt.yticks(fontsize=15, fontweight='bold')

plt.tight_layout()
plt.savefig(r"../medmcqa_plots_pdf/subpath_positions.pdf")
plt.close()

Generate Node Label Distribution Graphs

In [None]:
with open(r"../cleaned_results/cleaned_results_medmcqa.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

node_labels = []

for row in data:
    labels = data[row]["labels"]
    node_labels.extend(labels)

cleaned_labels = []

for label in node_labels:
    if label[-1] == "s" :
        cleaned_labels.append(label[:-1])
    elif label[-2:] == "**":
        temp_label = label[:-2]
        if temp_label[-1] == "s" :
            cleaned_labels.append(temp_label[:-1])
        else: 
            cleaned_labels.append(temp_label)

    elif label != "Label2":
        cleaned_labels.append(label)

plt.figure(figsize=(10, 6))
sns.countplot(x=cleaned_labels, order=sorted(set(cleaned_labels)))  # Sorted for readability
# plt.title("Distribution of Node Labels")
plt.xlabel("Node Label", fontsize=25, fontweight='bold')
plt.ylabel("")
plt.xticks(rotation=45, fontsize=20, fontweight='bold')
plt.yticks(fontsize=15, fontweight='bold')
plt.tight_layout()
plt.savefig(r"../medmcqa_plots_pdf/node_label_distribution.pdf")

plt.show()

Degree Distribution

In [None]:
with open(r"../cleaned_results/cleaned_results_mmlu.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)


node_degrees = []

for row in data:
    current_test = data[row]
    node_data = current_test["node"]
    
    for perturbation in node_data:
        current_perturbation = node_data[perturbation]

        try:
            if current_perturbation["changed"] == 1:
                node = current_perturbation["perturbation"][-2]

                # print(f'Important Node: {node}')
                d = G.degree[node]

                node_degrees.append(d)
        except:
            continue

# print(node_degrees)

plt.figure(figsize=(10, 6))
sns.histplot(node_degrees, bins=30, kde=True, color='skyblue')
# plt.title("Distribution of Node Degrees for Changed Perturbations")
plt.xlabel("Degree", fontsize=25, fontweight='bold')
plt.ylabel("Frequency", fontsize=25, fontweight='bold')

plt.xticks(fontsize=20, fontweight='bold')
plt.yticks(fontsize=20, fontweight='bold')

plt.grid(True)

plt.tight_layout()

plt.savefig(r"../mmlu_plots_pdf/node_degree_distribution.pdf")
plt.show()

***Cases where important nodes had the maximum degree within the derived path (not used in experiments)***

In [None]:
with open(r"../cleaned_results/cleaned_results_mmlu.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)


max_node_degrees_counter = 0
total_perturbations = 0

for row in data:
    current_test = data[row]
    node_data = current_test["node"]

    total_perturbations += 1


    temp_node_degrees = []
    temp_important_node_degrees = []
    
    for perturbation in node_data:
        current_perturbation = node_data[perturbation]

        try:
            node = current_perturbation["perturbation"][-2]
            d = G.degree[node]

            temp_node_degrees.append(d)

            if current_perturbation["changed"] == 1:
                temp_important_node_degrees.append(d)

        except:
            continue

    print(temp_node_degrees)
    print(temp_important_node_degrees)
    print()

    max_degree_of_path = max(temp_node_degrees)
    
    if max_degree_of_path in temp_important_node_degrees:
        print(f"Max degree {max_degree_of_path} is from an important node.")
        max_node_degrees_counter += 1

print(f"Total cases where max degree was an important node: {max_node_degrees_counter}")
print(f"Ratio: {max_node_degrees_counter/total_perturbations * 100}")

**Distribution of important node ranking**

In [None]:
with open(r"../cleaned_results/cleaned_results_mmlu.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)


important_node_ranking_scores = []

for row in data:
    current_test = data[row]
    node_data = current_test["node"]

    temp_node_degrees = {}
    temp_important_nodes = []
    
    for perturbation in node_data:
        current_perturbation = node_data[perturbation]

        try:
            node = current_perturbation["perturbation"][-2]
            d = G.degree[node]

            temp_node_degrees[node] = d

            if current_perturbation["changed"] == 1:
                temp_important_nodes.append(node)

        except:
            continue

    ordered_nodes = [(node_name, temp_node_degrees[node_name]) for  node_name in temp_node_degrees.keys()]
    ordered_nodes = sorted(ordered_nodes, key=lambda x: x[1], reverse=True)

    sorted_node_names = [name for name, _ in ordered_nodes]
    N = len(sorted_node_names)

    important_relative_ranks = {
        node: sorted_node_names.index(node) / (N - 1)
        for node in temp_important_nodes
    }

    for node, rel_rank in important_relative_ranks.items():
        important_node_ranking_scores.append(rel_rank)
    

plt.figure(figsize=(10, 5))
sns.histplot(important_node_ranking_scores, bins=20, kde=True, color="skyblue", edgecolor="black")

# plt.title("Distribution of Important Node Relative Ranks")
plt.xlabel("Relative Rank (0 = highest degree, 1 = lowest)", fontsize=20, fontweight='bold')
plt.ylabel("Frequency", fontsize=25, fontweight='bold')

plt.xticks(fontsize=15, fontweight='bold')
plt.yticks(fontsize=15, fontweight='bold')

plt.grid(True, linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig(r"../mmlu_plots_pdf/important_node_relative_rank_distribution.pdf")
plt.show()

Helper functions

In [None]:
import re

edge_betweenness_df = pd.read_csv(r"../data/edge_betweenness.csv")

def extract_surrounding_entities(path):
    parts = re.split(r'->\[(.*?)\]->', path)

    nodes_before_after_blank = []

    for i in range(1, len(parts)-1, 2):  # i is index of relation
        if parts[i].strip() == "__":
            source_node = parts[i-1].strip()
            target_node = parts[i+1].strip()
            nodes_before_after_blank.append((source_node, target_node))

    return nodes_before_after_blank[0]

def extract_betweenness(entity_pair):
    pair_str = str(entity_pair)

    match = edge_betweenness_df[edge_betweenness_df["Edge"] == pair_str]

    if not match.empty:
        return match["Betweenness"].values[0]
    else:
        return None

**Distribution of important edge ranking**

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns

with open(r"../cleaned_results/cleaned_results_mmlu.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

important_edge_ranking_scores = []

for row in data:
    current_test = data[row]
    edge_data = current_test.get("edge", {})

    temp_edge_betweenness = {}
    temp_important_edges = []

    for perturbation in edge_data:
        current_perturbation = edge_data[perturbation]
        try:
            path = current_perturbation["perturbation"]
            node_pair = extract_surrounding_entities(path)
            temp_edge_b = extract_betweenness(node_pair)

            if temp_edge_b is not None:
                temp_edge_betweenness[node_pair] = temp_edge_b
                if current_perturbation.get("changed") == 1:
                    temp_important_edges.append(node_pair)
        except Exception as e:
            continue

    ordered_edges = [
        (f"{edge[0]},{edge[1]}", temp_edge_betweenness[edge])
        for edge in temp_edge_betweenness
    ]
    ordered_edges = sorted(ordered_edges, key=lambda x: x[1], reverse=True)

    # print("\nEdge Betweenness Scores:")
    # for edge_name, betweenness in ordered_edges:
    #     print(f"{edge_name}: {betweenness:.6f}")

    # print(temp_important_edges)

    sorted_edge_names = [name for name, _ in ordered_edges]
    N = len(sorted_edge_names)

    if N > 1:
        for edge in temp_important_edges:
            edge_str = f"{edge[0]},{edge[1]}"
            if edge_str in sorted_edge_names:
                rel_rank = sorted_edge_names.index(edge_str) / (N - 1)
                important_edge_ranking_scores.append(rel_rank)

plt.figure(figsize=(10, 5))
sns.histplot(important_edge_ranking_scores, bins=20, kde=True, color="skyblue", edgecolor="black")

# plt.title("Distribution of Important Edge Relative Ranks")
plt.xlabel("Relative Rank (0 = highest betweenness, 1 = lowest)", fontsize=20, fontweight='bold')
plt.ylabel("Frequency", fontsize=25, fontweight='bold')

plt.xticks(fontsize=15, fontweight='bold')
plt.yticks(fontsize=15, fontweight='bold')

plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()

plt.savefig(r"../mmlu_plots_pdf/important_edge_relative_rank_distribution.pdf")

plt.show()


**Distribution of important sub-path ranking**

In [None]:
from scipy.stats import rankdata

with open(r"../cleaned_results/cleaned_results_medmcqa.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

def get_original_path(test):
    node_data = test["node"]
    
    path_list = node_data["0_0_node"]["perturbation"][0].split("->")
    missing_node = node_data["0_0_node"]["perturbation"][1]
    path_list[0] = missing_node

    return "->".join(path_list)

def extract_subpath_entities(original_path, path):
    original_path_list = original_path.split("->")
    path_list = path.split("->")

    empty_node_indices = []

    for index,element in enumerate(path_list):
        if element == "__":
            empty_node_indices.append(index)

    if len(empty_node_indices) < 2:
        return []

    return (f'{original_path_list[empty_node_indices[0]]}',f'{original_path_list[empty_node_indices[1]]}')


subpath_relative_ranking_scores = []

for row in data:
    current_test = data[row]
    original_path = get_original_path(current_test)
    subpath_data = current_test["subpath"]

    temp_subpath_scores = []
    temp_changed_subpath_scores = []

    for perturbation in subpath_data:
        current_perturbation = subpath_data[perturbation]

        try:
            if current_perturbation["changed"] in [0, 1]:
                path = current_perturbation["perturbation"]
                node_pair = extract_subpath_entities(original_path, path)
                if len(node_pair) < 2:
                    continue

                temp_edge_betweenness = extract_betweenness(node_pair)
                degree1 = G.degree[node_pair[0]]
                degree2 = G.degree[node_pair[1]]
                subpath_score = temp_edge_betweenness / (degree1 + degree2)

                temp_subpath_scores.append(subpath_score)

                if current_perturbation["changed"] == 1:
                    temp_changed_subpath_scores.append(subpath_score)

        except:
            continue

    if len(temp_subpath_scores) == 0:
        continue

    ranks = rankdata([-s for s in temp_subpath_scores], method='average')
    N = len(ranks)

    for score in temp_changed_subpath_scores:
        indices = [i for i, val in enumerate(temp_subpath_scores) if val == score]
        avg_rank = sum(ranks[i] for i in indices) / len(indices)
        zero_based_rank = avg_rank - 1
        rel_rank = zero_based_rank / (N - 1) if N > 1 else 0
        subpath_relative_ranking_scores.append(rel_rank)

plt.figure(figsize=(8,5))
plt.hist(subpath_relative_ranking_scores, bins=20, color='skyblue', edgecolor='black')
# plt.title('Histogram of Relative Ranking Scores for Changed Subpaths')
plt.xlabel('Relative Rank (0 = highest rank)', fontsize=20, fontweight='bold')
plt.ylabel('Frequency', fontsize=25, fontweight='bold')

plt.xticks(fontsize=15, fontweight='bold')
plt.yticks(fontsize=15, fontweight='bold')

plt.grid(True)
plt.tight_layout()

plt.savefig(r"../medmcqa_plots_pdf/important_subpath_relative_rank_distribution.pdf")

plt.show()