# Semantic Search Implementation with Language Models

In [None]:
import torch
import torchrec
from transformers import (
    AutoTokenizer,
    AutoModel,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from typing import Dict, List, Tuple, Optional, NamedTuple
from dataclasses import dataclass
import numpy as np
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity


## Semantic Processing Components

In [None]:
@dataclass
class SemanticSearchResult:
    """Container for semantic search results"""
    document_id: int
    score: float
    semantic_similarity: float
    entity_match_score: float
    intent_match_score: float
    text_snippet: str
    matched_entities: List[str]
    relevance_explanation: str

class SemanticProcessor:
    """Process queries and documents for semantic understanding"""
    def __init__(
        self,
        model_name: str = "sentence-transformers/all-mpnet-base-v2",
        t5_model: str = "t5-base",
        device: str = "cuda"
    ):
        self.device = device
        
        # Semantic embedding model
        self.embedding_model = SentenceTransformer(model_name).to(device)
        
        # Query understanding model
        self.tokenizer = T5Tokenizer.from_pretrained(t5_model)
        self.t5_model = T5ForConditionalGeneration.from_pretrained(t5_model).to(device)
        
        # Entity recognition model
        self.entity_tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
        self.entity_model = AutoModel.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english").to(device)
    
    def process_query(self, query: str) -> Dict[str, torch.Tensor]:
        """Process query for semantic search"""
        # Generate semantic embedding
        query_embedding = self.embedding_model.encode(
            query,
            convert_to_tensor=True,
            device=self.device
        )
        
        # Understand query intent
        intent_input = f"classify intent: {query}"
        intent_tokens = self.tokenizer(
            intent_input,
            return_tensors="pt",
            truncation=True,
            max_length=128
        ).to(self.device)
        
        intent_output = self.t5_model.generate(
            **intent_tokens,
            max_length=32,
            num_beams=4
        )
        intent = self.tokenizer.decode(intent_output[0], skip_special_tokens=True)
        
        # Extract entities
        entity_tokens = self.entity_tokenizer(
            query,
            return_tensors="pt",
            truncation=True,
            max_length=128
        ).to(self.device)
        
        entity_outputs = self.entity_model(**entity_tokens)
        entity_embeddings = entity_outputs.last_hidden_state
        
        return {
            "embedding": query_embedding,
            "intent": intent,
            "entity_embeddings": entity_embeddings,
            "original_query": query
        }
    
    def expand_query(self, query: str) -> List[str]:
        """Generate query expansions"""
        expansion_input = f"expand query: {query}"
        expansion_tokens = self.tokenizer(
            expansion_input,
            return_tensors="pt",
            truncation=True,
            max_length=128
        ).to(self.device)
        
        expansion_outputs = self.t5_model.generate(
            **expansion_tokens,
            max_length=128,
            num_beams=4,
            num_return_sequences=3
        )
        
        expansions = [
            self.tokenizer.decode(output, skip_special_tokens=True)
            for output in expansion_outputs
        ]
        
        return expansions

class SemanticIndex:
    """Index for semantic search"""
    def __init__(
        self,
        processor: SemanticProcessor,
        use_faiss: bool = True
    ):
        self.processor = processor
        self.use_faiss = use_faiss
        self.document_store = {}
        self.document_embeddings = None
        
        if use_faiss:
            import faiss
            self.index = None
    
    def add_documents(
        self,
        documents: List[Dict[str, str]],
        batch_size: int = 32
    ):
        """Add documents to semantic index"""
        print("Processing documents...")
        embeddings = []
        
        for i in range(0, len(documents), batch_size):
            batch = documents[i:i + batch_size]
            batch_texts = [doc["text"] for doc in batch]
            
            # Generate embeddings
            batch_embeddings = self.processor.embedding_model.encode(
                batch_texts,
                convert_to_tensor=True,
                device=self.processor.device
            )
            
            embeddings.append(batch_embeddings.cpu().numpy())
            
            # Process and store documents
            for j, doc in enumerate(batch):
                doc_id = i + j
                processed_doc = {
                    "id": doc_id,
                    "text": doc["text"],
                    "entities": self._extract_entities(doc["text"]),
                    "semantic_properties": self._extract_semantic_properties(doc["text"])
                }
                self.document_store[doc_id] = processed_doc
        
        # Concatenate all embeddings
        self.document_embeddings = np.vstack(embeddings)
        
        if self.use_faiss:
            # Initialize FAISS index
            dimension = self.document_embeddings.shape[1]
            self.index = faiss.IndexFlatIP(dimension)
            self.index.add(self.document_embeddings)
    
    def _extract_entities(self, text: str) -> List[str]:
        """Extract entities from text"""
        tokens = self.processor.entity_tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(self.processor.device)
        
        outputs = self.processor.entity_model(**tokens)
        entity_embeddings = outputs.last_hidden_state
        
        # Simplified entity extraction (in practice, would use more sophisticated NER)
        return []  # Placeholder
    
    def _extract_semantic_properties(self, text: str) -> Dict[str, float]:
        """Extract semantic properties from text"""
        property_input = f"extract properties: {text}"
        property_tokens = self.processor.tokenizer(
            property_input,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(self.processor.device)
        
        property_output = self.processor.t5_model.generate(
            **property_tokens,
            max_length=128,
            num_beams=4
        )
        
        properties = self.processor.tokenizer.decode(
            property_output[0],
            skip_special_tokens=True
        )
        
        # Simplified property extraction (in practice, would parse the output)
        return {}  # Placeholder
    
    def search(
        self,
        query: str,
        k: int = 10,
        rerank: bool = True
    ) -> List[SemanticSearchResult]:
        """Perform semantic search"""
        # Process query
        query_data = self.processor.process_query(query)
        query_embedding = query_data["embedding"].cpu().numpy()
        
        # Initial retrieval
        if self.use_faiss:
            scores, doc_indices = self.index.search(
                query_embedding.reshape(1, -1),
                k * 2  # Retrieve more for reranking
            )
            initial_results = [
                (self.document_store[idx], score)
                for score, idx in zip(scores[0], doc_indices[0])
            ]
        else:
            # Compute similarities
            similarities = cosine_similarity(
                query_embedding.reshape(1, -1),
                self.document_embeddings
            )[0]
            
            # Get top results
            top_indices = np.argsort(similarities)[-k * 2:][::-1]
            initial_results = [
                (self.document_store[idx], similarities[idx])
                for idx in top_indices
            ]
        
        if rerank:
            # Rerank results using more sophisticated scoring
            results = []
            for doc, initial_score in initial_results:
                # Calculate additional relevance signals
                semantic_score = self._calculate_semantic_score(
                    query_data,
                    doc
                )
                
                entity_score = self._calculate_entity_score(
                    query_data,
                    doc
                )
                
                intent_score = self._calculate_intent_score(
                    query_data,
                    doc
                )
                
                # Combined score
                final_score = (
                    initial_score * 0.4 +
                    semantic_score * 0.3 +
                    entity_score * 0.2 +
                    intent_score * 0.1
                )
                
                results.append(
                    SemanticSearchResult(
                        document_id=doc["id"],
                        score=final_score,
                        semantic_similarity=semantic_score,
                        entity_match_score=entity_score,
                        intent_match_score=intent_score,
                        text_snippet=self._generate_snippet(query, doc["text"]),
                        matched_entities=self._find_matching_entities(
                            query_data,
                            doc
                        ),
                        relevance_explanation=self._explain_relevance(
                            query_data,
                            doc,
                            final_score
                        )
                    )
                )
            
            # Sort by final score
            results.sort(key=lambda x: x.score, reverse=True)
            return results[:k]
        
        return initial_results[:k]