# Block 1: Imports, Configuration, and Base Classes

In [None]:
import pandas as pd
import numpy as np
from neo4j import GraphDatabase
from sentence_transformers import SentenceTransformer, util
import requests
import json
import logging
from typing import List, Dict, Set, Optional, Tuple, Any
from tqdm.notebook import tqdm
import time
import re
import asyncio
import aiohttp
import psutil
import gc
from dataclasses import dataclass, asdict
from enum import Enum, auto
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
import nest_asyncio
from pathlib import Path
from datetime import datetime
from collections import OrderedDict

# Appely nest_asyncio for Jupyter compatibility
nest_asyncio.apply()

# Configuration
class Config:
    # Neo4j connection
    NEO4J_URI = "neo4j://localhost:7687"
    NEO4J_USER = "neo4j"
    NEO4J_PASSWORD = "kuxFc8HN"
    
    # LLM settings
    OLLAMA_API = "http://localhost:11434/api/generate"
    OLLAMA_MODEL = "mixtral"
    OLLAMA_TIMEOUT = 30

    # Processing settings
    BATCH_SIZE = 5
    SAMPLE_SIZE = 1741  # <-- Add this to process only 10 cases
    MEMORY_THRESHOLD_MB = 1000
    
    # Scoring thresholds for cascading evaluation
    class Thresholds:
        KEYWORD = 0.2      # If keyword match is strong enough, skip other evaluations
        SEMANTIC = 0.3    # If semantic match is strong enough, skip LLM
        LLM = 0.5         # Minimum threshold for LLM-based matches
        
        # Relationship type thresholds
        PRIMARY = 0.7     
        SECONDARY = 0.5
        RELATED = 0.3

    # Method weights for final score calculation
    METHOD_WEIGHTS = {
        'keyword': {
            'score_weight': 1.0,
            'confidence': 'HIGH'
        },
        'semantic': {
            'score_weight': 0.9,
            'confidence': 'MEDIUM'
        },
        'llm': {
            'score_weight': 0.8,
            'confidence': 'VARIABLE'  # Will be set by LLM response
        }
    }

# Set up logging with enhanced output
logging.basicConfig(
    level=logging.DEBUG,  # Change to DEBUG for more verbose logging
    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    filename='technology_mapping.log'  # Log to a file
)

logger = logging.getLogger(__name__)

# Add custom logging formatter for evaluation method tracking
class EvalMethodFormatter(logging.Formatter):
    eval_colors = {
        'KEYWORD': '\033[32m',  # Green
        'SEMANTIC': '\033[33m',  # Yellow
        'LLM': '\033[36m',      # Cyan
        'RESET': '\033[0m'
    }
    
    def format(self, record):
        if hasattr(record, 'eval_method'):
            record.msg = f"{self.eval_colors[record.eval_method]}{record.eval_method}{self.eval_colors['RESET']} - {record.msg}"
        return super().format(record)

# Add handler with custom formatter
handler = logging.StreamHandler()
handler.setFormatter(EvalMethodFormatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)

class LRUCache:
    """Least Recently Used (LRU) cache implementation"""
    def __init__(self, capacity: int):
        self.cache = OrderedDict()
        self.capacity = capacity

    def get(self, key: str) -> any:
        if key not in self.cache:
            return None
        self.cache.move_to_end(key)
        return self.cache[key]

    def put(self, key: str, value: any) -> None:
        if key in self.cache:
            self.cache.move_to_end(key)
        else:
            if len(self.cache) >= self.capacity:
                self.cache.popitem(last=False)
        self.cache[key] = value
    
    def clear(self):
        """Clear all items from cache"""
        self.cache.clear()

# Enums for type safety
class MatchMethod(Enum):
    KEYWORD = auto()
    SEMANTIC = auto()
    LLM = auto()
    NO_MATCH = auto()
    ERROR = auto()

class RelationType(Enum):
    PRIMARY = auto()
    SECONDARY = auto()
    RELATED = auto()
    NO_MATCH = auto()

# Data classes for structured results
@dataclass
class MatchResult:
    use_case_name: str
    agency: str
    abbreviation: str 
    category_name: str
    score: float
    method: MatchMethod
    relationship_type: RelationType
    confidence: float
    matched_terms: Optional[List[str]] = None
    justification: Optional[str] = None
    error: Optional[str] = None

def verify_environment():
    """Verify all required components are available"""
    required_components = {
        'neo4j': False,
        'sentence_transformers': False,
        'ollama': False,
        'memory': False
    }
    
    try:
        # Check Neo4j connection
        with GraphDatabase.driver(
            Config.NEO4J_URI, 
            auth=(Config.NEO4J_USER, Config.NEO4J_PASSWORD)
        ) as driver:
            driver.verify_connectivity()
            required_components['neo4j'] = True
            logger.info("✓ Neo4j connection verified")
    except Exception as e:
        logger.error(f"✗ Neo4j connection failed: {str(e)}")
    
    try:
        # Check sentence transformers
        model = SentenceTransformer('all-MiniLM-L6-v2')
        required_components['sentence_transformers'] = True
        logger.info("✓ Sentence transformers model loaded")
    except Exception as e:
        logger.error(f"✗ Sentence transformers failed: {str(e)}")
    
    try:
        # Check Ollama availability
        response = requests.get(
            Config.OLLAMA_API.replace('/generate', '/version'), 
            timeout=5
        )
        if response.status_code == 200:
            required_components['ollama'] = True
            logger.info("✓ Ollama service available")
    except Exception as e:
        logger.error(f"✗ Ollama service unavailable: {str(e)}")
    
    # Check available memory
    available_memory = psutil.virtual_memory().available / (1024 * 1024)
    required_components['memory'] = available_memory > Config.MEMORY_THRESHOLD_MB
    if required_components['memory']:
        logger.info(f"✓ Sufficient memory available ({available_memory:.0f}MB)")
    else:
        logger.error(f"✗ Insufficient memory: {available_memory:.0f}MB available, "
                    f"{Config.MEMORY_THRESHOLD_MB}MB required")
    
    # Return overall status
    if all(required_components.values()):
        logger.info("All components verified successfully!")
        return True
    else:
        failed = [k for k, v in required_components.items() if not v]
        logger.error(f"Verification failed for: {', '.join(failed)}")
        return False

# Run initial verification
print("Verifying environment setup...")
verify_environment()

# Block 2: Neo4j Interface

In [None]:
import csv
from pathlib import Path
from datetime import datetime

@dataclass
class Category:
    """Data structure for AI technology categories"""
    name: str
    definition: str
    maturity_level: str
    keywords: List[str]
    capabilities: List[str]
    combined_text: str
    keyword_count: int
    capability_count: int

@dataclass
class UseCase:
    """Data structure for use cases"""
    name: str
    agency: str
    abbreviation: str
    topic_area: str
    dev_stage: str
    purpose_benefits: str
    outputs: List[str]
    combined_text: str

class Neo4jInterface:
    """Enhanced Neo4j interface with connection pooling and caching"""
    
    def __init__(self, uri: str, user: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        self.logger = logging.getLogger('tech_mapper.neo4j')
        self._categories: Optional[Dict[str, Category]] = None
        self._use_cases: Optional[Dict[str, UseCase]] = None
        
        # Initialize connection pool
        self.driver.verify_connectivity()
        
        # Cache settings
        self._cache_timestamp = None
        self._cache_lifetime = 3600  # 1 hour cache lifetime
    
    def close(self):
        """Safely close Neo4j connection"""
        if self.driver:
            self.driver.close()
            self.driver = None
            
    def __enter__(self):
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
    
    @property
    def categories(self) -> Dict[str, Category]:
        """Cached access to AI technology categories"""
        if self._categories is None or self._cache_expired:
            self._categories = self._fetch_categories()
            self._cache_timestamp = time.time()
        return self._categories
    
    @property
    def use_cases(self) -> Dict[str, UseCase]:
        """Cached access to use cases"""
        if self._use_cases is None or self._cache_expired:
            self._use_cases = self._fetch_use_cases()
            self._cache_timestamp = time.time()
        return self._use_cases
    
    @property
    def _cache_expired(self) -> bool:
        """Check if cache needs refresh"""
        if self._cache_timestamp is None:
            return True
        return (time.time() - self._cache_timestamp) > self._cache_lifetime
    
    def _execute_query(self, query: str, parameters: Optional[Dict] = None) -> List[Dict]:
        """Execute Neo4j query with retry logic and error handling"""
        max_retries = 3
        retry_delay = 1
        
        for attempt in range(max_retries):
            try:
                with self.driver.session() as session:
                    result = session.run(query, parameters or {})
                    return [record.data() for record in result]
                    
            except Exception as e:
                if attempt == max_retries - 1:
                    self.logger.error(f"Query failed after {max_retries} attempts: {str(e)}")
                    raise
                    
                self.logger.warning(
                    f"Query attempt {attempt + 1} failed: {str(e)}. "
                    f"Retrying in {retry_delay}s..."
                )
                time.sleep(retry_delay)
                retry_delay *= 2  # Exponential backoff
    
    def _fetch_categories(self) -> Dict[str, Category]:
        """Fetch all AI technology categories with enhanced context"""
        query = """
        MATCH (c:AICategory)
        OPTIONAL MATCH (c)-[:TAGGED_WITH]->(k:Keyword)
        OPTIONAL MATCH (c)-[:HAS_CAPABILITY]->(cap:Capability)
        WITH c,
             collect(DISTINCT k.name) as keywords,
             collect(DISTINCT cap.name) as capabilities,
             c.definition + ' ' + 
             reduce(s = '', x IN collect(DISTINCT k.name) | s + ' ' + x) + ' ' +
             reduce(s = '', x IN collect(DISTINCT cap.name) | s + ' ' + x) as combined_text
        RETURN 
            c.name as name,
            c.definition as definition,
            c.maturity_level as maturity_level,
            keywords,
            capabilities,
            combined_text,
            size(keywords) as keyword_count,
            size(capabilities) as capability_count
        ORDER BY c.name
        """
        
        try:
            results = self._execute_query(query)
            categories = {}
            
            for row in results:
                category = Category(
                    name=row['name'],
                    definition=row['definition'],
                    maturity_level=row['maturity_level'],
                    keywords=row['keywords'],
                    capabilities=row['capabilities'],
                    combined_text=row['combined_text'],
                    keyword_count=row['keyword_count'],
                    capability_count=row['capability_count']
                )
                categories[category.name] = category
            
            self.logger.info(
                f"Loaded {len(categories)} categories with "
                f"{sum(c.keyword_count for c in categories.values())} total keywords"
            )
            return categories
            
        except Exception as e:
            self.logger.error(f"Failed to fetch categories: {str(e)}")
            raise
    
    def _fetch_use_cases(self) -> Dict[str, UseCase]:
        """Fetch all use cases with batch processing"""
        query = """
        MATCH (u:UseCase)
        OPTIONAL MATCH (u)-[:HAS_PURPOSE]->(p:PurposeBenefit)
        OPTIONAL MATCH (u)-[:PRODUCES]->(o:Output)
        OPTIONAL MATCH (u)<-[:HAS_USE_CASE]-(a:Agency)
        WITH u, a,
             collect(DISTINCT p.description) as purposes,
             collect(DISTINCT o.description) as outputs,
             u.purpose_benefits + ' ' +
             reduce(s = '', x IN collect(DISTINCT p.description) | s + ' ' + x) + ' ' +
             reduce(s = '', x IN collect(DISTINCT o.description) | s + ' ' + x) as combined_text
        RETURN 
            u.name as name,
            u.agency as agency,
            coalesce(a.abbreviation, '') as abbreviation,
            u.topic_area as topic_area,
            u.dev_stage as dev_stage,
            u.purpose_benefits as purpose_benefits,
            outputs,
            combined_text
        """
        
        try:
            results = self._execute_query(query)
            use_cases = {}
            
            for row in results:
                use_case = UseCase(
                    name=row['name'],
                    agency=row['agency'],
                    abbreviation=row['abbreviation'] or '',
                    topic_area=row['topic_area'],
                    dev_stage=row['dev_stage'],
                    purpose_benefits=row['purpose_benefits'] or '',
                    outputs=row['outputs'],
                    combined_text=row['combined_text']
                )
                use_cases[f"{use_case.name}|{use_case.agency}"] = use_case
            
            self.logger.info(f"Loaded {len(use_cases)} unmapped use cases")
            return use_cases
            
        except Exception as e:
            self.logger.error(f"Failed to fetch use cases: {str(e)}")
            raise
    
    import csv

    def save_match(self, match: MatchResult) -> bool:
        """Save match results to a single CSV file for the entire run"""
        try:
            # Create output directory if it doesn't exist
            output_dir = Path("output")
            output_dir.mkdir(exist_ok=True)
            
            # Use a consistent filename for the entire run
            run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_file = output_dir / f"technology_mapping_results_{run_timestamp}.csv"
            
            # Check if file exists to determine if we need to write headers
            file_exists = output_file.exists()
            
            # Open file in append mode
            with open(output_file, 'a', newline='') as csvfile:
                fieldnames = [
                    'use_case_name', 
                    'agency', 
                    'abbreviation',
                    'category_name', 
                    'keyword_score', 
                    'semantic_score', 
                    'llm_score',
                    'final_score', 
                    'match_method', 
                    'relationship_type', 
                    'confidence', 
                    'matched_keywords', 
                    'justification',
                    'error'
                ]
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                
                # Write headers if file is new
                if not file_exists:
                    writer.writeheader()
                
                # Convert match to dictionary with more detailed scoring
                match_dict = {
                    'use_case_name': match.use_case_name,
                    'agency': match.agency,
                    'abbreviation': match.abbreviation,
                    'category_name': match.category_name,
                    'keyword_score': match.score if match.method == MatchMethod.KEYWORD else 0.0,
                    'semantic_score': match.score if match.method == MatchMethod.SEMANTIC else 0.0,
                    'llm_score': match.score if match.method == MatchMethod.LLM else 0.0,
                    'final_score': match.score,
                    'match_method': match.method.name,
                    'relationship_type': match.relationship_type.name,
                    'confidence': match.confidence,
                    'matched_keywords': ', '.join(match.matched_terms) if match.matched_terms else '',
                    'justification': match.justification,
                    'error': match.error if hasattr(match, 'error') else ''
                }
                
                # Write the match
                writer.writerow(match_dict)
            
            return True
        
        except Exception as e:
            self.logger.error(f"Failed to save match to CSV: {str(e)}")
            return False
            
    def get_match_statistics(self) -> Dict:
        """Get statistics about current category matches"""
        query = """
        MATCH (u:UseCase)
        RETURN count(DISTINCT u) as total_use_cases
        """
        
        try:
            result = self._execute_query(query)[0]
            
            return {
                'total_use_cases': result['total_use_cases'],
                'total_matches': 0,
                'matched_categories': 0,
                'match_methods': {},
                'relationship_types': {},
                'completion_rate': 0
            }
            
        except Exception as e:
            self.logger.error(f"Failed to get match statistics: {str(e)}")
            return {}

# Example usage
if __name__ == "__main__":
    # Test Neo4j interface
    with Neo4jInterface(
        Config.NEO4J_URI,
        Config.NEO4J_USER,
        Config.NEO4J_PASSWORD
    ) as db:
        # Print some basic stats
        print(f"Categories loaded: {len(db.categories)}")
        print(f"Use cases loaded: {len(db.use_cases)}")
        
        # Print match statistics
        stats = db.get_match_statistics()
        print("\nMatch Statistics:")
        print(f"Completion Rate: {stats['completion_rate']:.1%}")
        print("\nMatch Methods:")
        for method, count in stats['match_methods'].items():
            print(f"- {method}: {count}")

# Block 3: Enhanced text processing and matching with memory management

In [None]:
class MemoryManager:
    """Manages memory usage and garbage collection"""
    
    def __init__(self, threshold_mb: int = 1000):
        self.threshold_mb = threshold_mb
        self.last_check = time.time()
        self.check_interval = 60  # Check every minute
        self.logger = logging.getLogger('tech_mapper.memory')
        
    def check_memory(self) -> bool:
        """Check memory usage and clean up if needed"""
        current_time = time.time()
        if current_time - self.last_check < self.check_interval:
            return True
            
        try:
            memory_used = psutil.Process().memory_info().rss / (1024 * 1024)
            self.last_check = current_time
            
            if memory_used > self.threshold_mb:
                self.logger.warning(
                    f"Memory usage ({memory_used:.0f}MB) exceeded threshold "
                    f"({self.threshold_mb}MB). Running garbage collection."
                )
                gc.collect()
                
                # Check if gc helped
                memory_after = psutil.Process().memory_info().rss / (1024 * 1024)
                if memory_after > self.threshold_mb:
                    self.logger.error(
                        f"Memory usage still high ({memory_after:.0f}MB) "
                        "after garbage collection!"
                    )
                    return False
                    
                self.logger.info(
                    f"Memory reduced from {memory_used:.0f}MB to {memory_after:.0f}MB"
                )
                
            return True
            
        except Exception as e:
            self.logger.error(f"Memory check failed: {str(e)}")
            return False

class TextProcessor:
    """Handles text processing and matching with memory management"""
    
    def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
        self.model = SentenceTransformer(model_name)
        self.memory_manager = MemoryManager(Config.MEMORY_THRESHOLD_MB)
        self.embedding_cache = {}
        self.cache_hits = 0
        self.cache_misses = 0
        self.logger = logging.getLogger('tech_mapper.text')
        
        # Pre-compile regex patterns
        self.cleanup_pattern = re.compile(r'\s+')
        self.split_pattern = re.compile(r'[;,\n]')
    
    def cleanup_text(self, text: str) -> str:
        """Clean and normalize text"""
        if not text:
            return ""
        
        # Convert to lowercase and normalize whitespace
        text = text.lower()
        text = self.cleanup_pattern.sub(' ', text)
        return text.strip()
    
    @lru_cache(maxsize=10000)
    
    def get_keyword_matches(self, text: str, keywords: Tuple[str], threshold: float = 0.3) -> Tuple[float, Set[str]]:
        if not text or not keywords:
            return 0.0, set()
            
        text = self.cleanup_text(text)
        matches = set()
        matched_details = {}
        
        # Track partial matches for more nuanced scoring
        partial_matches = 0
        
        for keyword in keywords:
            keyword = keyword.lower()
            
            # Exact match
            if keyword in text:
                matches.add(keyword)
                matched_details[keyword] = 'exact'
                continue
            
            # Check for partial matches
            keyword_terms = set(keyword.split())
            
            # If keyword has multiple terms, check for partial match
            if len(keyword_terms) > 1:
                matched_terms = [term for term in keyword_terms if term in text]
                
                # Require at least half the terms to match
                if len(matched_terms) >= len(keyword_terms) / 2:
                    matches.add(keyword)
                    partial_matches += len(matched_terms) / len(keyword_terms)
                    matched_details[keyword] = f'partial: {matched_terms}'
        
        # Log matching details
        if matched_details:
            self.logger.info(f"Keyword Matches: {matched_details}")
        
        # Modified scoring to be more lenient
        total_matches = len(matches) + (partial_matches * 0.7)  # Increase partial match weight
        score = min(1.0, total_matches / (len(keywords) * 0.5))  # Only require matching half of keywords for max score
        
        return score, matches
        
    def get_embedding(self, text: str) -> np.ndarray:
        """Get or compute text embedding with memory management"""
        if not text:
            return self.model.encode("")
        
        # Check cache
        text = self.cleanup_text(text)
        if text in self.embedding_cache:
            self.cache_hits += 1
            return self.embedding_cache[text]
        
        # Check memory before computing new embedding
        if not self.memory_manager.check_memory():
            self.logger.warning("Memory threshold exceeded, clearing embedding cache")
            self.embedding_cache.clear()
            gc.collect()
        
        # Compute new embedding
        self.cache_misses += 1
        try:
            embedding = self.model.encode(text, convert_to_tensor=True)
            self.embedding_cache[text] = embedding
            return embedding
            
        except Exception as e:
            self.logger.error(f"Failed to compute embedding: {str(e)}")
            return self.model.encode("")  # Return empty embedding on error
    
    def get_semantic_similarity(
        self, 
        text1: str, 
        text2: str,
        threshold: float = 0.75
    ) -> float:
        """Calculate semantic similarity between texts"""
        try:
            # Log semantic similarity computation details
            self.logger.info("Computing Semantic Similarity")
            self.logger.debug(f"Text 1: {text1}")
            self.logger.debug(f"Text 2: {text2}")
            self.logger.debug(f"Threshold: {threshold}")
            
            # Get embeddings
            embedding1 = self.get_embedding(text1)
            embedding2 = self.get_embedding(text2)
            
            # Calculate similarity
            similarity = float(
                util.pytorch_cos_sim(embedding1, embedding2)[0][0]
            )
            
            # Log strong matches
            if similarity >= threshold:
                self.logger.info(
                    f"Strong semantic match ({similarity:.2f})",
                    extra={'method': 'SEMANTIC'}
                )
            
            return similarity
            
        except Exception as e:
            self.logger.error(f"Semantic similarity failed: {str(e)}")
            return 0.0
    
    def evaluate_match(
        self, 
        use_case: UseCase, 
        category: Category,
        thresholds: Config.Thresholds = Config.Thresholds
    ) -> MatchResult:
        try:
            # Step 1: Quick keyword matching
            keyword_score, matched_terms = self.get_keyword_matches(
                use_case.combined_text,
                tuple(category.keywords + category.capabilities),
                threshold=thresholds.KEYWORD
            )
            
            if keyword_score >= thresholds.KEYWORD:
                return MatchResult(
                    use_case_name=use_case.name,
                    agency=use_case.agency,
                    abbreviation=use_case.abbreviation or '',
                    category_name=category.name,
                    score=keyword_score,
                    method=MatchMethod.KEYWORD,
                    relationship_type=self._get_relationship_type(keyword_score),
                    confidence=keyword_score,
                    matched_terms=list(matched_terms),
                    justification=f"Keyword match score: {keyword_score:.2f}"
                )
            
            # Step 2: Try semantic matching if keywords weren't sufficient
            semantic_score = self.get_semantic_similarity(
                use_case.combined_text,
                category.combined_text,
                threshold=thresholds.SEMANTIC
            )
            
            if semantic_score >= thresholds.SEMANTIC:
                return MatchResult(
                    use_case_name=use_case.name,
                    agency=use_case.agency,
                    abbreviation=use_case.abbreviation or '',
                    category_name=category.name,
                    score=semantic_score,
                    method=MatchMethod.SEMANTIC,
                    relationship_type=self._get_relationship_type(semantic_score),
                    confidence=semantic_score * 0.9,
                    matched_terms=list(matched_terms) if matched_terms else None,
                    justification=f"Semantic match score: {semantic_score:.2f}"
                )
            
            # Step 3: Potentially trigger LLM evaluation
            best_score = max(keyword_score, semantic_score)
            
            # Trigger LLM if the score is above the RELATED threshold
            if best_score > thresholds.RELATED:
                # Log that we're considering LLM evaluation
                self.logger.info(
                    f"Potential match detected for use case {use_case.name} "
                    f"in category {category.name}. Best score: {best_score:.2f}"
                )
                
                return MatchResult(
                    use_case_name=use_case.name,
                    agency=use_case.agency,
                    abbreviation=use_case.abbreviation or '',
                    category_name=category.name,
                    score=best_score,
                    method=MatchMethod.NO_MATCH,
                    relationship_type=RelationType.RELATED,
                    confidence=best_score,
                    matched_terms=list(matched_terms) if matched_terms else None,
                    justification="Potential match - recommended for LLM validation"
                )
            
            # No significant match found
            return MatchResult(
                use_case_name=use_case.name,
                agency=use_case.agency,
                abbreviation=use_case.abbreviation or '',
                category_name=category.name,
                score=best_score,
                method=MatchMethod.NO_MATCH,
                relationship_type=RelationType.NO_MATCH,
                confidence=0.0,
                justification="No significant match found"
            )
        
        except Exception as e:
            error_msg = f"Evaluation failed: {str(e)}"
            self.logger.error(error_msg)
            return MatchResult(
                use_case_name=use_case.name,
                agency=use_case.agency,
                abbreviation=use_case.abbreviation or '',
                category_name=category.name,
                score=0.0,
                method=MatchMethod.ERROR,
                relationship_type=RelationType.NO_MATCH,
                confidence=0.0,
                error=error_msg
            )
        
        
    def _get_relationship_type(self, score: float) -> RelationType:
        """Determine relationship type based on score"""
        if score >= Config.Thresholds.PRIMARY:
            return RelationType.PRIMARY
        elif score >= Config.Thresholds.SECONDARY:
            return RelationType.SECONDARY
        elif score >= Config.Thresholds.RELATED:
            return RelationType.RELATED
        return RelationType.NO_MATCH
    
    def get_cache_stats(self) -> Dict:
        """Get cache performance statistics"""
        total_requests = self.cache_hits + self.cache_misses
        if total_requests == 0:
            return {
                'cache_hits': 0,
                'cache_misses': 0,
                'hit_rate': 0,
                'cache_size': len(self.embedding_cache)
            }
            
        return {
            'cache_hits': self.cache_hits,
            'cache_misses': self.cache_misses,
            'hit_rate': self.cache_hits / total_requests * 100,
            'cache_size': len(self.embedding_cache)
        }

# Example usage
if __name__ == "__main__":
    processor = TextProcessor()
    
    # Test text processing
    text1 = "AI-powered data analytics for predictive maintenance"
    text2 = "Machine learning models for equipment failure prediction"
    
    # Get similarity
    similarity = processor.get_semantic_similarity(text1, text2)
    print(f"Semantic similarity: {similarity:.2f}")
    
    # Test memory management
    print("\nMemory usage:")
    for _ in range(100):
        processor.get_embedding(f"Test text {_}" * 100)
        if _ % 10 == 0:
            stats = processor.get_cache_stats()
            print(f"Cache size: {stats['cache_size']}, Hit rate: {stats['hit_rate']:.1f}%")

# Block 4: Enhanced LLM integration with robust error handling and response parsing

In [None]:
import aiohttp
import backoff
from typing import Optional, Dict, Any
import json
import asyncio
from dataclasses import asdict

@dataclass
class LLMResponse:
    """Structured LLM response"""
    is_match: bool
    confidence: float
    primary_category: bool
    justification: str
    suggested_categories: List[str]
    error: Optional[str] = None

class LLMInterface:
    """Enhanced LLM interface with retry logic and response validation"""
    
    def __init__(
        self, 
        api_url: str, 
        model: str,
        timeout: int = 30,
        max_retries: int = 2
    ):
        self.api_url = api_url
        self.model = model
        self.timeout = aiohttp.ClientTimeout(total=timeout)
        self.max_retries = max_retries
        self.logger = logging.getLogger('tech_mapper.llm')
        
        # Template management
        self._load_templates()
        
        # Session management
        self._session: Optional[aiohttp.ClientSession] = None
        self._results_cache = {}
    
    def _load_templates(self):
        """Load prompt templates"""
        self.templates = {
            'category_match': """You are a strict JSON-only response system performing AI technology category matching. 
            Evaluate if this use case matches the category based on the provided information and previous evaluations.
            
            Format your response as a JSON object with ONLY these fields:
            {
                "is_match": boolean,
                "confidence": float between 0 and 1,
                "primary_category": boolean,
                "justification": "string explanation",
                "suggested_categories": ["Category1", "Category2", ...]
            }
            
            Previous Evaluation Methods:
            - Keyword Match Score: {keyword_score:.2f}
            - Semantic Match Score: {semantic_score:.2f}
            - Keywords Found: {matched_keywords}
            
            Category to Evaluate:
            Name: {category_name}
            Definition: {category_definition}
            Primary Capabilities: {capabilities}
            
            Use Case:
            Name: {use_case_name}
            Agency: {agency}
            Purpose: {purpose}
            Outputs: {outputs}
            
            Evaluate if this use case matches this category and return ONLY the required JSON object.
            """
        }
    
    async def __aenter__(self):
        """Async context manager entry"""
        self._session = aiohttp.ClientSession(timeout=self.timeout)
        return self
    
    async def __aexit__(self, exc_type, exc, tb):
        """Async context manager exit"""
        if self._session:
            await self._session.close()
            self._session = None
    
    def _build_prompt(
        self, 
        use_case: UseCase, 
        category: Category, 
        prev_scores: Dict[str, Any]
    ) -> str:
        """Build prompt from template"""
        # Safely handle matched terms
        matched_terms = prev_scores.get('matched_terms', [])
        matched_terms_str = ', '.join(matched_terms) if matched_terms else 'None'
    
        return f"""You are a strict JSON-only response system performing AI technology category matching. 
        Carefully evaluate if this use case matches the category based on the provided information.
    
        IMPORTANT: Your ENTIRE response must be a valid JSON object with these EXACT fields:
        {{
            "is_match": true or false,
            "confidence": a float between 0 and 1,
            "primary_category": true or false,
            "justification": "a clear explanation of your reasoning",
            "suggested_categories": ["Category1", "Category2"]
        }}
    
        Previous Evaluation Methods:
        - Keyword Match Score: {prev_scores.get('keyword_score', 0):.2f}
        - Semantic Match Score: {prev_scores.get('semantic_score', 0):.2f}
        - Keywords Found: {matched_terms_str}
        
        Category to Evaluate:
        Name: {category.name}
        Definition: {category.definition}
        Primary Capabilities: {', '.join(category.capabilities[:5])}
        
        Use Case:
        Name: {use_case.name}
        Agency: {use_case.agency}
        Purpose: {use_case.purpose_benefits[:500]}
        Outputs: {', '.join(use_case.outputs[:5])}
        
        Evaluate carefully and return ONLY the required JSON object."""

    
    def _get_cache_key(
        self, 
        use_case: UseCase, 
        category: Category,
        scores: Dict
    ) -> str:
        """Create cache key for results"""
        key_parts = [
            f"uc:{use_case.name}",
            f"ag:{use_case.agency}",
            f"cat:{category.name}",
            f"scores:{scores.get('keyword_score', 0):.2f}|{scores.get('semantic_score', 0):.2f}"
        ]
        return "|".join(key_parts)
    
    @backoff.on_exception(
        backoff.expo,
        (aiohttp.ClientError, asyncio.TimeoutError),
        max_tries=3
    )
    async def _make_request(
        self,
        prompt: str,
        temperature: float = 0.1
    ) -> Dict:
        """Make request to LLM API with retries"""
        if not self._session:
            raise RuntimeError("Session not initialized - use async context manager")
            
        request_data = {
            "model": self.model,
            "prompt": prompt,
            "temperature": temperature,
            "top_p": 0.1,
            "top_k": 10,
            "num_predict": 100,
            "stream": False
        }
        
        async with self._session.post(
            self.api_url,
            json=request_data
        ) as response:
            response.raise_for_status()
            return await response.json()
    
    def _extract_json(self, text: str) -> Optional[Dict]:
        """Extract and validate JSON from LLM response"""
        try:
            # Clean up the text to remove potential leading/trailing whitespace or newlines
            text = text.strip()
            
            # Find JSON object bounds with more flexible parsing
            start = text.find('{')
            end = text.rfind('}') + 1
            
            if start == -1 or end == 0:
                # Try to handle potentially malformed JSON
                try:
                    # Attempt to parse the entire text as JSON
                    data = json.loads(text)
                except json.JSONDecodeError:
                    raise ValueError("No valid JSON object found in response")
            else:
                # Extract and parse JSON substring
                json_str = text[start:end]
                try:
                    data = json.loads(json_str)
                except json.JSONDecodeError:
                    # If substring parsing fails, try parsing the entire text
                    try:
                        data = json.loads(text)
                    except json.JSONDecodeError:
                        raise ValueError("Failed to parse JSON from response")
        
            # Validate required fields
            required_fields = {
                'is_match': bool,
                'confidence': float,
                'primary_category': bool,
                'justification': str,
                'suggested_categories': list
            }
            
            for field, field_type in required_fields.items():
                if field not in data:
                    raise ValueError(f"Missing required field: {field}")
                if not isinstance(data[field], field_type):
                    raise ValueError(
                        f"Invalid type for {field}: "
                        f"expected {field_type}, got {type(data[field])}"
                    )
        
            # Validate confidence range
            if not 0 <= data['confidence'] <= 1:
                raise ValueError(
                    f"Confidence out of range: {data['confidence']}"
                )
            
            return data
            
        except json.JSONDecodeError as e:
            self.logger.error(f"JSON parsing failed: {str(e)}\nText: {text}")
            return None
        except Exception as e:
            self.logger.error(f"JSON extraction failed: {str(e)}\nText: {text}")
            return None
    
    def _create_fallback_response(
        self,
        error_msg: str,
        category: Category
    ) -> LLMResponse:
        """Create fallback response for errors"""
        return LLMResponse(
            is_match=False,
            confidence=0.0,
            primary_category=False,
            justification=f"Error occurred: {error_msg}",
            suggested_categories=[category.name],
            error=error_msg
        )
    
    async def evaluate_match(
        self,
        use_case: UseCase,
        category: Category,
        prev_scores: Dict[str, Any]
    ) -> LLMResponse:
        """Evaluate category match using LLM"""

        self.logger.info(f"LLM Evaluation Started")
        self.logger.info(f"Use Case: {use_case.name}")
        self.logger.info(f"Category: {category.name}")
        self.logger.info(f"Previous Scores: {prev_scores}")
        
        # Check cache first
        cache_key = self._get_cache_key(use_case, category, prev_scores)
        if cache_key in self._results_cache:
            return self._results_cache[cache_key]
        
        try:
            # Build and send prompt
            prompt = self._build_prompt(use_case, category, prev_scores)

            # Log the full prompt for inspection
            self.logger.debug(f"Generated Prompt:\n{prompt}")
            
            response = await self._make_request(prompt)

            # Log the raw response
            self.logger.info(f"Raw LLM Response: {response}")
            
            if 'response' not in response:
                raise ValueError("No response field in LLM output")
                
            # Extract and validate JSON
            result = self._extract_json(response['response'])

            # Log the extracted result
            self.logger.info(f"Extracted Result: {result}")
        
            if not result:
                raise ValueError("Failed to extract valid JSON from response")
            
            # Create structured response
            llm_response = LLMResponse(
                is_match=result['is_match'],
                confidence=result['confidence'],
                primary_category=result['primary_category'],
                justification=result['justification'],
                suggested_categories=result['suggested_categories']
            )
            
            # Cache result
            self._results_cache[cache_key] = llm_response
            
            # Log strong matches
            if llm_response.is_match and llm_response.confidence >= 0.8:
                self.logger.info(
                    f"LLM confirmed match with {llm_response.confidence:.2f} confidence",
                    extra={'method': 'LLM'}
                )
            
            return llm_response
            
        except Exception as e:
            error_msg = f"LLM evaluation failed: {str(e)}"
            self.logger.error(error_msg)
            return self._create_fallback_response(error_msg, category)
    
    def get_cache_stats(self) -> Dict:
        """Get cache statistics"""
        return {
            'cache_size': len(self._results_cache),
            'cached_evaluations': len(self._results_cache)
        }

# Example usage
async def test_llm():
    async with LLMInterface(
        Config.OLLAMA_API,
        Config.OLLAMA_MODEL
    ) as llm:
        # Create test data
        use_case = UseCase(
            name="AI-powered predictive maintenance",
            agency="TEST",
            abbreviation="TEST",
            topic_area="Technology",
            dev_stage="Production",
            purpose_benefits="Predict equipment failures using machine learning",
            outputs=["Failure predictions", "Maintenance schedules"],
            combined_text="AI-powered predictive maintenance system..."
        )
        
        category = Category(
            name="Predictive & Pattern Analytics",
            definition="Advanced analytical and predictive systems",
            maturity_level="Mature",
            keywords=["predictive analytics", "machine learning"],
            capabilities=["failure prediction", "pattern detection"],
            combined_text="Predictive analytics and pattern detection...",
            keyword_count=2,
            capability_count=2
        )
        
        # Test evaluation
        prev_scores = {
            'keyword_score': 0.7,
            'semantic_score': 0.8,
            'matched_terms': ['predictive', 'analytics']
        }
        
        result = await llm.evaluate_match(use_case, category, prev_scores)
        print(f"Match Result: {asdict(result)}")

if __name__ == "__main__":
    asyncio.run(test_llm())

# Block 5: Processing

In [None]:
import asyncio
import json
import logging
import time
from typing import Dict, List, Optional, Set, Tuple
from tqdm.notebook import tqdm
from pathlib import Path
from datetime import datetime
import concurrent.futures
import csv

# Ensure these imports match your existing imports
import pandas as pd
import numpy as np
from neo4j import GraphDatabase
from sentence_transformers import SentenceTransformer, util
import requests
import psutil
import gc

# Import your existing classes and configurations
from typing import List, Dict, Set, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from enum import Enum, auto

class TechnologyMappingProcessor:
    def __init__(
        self, 
        sample_size: Optional[int] = None,
        save_intermediates: bool = True
    ):
        # Configuration
        self.neo4j_uri = Config.NEO4J_URI
        self.neo4j_user = Config.NEO4J_USER
        self.neo4j_password = Config.NEO4J_PASSWORD
        
        self.sample_size = sample_size if sample_size is not None else Config.SAMPLE_SIZE
        self.batch_size = Config.BATCH_SIZE
        
        # Logging
        self.logger = logging.getLogger('tech_mapper.processor')
        
        # Components
        self.db = None
        self.text_processor = None
        self.llm = None
        
        # Results tracking
        self.start_time = None
        self.stats = {}
        self.results = []
        self.save_intermediates = save_intermediates
        
        # Output directory
        self.output_dir = Path("output")
        self.output_dir.mkdir(exist_ok=True)

    def initialize(self) -> bool:
        """Initialize processing components"""
        try:
            # Create database connection
            self.db = Neo4jInterface(
                self.neo4j_uri,
                self.neo4j_user,
                self.neo4j_password
            )
            
            # Initialize text processor
            self.text_processor = TextProcessor()
            
            # Initialize LLM interface
            self.llm = LLMInterface(
                Config.OLLAMA_API,
                Config.OLLAMA_MODEL
            )
            
            # Verify environment
            if not verify_environment():
                raise RuntimeError("Environment verification failed")
            
            return True
        
        except Exception as e:
            self.logger.error(f"Initialization failed: {str(e)}")
            return False

    def _group_categories(self, categories: Dict[str, Category]) -> Dict[str, Set[str]]:
        """Group similar categories to optimize matching"""
        category_groups = {}
        
        for cat_name, category in categories.items():
            keywords = set(category.keywords)
            
            # Find groups with overlapping keywords
            matching_groups = [
                group_id for group_id, group_keywords in category_groups.items()
                if len(keywords & group_keywords) / len(keywords) > 0.3
            ]
            
            if matching_groups:
                # Add to existing group with most overlap
                best_group = max(
                    matching_groups,
                    key=lambda g: len(keywords & category_groups[g])
                )
                category_groups[best_group].update(keywords)
            else:
                # Create new group
                group_id = f"group_{len(category_groups)}"
                category_groups[group_id] = keywords
        
        return category_groups

    async def process(self) -> Dict:
        """Main processing pipeline"""
        if not self.initialize():
            return {'error': 'Initialization failed'}
        
        try:
            self.start_time = time.time()
            
            # Get use cases
            all_use_cases = list(self.db.use_cases.values())
            
            # Apply sample size if specified
            if self.sample_size and self.sample_size < len(all_use_cases):
                use_cases = all_use_cases[:self.sample_size]
                self.logger.info(f"Processing sample of {len(use_cases)} use cases")
            else:
                use_cases = all_use_cases
                self.logger.info(f"Processing ALL {len(use_cases)} use cases")
            
            # Group categories for optimization
            self._group_categories(self.db.categories)
            
            # Process use cases
            results = []
            for use_case in tqdm(use_cases, desc="Processing Use Cases"):
                case_results = await self._process_use_case(use_case)
                results.extend(case_results)
            
            # Calculate statistics
            self.stats = self._calculate_statistics(results)
            
            # Save results
            if self.save_intermediates:
                self._save_detailed_results(results)
            
            # Generate summary report
            self.generate_summary_report()
            
            return self.stats
        
        except Exception as e:
            self.logger.error(f"Processing failed: {str(e)}")
            return {'error': str(e)}
        
    async def _process_use_case(self, use_case: UseCase) -> List[MatchResult]:
        """Process a single use case"""
        results = []
        
        for category in self.db.categories.values():
            try:
                # Perform matching
                match_result = self.text_processor.evaluate_match(
                    use_case, 
                    category
                )
                
                # If no strong match found, try LLM
                if match_result.method == MatchMethod.NO_MATCH and match_result.score > Config.Thresholds.RELATED:
                    try:
                        llm_result = await self.llm.evaluate_match(
                            use_case,
                            category,
                            {
                                'keyword_score': match_result.score,
                                'semantic_score': match_result.score,
                                'matched_terms': match_result.matched_terms or []
                            }
                        )
                        
                        if llm_result.is_match:
                            llm_match = MatchResult(
                                use_case_name=use_case.name,
                                agency=use_case.agency,
                                abbreviation=use_case.abbreviation or '',
                                category_name=category.name,
                                score=llm_result.confidence,
                                method=MatchMethod.LLM,
                                relationship_type=self.text_processor._get_relationship_type(llm_result.confidence),
                                confidence=llm_result.confidence,
                                justification=llm_result.justification
                            )
                            results.append(llm_match)
                            self.db.save_match(llm_match)
                    except Exception as llm_error:
                        self.logger.error(f"LLM evaluation error: {str(llm_error)}")
                
                # Save initial match if significant
                if match_result.method != MatchMethod.NO_MATCH:
                    results.append(match_result)
                    self.db.save_match(match_result)
            
            except Exception as e:
                # Log any errors during processing
                self.logger.error(f"Error processing use case {use_case.name} for category {category.name}: {str(e)}")
        
        return results
               
    
    def _calculate_statistics(self, results: List[MatchResult]) -> Dict:
        """Calculate processing statistics"""
        method_distribution = {
            method.name: len([r for r in results if r.method == method])
            for method in MatchMethod
        }
        
        return {
            'total_processed': len(results),
            'method_distribution': method_distribution,
            'total_duration': time.time() - self.start_time,
            'success_rate': len(results) / len(self.db.use_cases)
        }
    
    def _save_detailed_results(self, results: List[MatchResult]):
        """Save detailed results to a timestamped CSV"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = self.output_dir / f"technology_mapping_results_{timestamp}.csv"
        
        # Use the existing save_match method to write results
        with open(output_file, 'w', newline='') as csvfile:
            for result in results:
                self.db.save_match(result)
        
        self.logger.info(f"Detailed results saved to {output_file}")

    def generate_summary_report(self):
        """Generate a comprehensive summary report"""
        report = {
            'timestamp': datetime.now().isoformat(),
            'total_use_cases': len(self.db.use_cases),
            'processed_use_cases': self.stats.get('total_processed', 0),
            'processing_duration': self.stats.get('total_duration', 0),
            'method_distribution': self.stats.get('method_distribution', {}),
            'success_rate': self.stats.get('success_rate', 0),
        }
        
        summary_file = self.output_dir / f"mapping_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(summary_file, 'w') as f:
            json.dump(report, f, indent=2)
        
        return report

# Async runner function
async def run_technology_mapping(sample_size: Optional[int] = None):
    """Run the technology mapping process"""
    try:
        # Initialize and run the processor
        processor = TechnologyMappingProcessor(
            sample_size=sample_size,
            save_intermediates=True
        )
        
        # Run the mapping process
        stats = await processor.process()
        
        # Display results
        print("\n--- Technology Mapping Results ---")
        print(json.dumps(stats, indent=2))
        
        return stats
    
    except Exception as e:
        logging.error(f"Technology mapping failed: {str(e)}")
        return {'error': str(e)}

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename='technology_mapping.log'
)

# Execution for Jupyter notebook
if __name__ == "__main__":
    # Run the mapping with sample size from Config
    results = await run_technology_mapping(sample_size=Config.SAMPLE_SIZE)

# For direct execution in Jupyter
await run_technology_mapping(sample_size=Config.SAMPLE_SIZE)