In [11]:
import json
import requests
import time
from getpass import getpass
from huggingface_hub import HfApi
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
from tenacity import retry, stop_after_attempt, wait_exponential
from fuzzywuzzy import process

# ================== CONFIGURATION ==================
HF_TOKEN = getpass("Enter Hugging Face Token (or set HF_TOKEN environment variable): ")

# Initialize Hugging Face API client
hf_api = HfApi(token=HF_TOKEN)

# Ontology prefix to remove
ONTOLOGY_PREFIX = "http://purl.obolibrary.org/obo/mcro.owl#"

# Model name mapping cache
MODEL_NAME_CACHE = {}

# Output file
RESULTS_FILE = "validation_results.json"
# ===================================================

def load_and_clean_triples(file_path):
    """Load JSON triples and remove ontology prefix"""
    try:
        with open(file_path, 'r') as f:
            raw_triples = json.load(f)
        
        cleaned_triples = []
        for t in raw_triples:
            cleaned_triple = {
                key: value.replace(ONTOLOGY_PREFIX, '') 
                for key, value in t.items()
            }
            cleaned_triples.append(cleaned_triple)
        
        return cleaned_triples
    
    except FileNotFoundError:
        print(f"❌ ERROR: File {file_path} not found")
        return []

@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, max=10))
def auto_map_model(local_name, threshold=70):
    """Search HF models via direct API call"""
    if local_name in MODEL_NAME_CACHE:
        return MODEL_NAME_CACHE[local_name]
    
    try:
        url = f"https://huggingface.co/api/models?search={local_name[:30]}"
        response = requests.get(url, headers={"Authorization": f"Bearer {HF_TOKEN}"})
        
        if response.status_code != 200:
            raise Exception(f"API error: {response.status_code}")
            
        results = response.json()
        
        if not results:
            MODEL_NAME_CACHE[local_name] = (None, 0)
            return None, 0

        # Extract model IDs and match similarity
        candidates = [(m["id"], m["id"].lower()) for m in results]
        matches = [m[1] for m in candidates]

        best_match = process.extractOne(local_name.lower(), matches)

        if best_match and best_match[1] >= threshold:
            try:
                match_index = matches.index(best_match[0])
                hf_model_id = candidates[match_index][0]
                MODEL_NAME_CACHE[local_name] = (hf_model_id, best_match[1])
                return MODEL_NAME_CACHE[local_name]
            except ValueError:
                pass

        MODEL_NAME_CACHE[local_name] = (None, 0)
        return None, 0
        
    except Exception as e:
        print(f"Search error for '{local_name}': {e}")
        MODEL_NAME_CACHE[local_name] = (None, 0)
        return None, 0

def extract_numeric_value(metadata, keyword):
    """Extract numeric values from description/text fields"""
    text = str(metadata.get("description", "") + 
              " " + str(metadata.get("modelId"))).lower()
    
    import re
    patterns = {
        "batch_size": r"batch[\s_-]?size.*?(\d+)",
        "learning_rate": r"(?:lr|learning[\s_-]?rate).*?([0-9.eE\-]+)",
        "epochs": r"epoch[s]?:?\s*(\d+)"
    }
    
    if pattern := patterns.get(keyword):
        if match := re.search(pattern, text):
            return match.group(1)
    return None

# Map ontology predicates to Hugging Face API fields
PREDICATE_MAP = {
    # Direct field mappings
    "hasLicense": ("license", lambda m: m.get("license")),
    "hasTask": ("pipeline_tag", lambda m: m.get("pipeline_tag")),
    "hasAuthor": ("author", lambda m: m.get("author")),
    "hasDownloadCount": ("downloads", lambda m: m.get("downloads")),
    
    # Custom field handlers
    "hasArchitecture": ("architecture", lambda m: m.get("modelId").split("/")[-1].lower()),
    "hasDataset": ("dataset_name", lambda m: str(m.get("dataset_name", ""))),
    "hasTrainingDataset": ("train_dataset", lambda m: str(m.get("dataset_name", ""))),
    "hasClasses": ("classification", lambda m: "classification" in str(m.get("tags", []))),
    "hasBatchSize": ("batch_size", lambda m: extract_numeric_value(m, "batch_size")),
    "hasLearningRate": ("learning_rate", lambda m: extract_numeric_value(m, "learning_rate")),
    "hasEpochs": ("epochs", lambda m: extract_numeric_value(m, "epochs")),
    "hasMetric": ("metric", lambda m: any(tag in str(m.get("tags", [])).lower() for tag in ["accuracy", "f1", "bleu"])),
    "hasEvaluationAccuracy": ("accuracy", lambda m: "accuracy" in str(m.get("metrics", []))),
    "hasLanguage": ("language", lambda m: "en" in str(m.get("tags", []))),
    "hasModelDate": ("date", lambda m: m.get("last_modified").split("T")[0]),  # YYYY-MM-DD format
}

def validate_hf_triple(triple, max_retries=3):
    """Validate triples against Hugging Face API when possible"""
    subject = triple['s']
    predicate = triple['p']
    expected_value = triple['o'].lower()
    
    # Skip ontology-only statements
    if predicate in ["rdf:type", "rdfs:subClassOf"]:
        return {
            **triple,
            "status": "SKIPPED",
            "reason": "Ontology statement"
        }
    
    # Get mapping info
    mapping_info = PREDICATE_MAP.get(predicate)
    if not mapping_info:
        return {
            **triple,
            "status": "UNMAPPED",
            "reason": f"Unknown predicate: {predicate}"
        }
    
    field_name, value_extractor = mapping_info
    
    # Map model name
    if subject in MODEL_NAME_CACHE:
        hf_model_id, confidence = MODEL_NAME_CACHE[subject]
    else:
        hf_model_id, confidence = auto_map_model(subject)
    
    if not hf_model_id:
        return {
            **triple,
            "status": "UNMAPPED",
            "reason": f"No HF mapping found for '{subject}' (confidence: {confidence})"
        }
    
    # Get model metadata
    try:
        metadata = safe_get_metadata(hf_model_id)
        if not metadata:
            return {
                **triple,
                "status": "ERROR",
                "reason": "Failed to fetch metadata after retries"
            }
            
        actual_value = value_extractor(metadata)
        result = str(expected_value).lower() == str(actual_value).lower() if actual_value else False
        
        return {
            **triple,
            "status": result,
            "actual_value": actual_value,
            "field": field_name,
            "hf_model_id": hf_model_id,
            "confidence": confidence
        }
        
    except Exception as e:
        return {
            **triple,
            "status": "ERROR",
            "reason": str(e),
            "hf_model_id": hf_model_id
        }

@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, max=10))
def safe_get_metadata(model_id):
    """Get model metadata with retry logic"""
    url = f"https://huggingface.co/api/models/{model_id}"
    headers = {"Authorization": f"Bearer {HF_TOKEN}"}
    
    response = requests.get(url, headers=headers, timeout=10)
    
    if response.status_code == 429:  # Rate limited
        wait = int(response.headers.get('Retry-After', '5'))
        print(f"⏳ Rate limited. Waiting {wait}s...")
        time.sleep(wait)
        raise Exception("Rate limited")
        
    if response.status_code != 200:
        raise Exception(f"HTTP {response.status_code}: {response.text[:100]}")
        
    return response.json()

def batch_validate_triples(triples, batch_size=50, max_workers=10):
    """Process triples in batches with parallel execution"""
    results = []
    total = len(triples)
    
    for i in range(0, total, batch_size):
        batch = triples[i:i+batch_size]
        print(f"\nProcessing batch {i//batch_size + 1} ({i+1}-{min(i+batch_size, total)})")
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            for triple in batch:
                futures.append(executor.submit(validate_hf_triple, triple))
            
            for future in futures:
                results.append(future.result())
                
        time.sleep(5)  # Rate limit buffer
        
    return results

def generate_report(results):
    """Generate summary report focusing on accuracy metrics"""
    valid = [r for r in results if r["status"] is True]
    invalid = [r for r in results if r["status"] is False]
    skipped = [r for r in results if r["status"] == "SKIPPED"]
    unmapped = [r for r in results if r["status"] == "UNMAPPED"]
    errors = [r for r in results if r["status"] == "ERROR"]
    
    # Calculate match confidence statistics
    mapped_results = [r for r in results if r.get("confidence", 0) > 0]
    confidences = [r["confidence"] for r in mapped_results]
    avg_confidence = sum(confidences)/len(confidences) if confidences else 0
    
    # Calculate accuracy by predicate
    predicate_stats = {}
    for r in results:
        pred = r['p']
        if pred not in predicate_stats:
            predicate_stats[pred] = {"total": 0, "valid": 0}
        predicate_stats[pred]["total"] += 1
        if r["status"] is True:
            predicate_stats[pred]["valid"] += 1

    # Calculate semantic accuracy (excluding skipped)
    semantic_results = [r for r in results if r["status"] not in ["SKIPPED", "UNMAPPED", "ERROR"]]
    semantic_valid = [r for r in semantic_results if r["status"] is True]
    semantic_accuracy = len(semantic_valid)/len(semantic_results) * 100 if semantic_results else 0
    
    # Print detailed accuracy report
    print("\n📊 SEMANTIC ACCURACY REPORT")
    print("="*50)
    print(f"Overall Semantic Accuracy: {semantic_accuracy:.1f}%")
    print(f"Average Match Confidence: {avg_confidence:.1f}/100")
    print(f"Total Valid Matches: {len(valid)}")
    print(f"Total Invalid Matches: {len(invalid)}")
    print(f"Unmapped Models: {len(unmapped)}")
    print(f"Validation Errors: {len(errors)}")
    
    print("\n📈 PREDICATE-SPECIFIC ACCURACY:")
    for pred, stats in sorted(predicate_stats.items()):
        acc = stats["valid"]/stats["total"] * 100
        print(f"{pred:<20} {acc:.1f}% ({stats['valid']}/{stats['total']})")
    
    print("\n🔍 CONFIDENCE DISTRIBUTION:")
    print(f"Highest Confidence: {max(confidences) if confidences else 0:.1f}")
    print(f"Lowest Confidence: {min(confidences) if confidences else 0:.1f}")
    print(f"Confidence Threshold: 70%")

    # Save detailed results with accuracy metrics
    report_data = {
        "summary": {
            "overall_semantic_accuracy": semantic_accuracy,
            "average_match_confidence": avg_confidence,
            "validation_stats": {
                "valid": len(valid),
                "invalid": len(invalid),
                "skipped": len(skipped),
                "unmapped": len(unmapped),
                "errors": len(errors)
            },
            "predicate_accuracy": {
                pred: {
                    "accuracy": stats["valid"]/stats["total"] * 100,
                    "count": stats["total"]
                } for pred, stats in predicate_stats.items()
            },
            "confidence_stats": {
                "average": avg_confidence,
                "highest": max(confidences) if confidences else 0,
                "lowest": min(confidences) if confidences else 0,
                "threshold": 70
            }
        },
        "detailed_results": results
    }
    
    with open(RESULTS_FILE, "w") as f:
        json.dump(report_data, f, indent=2)
    
    print(f"\n📄 Full report saved to {RESULTS_FILE}")

def main():
    # Load triples
    TRIPLES_FILE = "extracted_triples.json"
    print(f"🔄 Loading triples from {TRIPLES_FILE}...")
    
    triples = load_and_clean_triples(TRIPLES_FILE)
    
    if not triples:
        print("No triples loaded. Exiting.")
        return
    
    print(f"✅ Loaded {len(triples)} triples\n")
    
    # Validate triples
    results = batch_validate_triples(triples)
    
    # Generate final report
    generate_report(results)

if __name__ == "__main__":
    main()

🔄 Loading triples from extracted_triples.json...
✅ Loaded 284 triples


Processing batch 1 (1-50)

Processing batch 2 (51-100)

Processing batch 3 (101-150)

Processing batch 4 (151-200)

Processing batch 5 (201-250)

Processing batch 6 (251-284)

📊 SEMANTIC ACCURACY REPORT
Overall Semantic Accuracy: 0.0%
Average Match Confidence: 95.0/100
Total Valid Matches: 0
Total Invalid Matches: 29
Unmapped Models: 136
Validation Errors: 1

📈 PREDICATE-SPECIFIC ACCURACY:
dul:hasParameterDataValue 0.0% (0/83)
hasArchitecture      0.0% (0/4)
hasBatchSize         0.0% (0/2)
hasClasses           0.0% (0/2)
hasDataset           0.0% (0/15)
hasDownstreamTask    0.0% (0/3)
hasEvaluationAccuracy 0.0% (0/1)
hasEvaluationMetric  0.0% (0/1)
hasEvaluationResult  0.0% (0/3)
hasFineTuningObjective 0.0% (0/1)
hasFormat            0.0% (0/1)
hasHyperparameter    0.0% (0/2)
hasIntendedUse       0.0% (0/1)
hasLanguage          0.0% (0/1)
hasLearningRate      0.0% (0/2)
hasMetric            0.0% (0/1)
hasModelCreat