In [3]:
# Install the OpenAI and LangChain libraries
# - `openai`: Provides access to OpenAI's GPT models for tasks like text generation, embeddings, and completions.
# - `langchain`: A framework for building applications using large language models (LLMs).
#                Includes tools for chaining prompts, memory, and integrations like knowledge graphs.
!pip install -q openai langchain
# Attempt to install the LangChain Community library
# - `langchain-community`: This may refer to a community-supported version or extensions of LangChain.
#   Ensure this package exists and is maintained if errors occur during installation.
!pip install -q langchain-community

This script initializes the OpenAI API client and defines a function to interact with the GPT model. The get_chat_response function sends a user-provided text input to the GPT model (gpt-3.5-turbo) and returns the model's response.

In [4]:
import os
from openai import OpenAI

# Set the API key in the environment variable
os.environ["OPENAI_API_KEY"] = "sk-MNL1gYbV6CyXkh2rwPxao_D7n8nSxwW4_0wozr5sUtT3BlbkFJoEpwVXUH_Z3deg71NI-mM8QqSOkOGzQ5WDXmQ8FQEA" # Replace with your actual API key

client = OpenAI()

def get_chat_response(text):
    """
    This function takes a text input and returns the chat completion message.
    """
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": text,
            }
        ],
        model="gpt-3.5-turbo",
    )
    return chat_completion.choices[0].message.content


In [None]:
# Install required packages
!pip install -q langchain langchain-community rdflib SPARQLWrapper

# Standard libraries
import os
import re
import json
import random
import textwrap
import unicodedata
import urllib.parse
from collections import defaultdict

# Data handling & scientific computing
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns

# Graphs / Networks
import networkx as nx
from langchain.graphs.networkx_graph import NetworkxEntityGraph, KnowledgeTriple

# Machine Learning / Statistics
from sklearn.linear_model import LinearRegression, BayesianRidge
from sklearn.metrics import (
    mean_absolute_error, mean_squared_error, r2_score,
    roc_curve, auc
)
import sklearn.metrics  # For additional metrics if needed
from sklearn.datasets import fetch_20newsgroups

from scipy.spatial.distance import cosine
from scipy.stats import wasserstein_distance

# LangChain & OpenAI
from langchain.llms import OpenAI as LangChainOpenAI
from langchain.chains import GraphQAChain
from langchain.prompts import PromptTemplate
from openai import OpenAI

# Colab specific
from google.colab import drive


This script defines a knowledge graph using a set of triples representing entities (nodes) and their relationships (edges). The triples are categorized into parts based on themes, such as LLMs in the legal context, RAG integration, collaborations, and key people involved. The knowledge graph is constructed programmatically by adding these triples into the graph index, which allows for efficient querying and analysis.

In [None]:
# Mount Google Drive
drive.mount('/content/drive')

# Set the API key in the environment variable and initialize client
os.environ["OPENAI_API_KEY"] = "sk-MNL1gYbV6CyXkh2rwPxao_D7n8nSxwW4_0wozr5sUtT3BlbkFJoEpwVXUH_Z3deg71NI-mM8QqSOkOGzQ5WDXmQ8FQEA"
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])  # Explicitly set API key

# Paths
file_path = "/content/drive/MyDrive/PrimeKG_Data/final_test.json"
output_dir = "/content/drive/MyDrive/PrimeKG_Data/"

# Embedding cache and model
embedding_cache = {}
EMBEDDING_MODEL = "text-embedding-ada-002"  # Unified model

def get_chat_response(text):
    chat_completion = client.chat.completions.create(
        messages=[{"role": "user", "content": text}],
        model="gpt-3.5-turbo",
    )
    return chat_completion.choices[0].message.content

def load_primekg_data(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"Successfully loaded {len(data)} entries from {file_path}")
        return data
    except Exception as e:
        print(f"Error loading data: {e}")
        return []

def is_diabetes_related(text):
    core_diabetes_terms = ['diabetes', 'diabetic', 'insulin', 'glucose', 'type 1 diabetes', 'type 2 diabetes', 't1d', 't2d']
    secondary_terms = ['hyperglycemia', 'hypoglycemia', 'glycemic', 'a1c', 'hemoglobin a1c', 'gestational diabetes', 'prediabetes', 'metabolic syndrome', 'pancreatic', 'islet', 'beta cell', 'metformin', 'glucagon', 'diabetic retinopathy', 'diabetic nephropathy', 'diabetic neuropathy']
    text_lower = text.lower()
    for term in core_diabetes_terms:
        if term in text_lower:
            return 2
    for term in secondary_terms:
        if term in text_lower:
            return 1
    return 0

def extract_diabetes_knowledge_graph(data, max_parts=10):
    all_diabetes_triples = []
    for entry in data:
        triples = entry.get('value', [])
        if not triples:
            continue
        for triple in triples:
            if len(triple) == 3:
                subject, predicate, obj = triple
                subject_decoded = urllib.parse.unquote(subject)
                predicate_decoded = urllib.parse.unquote(predicate)
                obj_decoded = urllib.parse.unquote(obj)
                total_score = is_diabetes_related(subject_decoded) + is_diabetes_related(predicate_decoded) + is_diabetes_related(obj_decoded)
                if total_score > 0:
                    all_diabetes_triples.append((triple, total_score))

    all_diabetes_triples.sort(key=lambda x: x[1], reverse=True)
    G = nx.Graph()
    for (triple, _) in all_diabetes_triples:
        subject, predicate, obj = triple
        subject_decoded = urllib.parse.unquote(subject)
        obj_decoded = urllib.parse.unquote(obj)
        G.add_edge(subject_decoded, obj_decoded, relation=predicate)

    connected_components = list(nx.connected_components(G))
    connected_components.sort(key=len, reverse=True)
    components_to_use = connected_components[:max_parts]

    nodes_to_keep = set()
    for component in components_to_use:
        nodes_to_keep.update(component)

    filtered_triples = []
    for (triple, score) in all_diabetes_triples:
        subject, predicate, obj = triple
        subject_decoded = urllib.parse.unquote(subject)
        obj_decoded = urllib.parse.unquote(obj)
        if subject_decoded in nodes_to_keep and obj_decoded in nodes_to_keep:
            filtered_triples.append(triple)

    kg = []
    portion_indices = {}
    triple_index = 0
    portion_counter = 1

    print("\nStructured Diabetes Knowledge Graph (Focused on Strong Connections):\n")
    component_to_triples = defaultdict(list)
    for triple in filtered_triples:
        subject, _, obj = triple
        subject_decoded = urllib.parse.unquote(subject)
        obj_decoded = urllib.parse.unquote(obj)
        for i, component in enumerate(components_to_use):
            if subject_decoded in component and obj_decoded in component:
                component_to_triples[i].append(triple)
                break

    for component_idx, triples in component_to_triples.items():
        if not triples or portion_counter > max_parts:
            continue
        triples = triples[:20]  # Limit to 20 triples per part
        start_index = triple_index
        end_index = start_index + len(triples)
        portion_indices[f"Part {portion_counter}"] = range(start_index, end_index)
        print(f"\n# Part {portion_counter}")
        for triple in triples:
            subject, predicate, obj = triple
            subject_decoded = urllib.parse.unquote(subject)
            predicate_decoded = urllib.parse.unquote(predicate)
            obj_decoded = urllib.parse.unquote(obj)
            print(f"({subject_decoded}) → ({predicate_decoded}) → ({obj_decoded})")
            kg.append((subject_decoded, predicate_decoded, obj_decoded))
            triple_index += 1
        print("-" * 80)
        portion_counter += 1
        if portion_counter > max_parts:
            break

    return kg, portion_indices

def save_knowledge_graph(kg, output_file=os.path.join(output_dir, "focused_diabetes_kg.txt")):
    try:
        with open(output_file, "w", encoding="utf-8") as f:
            for triple in kg:
                f.write(f"( {triple[0]} , {triple[1]} , {triple[2]} )\n\n")
        print(f"\nFocused Diabetes Knowledge Graph saved as '{output_file}'.")
    except Exception as e:
        print(f"Error saving knowledge graph: {e}")

def wrap_text(text, max_words=8):
    words = text.split()
    return "\n".join(textwrap.wrap(text, width=15)) if len(words) > max_words else text

def visualize_graph_with_chains(kg, part_indices):
    G = nx.DiGraph()
    for node1, relation, node2 in kg:
        G.add_edge(node1, node2, label=relation)
    pos = nx.spring_layout(G, k=8, iterations=100, seed=0)
    chain_cmap = mcolors.LinearSegmentedColormap.from_list('chain_colors', ['#66c2a5', '#fc8d62', '#8da0cb', '#e78ac3', '#a6d854'])
    chain_norm = mcolors.Normalize(vmin=0, vmax=len(part_indices) - 1)
    node_colors = ['lightblue'] * len(G.nodes())
    edge_colors = ['gray'] * len(G.edges())
    chain_color_map = {f"Part {i+1}": chain_cmap(chain_norm(i)) for i in range(len(part_indices))}
    node_chain_map = {}
    for chain_name, indices in part_indices.items():
        color = chain_color_map[chain_name]
        for idx in indices:
            node1, _, node2 = kg[idx]
            if node1 in G.nodes:
                node_chain_map[node1] = chain_name
                node_colors[list(G.nodes).index(node1)] = color
            if node2 in G.nodes:
                node_chain_map[node2] = chain_name
                node_colors[list(G.nodes).index(node2)] = color
    for i, (node1, node2) in enumerate(G.edges()):
        for chain_name, indices in part_indices.items():
            color = chain_color_map[chain_name]
            for idx in indices:
                n1, _, n2 = kg[idx]
                if (node1, node2) == (n1, n2):
                    edge_colors[i] = color
                    break
    wrapped_labels = {node: wrap_text(node) for node in G.nodes()}
    fig, axs = plt.subplots(1, 2, figsize=(20, 8), dpi=600)
    nx.draw_networkx_nodes(G, pos, node_color='#d3d3d3', node_size=1200, ax=axs[0])
    nx.draw_networkx_edges(G, pos, edge_color='gray', width=1.2, ax=axs[0])
    nx.draw_networkx_labels(G, pos, labels=wrapped_labels, font_size=6, ax=axs[0])
    edge_labels = nx.get_edge_attributes(G, 'label')
    wrapped_edge_labels = {edge: wrap_text(label) for edge, label in edge_labels.items()}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=wrapped_edge_labels, font_size=6, ax=axs[0])
    axs[0].set_title("Original Diabetes Knowledge Graph", fontsize=10)
    axs[0].axis('off')
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=1200, ax=axs[1], edgecolors='black')
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=1.5, ax=axs[1])
    nx.draw_networkx_labels(G, pos, labels=wrapped_labels, font_size=6, ax=axs[1])
    nx.draw_networkx_edge_labels(G, pos, edge_labels=wrapped_edge_labels, font_size=6, ax=axs[1])
    axs[1].set_title("Diabetes Graph Highlighted by Chain Membership", fontsize=10)
    axs[1].axis('off')
    handles = [plt.Line2D([0], [0], marker='o', color=color, markersize=10, linestyle='', label=chain_name) for chain_name, color in chain_color_map.items()]
    axs[1].legend(handles=handles, title="Chains", loc='upper right', fontsize=8)
    plt.show()
    print("\n--- Node Chain Mapping ---")
    for node, chain in node_chain_map.items():
        print(f"Node '{node}' belongs to chain '{chain}'.")

def perturb_kg_by_removing_parts(kg, parts_to_remove, portion_indices):
    indices_to_remove = set()
    for part in parts_to_remove:
        if part in portion_indices:
            indices_to_remove.update(portion_indices[part])
    perturbed_kg = [triple for i, triple in enumerate(kg) if i not in indices_to_remove]
    return perturbed_kg

def normalize_text(text):
    text = text.replace("\n", " ").strip()
    text = unicodedata.normalize("NFKC", text)
    text = re.sub(r"\s+", " ", text)
    text = text.lower()
    return text

def get_embedding(text):
    text = normalize_text(text)
    if text in embedding_cache:
        return embedding_cache[text]
    try:
        response = client.embeddings.create(input=[text], model=EMBEDDING_MODEL)
        embedding = response.data[0].embedding
        embedding_cache[text] = embedding
        return embedding
    except AttributeError:
        print(f"Error: 'embeddings' attribute not found. Please update the OpenAI library using '!pip install -U openai' and restart the runtime. Returning zero vector for '{text}'.")
        return [0] * 1536
    except Exception as e:
        print(f"Error getting embedding for '{text}': {e}")
        return [0] * 1536

def get_answer_and_embedding(question: str, temp: float, graph):
    """
    Sends a question and temperature to the GraphQAChain and returns the original answer string and its embedding.
    """
    # Initialize the GraphQAChain with OpenAI LLM
    llm = LangChainOpenAI(temperature=temp, api_key=os.environ["OPENAI_API_KEY"])  # Explicit API key
    chain = GraphQAChain.from_llm(llm, graph=graph, verbose=False)
    original_answer = chain.run(question)
    original_answer_str = str(original_answer)

    # Generate embedding using the OpenAI client
    try:
        response = client.embeddings.create(
            model=EMBEDDING_MODEL,  # Use the same model as get_embedding
            input=original_answer_str
        )
        original_answer_embedding = response.data[0].embedding
        print(f"Embedding generated successfully. Length: {len(original_answer_embedding)}")
    except Exception as e:
        print(f"Error generating embedding: {e}. Returning zero vector.")
        original_answer_embedding = [0] * 1536  # Fallback to zero vector

    return original_answer_str, original_answer_embedding

def plot_knowledge_graph_explainability(kg, part_indices, coeff):
    G = nx.DiGraph()
    for node1, relation, node2 in kg:
        G.add_edge(node1, node2, label=relation)
    pos = nx.spring_layout(G, k=8, iterations=100, seed=0)
    cmap = mcolors.LinearSegmentedColormap.from_list('red_blue', ['blue', '#d3d3d3', 'red'])
    norm = mcolors.Normalize(vmin=-1, vmax=1)
    node_sizes = [1500 + 100 * G.degree(node) for node in G.nodes()]
    node_colors = []
    edge_colors = []
    for node in G.nodes():
        for part_name, indices in part_indices.items():
            part_idx = int(part_name.split()[-1]) - 1
            coeff_value = coeff[part_idx]
            color = cmap(norm(coeff_value))
            if node in [kg[i][0] for i in indices] or node in [kg[i][2] for i in indices]:
                node_colors.append(color)
                break
        else:
            node_colors.append('#8da0cb')
    for i, (node1, node2) in enumerate(G.edges()):
        for part_name, indices in part_indices.items():
            part_idx = int(part_name.split()[-1]) - 1
            coeff_value = coeff[part_idx]
            color = cmap(norm(coeff_value))
            if i in indices:
                edge_colors.append(color)
                break
        else:
            edge_colors.append('gray')
    wrapped_labels = {node: wrap_text(node) for node in G.nodes()}
    fig, axs = plt.subplots(1, 2, figsize=(24, 10), dpi=600, gridspec_kw={'width_ratios': [1, 1.3]})
    nx.draw_networkx_nodes(G, pos, node_color='#d3d3d3', node_size=node_sizes, ax=axs[0])
    nx.draw_networkx_edges(G, pos, edge_color='gray', width=1.5, ax=axs[0])
    nx.draw_networkx_labels(G, pos, labels=wrapped_labels, font_size=8, ax=axs[0])
    edge_labels = nx.get_edge_attributes(G, 'label')
    wrapped_edge_labels = {edge: wrap_text(label) for edge, label in edge_labels.items()}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=wrapped_edge_labels, font_size=8, ax=axs[0])
    axs[0].set_title("Original Diabetes Knowledge Graph", fontsize=12)
    axs[0].axis('off')
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, ax=axs[1])
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=1.8, ax=axs[1])
    nx.draw_networkx_labels(G, pos, labels=wrapped_labels, font_size=8, ax=axs[1])
    nx.draw_networkx_edge_labels(G, pos, edge_labels=wrapped_edge_labels, font_size=8, ax=axs[1])
    axs[1].set_title("Simple SMILE GraphRAG Explainability", fontsize=12)
    axs[1].axis('off')
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    fig.colorbar(sm, ax=axs[1], label='Importance Coefficients')
    plt.savefig(os.path.join(output_dir, 'knowledge_graph_explainability_improved.png'), bbox_inches='tight')
    plt.show()


def calculate_coefficients_print_Temperature(temp, original, kg, part_names, portion_indices, question, original_answer_embedding, original_answer_str):
    """
    Function to calculate coefficients for perturbations on a knowledge graph.
    It removes parts of the KG, generates perturbed responses, and calculates coefficients.
    Uses Wasserstein distance as the similarity metric, aligned with DBpedia code.

    Parameters:
    - temp: Temperature value (0 or 1)
    - original: Original vector (numpy array)
    - kg: Knowledge graph (list of triples)
    - part_names: List of part names in the KG
    - portion_indices: Dictionary mapping part names to index ranges
    - question: Question for GraphQAChain
    - original_answer_embedding: Embedding of the original answer
    - original_answer_str: Original answer text

    Returns:
    - coeff: Coefficients from linear regression
    """
    # Define the original vector (all parts present) based on the number of parts
    original = np.ones(len(part_names))
    original = original.reshape(1, -1)  # Shape becomes (1, number of parts)

    similarities_wd = []
    perturbations_vect2 = []
    perturbation_texts = []
    generated_embeddings = []

    # Increase epsilon for better numerical stability
    epsilon = 1e-3

    # Adjust temperature if it's 0 to add some randomness
    adjusted_temp = max(temp, 0.001) if temp == 0 else temp

    # Increase number of perturbations for better diversity
    num_perturbations = 20

    for i in range(num_perturbations):
        perturbation_vector = original.copy().flatten()

        # Ensure more diverse perturbations
        if len(part_names) > 3:
            # Remove between 1 and len(part_names)-1 parts (never remove all)
            num_parts_to_remove = random.randint(1, max(1, len(part_names) - 1))
        else:
            num_parts_to_remove = random.randint(1, len(part_names))

        parts_to_remove_indices = random.sample(range(len(part_names)), num_parts_to_remove)

        for part_idx in parts_to_remove_indices:
            perturbation_vector[part_idx] = 0

        perturbations_vect2.append(perturbation_vector)
        parts_to_remove = [part_names[idx] for idx in parts_to_remove_indices]

        # Call the perturb_kg_by_removing_parts function with portion_indices
        perturbed_kg = perturb_kg_by_removing_parts(kg, parts_to_remove, portion_indices)

        graph_temp = NetworkxEntityGraph()
        for (node1, relation, node2) in perturbed_kg:
            graph_temp.add_triple(KnowledgeTriple(node1, relation, node2))

        # Use adjusted temperature for better response diversity
        llm = LangChainOpenAI(temperature=adjusted_temp, api_key=os.environ["OPENAI_API_KEY"])
        chain = GraphQAChain.from_llm(llm, graph=graph_temp, verbose=False)
        temp_response = chain.run(question)
        perturbation_texts.append(temp_response)

        # Call the get_embedding function directly
        temp_response_embedding = get_embedding(temp_response)
        generated_embeddings.append(temp_response_embedding)

        # Calculate Wasserstein distance
        orig_emb = np.array(original_answer_embedding) if isinstance(original_answer_embedding, list) else original_answer_embedding
        temp_emb = np.array(temp_response_embedding) if isinstance(temp_response_embedding, list) else temp_response_embedding
        similarity_wd = wasserstein_distance(orig_emb, temp_emb)
        similarities_wd.append(similarity_wd)

        # Print progress for each iteration
        print(f"Iteration {i + 1}")
        print(f"Parts removed: {parts_to_remove}")
        print(f"Original answer response: {original_answer_str}")
        print(f"Perturbed response: {temp_response}")
        print(f"Wasserstein distance with original answer: {similarity_wd}\n")

    perturbations_vect2 = np.array(perturbations_vect2)
    distances = sklearn.metrics.pairwise_distances(perturbations_vect2, original, metric='cosine').ravel()

    kernel_width = 0.25
    weights = np.sqrt(np.exp(-(distances**2) / kernel_width**2))

    # Print all similarities and weights
    print(f"similarities_wd: {similarities_wd}")
    print(f"Weights: {weights}")

    # Calculate inverse Wasserstein distances with larger epsilon
    inverse_similarities_wd = [1.0 / (dist + epsilon) for dist in similarities_wd]

    print(f"Inverse similarities range: {min(inverse_similarities_wd):.6f} to {max(inverse_similarities_wd):.6f}")

    # Scale inverse Wasserstein distances with better error handling
    min_value = min(inverse_similarities_wd)
    max_value = max(inverse_similarities_wd)

    # Check if the range is too small (essentially equal values)
    value_range = max_value - min_value
    if value_range < 1e-10:  # Very small range
        print(f"Warning: Very small range in similarities ({value_range:.2e}). Using alternative scaling.")
        # Alternative: use the raw inverse similarities without scaling
        scaled_inverse_similarities_wd = inverse_similarities_wd
        # Or use a small amount of noise
        # scaled_inverse_similarities_wd = [val + random.uniform(-1e-4, 1e-4) for val in inverse_similarities_wd]
    else:
        scaled_inverse_similarities_wd = [
            (value - min_value) / (max_value - min_value) for value in inverse_similarities_wd
        ]

    print(f"Scaled similarities range: {min(scaled_inverse_similarities_wd):.6f} to {max(scaled_inverse_similarities_wd):.6f}")

    # Linear regression for scaled inverse Wasserstein distances
    simpler_model = LinearRegression()
    simpler_model.fit(X=perturbations_vect2, y=scaled_inverse_similarities_wd, sample_weight=weights)
    coeff = simpler_model.coef_

    print(f"Regression R² score: {simpler_model.score(perturbations_vect2, scaled_inverse_similarities_wd):.4f}")

    return coeff
# Main execution
if __name__ == "__main__":
    primekg_data = load_primekg_data(file_path)
    if not primekg_data:
        print("No data loaded.")
        exit()
    kg, portion_indices = extract_diabetes_knowledge_graph(primekg_data, max_parts=10)
    if not kg:
        print("No diabetes-related information found.")
        exit()
    print(f"\nFound {len(kg)} strongly connected diabetes-related triples in {len(portion_indices)} parts.")
    print("\nPortion Indices:\n")
    for part, index_range in portion_indices.items():
        print(f"{part}: {index_range}")
    save_knowledge_graph(kg)
    print("\nFinal Focused Diabetes Knowledge Graph List:\n")
    for triple in kg:
        print("(", triple[0], ",", triple[1], ", ", triple[2], ")")
    print("Original KG node count:", len(set(node for triple in kg for node in (triple[0], triple[2]))))

    graph = NetworkxEntityGraph()
    for node1, relation, node2 in kg:
        graph.add_triple(KnowledgeTriple(node1, relation, node2))

    visualize_graph_with_chains(kg, portion_indices)

    question = 'What is insulin-like growth factor receptor binding associated with?'
    #question = "What are the key biological and medical factors and terms involved in diabetes?"
    temp = 0
    original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
    print(f"\nOriginal Answer: {original_answer_str}")

    original = np.ones(len(portion_indices)).reshape(1, -1)

    coeff = calculate_coefficients_print_Temerature(temp, original, kg, list(portion_indices.keys()), portion_indices, question, original_answer_embedding, original_answer_str)
    plot_knowledge_graph_explainability(kg, portion_indices, coeff)

Visualizes the knowledge graph as a directed graph using NetworkX and Matplotlib. Nodes represent entities, and edges depict relationships with labels for clarity. The layout uses spring positioning with increased spacing for readability. Custom node colors and labeled edges enhance the visualization, displayed without axes.

Defines a function to perturb the knowledge graph by selectively removing triples belonging to specified parts. This allows testing the impact of missing information on downstream tasks or analysis. The function filters out triples associated with the indices of the parts to be removed and returns the modified knowledge graph.

This function computes the embedding for a given text using a specified model. It processes the text by removing newline characters and queries the OpenAI embeddings API to generate a vector representation, useful for similarity comparisons and downstream tasks.

Defines a function to query a GraphQAChain with a question and temperature setting, returning the answer and its embedding. The function initializes the chain with a specified graph and temperature, processes the question, and computes the embedding for the returned answer, facilitating downstream analysis or comparison

This function visualizes the explainability of a knowledge graph by displaying the original graph and an enhanced graph with nodes and edges colored based on their importance coefficients. It leverages a directed graph structure, wraps node labels for readability, adjusts node sizes based on connectivity, and applies a custom colormap to represent the significance of graph components. The visualization is presented in a two-panel layout, highlighting both the original structure and the explainability features derived from Simple SMILE GraphRAG analysis. A color bar provides a reference for importance coefficients

Defines the question to query the GraphQAChain or knowledge retrieval system. Here, the question seeks information about it, a framework that integrates external knowledge bases to improve the accuracy and reliability of AI-generated responses

This snippet sets the temperature parameter to 0 for deterministic response generation and queries the GraphQAChain with the question. The function get_answer_and_embedding returns the original answer as a string along with its embedding. The answer is then printed for review.


This function calculates the importance coefficients for perturbations on a knowledge graph while providing
detailed logging at each iteration. It removes random parts of the knowledge graph, generates perturbed
responses, computes similarities to the original answer, and fits a linear regression model to calculate
coefficients. The function includes:

- Temperature (`temp`) parameter to adjust the behavior of the GraphQAChain.
- Iterative logs showing the removed parts, perturbed responses, and calculated similarities.
- A summary of all similarities and weights after processing.
The coefficients provide insights into the contribution of each part of the knowledge graph to response fidelity.

# Stability: Consistent explanations despite small changes in input or graph structure.

This script augments the knowledge graph ('kg') with an additional triple to assess stability
in response generation and explainability. The stability analysis involves perturbing the knowledge
graph by removing random parts over 20 iterations for each temperature setting (0 and 1).

Key steps for stability analysis:
1. **Augmentation:** Adds a triple to the knowledge graph to introduce a new component for stability assessment.
2. **Perturbations:** Randomly removes parts of the knowledge graph multiple times to analyze the system's
   ability to maintain consistent and meaningful responses.
3. **Temperature Variation:** Runs the process with both deterministic (temperature = 0) and stochastic
   (temperature = 1) configurations to observe the impact of randomness on stability.
4. **Stability Metric:** Evaluates response consistency and similarity under perturbations, providing insights
   into the robustness of the system in preserving the core knowledge structure and response fidelity.

The results highlight how stable and resilient the system is to changes in the underlying knowledge graph.

In [None]:
def wrap_label(label, width=15):
    """Wraps labels to fit within a specified width for better visualization."""
    return '\n'.join(textwrap.wrap(label, width))

def wrap_text(node1, relation, node2, max_words=8):
    """Wrap text if it contains more than `max_words` words."""
    words = node2.split()
    return '\n'.join(textwrap.wrap(node2, width=15)) if len(words) > max_words else node2

def build_graph(kg, coeff, part_indices):
    """Helper function to build graph, assign colors, and sizes."""
    G = nx.DiGraph()
    for node1, relation, node2 in kg:
        wrapped_node1 = wrap_label(node1)
        wrapped_node2 = wrap_label(wrap_text(node1, relation, node2))
        wrapped_relation = wrap_label(relation)
        G.add_edge(wrapped_node1, wrapped_node2, label=wrapped_relation)
    pos = nx.spring_layout(G, k=8, iterations=100, seed=0)
    cmap = mcolors.LinearSegmentedColormap.from_list('red_blue', ['blue', '#d3d3d3', 'red'])
    norm = mcolors.Normalize(vmin=-1, vmax=1)
    node_sizes = [1500 + 100 * G.degree(node) for node in G.nodes()]
    node_colors = []
    for node in G.nodes():
        for part_name, indices in part_indices.items():
            part_idx = int(part_name.split()[-1]) - 1
            coeff_value = coeff[part_idx]
            color = cmap(norm(coeff_value))
            if any(i < len(kg) and (wrap_label(node) == wrap_label(kg[i][0]) or wrap_label(node) == wrap_label(wrap_text(kg[i][0], kg[i][1], kg[i][2]))) for i in indices):
                node_colors.append(color)
                break
        else:
            node_colors.append('#8da0cb')
    edge_colors = []
    for i, (node1, node2) in enumerate(G.edges()):
        for part_name, indices in part_indices.items():
            part_idx = int(part_name.split()[-1]) - 1
            coeff_value = coeff[part_idx]
            color = cmap(norm(coeff_value))
            if i in indices:
                edge_colors.append(color)
                break
        else:
            edge_colors.append('gray')
    return G, pos, node_sizes, node_colors, edge_colors

def plot_knowledge_graph_explainability_compare(kg_original, coeff_original, part_indices_original,
                                               kg_added, coeff_added, part_indices_added):
    """Visualize two knowledge graphs (original and added) side by side with explainability features."""
    G_original, pos_original, node_sizes_original, node_colors_original, edge_colors_original = build_graph(
        kg_original, coeff_original, part_indices_original)
    G_added, pos_added, node_sizes_added, node_colors_added, edge_colors_added = build_graph(
        kg_added, coeff_added, part_indices_added)
    fig, axs = plt.subplots(1, 2, figsize=(24, 10), dpi=300)
    nx.draw_networkx_nodes(G_original, pos_original, node_color=node_colors_original, node_size=node_sizes_original, ax=axs[0])
    nx.draw_networkx_edges(G_original, pos_original, edge_color=edge_colors_original, width=1.5, ax=axs[0])
    nx.draw_networkx_labels(G_original, pos_original, font_size=8, ax=axs[0])
    edge_labels_original = nx.get_edge_attributes(G_original, 'label')
    nx.draw_networkx_edge_labels(G_original, pos_original, edge_labels=edge_labels_original, font_size=8, ax=axs[0])
    axs[0].set_title("Original Knowledge Graph", fontsize=14)
    axs[0].axis('off')
    nx.draw_networkx_nodes(G_added, pos_added, node_color=node_colors_added, node_size=node_sizes_added, ax=axs[1])
    nx.draw_networkx_edges(G_added, pos_added, edge_color=edge_colors_added, width=1.5, ax=axs[1])
    nx.draw_networkx_labels(G_added, pos_added, font_size=8, ax=axs[1])
    edge_labels_added = nx.get_edge_attributes(G_added, 'label')
    nx.draw_networkx_edge_labels(G_added, pos_added, edge_labels=edge_labels_added, font_size=8, ax=axs[1])
    axs[1].set_title("Added Knowledge Graph", fontsize=14)
    axs[1].axis('off')
    cmap = mcolors.LinearSegmentedColormap.from_list('red_blue', ['blue', '#d3d3d3', 'red'])
    norm = mcolors.Normalize(vmin=-1, vmax=1)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    fig.colorbar(sm, ax=axs, orientation='horizontal', label='Importance Coefficients', fraction=0.03, pad=0.05)
    plt.savefig('knowledge_graph_explainability_comparison.png', bbox_inches='tight')
    plt.show()

In [None]:
def calculate_jaccard_index(coeff_original, coeff_added):
    """
    Calculate the Jaccard index using ABSOLUTE VALUES of coefficients.
    This treats both positive and negative coefficients as indicating importance.

    Parameters:
    coeff_original (np.ndarray): Original coefficients from the PrimeKGQA graph.
    coeff_added (np.ndarray): Added coefficients from the augmented PrimeKGQA graph.

    Returns:
    float: The Jaccard index.
    """
    from sklearn.metrics import jaccard_score
    import numpy as np

    coeff_original = np.array(coeff_original)
    coeff_added = np.array(coeff_added)

    print(f"Original coefficients: {coeff_original}")
    print(f"Added coefficients: {coeff_added}")

    # Ensure both arrays have the same length by truncating the longer one
    if len(coeff_original) > len(coeff_added):
        coeff_original = coeff_original[:len(coeff_added)]
        print(f"Truncated original to length {len(coeff_added)}")
    elif len(coeff_added) > len(coeff_original):
        coeff_added = coeff_added[:len(coeff_original)]
        print(f"Truncated added to length {len(coeff_original)}")

    # Check for empty arrays
    if len(coeff_original) == 0 or len(coeff_added) == 0:
        print("Warning: Empty coefficient arrays")
        return 0.0

    # Use ABSOLUTE VALUES - both positive and negative coefficients indicate importance
    abs_original = np.abs(coeff_original)
    abs_added = np.abs(coeff_added)

    print(f"Absolute original: {abs_original}")
    print(f"Absolute added: {abs_added}")

    # Check if coefficients are very similar (min/max equal condition)
    orig_range = np.max(abs_original) - np.min(abs_original)
    added_range = np.max(abs_added) - np.min(abs_added)
    combined_range = np.max(np.concatenate([abs_original, abs_added])) - np.min(np.concatenate([abs_original, abs_added]))

    print(f"Original abs range: {orig_range:.6f}")
    print(f"Added abs range: {added_range:.6f}")
    print(f"Combined abs range: {combined_range:.6f}")

    # Check if both arrays have very small ranges (essentially equal min/max)
    small_range_threshold = 1e-6
    if combined_range < small_range_threshold:
        print(f"All absolute coefficients very similar (range: {combined_range:.2e}) - Jaccard = 1.0")
        return 1.0

    # Check if arrays are nearly identical in absolute terms
    if np.allclose(abs_original, abs_added, atol=1e-6, rtol=1e-4):
        print("Absolute arrays are nearly identical - Jaccard = 1.0")
        return 1.0

    # Use a threshold based on the combined distribution
    # Option 1: Simple > 0 threshold on absolute values
    threshold = 0

    # Option 2: Use mean of absolute values as threshold (uncomment to try)
    # threshold = np.mean(np.concatenate([abs_original, abs_added]))
    # print(f"Using mean threshold: {threshold:.6f}")

    # Convert coefficients to binary arrays using absolute values
    coeff_original_binary = (abs_original > threshold).astype(int)
    coeff_added_binary = (abs_added > threshold).astype(int)

    print(f"Original binary (abs > {threshold}): {coeff_original_binary} (sum: {np.sum(coeff_original_binary)})")
    print(f"Added binary (abs > {threshold}): {coeff_added_binary} (sum: {np.sum(coeff_added_binary)})")

    # Check if both binary arrays are identical
    if np.array_equal(coeff_original_binary, coeff_added_binary):
        print("Binary arrays are identical - Jaccard = 1.0")
        return 1.0

    # Check if both arrays are all zeros
    if np.sum(coeff_original_binary) == 0 and np.sum(coeff_added_binary) == 0:
        print("Both binary arrays are all zeros - Jaccard = 1.0")
        return 1.0

    # Calculate Jaccard index
    try:
        jaccard_index = jaccard_score(coeff_original_binary, coeff_added_binary, zero_division=1)
        print(f"Jaccard Index: {jaccard_index}")

        # Manual verification
        intersection = np.sum(coeff_original_binary & coeff_added_binary)
        union = np.sum(coeff_original_binary | coeff_added_binary)
        manual_jaccard = intersection / union if union > 0 else 1.0
        print(f"Manual verification - Intersection: {intersection}, Union: {union}, Jaccard: {manual_jaccard}")

        return jaccard_index
    except Exception as e:
        print(f"Error calculating Jaccard score: {e}")
        return 0.0



In [None]:
question = "What is insulin-like growth factor receptor binding associated with?"

Get original coeffiencets for temp = 0 and temp = 1

In [None]:
def get_coefficients(question, kg, part_names, portion_indices, temp=0):
    """
    Calculate coefficients for a given question using hybrid text metrics.

    Parameters:
    - question: Question string (e.g., "What is polycystic ovary syndrome?")
    - kg: Knowledge graph (list of triples)
    - part_names: List of part names in the KG
    - portion_indices: Dictionary mapping part names to index ranges
    - temp: Temperature (default 0)

    Returns:
    - coeff: Coefficients array
    """
    import sklearn.metrics
    import numpy as np
    from scipy.stats import wasserstein_distance
    from scipy.spatial.distance import cosine
    from sklearn.linear_model import LinearRegression
    import random

    # Define the original vector (all parts present)
    original = np.ones(len(part_names)).reshape(1, -1)

    # Create original graph
    graph = NetworkxEntityGraph()
    for node1, relation, node2 in kg:
        graph.add_triple(KnowledgeTriple(node1, relation, node2))

    # Get original answer and embedding
    original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)

    # Initialize lists
    similarities_wd = []
    similarities_cosine = []
    perturbations_vect2 = []
    epsilon = 1e-6

    # Loop for perturbations
    for i in range(20):
        perturbation_vector = original.copy().flatten()
        num_parts_to_remove = random.randint(1, len(part_names))
        parts_to_remove_indices = random.sample(range(len(part_names)), num_parts_to_remove)

        for part_idx in parts_to_remove_indices:
            perturbation_vector[part_idx] = 0

        perturbations_vect2.append(perturbation_vector)
        parts_to_remove = [part_names[idx] for idx in parts_to_remove_indices]

        # Create perturbed graph
        perturbed_kg = perturb_kg_by_removing_parts(kg, parts_to_remove, portion_indices)
        graph_temp = NetworkxEntityGraph()
        for node1, relation, node2 in perturbed_kg:
            graph_temp.add_triple(KnowledgeTriple(node1, relation, node2))

        # Get perturbed answer and embedding
        temp_response, temp_response_embedding = get_answer_and_embedding(question, temp, graph_temp)

        # Convert embeddings to numpy arrays if needed
        orig_emb = np.array(original_answer_embedding) if isinstance(original_answer_embedding, list) else original_answer_embedding
        temp_emb = np.array(temp_response_embedding) if isinstance(temp_response_embedding, list) else temp_response_embedding

        # Calculate metrics
        similarity_wd = wasserstein_distance(orig_emb, temp_emb)
        similarities_wd.append(similarity_wd)

        #similarity_cosine = 1 - cosine(orig_emb, temp_emb)
        #similarities_cosine.append(similarity_cosine)

    # Convert to numpy array
    perturbations_vect2 = np.array(perturbations_vect2)

    # Calculate weights
    distances = sklearn.metrics.pairwise_distances(perturbations_vect2, original, metric='cosine').ravel()
    kernel_width = 0.25
    weights = np.sqrt(np.exp(-(distances**2)/kernel_width**2))

    # Scale Wasserstein distances
    inverse_similarities_wd = [1.0 / (dist + epsilon) for dist in similarities_wd]
    min_value = min(inverse_similarities_wd)
    max_value = max(inverse_similarities_wd)

    if max_value == min_value:
        scaled_similarities_wd = [1.0 for _ in inverse_similarities_wd]
    else:
        scaled_similarities_wd = [(value - min_value) / (max_value - min_value) for value in inverse_similarities_wd]

    # Combine similarities
    #Similarities_ = [wd + cos for wd, cos in zip(scaled_similarities_wd, similarities_cosine)]
    Similarities_ = scaled_similarities_wd

    # Train model and get coefficients
    model = LinearRegression()
    model.fit(X=perturbations_vect2, y=Similarities_, sample_weight=weights)
    coeff = model.coef_

    return coeff

In [None]:
# Get coefficients for your question
coeff0 = get_coefficients(
    question =" What is insulin-like growth factor receptor binding associated with?",
    kg=kg,
    part_names=list(portion_indices.keys()),
    portion_indices=portion_indices,
    temp=0
)

print("Coefficients:", coeff0)

In [None]:
def get_coefficients(question, kg, part_names, portion_indices, temp=1):
    """
    Calculate coefficients for a given question using hybrid text metrics.

    Parameters:
    - question: Question string (e.g., "What is polycystic ovary syndrome?")
    - kg: Knowledge graph (list of triples)
    - part_names: List of part names in the KG
    - portion_indices: Dictionary mapping part names to index ranges
    - temp: Temperature (default 0)

    Returns:
    - coeff: Coefficients array
    """
    import sklearn.metrics
    import numpy as np
    from scipy.stats import wasserstein_distance
    from scipy.spatial.distance import cosine
    from sklearn.linear_model import LinearRegression
    import random
    temp = 1
    # Define the original vector (all parts present)
    original = np.ones(len(part_names)).reshape(1, -1)

    # Create original graph
    graph = NetworkxEntityGraph()
    for node1, relation, node2 in kg:
        graph.add_triple(KnowledgeTriple(node1, relation, node2))

    # Get original answer and embedding
    original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)

    # Initialize lists
    similarities_wd = []
    similarities_cosine = []
    perturbations_vect2 = []
    epsilon = 1e-6

    # Loop for perturbations
    for i in range(20):
        perturbation_vector = original.copy().flatten()
        num_parts_to_remove = random.randint(1, len(part_names))
        parts_to_remove_indices = random.sample(range(len(part_names)), num_parts_to_remove)

        for part_idx in parts_to_remove_indices:
            perturbation_vector[part_idx] = 0

        perturbations_vect2.append(perturbation_vector)
        parts_to_remove = [part_names[idx] for idx in parts_to_remove_indices]

        # Create perturbed graph
        perturbed_kg = perturb_kg_by_removing_parts(kg, parts_to_remove, portion_indices)
        graph_temp = NetworkxEntityGraph()
        for node1, relation, node2 in perturbed_kg:
            graph_temp.add_triple(KnowledgeTriple(node1, relation, node2))

        # Get perturbed answer and embedding
        temp_response, temp_response_embedding = get_answer_and_embedding(question, temp, graph_temp)

        # Convert embeddings to numpy arrays if needed
        orig_emb = np.array(original_answer_embedding) if isinstance(original_answer_embedding, list) else original_answer_embedding
        temp_emb = np.array(temp_response_embedding) if isinstance(temp_response_embedding, list) else temp_response_embedding

        # Calculate metrics
        similarity_wd = wasserstein_distance(orig_emb, temp_emb)
        similarities_wd.append(similarity_wd)

        #similarity_cosine = 1 - cosine(orig_emb, temp_emb)
        #similarities_cosine.append(similarity_cosine)

    # Convert to numpy array
    perturbations_vect2 = np.array(perturbations_vect2)

    # Calculate weights
    distances = sklearn.metrics.pairwise_distances(perturbations_vect2, original, metric='cosine').ravel()
    kernel_width = 0.25
    weights = np.sqrt(np.exp(-(distances**2)/kernel_width**2))

    # Scale Wasserstein distances
    inverse_similarities_wd = [1.0 / (dist + epsilon) for dist in similarities_wd]
    min_value = min(inverse_similarities_wd)
    max_value = max(inverse_similarities_wd)

    if max_value == min_value:
        scaled_similarities_wd = [1.0 for _ in inverse_similarities_wd]
    else:
        scaled_similarities_wd = [(value - min_value) / (max_value - min_value) for value in inverse_similarities_wd]

    # Combine similarities
    Similarities_ = scaled_similarities_wd #[wd + cos for wd, cos in zip(scaled_similarities_wd, similarities_cosine)]

    # Train model and get coefficients
    model = LinearRegression()
    model.fit(X=perturbations_vect2, y=Similarities_, sample_weight=weights)
    coeff = model.coef_

    return coeff

In [None]:
# Get coefficients for your question
coeff1 = get_coefficients(
    question ="What is insulin-like growth factor receptor binding associated with?",
    kg=kg,
    part_names=list(portion_indices.keys()),
    portion_indices=portion_indices,
    temp=0
)

print("Coefficients:", coeff1)

## Stability: Evaluate the explanation for "Question" (Temp 1) by introducing small perturbations in the input or graph structure -"Neonatal insulin-dependent diabetes mellitus", "treated_with", "insulin_therapy"

In [None]:
kg_added_base = [
    # Part 1 (20 original + 1 augmentation)
    ("Neonatal insulin-dependent diabetes mellitus", "associated with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "ppi", "Neonatal insulin-dependent diabetes mellitus"),
    ("Neonatal insulin-dependent diabetes mellitus", "associated with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "ppi", "Neonatal insulin-dependent diabetes mellitus"),
    ("Neonatal insulin-dependent diabetes mellitus", "ppi", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "Neonatal insulin-dependent diabetes mellitus"),
    ("Neonatal insulin-dependent diabetes mellitus", "ppi", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "Neonatal insulin-dependent diabetes mellitus"),
    ("pancreatic A cell fate commitment", "ppi", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "pancreatic A cell fate commitment"),
    ("pancreatic serous cystadenocarcinoma", "interacts with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "pancreatic serous cystadenocarcinoma"),
    ("pancreatic A cell fate commitment", "associated with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "ppi", "pancreatic A cell fate commitment"),
    ("pancreatic serous cystadenocarcinoma", "interacts with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "pancreatic serous cystadenocarcinoma"),
    ("pancreatic serous cystadenocarcinoma", "associated with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "interacts with", "pancreatic serous cystadenocarcinoma"),
    ("glucagon receptor activity", "ppi", "low-affinity glucose:proton symporter activity"),
    ("low-affinity glucose:proton symporter activity", "associated with", "glucagon receptor activity"),
    ("Neonatal insulin-dependent diabetes mellitus", "treated_with", "insulin_therapy"),  # Common augmentation
    # Part 2
    ("vasoconstriction of artery involved in baroreceptor response to lowering of systemic arterial blood pressure", "expression present", "glucose binding"),
    ("glucose binding", "expression present", "Neurofibrillary tangles"),
    ("Neurofibrillary tangles", "ppi", "glucose binding"),
    # Part 3
    ("Flurandrenolide", "synergistic interaction", "Insulin peglispro"),
    # Part 4
    ("glucose 1-phosphate phosphorylation", "ppi", "mitochondrial genome maintenance"),
    ("mitochondrial genome maintenance", "expression present", "glucose 1-phosphate phosphorylation"),
    ("glucose 1-phosphate phosphorylation", "expression present", "mitochondrial genome maintenance"),
    ("mitochondrial genome maintenance", "ppi", "glucose 1-phosphate phosphorylation"),
    # Part 5
    ("UDP-glucose:glycoprotein glucosyltransferase activity", "ppi", "serotonin:sodium symporter activity"),
    ("serotonin:sodium symporter activity", "associated with", "UDP-glucose:glycoprotein glucosyltransferase activity"),
    # Part 6
    ("Insulin tregopil", "synergistic interaction", "Cyclothiazide"),
    # Part 7
    ("UDP-glucose transmembrane transporter activity", "associated with", "protein-DNA-RNA complex remodeling"),
    # Part 8
    ("Severe intrauterine growth retardation", "phenotype present", "glucose-1-phosphate thymidylyltransferase activity"),
    ("glucose-1-phosphate thymidylyltransferase activity", "ppi", "CCL4L1"),
    # Part 9
    ("SFT2D2", "expression present", "adrenal gland"),
    ("SFT2D2", "expression present", "deltoid"),
    # Part 10
    ("diabetic peripheral angiopathy", "associated with", "calcium-release channel activity"),
    ("diabetic peripheral angiopathy", "associated with", "acute myeloid leukemia with t(8;21)(q22;q22) translocation")
]

In [None]:
part_indices_added_base = {
    "Part 1": range(0, 21),   # 20 original + 1 augmentation
    "Part 2": range(21, 24),  # 3 triples
    "Part 3": range(24, 25),  # 1 triple
    "Part 4": range(25, 29),  # 4 triples
    "Part 5": range(29, 31),  # 2 triples
    "Part 6": range(31, 32),  # 1 triple
    "Part 7": range(32, 33),  # 1 triple
    "Part 8": range(33, 35),  # 2 triples
    "Part 9": range(35, 37),  # 2 triples
    "Part 10": range(37, 39)  # 2 triples
}
part_names_added = list(part_indices_added_base.keys())

In [None]:
kg_added = kg_added_base + [
    ("insulin-like growth factor receptor binding", "associated_with", "type_1_diabetes")
]

part_indices_added = {

    "Part 1": range(0, 22),   # 21 base + 1 perturbation
    "Part 2": range(22, 25),
    "Part 3": range(25, 26),
    "Part 4": range(26, 30),
    "Part 5": range(30, 32),
    "Part 6": range(32, 33),
    "Part 7": range(33, 34),
    "Part 8": range(34, 36),
    "Part 9": range(36, 38),
    "Part 10": range(38, 40)
}
part_names_added = list(part_indices_added.keys())

# Print node counts for verification
print("Original KG node count:", len(set(node for triple in kg for node in (triple[0], triple[2]))))
print("Updated KG node count:", len(set(node for triple in kg_added for node in (triple[0], triple[2]))))

In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
temp= 1
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices=part_indices_added,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff = np.round(coeff, 3)  # Rounds to 3 decimal places
coeff

In [None]:
coeff_original = coeff1
coeff_added =coeff

# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 0) by introducing small perturbations in the input or graph structure -"Neonatal insulin-dependent diabetes mellitus", "treated_with", "insulin_therapy

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices=portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff = np.round(coeff, 3)  # Rounds to 3 decimal places
coeff

In [None]:
# Define the coefficient arrays
coeff_original = coeff0
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 1) by introducing small perturbations in the input or graph structure -("pancreatic serous cystadenocarcinoma", "treated_with", "chemotherapy")- and assessing the consistency of generated explanations.

In [None]:
kg_added = kg_added_base + [
    ("pancreatic serous cystadenocarcinoma", "treated_with", "chemotherapy")
]

part_indices_added = {
    "Part 1": range(0, 22),   # 21 base + 1 perturbation
    "Part 2": range(22, 25),
    "Part 3": range(25, 26),
    "Part 4": range(26, 30),
    "Part 5": range(30, 32),
    "Part 6": range(32, 33),
    "Part 7": range(33, 34),
    "Part 8": range(34, 36),
    "Part 9": range(36, 38),
    "Part 10": range(38, 40)
}
part_names_added = list(part_indices_added.keys())


In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(original_answer_str)

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices=portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff1
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation(Temp 0) by introducing small perturbations in the input or graph structure -("pancreatic serous cystadenocarcinoma", "treated_with", "chemotherapy")- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices=portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0
coeff_added =coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation f(Temp 1) by introducing small perturbations in the input or graph structure -("glucagon receptor activity", "inhibited_by", "insulin")- and assessing the consistency of generated explanations.

In [None]:
kg_added = kg_added_base + [
    ("glucagon receptor activity", "inhibited_by", "insulin")
]
part_indices_added = {
    "Part 1": range(0, 22),   # 21 base + 1 perturbation
    "Part 2": range(22, 25),
    "Part 3": range(25, 26),
    "Part 4": range(26, 30),
    "Part 5": range(30, 32),
    "Part 6": range(32, 33),
    "Part 7": range(33, 34),
    "Part 8": range(34, 36),
    "Part 9": range(36, 38),
    "Part 10": range(38, 40)
}
part_names_added = list(part_indices_added.keys())


In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(original_answer_str)

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices=portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff1
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 0) by introducing small perturbations in the input or graph structure -("glucagon receptor activity", "inhibited_by", "insulin")- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices=portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0
coeff_added = coeff

# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 1) by introducing small perturbations in the input or graph structure -("glucose binding", "expression_present", "pancreas")- and assessing the consistency of generated explanations.

In [None]:
kg_added =   kg_added_base + [
    ("glucose binding", "expression_present", "pancreas")
]
part_indices_added = {
    "Part 1": range(0, 21),
    "Part 2": range(21, 25),  # 3 base + 1 perturbation
    "Part 3": range(25, 26),
    "Part 4": range(26, 30),
    "Part 5": range(30, 32),
    "Part 6": range(32, 33),
    "Part 7": range(33, 34),
    "Part 8": range(34, 36),
    "Part 9": range(36, 38),
    "Part 10": range(38, 40)
}
part_names_added = list(part_indices_added.keys())

In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(original_answer_str)

In [None]:
temp= 1
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices = portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff1
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 0) by introducing small perturbations in the input or graph structure -("glucose binding", "expression_present", "pancreas")- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0
coeff_added = coeff

# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation for (Temp 1) by introducing small perturbations in the input or graph structure -("Flurandrenolide", "used_for", "diabetic_skin_conditions")- and assessing the consistency of generated explanations.

In [None]:
kg_added = kg_added_base + [
    ("Flurandrenolide", "used_for", "diabetic_skin_conditions")
]
part_indices_added = {
    "Part 1": range(0, 21),
    "Part 2": range(21, 24),
    "Part 3": range(24, 26),  # 1 base + 1 perturbation
    "Part 4": range(26, 30),
    "Part 5": range(30, 32),
    "Part 6": range(32, 33),
    "Part 7": range(33, 34),
    "Part 8": range(34, 36),
    "Part 9": range(36, 38),
    "Part 10": range(38, 40)
}
part_names_added = list(part_indices_added.keys())


In [None]:
#visualize_graph_with_chains(kg, portion_indices)

In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff1
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 0) by introducing small perturbations in the input or graph structure -(("Flurandrenolide", "used_for", "diabetic_skin_conditions")- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 1) by introducing small perturbations in the input or graph structure -(("glucose 1-phosphate phosphorylation", "associated_with", "glycogen_storage_disease")- and assessing the consistency of generated explanations.

In [None]:
kg_added = kg_added_base + [
    ("glucose 1-phosphate phosphorylation", "associated_with", "glycogen_storage_disease")
]

part_indices_added = {
    "Part 1": range(0, 21),
    "Part 2": range(21, 24),
    "Part 3": range(24, 25),
    "Part 4": range(25, 30),  # 4 base + 1 perturbation
    "Part 5": range(30, 32),
    "Part 6": range(32, 33),
    "Part 7": range(33, 34),
    "Part 8": range(34, 36),
    "Part 9": range(36, 38),
    "Part 10": range(38, 40)
}
part_names_added = list(part_indices_added.keys())



In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(original_answer_str)

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_added = coeff
coeff_original = coeff1
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 0) by introducing small perturbations in the input or graph structure -("glucose 1-phosphate phosphorylation", "associated_with", "glycogen_storage_disease")- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0
coeff_added = coeff
# Call the function, can we use kg added instead of original
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 1) by introducing small perturbations in the input or graph structure -(("UDP-glucose:glycoprotein glucosyltransferase activity", "regulates", "protein_folding")- and assessing the consistency of generated explanations.

In [None]:
kg_added = kg_added_base + [
    ("UDP-glucose:glycoprotein glucosyltransferase activity", "regulates", "protein_folding")
]

part_indices_added = {
    "Part 1": range(0, 21),
    "Part 2": range(21, 24),
    "Part 3": range(24, 25),
    "Part 4": range(25, 29),
    "Part 5": range(29, 32),  # 2 base + 1 perturbation
    "Part 6": range(32, 33),
    "Part 7": range(33, 34),
    "Part 8": range(34, 36),
    "Part 9": range(36, 38),
    "Part 10": range(38, 40)
}
part_names_added = list(part_indices_added.keys())


In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff1

coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 0) by introducing small perturbations in the input or graph structure -("UDP-glucose:glycoprotein glucosyltransferase activity", "regulates", "protein_folding")- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation(Temp 1) by introducing small perturbations in the input or graph structure -("SFT2D2", "expression_present", "pancreatic_islets")- and assessing the consistency of generated explanations.

In [None]:
kg_added = kg_added_base + [
    ("SFT2D2", "expression_present", "pancreatic_islets")
]
part_indices_added ={
    "Part 1": range(0, 21),
    "Part 2": range(21, 24),
    "Part 3": range(24, 25),
    "Part 4": range(25, 29),
    "Part 5": range(29, 31),
    "Part 6": range(31, 32),
    "Part 7": range(32, 33),
    "Part 8": range(33, 35),
    "Part 9": range(35, 38),  # 2 base + 1 perturbation
    "Part 10": range(38, 40)
}
part_names_added = list(part_indices_added.keys())


In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(original_answer_str)

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff1

coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 0) by introducing small perturbations in the input or graph structure -("SFT2D2", "expression_present", "pancreatic_islets")- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0

coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 1) by introducing small perturbations in the input or graph structure -("diabetic peripheral angiopathy", "caused_by", "chronic_hyperglycemia")- and assessing the consistency of generated explanations.

In [None]:
kg_added = kg_added_base + [
    ("diabetic peripheral angiopathy", "caused_by", "chronic_hyperglycemia")
]
part_indices_added = {
    "Part 1": range(0, 21),
    "Part 2": range(21, 24),
    "Part 3": range(24, 25),
    "Part 4": range(25, 29),
    "Part 5": range(29, 31),
    "Part 6": range(31, 32),
    "Part 7": range(32, 33),
    "Part 8": range(33, 35),
    "Part 9": range(35, 37),
    "Part 10": range(37, 40)  # 2 base + 1 perturbation
}
part_names_added = list(part_indices_added.keys())



In [None]:
#visualize_graph_with_chains(kg, part_indices)

In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(original_answer_str)

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff1
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 0) by introducing small perturbations in the input or graph structure -("diabetic peripheral angiopathy", "caused_by", "chronic_hyperglycemia"))- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0
coeff_added = coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation (Temp 1) by introducing small perturbations in the input or graph structure -("Neonatal insulin-dependent diabetes mellitus", "treated_with", "insulin_therapy")- and assessing the consistency of generated explanations.

In [None]:
kg_added = [
    # Part 1 (20 original + 1 augmentation)
    ("Neonatal insulin-dependent diabetes mellitus", "associated with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "ppi", "Neonatal insulin-dependent diabetes mellitus"),
    ("Neonatal insulin-dependent diabetes mellitus", "associated with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "ppi", "Neonatal insulin-dependent diabetes mellitus"),
    ("Neonatal insulin-dependent diabetes mellitus", "ppi", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "Neonatal insulin-dependent diabetes mellitus"),
    ("Neonatal insulin-dependent diabetes mellitus", "ppi", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "Neonatal insulin-dependent diabetes mellitus"),
    ("pancreatic A cell fate commitment", "ppi", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "pancreatic A cell fate commitment"),
    ("pancreatic serous cystadenocarcinoma", "interacts with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "pancreatic serous cystadenocarcinoma"),
    ("pancreatic A cell fate commitment", "associated with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "ppi", "pancreatic A cell fate commitment"),
    ("pancreatic serous cystadenocarcinoma", "interacts with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "associated with", "pancreatic serous cystadenocarcinoma"),
    ("pancreatic serous cystadenocarcinoma", "associated with", "insulin-like growth factor receptor binding"),
    ("insulin-like growth factor receptor binding", "interacts with", "pancreatic serous cystadenocarcinoma"),
    ("glucagon receptor activity", "ppi", "low-affinity glucose:proton symporter activity"),
    ("low-affinity glucose:proton symporter activity", "associated with", "glucagon receptor activity"),
    ("Neonatal insulin-dependent diabetes mellitus", "treated_with", "insulin_therapy"),  # Common augmentation
    # Part 2
    ("vasoconstriction of artery involved in baroreceptor response to lowering of systemic arterial blood pressure", "expression present", "glucose binding"),
    ("glucose binding", "expression present", "Neurofibrillary tangles"),
    ("Neurofibrillary tangles", "ppi", "glucose binding"),
    # Part 3
    ("Flurandrenolide", "synergistic interaction", "Insulin peglispro"),
    # Part 4
    ("glucose 1-phosphate phosphorylation", "ppi", "mitochondrial genome maintenance"),
    ("mitochondrial genome maintenance", "expression present", "glucose 1-phosphate phosphorylation"),
    ("glucose 1-phosphate phosphorylation", "expression present", "mitochondrial genome maintenance"),
    ("mitochondrial genome maintenance", "ppi", "glucose 1-phosphate phosphorylation"),
    # Part 5
    ("UDP-glucose:glycoprotein glucosyltransferase activity", "ppi", "serotonin:sodium symporter activity"),
    ("serotonin:sodium symporter activity", "associated with", "UDP-glucose:glycoprotein glucosyltransferase activity"),
    # Part 6
    ("Insulin tregopil", "synergistic interaction", "Cyclothiazide"),
    ("Insulin tregopil", "used_for", "type_2_diabetes"),  # Perturbation triple
    # Part 7
    ("UDP-glucose transmembrane transporter activity", "associated with", "protein-DNA-RNA complex remodeling"),
    # Part 8
    ("Severe intrauterine growth retardation", "phenotype present", "glucose-1-phosphate thymidylyltransferase activity"),
    ("glucose-1-phosphate thymidylyltransferase activity", "ppi", "CCL4L1"),
    # Part 9
    ("SFT2D2", "expression present", "adrenal gland"),
    ("SFT2D2", "expression present", "deltoid"),
    # Part 10
    ("diabetic peripheral angiopathy", "associated with", "calcium-release channel activity"),
    ("diabetic peripheral angiopathy", "associated with", "acute myeloid leukemia with t(8;21)(q22;q22) translocation")
]

part_indices_added = {
    "Part 1": range(0, 21),   # 21 triples
    "Part 2": range(21, 24),  # 3 triples
    "Part 3": range(24, 25),  # 1 triple
    "Part 4": range(25, 29),  # 4 triples
    "Part 5": range(29, 31),  # 2 triples
    "Part 6": range(31, 33),  # 1 base + 1 perturbation
    "Part 7": range(33, 34),  # 1 triple
    "Part 8": range(34, 36),  # 2 triples
    "Part 9": range(36, 38),  # 2 triples
    "Part 10": range(38, 40)  # 2 triples
}
part_names_added = list(part_indices_added.keys())



In [None]:
temp = 1
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff1
coeff_added =coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)

## Stability: Evaluate the explanation(Temp 0) by introducing small perturbations in the input or graph structure -("Neonatal insulin-dependent diabetes mellitus", "treated_with", "insulin_therapy")- and assessing the consistency of generated explanations.

In [None]:
temp = 0
original_answer_str, original_answer_embedding = get_answer_and_embedding(question, temp, graph)
print(f"Original answer: {original_answer_str}")

In [None]:
coeff = calculate_coefficients_print_Temerature(
    temp= temp,
    original=original,
    kg=kg_added,
    part_names=part_names_added,
    portion_indices =  portion_indices,
    question=question,
    original_answer_embedding=original_answer_embedding,
    original_answer_str = original_answer_str,
)

In [None]:
coeff

In [None]:
coeff_original = coeff0
coeff_added =coeff
# Call the function
plot_knowledge_graph_explainability_compare(
    kg_original= kg,
    coeff_original=coeff_original,
    part_indices_original=portion_indices,
    kg_added=kg_added,
    coeff_added=coeff_added,
    part_indices_added=part_indices_added
)

In [None]:
jaccard_index = calculate_jaccard_index(coeff_original, coeff_added)
print("Jaccard Index:", jaccard_index)