In [None]:
import pandas as pd
import numpy as np
import sys
# Добавляем путь к модулям проекта
sys.path.append('/home/igor/projects/CollisonDetection')

# Импорты из проекта
from src.graph_rag import custom_embedder, custom_llm
from src.llm_service import YandexCloudLLM
from src.fact_checker import FactConsistencyChecker
from llama_index.core import PropertyGraphIndex, Document
from llama_index.graph_stores.neo4j import Neo4jPGStore
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.retrievers import VectorContextRetriever, LLMSynonymRetriever
from llama_index.core import QueryBundle
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.prompts import PromptTemplate
import tokens


In [4]:
query = pd.read_csv('/home/igor/projects/CollisonDetection/test/query_30.csv')
query.head(3)


Unnamed: 0,query,evidence_wiki_url,label
0,Iowa is a part of the Midwestern United States.,Midwestern_United_States,SUPPORTS
1,South Korea has a highly educated workforce.,South_Korea,SUPPORTS
2,Michael Fassbender is an actor.,Michael_Fassbender,SUPPORTS


In [None]:
# Загрузка существующего графа
graph_store = Neo4jPGStore(url=tokens.NEO4J_URL, username=tokens.NEO4J_USERNAME, password=tokens.NEO4J_PASSWORD)
graph_index = PropertyGraphIndex.from_existing(
    llm=custom_llm,
    property_graph_store=graph_store,
    embed_model=custom_embedder,
    include_embeddings=True,
)

In [None]:
# Функция для получения релевантных узлов (из main.py)
def get_retrieved_nodes(index, query_str, vector_top_k=30, reranker_top_n=5):
    query_bundle = QueryBundle(query_str)
    
    syn = LLMSynonymRetriever(graph_store=index.property_graph_store, llm=custom_llm, 
                               include_text=True, max_keywords=8, path_depth=5)
    vec = VectorContextRetriever(graph_store=index.property_graph_store, vector_store=index.vector_store,
                                   embed_model=custom_embedder, include_text=True, 
                                   similarity_top_k=vector_top_k, path_depth=5)
    
    retriever = index.as_retriever(sub_retrievers=[syn, vec], include_text=True)
    retrieved_nodes = retriever.retrieve(query_bundle)
    
    # Reranker
    prompt_str = (
        "A list of documents is shown below. Each document has a number next to it along with a summary of the document. A question is also provided. \n"
        "Respond with the numbers of the documents you should consult to answer the question, in order of relevance, as well as the relevance score. "
        "The relevance score is a number from 1-10 based on how relevant you think the document is to the question.\n"
        "Prioritize documents based on their relevance to the question, regardless of whether they support or contradict the query. "
        "Both confirming and contradicting facts are considered equally relevant if they provide significant information, context, or arguments related to the question.\n"
        "Example format: \n"
        "Document 1:\n<summary of document 1>\n\nDocument 2:\n<summary of document 2>\n\n...\n\n"
        "Question: <question>\nAnswer:\nDoc: 9, Relevance: 7\nDoc: 3, Relevance: 4\n\n"
        "Let's try this now: \n\n{context_str}\nQuestion: {query_str}\nAnswer:\n"
    )
    reranker = LLMRerank(llm=custom_llm, choice_batch_size=5, top_n=reranker_top_n,
                          choice_select_prompt=PromptTemplate(template=prompt_str))
    return reranker.postprocess_nodes(retrieved_nodes, query_bundle)

# Инициализация Fact Checker
llm_service = YandexCloudLLM(api_key=tokens.AUTH_TOKEN, folder_id=tokens.FOLDER_ID, 
                              model_uri="yandexgpt-lite", temperature=0.1)
checker = FactConsistencyChecker(llm_service=llm_service, language="en")


In [None]:
# Цикл по всем запросам
results = []
for idx, row in query.iterrows():
    q = row['query']
    true_label = row['label']
    
    try:
        # Получаем релевантные узлы
        nodes = get_retrieved_nodes(graph_index, q)
        facts = [str(node.node.get_text()) for node in nodes]
        
        # Проверяем факты
        result = checker.check_facts(q, facts)
        
        # Определяем предсказанную метку
        if result.has_conflicts:
            pred_label = "REFUTES"
        elif result.has_supporting_facts:
            pred_label = "SUPPORTS"
        else:
            pred_label = "NOT ENOUGH INFO"
        
        results.append({
            'query': q,
            'true_label': true_label,
            'predicted_label': pred_label,
            'confidence': result.confidence
        })
    except ValueError as e:
        print(f"Ошибка при обработке запроса '{q}': {str(e)}. Пропускаем.")
        continue

results_df = pd.DataFrame(results)
print(f"\n✓ Обработано {len(results_df)} запросов")


In [99]:
# Преобразуем в списки
y_true = list(results_df['true_label'])
y_pred = list(results_df['predicted_label'])

# Целевой класс
target_class = 'SUPPORTS'

tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == target_class and yp == target_class)
fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt != target_class and yp == target_class)
fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == target_class and yp != target_class)
tn = sum(1 for yt, yp in zip(y_true, y_pred) if yt != target_class and yp != target_class)

accuracy = (tp + tn) / len(y_true)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0

print(f"Accuracy:  {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Precision: {precision:.4f} ({precision*100:.2f}%)")
print(f"Recall:    {recall:.4f} ({recall*100:.2f}%)")



Accuracy:  0.9000 (90.00%)
Precision: 0.9483 (94.83%)
Recall:    0.8871 (88.71%)
