In [23]:
import re
import requests
import numpy as np
from rdflib import Graph
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# --- Configuration ---
GTOKEN = "AIzaSyDQQY8FmOW6erFivgwsHjAdf419PYddNis"
genai.configure(api_key=GTOKEN)
GRAPHDB_ENDPOINT = "http://Vishals-MacBook-Air.local:7200/repositories/thesis"
TBOX_PATH = "model_card.ttl"
ENCODER_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
GEMINI_MODEL = genai.GenerativeModel("gemini-1.5-flash")

# --- Load T-Box ---
kg = Graph()
kg.parse(TBOX_PATH, format="turtle")

# --- KG Prefix Extraction ---
PREFIXES = """
PREFIX mcro: <http://purl.obolibrary.org/obo/mcro.owl#>
PREFIX dul: <http://www.ontologydesignpatterns.org/ont/dul/DUL.owl#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
"""

# --- T-Box Query Functions ---
def get_classes_properties():
    query = """
    SELECT DISTINCT ?resource ?label ?comment WHERE {
        { ?resource a rdfs:Class }
        UNION
        { ?resource a rdf:Property }
        OPTIONAL { ?resource rdfs:label ?label }
        OPTIONAL { ?resource rdfs:comment ?comment }
    }
    """
    return kg.query(query)

# --- Core Functions ---
def get_kg_subgraph(nl_question: str) -> str:
    tokens = set(re.findall(r'\b\w+\b', nl_question.lower()))
    matched_resources = []
    
    # Add class/property matching
    for row in get_classes_properties():
        uri = str(row.resource)
        label = str(row.label).lower() if row.label else uri.split("#")[-1].lower()
        comment = str(row.comment).lower() if row.comment else ""
        
        # Exact match boosting
        if any(token == label.split("#")[-1] for token in tokens):
            score = 2.0
        else:
            score = sum(1 for token in tokens if token in label or token in comment)
        
        if score > 0:
            matched_resources.append({
                "uri": uri,
                "label": label,
                "comment": comment,
                "score": score
            })
    
    # Sort by relevance
    matched_resources.sort(key=lambda x: -x['score'])
    
    return "\n".join([
        f"- {res['label']} ({res['comment']})" 
        for res in matched_resources[:5]
    ])

def get_similar_examples(nl_question: str, k=3) -> str:
    EXAMPLE_DATASET = [
        {
            "nl": "Show all models",
            "sparql": f"""
            {PREFIXES}
            SELECT DISTINCT ?model WHERE {{
                ?model a mcro:Model .
            }}
            """
        },
        {
            "nl": "Show models using Vision Transformer architecture",
            "sparql": f"""
            {PREFIXES}
            SELECT ?model WHERE {{
                ?model mcro:hasArchitecture ?arch .
                ?arch dul:hasParameterDataValue "Vision Transformer (ViT)" .
            }}
            """
        },
        {
            "nl": "Find models with accuracy over 90%",
            "sparql": f"""
            {PREFIXES}
            SELECT ?model ?accuracy WHERE {{
                ?model mcro:hasEvaluationScores ?scores .
                ?scores dul:hasParameterDataValue ?data .
                BIND(REPLACE(STR(?data), '[^0-9.]+', '') AS ?acc_str)
                BIND(xsd:decimal(?acc_str) AS ?accuracy)
                FILTER(?accuracy > 90)
            }}
            """
        }
    ]
    
    query_embedding = ENCODER_MODEL.encode(nl_question)
    example_embeddings = [ENCODER_MODEL.encode(ex["nl"]) for ex in EXAMPLE_DATASET]
    
    similarities = cosine_similarity([query_embedding], example_embeddings)[0]
    top_indices = np.argsort(similarities)[-k:][::-1]
    
    examples = []
    for i in top_indices:
        examples.append(f"Question: {EXAMPLE_DATASET[i]['nl']}\nSPARQL: {EXAMPLE_DATASET[i]['sparql']}")
    
    return "\n".join(examples)

def text_to_sparql(nl_question: str) -> str:
    kg_context = get_kg_subgraph(nl_question)
    examples = get_similar_examples(nl_question)
    
    prompt = f"""
You are a SPARQL expert. Generate a query for this KG schema:

KG Schema:
{kg_context}

Prefixes:
{PREFIXES}

Examples:
{examples}

Question: {nl_question}

Rules:
1. Use SELECT DISTINCT for model queries
2. mcro:Model is the base class for ML models
3. For numeric filters: BIND(REPLACE(STR(?value), '[^0-9.]', '') AS ?num)
4. Use FILTER with xsd:decimal for comparisons
5. Always include PREFIX declarations
6. Only return valid SPARQL in ```sparql blocks
"""
    
    response = GEMINI_MODEL.generate_content(prompt)
    raw_sparql = response.text
    
    if "```sparql" in raw_sparql:
        return re.search(r"```sparql(.*?)```", raw_sparql, re.DOTALL).group(1).strip()
    return raw_sparql

def execute_sparql(sparql: str) -> dict:
    try:
        response = requests.post(
            GRAPHDB_ENDPOINT,
            headers={
                "Accept": "application/sparql-results+json",
                "Content-Type": "application/sparql-query"
            },
            data=sparql.encode('utf-8'),
            timeout=30  # Increased timeout
        )
        response.raise_for_status()
        return response.json()
    except requests.exceptions.HTTPError as e:
        error_details = {
            "status_code": e.response.status_code,
            "message": e.response.text,
            "query": sparql
        }
        return {"error": error_details}
    except Exception as e:
        return {"error": str(e), "query": sparql}
    
def sparql_to_nl(results: dict) -> str:
    if "error" in results:
        error = results["error"]
        return f"Error {error.get('status_code', '')}: {error.get('message', 'Unknown error')}"
    
    bindings = results.get("results", {}).get("bindings", [])
    if not bindings:
        return "No results found"
    
    output = []
    for item in bindings:
        model_uri = item.get("model", {}).get("value", "")
        model_name = model_uri.split("/")[-1].replace("mcro_", "")
        
        # Handle additional fields
        accuracy = item.get("accuracy", {}).get("value", "")
        if accuracy:
            model_name += f" (Accuracy: {accuracy}%)"
        
        output.append(f"- {model_name}")
    
    return "\n".join(sorted(output)) if output else "No models found"

# --- Main Workflow ---
if __name__ == "__main__":
    question = "Show all models"
    
    print("Generating SPARQL query...")
    sparql = text_to_sparql(question)
    print("\nGenerated SPARQL:\n", sparql)
    
    if "SELECT" in sparql.upper():
        print("\nExecuting query...")
        results = execute_sparql(sparql)
        answer = sparql_to_nl(results)
        print("\nResults:\n", answer)
    else:
        print("Invalid SPARQL query generated")

Generating SPARQL query...

Generated SPARQL:
 PREFIX mcro: <http://purl.obolibrary.org/obo/mcro.owl#>
PREFIX dul: <http://www.ontologydesignpatterns.org/ont/dul/DUL.owl#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>

SELECT DISTINCT ?model WHERE {
    ?model a mcro:Model .
}

Executing query...

Results:
 - mcro.owl#Falconsainsfwimagedetection
- mcro.owl#dima806fairfaceageimagedetection
- mcro.owl#googlebertbertbaseuncased
- mcro.owl#openaiclipvitlargepatch14
- mcro.owl#sentencetransformersallMiniLML6v2
- mcro.owl#timmmobilenetv3small100lambin1k
