# Sentence Transformer Requirements Matching
**Processes requirement documents using sentence transformer models to identify and score semantic similarities between source and target requirements.**


In [None]:
# Cell [0] - Setup and Imports
# Purpose: Import all required libraries and configure environment settings for Multi-LLM testing
# Dependencies: os, sys, logging, pathlib, xml, collections, typing, torch, matplotlib, sentence_transformers, tqdm, dotenv, praxis_sentence_transformer
# Breadcrumbs: Setup -> Imports -> Environment Configuration

import os
import sys
import logging
from pathlib import Path
import xml.etree.ElementTree as ET
from collections import defaultdict
from typing import List, Tuple, Dict
import torch
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer, util
from tqdm.notebook import tqdm
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Import from the praxis-sentence-transformer package (installed via pip)
try:
    from praxis_sentence_transformer import (
        setup_logging, 
        handle_exception, 
        DebugTimer
    )
except ImportError as e:
    print(f"Failed to import praxis_sentence_transformer: {str(e)}")
    print("Please install the package using pip install praxis-sentence-transformer")
    raise

# Set up logging
logger = setup_logging("sentence-transformer-notebook", logging.DEBUG)

class RequirementsLoader:
    """Handles loading and parsing of requirements from XML files"""
    
    @handle_exception
    def parse_requirements(self, file_path: str) -> List[Tuple[str, str]]:
        """
        Parse requirements from XML file
        
        Parameters:
            file_path (str): Path to the requirements XML file
            
        Returns:
            List[Tuple[str, str]]: List of tuples containing requirement IDs and descriptions
        """
        logger.debug(f"Parsing requirements from {file_path}")
        tree = ET.parse(file_path)
        root = tree.getroot()
        requirements = []
        
        for artifact in root.findall('.//artifact'):
            req_id = artifact.find('id')
            req_desc = artifact.find('content')
            
            if req_id is not None and req_desc is not None:
                requirements.append((req_id.text.strip(), req_desc.text.strip()))
                
        logger.info(f"Successfully parsed {len(requirements)} requirements")
        return requirements
    
    @handle_exception
    def parse_answer_set(self, file_path: str) -> List[Tuple[str, str]]:
        """
        Parse the answer set from XML file
        
        Parameters:
            file_path (str): Path to the answer set XML file
        
        Returns:
            List[Tuple[str, str]]: List of tuples containing source and target requirement IDs
        """
        logger.debug(f"Parsing answer set from {file_path}")
        tree = ET.parse(file_path)
        root = tree.getroot()
        mappings = []
        
        for link in root.findall('.//link'):
            source = link.find('source_artifact_id')
            target = link.find('target_artifact_id')
            
            if source is not None and target is not None:
                mappings.append((source.text.strip(), target.text.strip()))
        
        logger.info(f"Successfully parsed {len(mappings)} reference mappings")
        return mappings

class SentenceTransformerAnalyzer:
    """Analyzes requirements using sentence transformers"""
    
    def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        """
        Initialize the analyzer with a specific model
        
        Parameters:
            model_name (str): Name of the sentence transformer model to use
        """
        self.model_name = model_name
        self.model = None
        self.loader = RequirementsLoader()
        
    @handle_exception
    def initialize(self):
        """Initialize the sentence transformer model"""
        logger.info(f"Initializing sentence transformer model: {self.model_name}")
        self.model = SentenceTransformer(self.model_name)
        if torch.cuda.is_available():
            self.model = self.model.to(torch.device("cuda"))
            logger.info("Model moved to CUDA")
    
    @handle_exception
    def analyze_requirements(self, 
                           source_file: str,
                           target_file: str,
                           threshold: float = 0.5) -> Dict[str, List[Tuple[str, float]]]:
        """
        Analyze requirements and find similarities using mean pooling and cosine similarity
        
        Parameters:
            source_file (str): Path to source requirements XML file
            target_file (str): Path to target requirements XML file
            threshold (float): Similarity threshold for matching (default: 0.5)
            
        Returns:
            Dict[str, List[Tuple[str, float]]]: Dictionary mapping source IDs to list of (target_id, similarity_score)
        """
        with DebugTimer(logger, "Requirements Analysis"):
            source_reqs = self.loader.parse_requirements(source_file)
            target_reqs = self.loader.parse_requirements(target_file)
            
            logger.info(f"Analyzing {len(source_reqs)} source and {len(target_reqs)} target requirements")
            
            source_texts = [desc for _, desc in source_reqs]
            target_texts = [desc for _, desc in target_reqs]
            
            with DebugTimer(logger, "Encoding source texts"):
                source_embeddings = self.model.encode(
                    source_texts, 
                    convert_to_tensor=True,
                    normalize_embeddings=True
                )
            
            with DebugTimer(logger, "Encoding target texts"):
                target_embeddings = self.model.encode(
                    target_texts, 
                    convert_to_tensor=True,
                    normalize_embeddings=True
                )
            
            similarities = util.pytorch_cos_sim(source_embeddings, target_embeddings)
            
            mappings = defaultdict(list)
            for i, (source_id, _) in enumerate(source_reqs):
                for j, (target_id, _) in enumerate(target_reqs):
                    similarity = similarities[i][j].item()
                    if similarity >= threshold:
                        mappings[source_id].append((target_id, similarity))
                
                if source_id in mappings:
                    mappings[source_id] = sorted(mappings[source_id], key=lambda x: x[1], reverse=True)
            
            logger.info(f"Found {len(mappings)} source requirements with matches above threshold {threshold}")
            return dict(mappings)

    @handle_exception
    def evaluate_results(self, 
                        calculated_mapping: Dict[str, List[Tuple[str, float]]], 
                        answer_set_file: str) -> Dict[str, float]:
        """
        Evaluate the calculated mappings against the reference answer set and generate confusion matrix
        
        Parameters:
            calculated_mapping (Dict[str, List[Tuple[str, float]]]): Calculated requirement mappings
            answer_set_file (str): Path to the answer set XML file
        
        Returns:
            Dict[str, float]: Dictionary containing evaluation metrics
        """
        reference_mappings = set(self.loader.parse_answer_set(answer_set_file))
        
        calculated_set = set()
        for source_id, matches in calculated_mapping.items():
            for target_id, _ in matches:
                calculated_set.add((source_id, target_id))
        
        true_positives = len(reference_mappings.intersection(calculated_set))
        false_positives = len(calculated_set - reference_mappings)
        false_negatives = len(reference_mappings - calculated_set)
        
        all_source_ids = {pair[0] for pair in reference_mappings.union(calculated_set)}
        all_target_ids = {pair[1] for pair in reference_mappings.union(calculated_set)}
        all_possible_pairs = {(s, t) for s in all_source_ids for t in all_target_ids}
        true_negatives = len(all_possible_pairs - reference_mappings - calculated_set)
        
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
        balanced_accuracy = (
            true_positives / (true_positives + false_negatives) +
            true_negatives / (true_negatives + false_positives)
        ) / 2
        
        confusion_matrix = f"""
Confusion Matrix:
                 Predicted Positive | Predicted Negative
Actual Positive |        {true_positives:^10d} |       {false_negatives:^10d}
Actual Negative |        {false_positives:^10d} |       {true_negatives:^10d}
"""
        
        results = {
            'precision': precision,
            'recall': recall,
            'f1_score': f1_score,
            'accuracy': accuracy,
            'balanced_accuracy': balanced_accuracy,
            'confusion_matrix': confusion_matrix,
            'true_positives': true_positives,
            'false_positives': false_positives,
            'true_negatives': true_negatives,
            'false_negatives': false_negatives
        }
        
        logger.info(f"Evaluation Results: {results}")
        logger.info(confusion_matrix)
        return results

    @handle_exception
    def find_optimal_threshold(self, 
                             source_file: str,
                             target_file: str,
                             answer_set_file: str,
                             threshold_range: List[float] = None) -> Dict[float, Dict[str, float]]:
        """
        Find optimal threshold by testing different values
        
        Parameters:
            source_file (str): Path to source requirements file
            target_file (str): Path to target requirements file
            answer_set_file (str): Path to answer set file
            threshold_range (List[float]): List of thresholds to test (default: [0.05 to 0.6 by 0.05])
            
        Returns:
            Dict[float, Dict[str, float]]: Dictionary mapping thresholds to their evaluation metrics
        """
        if threshold_range is None:
            threshold_range = [round(x * 0.05, 2) for x in range(1, 13)]  # 0.05 to 0.60
            
        results = {}
        for threshold in threshold_range:
            logger.info(f"Testing threshold: {threshold}")
            mapping = self.analyze_requirements(source_file, target_file, threshold=threshold)
            evaluation = self.evaluate_results(mapping, answer_set_file)
            
            fnr = evaluation['false_negatives'] / (evaluation['true_positives'] + evaluation['false_negatives'])
            evaluation['false_negative_rate'] = fnr
            
            results[threshold] = evaluation
            
        return results

    def plot_threshold_metrics(self, threshold_results: Dict[float, Dict[str, float]], save_path: str = None):
        """
        Plot threshold analysis metrics
        
        Parameters:
            threshold_results (Dict[float, Dict[str, float]]): Results from threshold analysis
            save_path (str): Optional path to save the plots
        """
        import matplotlib.pyplot as plt
        from datetime import datetime
        import os
        
        # Create timestamp and format save path
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if save_path:
            # Extract model name from the full path
            model_name = self.model_name.split('/')[-1]
            # Create directory structure: results/model_name/timestamp/
            save_dir = os.path.join('results', model_name, timestamp)
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, 'threshold_analysis.png')

        # Rest of the plotting code remains the same
        thresholds = sorted(threshold_results.keys())
        metrics = {
            'precision': [threshold_results[t]['precision'] for t in thresholds],
            'recall': [threshold_results[t]['recall'] for t in thresholds],
            'f1_score': [threshold_results[t]['f1_score'] for t in thresholds],
            'accuracy': [(threshold_results[t]['true_positives'] + threshold_results[t]['true_negatives']) / 
                        (threshold_results[t]['true_positives'] + threshold_results[t]['true_negatives'] + 
                        threshold_results[t]['false_positives'] + threshold_results[t]['false_negatives']) 
                        for t in thresholds],
            'balanced_accuracy': [
                (threshold_results[t]['true_positives'] / (threshold_results[t]['true_positives'] + threshold_results[t]['false_negatives']) +
                threshold_results[t]['true_negatives'] / (threshold_results[t]['true_negatives'] + threshold_results[t]['false_positives'])) / 2
                for t in thresholds
            ],
            'false_negative_rate': [threshold_results[t]['false_negative_rate'] for t in thresholds]
        }
        
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 18))
        
        # Plot 1: Precision-Recall vs Threshold
        ax1.plot(thresholds, metrics['precision'], 'b-', label='Precision')
        ax1.plot(thresholds, metrics['recall'], 'r-', label='Recall')
        ax1.set_xlabel('Threshold')
        ax1.set_ylabel('Score')
        ax1.set_title('Precision and Recall vs Threshold')
        ax1.grid(True)
        ax1.legend()
        
        # Plot 2: All Metrics vs Threshold
        for metric_name in ['precision', 'recall', 'f1_score', 'accuracy', 'balanced_accuracy']:
            ax2.plot(thresholds, metrics[metric_name], label=metric_name.replace('_', ' ').title())
        ax2.set_xlabel('Threshold')
        ax2.set_ylabel('Score')
        ax2.set_title('All Metrics vs Threshold')
        ax2.grid(True)
        ax2.legend()
        
        # Plot 3: False Negative Rate vs Threshold
        ax3.plot(thresholds, metrics['false_negative_rate'], 'r-', label='False Negative Rate')
        ax3.set_xlabel('Threshold')
        ax3.set_ylabel('Rate')
        ax3.set_title('False Negative Rate vs Threshold')
        ax3.grid(True)
        ax3.legend()
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
            logger.info(f"Plots saved to {save_path}")
        
        return fig
    
    @handle_exception
    def analyze_requirements_bidirectional(self,
                                         source_file: str,
                                         target_file: str,
                                         threshold: float = 0.5) -> Tuple[Dict[str, List[Tuple[str, float]]], Dict[str, List[Tuple[str, float]]]]:
        """
        Analyze requirements in both directions (source->target and target->source)
        
        Parameters:
            source_file (str): Path to source requirements XML file
            target_file (str): Path to target requirements XML file
            threshold (float): Similarity threshold for matching
            
        Returns:
            Tuple[Dict, Dict]: (source_to_target_mappings, target_to_source_mappings)
        """
        with DebugTimer(logger, "Bidirectional Requirements Analysis"):
            source_reqs = self.loader.parse_requirements(source_file)
            target_reqs = self.loader.parse_requirements(target_file)
            
            logger.info(f"Analyzing {len(source_reqs)} source and {len(target_reqs)} target requirements bidirectionally")
            
            source_texts = [desc for _, desc in source_reqs]
            target_texts = [desc for _, desc in target_reqs]
            
            with DebugTimer(logger, "Encoding texts"):
                source_embeddings = self.model.encode(
                    source_texts, 
                    convert_to_tensor=True,
                    normalize_embeddings=True
                )
                target_embeddings = self.model.encode(
                    target_texts, 
                    convert_to_tensor=True,
                    normalize_embeddings=True
                )
            
            similarities = util.pytorch_cos_sim(source_embeddings, target_embeddings)
            
            # Source to target mappings
            source_to_target = defaultdict(list)
            for i, (source_id, _) in enumerate(source_reqs):
                for j, (target_id, _) in enumerate(target_reqs):
                    similarity = similarities[i][j].item()
                    if similarity >= threshold:
                        source_to_target[source_id].append((target_id, similarity))
                
                if source_id in source_to_target:
                    source_to_target[source_id] = sorted(source_to_target[source_id], key=lambda x: x[1], reverse=True)
            
            # Target to source mappings
            target_to_source = defaultdict(list)
            for j, (target_id, _) in enumerate(target_reqs):
                for i, (source_id, _) in enumerate(source_reqs):
                    similarity = similarities[i][j].item()
                    if similarity >= threshold:
                        target_to_source[target_id].append((source_id, similarity))
                
                if target_id in target_to_source:
                    target_to_source[target_id] = sorted(target_to_source[target_id], key=lambda x: x[1], reverse=True)
            
            logger.info(f"Found {len(source_to_target)} source->target and {len(target_to_source)} target->source mappings above threshold {threshold}")
            return dict(source_to_target), dict(target_to_source)

    @handle_exception
    def find_optimal_threshold_bidirectional(self,
                                           source_file: str,
                                           target_file: str,
                                           answer_set_file: str,
                                           threshold_range: List[float] = None) -> Dict[str, Dict[float, Dict[str, float]]]:
        """
        Find optimal threshold by testing different values in both directions
        
        Parameters:
            source_file (str): Path to source requirements file
            target_file (str): Path to target requirements file
            answer_set_file (str): Path to answer set file
            threshold_range (List[float]): List of thresholds to test
            
        Returns:
            Dict[str, Dict[float, Dict[str, float]]]: Dictionary with results for both directions
        """
        if threshold_range is None:
            threshold_range = [round(x * 0.05, 2) for x in range(1, 13)]  # 0.05 to 0.60
            
        results = {
            'source_to_target': {},
            'target_to_source': {}
        }
        
        for threshold in threshold_range:
            logger.info(f"Testing threshold: {threshold}")
            s2t_mapping, t2s_mapping = self.analyze_requirements_bidirectional(
                source_file, target_file, threshold=threshold
            )
            
            # Evaluate source to target
            s2t_evaluation = self.evaluate_results(s2t_mapping, answer_set_file)
            s2t_evaluation['false_negative_rate'] = s2t_evaluation['false_negatives'] / (
                s2t_evaluation['true_positives'] + s2t_evaluation['false_negatives']
            )
            results['source_to_target'][threshold] = s2t_evaluation
            
            # Evaluate target to source
            t2s_evaluation = self.evaluate_results(t2s_mapping, answer_set_file)
            t2s_evaluation['false_negative_rate'] = t2s_evaluation['false_negatives'] / (
                t2s_evaluation['true_positives'] + t2s_evaluation['false_negatives']
            )
            results['target_to_source'][threshold] = t2s_evaluation
            
        return results

    def plot_bidirectional_metrics(self, threshold_results: Dict[str, Dict[float, Dict[str, float]]], save_path: str = None):
        """
        Plot threshold analysis metrics for bidirectional analysis
        
        Parameters:
            threshold_results (Dict): Results from bidirectional threshold analysis
            save_path (str): Optional path to save the plots
        """
        # Similar to existing plot_threshold_metrics but with two sets of plots
        # Implementation details omitted for brevity - can provide if needed
        pass

if __name__ == "__main__":
    # Initialize analyzer
    analyzer = SentenceTransformerAnalyzer()
    analyzer.initialize()
    
    # File paths
    source_file = "datasets/CM1/CM1-sourceArtifacts.xml"
    target_file = "datasets/CM1/CM1-targetArtifacts.xml"
    answer_set_file = "datasets/CM1/CM1-answerSet.xml"
    
    # Find optimal threshold
    threshold_results = analyzer.find_optimal_threshold(
        source_file, 
        target_file, 
        answer_set_file,
        threshold_range=[0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6]
    )
    
    # Create visualizations with default save path
    analyzer.plot_threshold_metrics(threshold_results, save_path="threshold_analysis.png")
    
    # Print results for each threshold
    print("\nThreshold Analysis:")
    print("Threshold | Precision | Recall | F1-Score | Accuracy | Bal Acc | FN Rate")
    print("-" * 75)
    for threshold, metrics in sorted(threshold_results.items()):
        print(f"{threshold:^9.2f} | {metrics['precision']:^9.3f} | {metrics['recall']:^6.3f} | "
              f"{metrics['f1_score']:^8.3f} | {metrics['accuracy']:^8.3f} | "
              f"{metrics['balanced_accuracy']:^7.3f} | {metrics['false_negative_rate']:^7.3f}")
    
    # Find threshold with best F1 score
    best_threshold = max(threshold_results.items(), key=lambda x: x[1]['f1_score'])[0]
    
    # Run final analysis with best threshold
    print(f"\nRunning final analysis with best threshold: {best_threshold}")
    mapping = analyzer.analyze_requirements(source_file, target_file, threshold=best_threshold)
    evaluation = analyzer.evaluate_results(mapping, answer_set_file)
    
    print("\nBest Threshold Results:")
    print(f"Precision: {evaluation['precision']:.3f}")
    print(f"Recall: {evaluation['recall']:.3f}")
    print(f"F1-Score: {evaluation['f1_score']:.3f}")
    print(f"Accuracy: {evaluation['accuracy']:.3f}")
    print(f"Balanced Accuracy: {evaluation['balanced_accuracy']:.3f}")
    print(evaluation['confusion_matrix'])