In [2]:
import numpy as np

# %%
import sys
import os

sys.path.insert(0, os.path.abspath("../.."))
sys.path.append(os.path.abspath("../../backend"))
sys.path.append(os.path.abspath(""))

from rdflib.plugins.stores.sparqlstore import SPARQLStore

from backend.ontology import OntologyManager, OntologyConfig, Graph
from backend.explorative.explorative_support import GuidanceManager
from backend.explorative.llm_query import (
    EnrichedEntitiesRelations,
    LLMQuery,
    QueryProgress,
)
from tqdm import tqdm
import pandas as pd

In [3]:
from seqeval import metrics
from sklearn.metrics import f1_score
import networkx as nx

In [9]:
store = SPARQLStore(
    "http://localhost:7012/",
    method="POST_FORM",
    params={"infer": False, "sameAs": False},
)
graph = Graph(store=store)

config = OntologyConfig()

ontology_manager = OntologyManager(config, graph)
topic_man = GuidanceManager(ontology_manager,
                            llm_model_id="unsloth/Mistral-Small-3.1-24B-Instruct-2503-GGUF",
                            llm_quant_model="*Q6_K.gguf")

query_man = LLMQuery(topic_man)

In [10]:
generated_queries = pd.read_csv("llama_examples.csv", index_col=0)
generated_queries["erl"] = generated_queries["erl"].apply(
    lambda x: EnrichedEntitiesRelations.model_validate_json(x)
)
generated_queries

Unnamed: 0,erl,response,generator,n_nodes,seed
0,relations=[EnrichedRelation(entity='Organisati...,A member of an organization died in a stream a...,llama,3,0
1,relations=[EnrichedRelation(entity='Organisati...,a Organisation member has a death place with s...,templated,3,0
2,relations=[EnrichedRelation(entity='written wo...,A multi-volume publication has volumes of a wr...,llama,3,1
3,relations=[EnrichedRelation(entity='written wo...,"a written work has a starring with actor, and ...",templated,3,1
4,"relations=[EnrichedRelation(entity='work', rel...",A work features an actor and is directed by a ...,llama,3,2
...,...,...,...,...,...
1795,relations=[EnrichedRelation(entity='ethnic gro...,a ethnic group has a population place with pop...,templated,10,297
1796,relations=[EnrichedRelation(entity='Organisati...,An Organisation member has nationality of a co...,llama,10,298
1797,relations=[EnrichedRelation(entity='Organisati...,a Organisation member has a nationality with c...,templated,10,298
1798,"relations=[EnrichedRelation(entity='person', r...",- A person is born in a populated place.\n- A ...,llama,10,299


In [11]:
generated_query = generated_queries.iloc[1]
target_erl: EnrichedEntitiesRelations = generated_query["erl"]
query = generated_query["response"]
progress = QueryProgress(id="0", max_steps=1, start_time="0")
query_man.run_query(query=query, progress=progress)


Loading LLM model unsloth/Mistral-Small-3.1-24B-Instruct-2503-GGUF None


(…)al-Small-3.1-24B-Instruct-2503-Q6_K.gguf:   0%|          | 0.00/19.3G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
(
    [[ent.type for ent in target_erl.entities]],
    [[ent.type for ent in progress.enriched_relations.entities]],
)

([['written work', 'actor', 'multi volume publication']],
 [['work', 'actor', 'multi volume publication']])

In [None]:
def graph_from_erl(erl: EnrichedEntitiesRelations):
    G = nx.DiGraph()
    for node in erl.entities:
        G.add_node(node.identifier, label=node.type)
    for link in erl.relations:
        G.add_edge(
            link.entity,
            link.target,
            weight=link.link.instance_count,
            label=link.relation,
        )
    return G

In [None]:
def f1k(y_true, y_pred, k: int = None):
    rel_set = set(y_true)
    # print(rel_set)
    doc_set = set(y_pred[:k])
    tp = len(doc_set.intersection(rel_set))  # docs that are in both -relevant docs
    fp = len(
        doc_set.difference(rel_set)
    )  # docs that are not in relevant set - irrelevant docs (false positiv)
    fn = len(
        rel_set.difference(doc_set)
    )  # relevant docs that are not present in doc set - missing docs
    if tp == 0:
        return 0
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    return 2 * precision * recall / (precision + recall)

In [None]:
f1_score_ents = f1k(
    [ent.type for ent in target_erl.entities],
    [ent.type for ent in progress.enriched_relations.entities],
)
f1_score_ents


0.6666666666666666

In [None]:
target_graph = graph_from_erl(target_erl)
retrieved_grap = graph_from_erl(progress.enriched_relations)


def edge_match(a, b):
    return a["label"] == b["label"] if "label" in a and "label" in b else False


def node_match(a, b):
    return a["label"] == b["label"] if "label" in a and "label" in b else False


1 - nx.graph_edit_distance(
    target_graph, retrieved_grap, edge_match=edge_match, node_match=node_match
) / (
    len(target_graph.edges)
    + len(retrieved_grap.edges)
    + len(target_graph.nodes)
    + len(retrieved_grap.nodes)
)

0.7

In [None]:
def run_eval(query: pd.Series, llm_query: LLMQuery):
    target_erl: EnrichedEntitiesRelations = query["erl"]
    query = query["response"]
    progress = QueryProgress(id="0", max_steps=1, start_time="0")
    llm_query.run_query(query=query, progress=progress)
    f1_score_ents = f1k(
        [ent.type for ent in target_erl.entities],
        [ent.type for ent in progress.enriched_relations.entities],
    )
    target_graph = graph_from_erl(target_erl)
    retrieved_grap = graph_from_erl(progress.enriched_relations)
    edit_distance = 1 - nx.graph_edit_distance(
        target_graph, retrieved_grap, edge_match=edge_match, node_match=node_match
    ) / (
        len(target_graph.edges)
        + len(retrieved_grap.edges)
        + len(target_graph.nodes)
        + len(retrieved_grap.nodes)
    )
    return f1_score_ents, edit_distance


def run_evals(queries: pd.DataFrame, llm_query: LLMQuery):
    results = queries.copy()
    for i, query in tqdm(queries.iterrows(), total=len(queries)):
        f1_score_ents, edit_distance = run_eval(query, llm_query)
        results.loc[i, "f1_score"] = f1_score_ents
        results.loc[i, "edit_distance"] = edit_distance
    return results


for zero_shot in [True, False]:
    print(f"oneshot={zero_shot}")
    query_man = LLMQuery(topic_man, zero_shot=zero_shot)
    results = run_evals(generated_queries, llm_query=query_man)
    results.to_csv(
        f"restuls/eval_results_{'oneshot' if zero_shot else 'zeroshot'}_{topic_man.llm_model_id.replace('/', '-')}.csv"
    )

oneshot=True


  0%|          | 0/897 [00:00<?, ?it/s]

Llama.generate: 6 prefix-match hit, remaining 120 prompt tokens to eval
llama_perf_context_print:        load time =     308.59 ms
llama_perf_context_print: prompt eval time =       0.00 ms /   120 tokens (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /   264 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    8165.69 ms /   384 tokens
Llama.generate: 6 prefix-match hit, remaining 98 prompt tokens to eval
llama_perf_context_print:        load time =     308.59 ms
llama_perf_context_print: prompt eval time =       0.00 ms /    98 tokens (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /   325 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    9485.19 ms /   423 tokens
  0%|          | 1/897 [00:20<5:00:07, 20.10s/it]Llama.generate: 6 prefix-match hit, re

KeyboardInterrupt: 