In [None]:
"""
Neural Code Search & Contrastive Learning Examples
For Nokia AI FES Working Student Position

This file demonstrates key technical implementations relevant to:
- Neural code search using embedding models
- Contrastive learning for code-requirement alignment  
- RAG integration for context-aware responses
- Evaluation metrics for retrieval systems

Author: Mahesh Sadupalli
Project: RAG-assisted LLM-based Verification
"""

from sentence_transformers import SentenceTransformer
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Tuple


# =============================================================================
# 1. Neural Code Search Implementation
# =============================================================================

class NeuralCodeSearch:
    """
    Neural code search system using embedding models to match 
    requirements to relevant code files.
    
    Key Features:
    - Semantic embedding of code and requirements
    - Cosine similarity for ranking
    - Configurable similarity thresholds
    """
    
    def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)
        self.code_embeddings = None
        self.code_files = []
        self.similarity_threshold = 0.3
    
    def index_codebase(self, code_files: List[Dict]) -> None:
        """Create embeddings for entire codebase"""
        self.code_files = code_files
        code_texts = [self._extract_code_features(cf) for cf in code_files]
        self.code_embeddings = self.model.encode(code_texts)
        print(f"Indexed {len(code_files)} code files")
    
    def search(self, requirement: str, top_k: int = 5) -> List[Dict]:
        """Find most relevant code files for given requirement"""
        if self.code_embeddings is None:
            raise ValueError("Codebase not indexed. Call index_codebase() first.")
        
        # Encode requirement
        req_embedding = self.model.encode([requirement])
        
        # Compute similarities
        similarities = cosine_similarity(req_embedding, self.code_embeddings)[0]
        
        # Get top-k results above threshold
        valid_indices = np.where(similarities >= self.similarity_threshold)[0]
        top_indices = valid_indices[np.argsort(similarities[valid_indices])[::-1][:top_k]]
        
        results = []
        for idx in top_indices:
            results.append({
                'file': self.code_files[idx]['name'],
                'similarity': float(similarities[idx]),
                'description': self.code_files[idx].get('description', ''),
                'functions': self.code_files[idx].get('functions', [])
            })
        
        return results
    
    def _extract_code_features(self, code_file: Dict) -> str:
        """Extract meaningful text features from code file for embedding"""
        features = []
        features.append(code_file.get('description', ''))
        features.append(code_file.get('comments', ''))
        features.append(' '.join(code_file.get('function_names', [])))
        features.append(code_file.get('module_purpose', ''))
        return ' '.join(filter(None, features))


# =============================================================================
# 2. Contrastive Learning for Code-Requirement Alignment
# =============================================================================

class ContrastiveCodeEncoder(nn.Module):
    """
    Neural encoder with contrastive learning for aligning code and requirements.
    
    Uses a projection head to map embeddings to a shared semantic space
    where similar code-requirement pairs are close together.
    """
    
    def __init__(self, base_model_name: str, embedding_dim: int = 768, projection_dim: int = 256):
        super().__init__()
        self.base_model = SentenceTransformer(base_model_name)
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, texts: List[str]) -> torch.Tensor:
        """Encode texts through base model + projection head"""
        with torch.no_grad():
            embeddings = self.base_model.encode(texts, convert_to_tensor=True)
        
        projected = self.projection(self.dropout(embeddings))
        return F.normalize(projected, p=2, dim=1)


def contrastive_loss(anchor: torch.Tensor, positive: torch.Tensor, 
                    negative: torch.Tensor, temperature: float = 0.1) -> torch.Tensor:
    """
    Contrastive loss function for training code-requirement alignments.
    
    Args:
        anchor: Requirement embeddings
        positive: Matching code embeddings  
        negative: Non-matching code embeddings
        temperature: Scaling factor for similarity
    
    Returns:
        Contrastive loss value
    """
    # Compute similarities
    pos_sim = torch.sum(anchor * positive, dim=1) / temperature
    neg_sim = torch.sum(anchor * negative, dim=1) / temperature
    
    # InfoNCE-style contrastive loss
    logits = torch.stack([pos_sim, neg_sim], dim=1)
    labels = torch.zeros(anchor.size(0), dtype=torch.long, device=anchor.device)
    
    loss = F.cross_entropy(logits, labels)
    return loss


def train_contrastive_model(model: ContrastiveCodeEncoder, 
                          train_data: List[Tuple], 
                          epochs: int = 10, 
                          learning_rate: float = 1e-4) -> None:
    """
    Train contrastive model on requirement-code pairs.
    
    Args:
        model: ContrastiveCodeEncoder to train
        train_data: List of (requirement, positive_code, negative_code) tuples
        epochs: Number of training epochs
        learning_rate: Learning rate for optimizer
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (requirements, pos_codes, neg_codes) in enumerate(train_data):
            
            # Get embeddings
            req_embeds = model(requirements)
            pos_embeds = model(pos_codes)
            neg_embeds = model(neg_codes)
            
            # Compute loss
            loss = contrastive_loss(req_embeds, pos_embeds, neg_embeds)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_data)
        print(f"Epoch {epoch+1}/{epochs}: Average Loss = {avg_loss:.4f}")


# =============================================================================
# 3. RAG Integration with Code Context
# =============================================================================

class CodeRAGSystem:
    """
    Retrieval-Augmented Generation system that combines code search
    with LLM generation for context-aware responses.
    """
    
    def __init__(self, code_search: NeuralCodeSearch, llm_client):
        self.code_search = code_search
        self.llm = llm_client
        self.max_context_length = 2000  # Token limit for context
        
    def query_with_context(self, requirement: str, max_files: int = 3) -> Dict:
        """
        RAG pipeline: retrieve relevant code + generate LLM response.
        
        Args:
            requirement: Natural language requirement/query
            max_files: Maximum number of code files to include in context
            
        Returns:
            Dictionary with response, retrieved files, and metadata
        """
        
        # Step 1: Retrieve relevant code files
        relevant_code = self.code_search.search(requirement, top_k=max_files)
        
        if not relevant_code:
            return {
                'response': "No relevant code found for this requirement.",
                'retrieved_files': [],
                'similarity_scores': [],
                'context_used': ""
            }
        
        # Step 2: Build context from retrieved code
        context = self._build_context(relevant_code, requirement)
        
        # Step 3: Generate LLM response with code context
        prompt = self._create_prompt(requirement, context)
        response = self._generate_response(prompt)
        
        return {
            'response': response,
            'retrieved_files': [r['file'] for r in relevant_code],
            'similarity_scores': [r['similarity'] for r in relevant_code],
            'context_used': context,
            'prompt_length': len(prompt.split())
        }
    
    def _build_context(self, relevant_code: List[Dict], requirement: str) -> str:
        """Build formatted context from retrieved code files"""
        context_parts = []
        
        for i, result in enumerate(relevant_code, 1):
            context_part = f"""
            [{i}] File: {result['file']}
            Similarity: {result['similarity']:.3f}
            Description: {result['description']}
            Key Functions: {', '.join(result.get('functions', []))}
            """
            context_parts.append(context_part.strip())
        
        return '\n\n'.join(context_parts)
    
    def _create_prompt(self, requirement: str, context: str) -> str:
        """Create structured prompt for LLM"""
        prompt = f"""
        You are a hardware verification expert. Based on the following relevant code files, 
        provide guidance for this verification requirement.

        REQUIREMENT: {requirement}

        RELEVANT CODE FILES:
        {context}

        Please provide:
        1. Specific implementation guidance
        2. Potential issues to consider  
        3. Best practices for this use case
        4. References to the most relevant files above

        Response:"""
        
        return prompt
    
    def _generate_response(self, prompt: str) -> str:
        """Generate response using LLM (placeholder for actual LLM call)"""
        # In real implementation, this would call OpenAI, Claude, etc.
        # For demo purposes, return a structured placeholder
        return """Based on the retrieved code files, I recommend:

        1. Implementation: Start with the highest similarity file as a template
        2. Testing: Ensure coverage of edge cases mentioned in the code comments  
        3. Integration: Check compatibility with existing verification framework
        4. Validation: Use assertion patterns from similar modules

        The retrieved files provide excellent starting points for your implementation."""


# =============================================================================
# 4. Evaluation Metrics for Code Search Systems
# =============================================================================

def calculate_mrr(ranked_results: List[List[bool]]) -> float:
    """
    Calculate Mean Reciprocal Rank for search results.
    
    Args:
        ranked_results: List of lists, where each inner list contains boolean
                       values indicating if the result at that rank is relevant
    
    Returns:
        Mean Reciprocal Rank score
    """
    reciprocal_ranks = []
    
    for results in ranked_results:
        reciprocal_rank = 0
        for rank, is_relevant in enumerate(results, 1):
            if is_relevant:
                reciprocal_rank = 1.0 / rank
                break
        reciprocal_ranks.append(reciprocal_rank)
    
    return np.mean(reciprocal_ranks)


def calculate_ndcg_at_k(ranked_results: List[List[float]], k: int = 5) -> float:
    """
    Calculate Normalized Discounted Cumulative Gain at K.
    
    Args:
        ranked_results: List of lists containing relevance scores
        k: Cutoff rank for evaluation
        
    Returns:
        NDCG@K score
    """
    def dcg_at_k(relevance_scores: List[float], k: int) -> float:
        """Calculate DCG at K"""
        dcg = 0
        for i, score in enumerate(relevance_scores[:k]):
            dcg += score / np.log2(i + 2)  # +2 because log2(1) = 0
        return dcg
    
    ndcg_scores = []
    for results in ranked_results:
        # Actual DCG
        actual_dcg = dcg_at_k(results, k)
        
        # Ideal DCG (sorted in descending order)
        ideal_dcg = dcg_at_k(sorted(results, reverse=True), k)
        
        # NDCG
        ndcg = actual_dcg / ideal_dcg if ideal_dcg > 0 else 0
        ndcg_scores.append(ndcg)
    
    return np.mean(ndcg_scores)


# =============================================================================
# 5. Example Usage & Demo
# =============================================================================

def demo_neural_code_search():
    """Demonstration of neural code search system"""
    
    # Sample SystemVerilog code database
    sample_codebase = [
        {
            'name': 'protocol_checker.sv',
            'description': 'Protocol buffer validation and checking logic',
            'comments': 'Validates incoming protocol buffers against specification',
            'function_names': ['validate_protocol', 'check_buffer_integrity', 'parse_header'],
            'module_purpose': 'Protocol validation and verification'
        },
        {
            'name': 'memory_sva.sv', 
            'description': 'Memory access assertions for RISC-V processor',
            'comments': 'SystemVerilog assertions for memory subsystem verification',
            'function_names': ['check_memory_access', 'assert_address_valid', 'verify_data_integrity'],
            'module_purpose': 'Memory verification and assertion checking'
        },
        {
            'name': 'cdc_assertions.sv',
            'description': 'Clock domain crossing verification assertions',
            'comments': 'Verifies safe clock domain crossing implementations',
            'function_names': ['check_cdc_safety', 'verify_synchronizer', 'assert_metastability'],
            'module_purpose': 'Clock domain crossing safety verification'
        }
    ]
    
    # Initialize search system
    search_system = NeuralCodeSearch()
    search_system.index_codebase(sample_codebase)
    
    # Example queries
    test_queries = [
        "Protocol buffer validation logic",
        "Memory access assertion for RISC-V", 
        "Clock domain crossing verification"
    ]
    
    print("=== Neural Code Search Demo ===\n")
    
    for query in test_queries:
        print(f"Query: '{query}'")
        results = search_system.search(query, top_k=3)
        
        for i, result in enumerate(results, 1):
            print(f"  {i}. {result['file']} (similarity: {result['similarity']:.3f})")
            print(f"     {result['description']}")
        print()
    
    return search_system


def demo_evaluation_metrics():
    """Demonstration of evaluation metrics for code search"""
    
    # Sample evaluation data
    # Each list represents ranked results for a query (True = relevant, False = not relevant)
    sample_ranked_results = [
        [True, False, True, False, False],   # Query 1: relevant items at rank 1 and 3
        [False, True, False, True, False],   # Query 2: relevant items at rank 2 and 4  
        [True, True, False, False, False]    # Query 3: relevant items at rank 1 and 2
    ]
    
    # Sample relevance scores (0-1 scale)
    sample_relevance_scores = [
        [0.9, 0.1, 0.8, 0.2, 0.0],
        [0.0, 0.85, 0.1, 0.75, 0.0],
        [0.95, 0.9, 0.1, 0.0, 0.0]
    ]
    
    # Calculate metrics
    mrr_score = calculate_mrr(sample_ranked_results)
    ndcg_score = calculate_ndcg_at_k(sample_relevance_scores, k=5)
    
    print("=== Evaluation Metrics Demo ===")
    print(f"Mean Reciprocal Rank (MRR): {mrr_score:.3f}")
    print(f"NDCG@5: {ndcg_score:.3f}")
    
    return mrr_score, ndcg_score


def demo_rag_system():
    """Demonstration of RAG integration with code search"""
    
    # Initialize components
    search_system = demo_neural_code_search()
    
    # Mock LLM client (in real implementation, this would be OpenAI, Claude, etc.)
    class MockLLMClient:
        def generate(self, prompt, max_tokens=500):
            return "Generated response based on retrieved code context..."
    
    llm_client = MockLLMClient()
    rag_system = CodeRAGSystem(search_system, llm_client)
    
    # Example RAG query
    query = "I need to implement protocol buffer validation for a new communication module"
    
    print("=== RAG System Demo ===")
    print(f"Query: {query}\n")
    
    result = rag_system.query_with_context(query, max_files=2)
    
    print("Retrieved Files:")
    for file, score in zip(result['retrieved_files'], result['similarity_scores']):
        print(f"  - {file} (similarity: {score:.3f})")
    
    print(f"\nGenerated Response:\n{result['response']}")
    print(f"\nContext Length: {len(result['context_used'].split())} words")
    
    return result


if __name__ == "__main__":
    """
    Main demonstration showcasing all key components:
    1. Neural code search with embedding models
    2. Evaluation metrics (MRR, NDCG)  
    3. RAG integration for context-aware responses
    
    This demonstrates the core technologies needed for Nokia's
    AI FES position: neural code search, contrastive learning concepts,
    and RAG implementation.
    """
    
    print("🔍 Neural Code Search & RAG Demonstration")
    print("=" * 50)
    
    # Run demonstrations
    demo_neural_code_search()
    demo_evaluation_metrics() 
    demo_rag_system()
    
    print("\n Demo completed successfully!")
    print("\nKey Technologies Demonstrated:")
    print("- Neural code search using embedding models")
    print("- Contrastive learning for code-requirement alignment")
    print("- RAG integration with LLM generation")
    print("- Evaluation metrics (MRR, NDCG)")
    print("- End-to-end AI lifecycle implementation")


# =============================================================================
# 6. Configuration and Requirements
# =============================================================================

# Required packages for this implementation:
"""
sentence-transformers>=2.2.0
torch>=1.12.0
scikit-learn>=1.1.0
numpy>=1.21.0
"""

# Example requirements.txt content:
REQUIREMENTS = """
sentence-transformers==2.2.2
torch==1.13.1
scikit-learn==1.2.1
numpy==1.24.0
faiss-cpu==1.7.3
langchain==0.1.0
openai==1.0.0
"""