# Sentence Transformer Requirements Traceability Analysis
**Evaluation and comparison of sentence transformer models for requirements traceability with threshold optimization, performance metrics, statistical visualization, and Neo4j integration.**

In [None]:
# Cell [0] - Setup and Imports
# Purpose: Import all required libraries and configure environment settings
# Dependencies: pandas, numpy, neo4j, sklearn, matplotlib, seaborn
# Breadcrumbs: Setup -> Imports

import os
import logging
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from neo4j import GraphDatabase
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, fbeta_score,
    matthews_corrcoef, confusion_matrix, balanced_accuracy_score,
    cohen_kappa_score, roc_auc_score, precision_recall_curve, auc,
    confusion_matrix, classification_report
)
import json
from datetime import datetime  # Added global import for datetime

def setup_environment():
    """
    Configure logging and load environment variables
    
    Returns:
        dict: Configuration parameters
    """
    # Configure logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    # Load environment variables
    load_dotenv()
    
    # Neo4j credentials from environment variables
    config = {
        'NEO4J_URI': os.getenv('NEO4J_URI'),
        'NEO4J_USER': os.getenv('NEO4J_USER'),
        'NEO4J_PASSWORD': os.getenv('NEO4J_PASSWORD'),
        'NEO4J_PROJECT_NAME': os.getenv('NEO4J_PROJECT_NAME'),
        'OPTIMIZATION_METRIC': os.getenv('OPTIMIZATION_METRIC', 'F2').upper(),
        'SHOW_VISUALIZATION': os.getenv('SHOW_VISUALIZATION', 'False').lower() == 'true',
        'MATCH_DIRECTION': os.getenv('MATCH_DIRECTION', 'source_to_target')
    }
    
    logger.info(f"Using {config['OPTIMIZATION_METRIC']} score for threshold application in traceability analysis")
    logger.info(f"Visualization display is set to: {config['SHOW_VISUALIZATION']}")
    print(f"Visualization setting: {'Enabled' if config['SHOW_VISUALIZATION'] else 'Disabled'}")
    print(f"Optimization metric: {config['OPTIMIZATION_METRIC']}")
    
    return config, logger

# Execute setup when imported
CONFIG, logger = setup_environment()
NEO4J_URI = CONFIG['NEO4J_URI']
NEO4J_USER = CONFIG['NEO4J_USER']
NEO4J_PASSWORD = CONFIG['NEO4J_PASSWORD']
NEO4J_PROJECT_NAME = CONFIG['NEO4J_PROJECT_NAME']
OPTIMIZATION_METRIC = CONFIG['OPTIMIZATION_METRIC']
SHOW_VISUALIZATION = CONFIG['SHOW_VISUALIZATION']

In [None]:
# Cell [1] - Neo4j Connection Setup
# Purpose: Create connection to Neo4j database containing traceability data
# Dependencies: neo4j, logging
# Breadcrumbs: Setup -> Database Connection

def create_neo4j_driver(uri=None, user=None, password=None):
    """
    Create and return a Neo4j driver instance
    
    Parameters:
        uri (str, optional): Neo4j URI. Defaults to NEO4J_URI from environment.
        user (str, optional): Neo4j username. Defaults to NEO4J_USER from environment.
        password (str, optional): Neo4j password. Defaults to NEO4J_PASSWORD from environment.
    
    Returns:
        GraphDatabase.driver: Connected Neo4j driver
    """
    try:
        # Use parameters if provided, otherwise use globals from setup_environment
        _uri = uri if uri is not None else NEO4J_URI
        _user = user if user is not None else NEO4J_USER
        _password = password if password is not None else NEO4J_PASSWORD
        
        driver = GraphDatabase.driver(_uri, auth=(_user, _password))
        logger.info("Successfully connected to Neo4j database")
        return driver
    except Exception as e:
        logger.error(f"Failed to connect to Neo4j: {str(e)}")
        logger.error("Exception details:", exc_info=True)
        raise

# Create Neo4j driver when this module is imported or run directly
driver = create_neo4j_driver()

In [None]:
# Cell [2] - Query SIMILAR_TO Links
# Purpose: Retrieve sentence transformer similarity links from Neo4j
# Dependencies: neo4j, pandas, logging
# Breadcrumbs: Data Acquisition -> Similarity Links

def query_similar_to_links(driver):
    """
    Query SIMILAR_TO links from Neo4j, with direction specified by MATCH_DIRECTION
    Only includes requirements that have GROUND_TRUTH links
    
    Parameters:
        driver: Neo4j driver connection
    
    Returns:
        tuple: (combined_df, source_to_target_df, target_to_source_df)
            - All three are pandas DataFrames containing similarity links
    """
    try:
        # Get match direction from environment
        MATCH_DIRECTION = os.getenv('MATCH_DIRECTION', 'source_to_target')
        logger.info(f"Using match direction from .env: '{MATCH_DIRECTION}'")
        
        # Validate MATCH_DIRECTION parameter
        valid_directions = ['source_to_target', 'target_to_source', 'both']
        if MATCH_DIRECTION not in valid_directions:
            logger.warning(f"Invalid MATCH_DIRECTION value: '{MATCH_DIRECTION}'. Using default: 'source_to_target'")
            MATCH_DIRECTION = 'source_to_target'
        else:
            logger.info(f"Validating MATCH_DIRECTION parameter: '{MATCH_DIRECTION}'")
        
        logger.info(f"Using MATCH_DIRECTION={MATCH_DIRECTION} for analysis")
        print(f"DEBUG - MATCH_DIRECTION value: '{MATCH_DIRECTION}'")
        print(f"DEBUG - Valid directions: {valid_directions}")
        print(f"MATCH_DIRECTION set to: '{MATCH_DIRECTION}'")
        
        # Query for source-to-target links with GROUND_TRUTH filtering
        if MATCH_DIRECTION in ['source_to_target', 'both']:
            source_to_target_query = """
            MATCH (p:Project {name: $project_name})-[:CONTAINS]->(d:Document)-[:CONTAINS]->(source_req:Requirement)-[r:SIMILAR_TO]->(target_req:Requirement)
            WHERE source_req.type = 'SOURCE' AND target_req.type = 'TARGET'
            AND EXISTS { 
                MATCH (source_req)-[:GROUND_TRUTH]->() 
            } 
            AND EXISTS { 
                MATCH ()-[:GROUND_TRUTH]->(target_req) 
            }
            RETURN 
                p.name as project_name,
                source_req.id as source_id,
                target_req.id as target_id,
                r.model as sentence_transformer_model,
                r.similarity as similarity_score,
                p.name as model_project,
                r.timestamp as timestamp,
                'source_to_target' as direction
            """
            
            with driver.session() as session:
                source_to_target_results = session.run(
                    source_to_target_query, 
                    project_name=NEO4J_PROJECT_NAME
                ).data()
                
                if source_to_target_results:
                    source_to_target_df = pd.DataFrame(source_to_target_results)
                    logger.info(f"Retrieved {len(source_to_target_df)} SIMILAR_TO links for project: {NEO4J_PROJECT_NAME} in direction: source_to_target")
                else:
                    source_to_target_df = pd.DataFrame()
                    logger.warning(f"No SIMILAR_TO links found for project: {NEO4J_PROJECT_NAME} in direction: source_to_target")
        else:
            source_to_target_df = pd.DataFrame()
        
        # Query for target-to-source links with GROUND_TRUTH filtering
        if MATCH_DIRECTION in ['target_to_source', 'both']:
            target_to_source_query = """
            MATCH (p:Project {name: $project_name})-[:CONTAINS]->(d:Document)-[:CONTAINS]->(target_req:Requirement)-[r:SIMILAR_TO]->(source_req:Requirement)
            WHERE target_req.type = 'TARGET' AND source_req.type = 'SOURCE'
            AND EXISTS { 
                MATCH (source_req)-[:GROUND_TRUTH]->() 
            } 
            AND EXISTS { 
                MATCH ()-[:GROUND_TRUTH]->(target_req) 
            }
            RETURN 
                p.name as project_name,
                source_req.id as source_id,
                target_req.id as target_id,
                r.model as sentence_transformer_model,
                r.similarity as similarity_score,
                p.name as model_project,
                r.timestamp as timestamp,
                'target_to_source' as direction
            """
            
            with driver.session() as session:
                target_to_source_results = session.run(
                    target_to_source_query, 
                    project_name=NEO4J_PROJECT_NAME
                ).data()
                
                if target_to_source_results:
                    target_to_source_df = pd.DataFrame(target_to_source_results)
                    logger.info(f"Retrieved {len(target_to_source_df)} SIMILAR_TO links for project: {NEO4J_PROJECT_NAME} in direction: target_to_source")
                else:
                    target_to_source_df = pd.DataFrame()
                    logger.warning(f"No SIMILAR_TO links found for project: {NEO4J_PROJECT_NAME} in direction: target_to_source")
        else:
            target_to_source_df = pd.DataFrame()
            
        # Combine results based on MATCH_DIRECTION
        if MATCH_DIRECTION == 'source_to_target':
            logger.info("Using only source_to_target direction as specified")
            similar_to_df = source_to_target_df
            similar_to_src_to_tgt_df = source_to_target_df
            similar_to_tgt_to_src_df = pd.DataFrame()
        elif MATCH_DIRECTION == 'target_to_source':
            logger.info("Using only target_to_source direction as specified")
            similar_to_df = target_to_source_df
            similar_to_src_to_tgt_df = pd.DataFrame()
            similar_to_tgt_to_src_df = target_to_source_df
        else:  # both
            logger.info("Using both source_to_target and target_to_source directions")
            similar_to_df = pd.concat([source_to_target_df, target_to_source_df], ignore_index=True)
            similar_to_src_to_tgt_df = source_to_target_df
            similar_to_tgt_to_src_df = target_to_source_df
            
        # Count unique models in the dataset
        if not similar_to_src_to_tgt_df.empty:
            src_to_tgt_models = similar_to_src_to_tgt_df['sentence_transformer_model'].value_counts().to_dict()
            logger.info(f"Sentence transformer models found in SIMILAR_TO links (source_to_target):")
            for model, count in src_to_tgt_models.items():
                logger.info(f"  - {model}: {count} links")
                
        if not similar_to_tgt_to_src_df.empty:
            tgt_to_src_models = similar_to_tgt_to_src_df['sentence_transformer_model'].value_counts().to_dict()
            logger.info(f"Sentence transformer models found in SIMILAR_TO links (target_to_source):")
            for model, count in tgt_to_src_models.items():
                logger.info(f"  - {model}: {count} links")
        
        return similar_to_df, similar_to_src_to_tgt_df, similar_to_tgt_to_src_df
        
    except Exception as e:
        logger.error(f"Error querying SIMILAR_TO links: {str(e)}")
        logger.error("Exception details:", exc_info=True)
        return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

# Execute the query and get results when imported directly
similar_to_df, similar_to_src_to_tgt_df, similar_to_tgt_to_src_df = query_similar_to_links(driver)

# Function to display information about the retrieved data
def display_similar_to_info():
    """Display information about the retrieved SIMILAR_TO links data"""
    if not similar_to_df.empty:
        print(f"\nUsing {'ONLY source_to_target' if similar_to_tgt_to_src_df.empty else 'BOTH'} direction for analysis")
        print("\nDataset Information for Project:", NEO4J_PROJECT_NAME)
        print("=" * 80)
        print(f"SIMILAR_TO links (source to target only): {len(similar_to_src_to_tgt_df)}")
        print(f"SIMILAR_TO links (target to source only): {len(similar_to_tgt_to_src_df)}")
        print(f"SIMILAR_TO links (using only MATCH_DIRECTION={os.getenv('MATCH_DIRECTION', 'source_to_target')}): {len(similar_to_df)}")
        
        # Display models found in source_to_target links
        if not similar_to_src_to_tgt_df.empty:
            print("\nSentence Transformer Models in SOURCE_TO_TARGET SIMILAR_TO links:")
            print("-" * 50)
            st_models = similar_to_src_to_tgt_df['sentence_transformer_model'].value_counts().to_dict()
            for model, count in st_models.items():
                print(f"{model}: {count} links")
                
        # Display head of DataFrame with data types info for debugging
        print("\nSample of SIMILAR_TO links:")
        print("-" * 50)
        display(similar_to_df.head())
        
        # Check for null values in similarity_score column
        null_scores = similar_to_df['similarity_score'].isna().sum()
        print(f"\nNull values in similarity_score column: {null_scores} ({null_scores/len(similar_to_df)*100:.2f}%)")
        
        # Display data types of columns
        print("\nData types of columns:")
        print(similar_to_df.dtypes)

# Run the display function when this cell is executed
display_similar_to_info()

In [None]:
# Cell [3] - Query Ground Truth Links
# Purpose: Retrieve ground truth traceability links from Neo4j
# Dependencies: neo4j, pandas, logging
# Breadcrumbs: Data Acquisition -> Ground Truth

def query_ground_truth_links(driver):
    """
    Query ground truth traceability links from Neo4j database
    
    Parameters:
        driver: Neo4j driver connection
    
    Returns:
        pd.DataFrame: DataFrame containing ground truth links
    """
    try:
        # Use the exact query format that works in notebook 36
        ground_truth_query = """
        MATCH (p:Project {name: $project_name})-[:CONTAINS]->(d:Document)-[:CONTAINS]->(source:Requirement)-[r:GROUND_TRUTH]->(target:Requirement)
        RETURN 
            p.name as project_name,
            p.description as project_description,
            d.id as document_id,
            source.id as source_id,
            source.type as source_type,
            target.id as target_id,
            target.type as target_type,
            1 as ground_truth
        ORDER BY source.id, target.id DESC
        """
        
        with driver.session() as session:
            try:
                # Execute the query with project name parameter
                results = session.run(ground_truth_query, project_name=NEO4J_PROJECT_NAME).data()
                
                if results:
                    logger.info(f"Retrieved {len(results)} ground truth links using GROUND_TRUTH relationship")
                    df_ground_truth = pd.DataFrame(results)
                    return df_ground_truth
                else:
                    logger.warning(f"No ground truth links found for project: {NEO4J_PROJECT_NAME}")
                    return pd.DataFrame()
                    
            except Exception as e:
                logger.error(f"Error executing ground truth query: {str(e)}")
                logger.error("Exception details:", exc_info=True)
                return pd.DataFrame()
    
    except Exception as e:
        logger.error(f"Error querying ground truth links: {str(e)}")
        logger.error("Exception details:", exc_info=True)
        return pd.DataFrame()

# Execute the query and get results
df_ground_truth = query_ground_truth_links(driver)

def display_ground_truth_info(df_ground_truth, similar_to_df=None):
    """
    Display information about the retrieved ground truth links
    
    Parameters:
        df_ground_truth: DataFrame containing ground truth links
        similar_to_df: Optional DataFrame containing similarity links 
                      used for creating synthetic ground truth if needed
    """
    if not df_ground_truth.empty:
        print("\nGround Truth Links for Project:", NEO4J_PROJECT_NAME)
        print("=" * 80)
        display(df_ground_truth.head())
        
        # Count source and target requirements
        unique_sources = df_ground_truth['source_id'].nunique()
        unique_targets = df_ground_truth['target_id'].nunique()
        print(f"\nGround Truth Dataset Metrics:")
        print("-" * 50)
        print(f"Total ground truth links: {len(df_ground_truth)}")
        print(f"Unique source requirements: {unique_sources}")
        print(f"Unique target requirements: {unique_targets}")
        print(f"Link density: {len(df_ground_truth) / (unique_sources * unique_targets):.4f}")
    else:
        print("\nNo ground truth links found. Creating synthetic ground truth for testing:")
        print("-" * 80)
        
        # If we have similarity data but no ground truth, create synthetic ground truth 
        # using high similarity scores for testing purposes
        if similar_to_df is not None and not similar_to_df.empty:
            # Get the top 10% highest similarity scores for each model as synthetic ground truth
            synthetic_gt = []
            
            for model in similar_to_df['sentence_transformer_model'].unique():
                model_df = similar_to_df[similar_to_df['sentence_transformer_model'] == model]
                threshold = model_df['similarity_score'].quantile(0.9)  # Top 10%
                
                high_sim_pairs = model_df[model_df['similarity_score'] >= threshold]
                
                if not high_sim_pairs.empty:
                    for _, row in high_sim_pairs.iterrows():
                        synthetic_gt.append({
                            'project_name': row['project_name'],
                            'project_description': 'Synthetic ground truth',
                            'document_id': 'synthetic',
                            'source_id': row['source_id'],
                            'source_type': 'SOURCE',
                            'target_id': row['target_id'], 
                            'target_type': 'TARGET',
                            'ground_truth': 1
                        })
            
            if synthetic_gt:
                df_ground_truth = pd.DataFrame(synthetic_gt).drop_duplicates(subset=['source_id', 'target_id'])
                print(f"Created {len(df_ground_truth)} synthetic ground truth links from high similarity scores")
                
                # Count source and target requirements
                unique_sources = df_ground_truth['source_id'].nunique()
                unique_targets = df_ground_truth['target_id'].nunique()
                print(f"\nSynthetic Ground Truth Dataset Metrics:")
                print("-" * 50)
                print(f"Total ground truth links: {len(df_ground_truth)}")
                print(f"Unique source requirements: {unique_sources}")
                print(f"Unique target requirements: {unique_targets}")
                print(f"Link density: {len(df_ground_truth) / (unique_sources * unique_targets):.4f}")
                print("\nWARNING: Using synthetic ground truth for testing purposes only!")
                
                return df_ground_truth

# Display information about ground truth links
display_ground_truth_info(df_ground_truth, similar_to_df)

In [None]:
# Cell [4] - Create combined dataset for analysis
# Purpose: Merge sentence transformer and ground truth data for analysis
# Dependencies: pandas, logging
# Breadcrumbs: Data Preparation -> Combination

# First, create a combined dataset with ground truth information
def create_combined_dataset(similar_to_df=None, df_ground_truth=None):
    """
    Create a combined dataset with sentence transformer results and ground truth
    
    Parameters:
        similar_to_df: DataFrame containing sentence transformer similarity results
        df_ground_truth: DataFrame containing ground truth traceability links
    
    Returns:
        pd.DataFrame: Combined dataset with ground truth information
    """
    try:
        # Use provided dataframes or global variables if None
        _similar_to_df = similar_to_df if similar_to_df is not None else globals().get('similar_to_df', pd.DataFrame())
        _df_ground_truth = df_ground_truth if df_ground_truth is not None else globals().get('df_ground_truth', pd.DataFrame())
        
        # Check if we have the required data
        if _similar_to_df.empty:
            logger.error("No SIMILAR_TO links available to create combined dataset")
            return pd.DataFrame()
            
        # Start with similar_to_df as the base
        combined_df = _similar_to_df.copy()
        
        # Add ground truth information if available
        if not _df_ground_truth.empty:
            # Create a set of ground truth links for fast lookup
            ground_truth_pairs = set(zip(_df_ground_truth['source_id'], _df_ground_truth['target_id']))
            
            # Add ground_truth_traceable column
            combined_df['ground_truth_traceable'] = combined_df.apply(
                lambda row: (row['source_id'], row['target_id']) in ground_truth_pairs,
                axis=1
            )
            
            logger.info(f"Added ground truth data: {combined_df['ground_truth_traceable'].sum()} true links out of {len(combined_df)} total")
        else:
            logger.warning("No ground truth data available for combined dataset")
            
        # Rename sentence_transformer_model to model for simplicity in later operations
        if 'sentence_transformer_model' in combined_df.columns and 'model' not in combined_df.columns:
            combined_df['model'] = combined_df['sentence_transformer_model']
            
        # IMPORTANT: Replace any None values in similarity_score with 0
        if 'similarity_score' in combined_df.columns:
            null_count = combined_df['similarity_score'].isna().sum()
            if null_count > 0:
                logger.warning(f"Found {null_count} null values in similarity_score column. Replacing with 0.")
                combined_df['similarity_score'] = combined_df['similarity_score'].fillna(0)
                
        # Convert similarity_score to numeric if it's not already
        if 'similarity_score' in combined_df.columns and combined_df['similarity_score'].dtype == 'object':
            try:
                combined_df['similarity_score'] = pd.to_numeric(combined_df['similarity_score'])
                logger.info("Converted similarity_score to numeric type")
            except Exception as e:
                logger.error(f"Error converting similarity_score to numeric: {str(e)}")
                
        return combined_df
    
    except Exception as e:
        logger.error(f"Error creating combined dataset: {str(e)}")
        logger.error("Exception details:", exc_info=True)
        return pd.DataFrame()

# Create the combined dataset
combined_df = create_combined_dataset()

# Function to display information about the combined dataset
def display_combined_dataset_info(combined_df):
    """
    Display information about the combined dataset for analysis
    
    Parameters:
        combined_df: Combined DataFrame with similarity and ground truth data
    """
    if not combined_df.empty:
        print("\nCombined Dataset for Analysis")
        print("=" * 80)
        print(f"Size of combined dataset: {len(combined_df)} records")
        print(f"Contains sentence transformer data: {'model' in combined_df.columns}")
        print(f"Contains ground truth data: {'ground_truth_traceable' in combined_df.columns}")
        
        if 'model' in combined_df.columns:
            st_models = combined_df['model'].value_counts().to_dict()
            print(f" - Sentence transformer records: {len(combined_df)} across {len(st_models)} models")
            
        # Display head of combined dataset for debugging
        print("\nSample of combined dataset:")
        print("-" * 50)
        display(combined_df.head())
        
        # Check for null values in key columns
        print("\nNull values in key columns:")
        null_cols = {
            'similarity_score': combined_df['similarity_score'].isna().sum(),
            'model': combined_df['model'].isna().sum(),
            'ground_truth_traceable': combined_df['ground_truth_traceable'].isna().sum() if 'ground_truth_traceable' in combined_df.columns else 'N/A'
        }
        for col, count in null_cols.items():
            print(f"  - {col}: {count}")
            
        # Display min/max similarity scores and data types
        if 'similarity_score' in combined_df.columns:
            print(f"\nSimilarity score range: {combined_df['similarity_score'].min():.4f} to {combined_df['similarity_score'].max():.4f}")
            
        # Display data types
        print("\nData types of columns:")
        print(combined_df.dtypes)

# Display information about the combined dataset
display_combined_dataset_info(combined_df)

In [None]:
# Cell [5] - Model evaluation and threshold optimization
# Purpose: Evaluate sentence transformer models and find optimal thresholds
# Dependencies: pandas, numpy, sklearn.metrics
# Breadcrumbs: Analysis -> Evaluation -> Threshold Optimization

def evaluate_model_thresholds(df, model_name, score_column='similarity_score', 
                             ground_truth_column='ground_truth_traceable', 
                             optimize_for='F2'):
    """
    Evaluate a model's performance across different thresholds
    
    Parameters:
        df: DataFrame containing model predictions and ground truth
        model_name: Name of the model to evaluate
        score_column: Column containing similarity scores
        ground_truth_column: Column containing ground truth values
        optimize_for: Metric to optimize for ('F1' or 'F2')
    
    Returns:
        dict: Dictionary containing evaluation results
    """
    try:
        # Filter data for this model
        model_df = df[df['model'] == model_name].copy()
        
        if model_df.empty:
            logger.warning(f"No data available for model: {model_name}")
            return {}
            
        if ground_truth_column not in model_df.columns:
            logger.warning(f"Ground truth column '{ground_truth_column}' not found for model: {model_name}")
            return {}
        
        # Get ground truth and scores
        y_true = model_df[ground_truth_column].astype(int).values
        
        # Check for and handle None/NaN values in similarity scores
        if model_df[score_column].isna().any():
            logger.warning(f"Found NaN values in {score_column} for model {model_name}. Filling with 0.")
            model_df[score_column] = model_df[score_column].fillna(0)
        
        # Ensure similarity scores are numeric
        if model_df[score_column].dtype == object:
            try:
                model_df[score_column] = pd.to_numeric(model_df[score_column])
                logger.info(f"Converted {score_column} to numeric for model {model_name}")
            except Exception as e:
                logger.error(f"Error converting {score_column} to numeric: {str(e)}")
                # Default to zeros if conversion fails
                model_df[score_column] = 0
        
        scores = model_df[score_column].values
        
        # Debug information
        print(f"  - Data points: {len(model_df)}")
        print(f"  - Positive examples: {y_true.sum()} ({y_true.sum()/len(y_true)*100:.2f}%)")
        print(f"  - Negative examples: {len(y_true) - y_true.sum()} ({(len(y_true) - y_true.sum())/len(y_true)*100:.2f}%)")
        print(f"  - Score range: {scores.min():.4f} to {scores.max():.4f}")
        
        # If all ground truth values are the same, we can't calculate meaningful metrics
        if len(np.unique(y_true)) < 2:
            logger.warning(f"Insufficient ground truth variety for model {model_name} - all values are {np.unique(y_true)[0]}")
            return {
                'model_name': model_name,
                'data_points': len(model_df),
                'ground_truth_positive': int(y_true.sum()),
                'ground_truth_negative': int(len(y_true) - y_true.sum())
            }
        
        # Calculate precision-recall curve
        precision, recall, thresholds = precision_recall_curve(y_true, scores)
        
        # Add a threshold of 1.0 to the end for completeness
        thresholds = np.append(thresholds, 1.0)
        
        # Calculate metrics for each threshold
        results = []
        
        for i, threshold in enumerate(thresholds):
            # Convert scores to binary predictions using this threshold
            y_pred = (scores >= threshold).astype(int)
            
            # Confusion matrix components
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
            
            # Basic metrics
            accuracy = accuracy_score(y_true, y_pred)
            balanced_acc = balanced_accuracy_score(y_true, y_pred)
            prec = precision[min(i, len(precision)-1)]
            rec = recall[min(i, len(recall)-1)]
            f1 = f1_score(y_true, y_pred, zero_division=0)
            f2 = fbeta_score(y_true, y_pred, beta=2, zero_division=0)
            
            # Additional metrics
            tnr = tn / (tn + fp) if (tn + fp) > 0 else 0  # Specificity/True Negative Rate
            fnr = fn / (fn + tp) if (fn + tp) > 0 else 0  # Miss Rate/False Negative Rate
            mcc = matthews_corrcoef(y_true, y_pred)  # Matthews Correlation Coefficient
            
            results.append({
                'threshold': threshold,
                'tp': tp,
                'fp': fp,
                'fn': fn,
                'tn': tn,
                'accuracy': accuracy,
                'balanced_accuracy': balanced_acc,
                'precision': prec,
                'recall': rec,
                'tnr': tnr,  # specificity
                'fnr': fnr,  # miss rate
                'f1_score': f1,
                'f2_score': f2,
                'mcc': mcc  # Matthews Correlation Coefficient
            })
        
        # Convert to DataFrame for easier analysis
        results_df = pd.DataFrame(results)
        
        # Find best threshold based on optimization metric
        if optimize_for == 'F1':
            best_idx = results_df['f1_score'].idxmax()
            best_metric = 'f1_score'
        else:  # F2
            best_idx = results_df['f2_score'].idxmax()
            best_metric = 'f2_score'
            
        best_result = results_df.loc[best_idx]
        
        # Return comprehensive results
        return {
            'model_name': model_name,
            'data_points': len(model_df),
            'ground_truth_positive': int(y_true.sum()),
            'ground_truth_negative': int(len(y_true) - y_true.sum()),
            'best_threshold': best_result['threshold'],
            'best_precision': best_result['precision'],
            'best_recall': best_result['recall'],
            'best_accuracy': best_result['accuracy'],
            'best_balanced_accuracy': best_result['balanced_accuracy'],
            'best_f1': best_result['f1_score'],
            'best_f2': best_result['f2_score'],
            'best_tnr': best_result['tnr'],
            'best_fnr': best_result['fnr'],
            'best_mcc': best_result['mcc'],
            'best_tp': best_result['tp'],
            'best_fp': best_result['fp'],
            'best_fn': best_result['fn'],
            'best_tn': best_result['tn'],
            'optimization_metric': optimize_for,
            'threshold_results': results_df
        }
    except Exception as e:
        logger.error(f"Error evaluating model {model_name}: {str(e)}")
        logger.error("Exception details:", exc_info=True)
        return {
            'model_name': model_name,
            'data_points': len(model_df) if 'model_df' in locals() else 0,
            'error': str(e)
        }

def evaluate_all_models(combined_df, optimization_metric='F2'):
    """
    Evaluate all models in the combined dataset
    
    Parameters:
        combined_df: DataFrame containing model predictions and ground truth
        optimization_metric: Metric to optimize thresholds for ('F1' or 'F2')
    
    Returns:
        tuple: (evaluation_results, best_thresholds_df)
            - evaluation_results: List of dictionaries with evaluation results
            - best_thresholds_df: DataFrame with best thresholds for each model
    """
    # Check if we have the necessary data
    if 'ground_truth_traceable' not in combined_df.columns or 'model' not in combined_df.columns:
        logger.error("Cannot evaluate models: missing ground truth or model data")
        return [], pd.DataFrame()
    
    # Get list of all models
    all_models = combined_df['model'].unique()
    
    # Evaluate each model
    evaluation_results = []
    
    print(f"\nEvaluating {len(all_models)} sentence transformer models")
    print(f"Optimizing for {optimization_metric} score")
    print("=" * 80)
    
    for model in all_models:
        print(f"Evaluating model: {model}")
        result = evaluate_model_thresholds(combined_df, model, optimize_for=optimization_metric)
        
        if result:
            evaluation_results.append(result)
            print(f"  - Data points: {result['data_points']}")
            
            if 'ground_truth_positive' in result:
                print(f"  - Ground truth positive: {result['ground_truth_positive']} ({result['ground_truth_positive']/result['data_points']*100:.2f}%)")
            
            if 'error' in result:
                print(f"  - Error: {result['error']}")
                
            if 'best_threshold' in result:
                print(f"  - Best threshold: {result['best_threshold']:.3f}")
                print(f"  - Confusion Matrix (TP, FP, FN, TN): {result['best_tp']}, {result['best_fp']}, {result['best_fn']}, {result['best_tn']}")
                print(f"  - Accuracy: {result['best_accuracy']:.3f}")
                print(f"  - Balanced Accuracy: {result['best_balanced_accuracy']:.3f}")
                print(f"  - Precision: {result['best_precision']:.3f}")
                print(f"  - Recall/TPR: {result['best_recall']:.3f}")
                print(f"  - Specificity/TNR: {result['best_tnr']:.3f}")
                print(f"  - Miss Rate/FNR: {result['best_fnr']:.3f}")
                print(f"  - F1: {result['best_f1']:.3f}")
                print(f"  - F2: {result['best_f2']:.3f}")
                print(f"  - Matthews Correlation Coefficient: {result['best_mcc']:.3f}")
    
    # Create DataFrame of best thresholds with all metrics
    if evaluation_results:
        best_thresholds_df = pd.DataFrame([
            {
                'model_name': r['model_name'],
                'best_threshold': r['best_threshold'] if 'best_threshold' in r else np.nan,
                'accuracy': r['best_accuracy'] if 'best_accuracy' in r else np.nan,
                'balanced_accuracy': r['best_balanced_accuracy'] if 'best_balanced_accuracy' in r else np.nan,
                'precision': r['best_precision'] if 'best_precision' in r else np.nan,
                'recall': r['best_recall'] if 'best_recall' in r else np.nan,
                'specificity': r['best_tnr'] if 'best_tnr' in r else np.nan,
                'miss_rate': r['best_fnr'] if 'best_fnr' in r else np.nan,
                'f1_score': r['best_f1'] if 'best_f1' in r else np.nan,
                'f2_score': r['best_f2'] if 'best_f2' in r else np.nan,
                'matthews_corr': r['best_mcc'] if 'best_mcc' in r else np.nan,
                'true_positives': r['best_tp'] if 'best_tp' in r else np.nan,
                'false_positives': r['best_fp'] if 'best_fp' in r else np.nan,
                'false_negatives': r['best_fn'] if 'best_fn' in r else np.nan,
                'true_negatives': r['best_tn'] if 'best_tn' in r else np.nan,
                'data_points': r['data_points'],
                'ground_truth_positive': r['ground_truth_positive'] if 'ground_truth_positive' in r else 0,
                'ground_truth_negative': r['ground_truth_negative'] if 'ground_truth_negative' in r else 0,
            }
            for r in evaluation_results if 'best_threshold' in r
        ])
        
        # Sort by the appropriate metric
        sort_col = 'f1_score' if optimization_metric == 'F1' else 'f2_score'
        best_thresholds_df = best_thresholds_df.sort_values(sort_col, ascending=False).reset_index(drop=True)
        
        print("\nBest Thresholds by Model:")
        print("-" * 80)
        display(best_thresholds_df)
        
        return evaluation_results, best_thresholds_df
    else:
        return [], pd.DataFrame()

# Run model evaluation if this cell is executed directly
if 'combined_df' in globals() and not combined_df.empty:
    evaluation_results, best_thresholds_df = evaluate_all_models(combined_df, OPTIMIZATION_METRIC)
    
    # Create visualizations if enabled
    if SHOW_VISUALIZATION and not best_thresholds_df.empty:
        # Define color palette for consistency
        color_palette = {
            'TP': '#1A85FF',  # Good - Blue
            'FP': '#FFC61A',  # Okay - Yellow/Gold
            'FN': '#D41159',  # Not Great - Magenta/Red
            'TN': '#CCCCCC'   # Neutral - Light Gray
        }
        
        # 1. Create confusion matrix for each model
        print("\nConfusion Matrices:")
        print("-" * 80)
        
        # Set up a grid of subplots for confusion matrices - MODIFIED FOR 5x2 LAYOUT
        n_models = len(best_thresholds_df)
        n_cols = min(5, n_models)  # Maximum 5 columns
        n_rows = min(2, (n_models + n_cols - 1) // n_cols)  # Maximum 2 rows
        
        # More compact figure size - reduced height
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.5, n_rows * 3))
        
        # Flatten axes array if it's multi-dimensional
        if n_models > 1:
            axes = axes.flatten()
        else:
            axes = [axes]
        
        for i, (idx, row) in enumerate(best_thresholds_df.iterrows()):
            # Skip if we have more models than subplot positions
            if i >= len(axes):
                break
                
            # Extract confusion matrix components
            model_cm = np.array([
                [row['true_negatives'], row['false_positives']],
                [row['false_negatives'], row['true_positives']]
            ])
            
            # Create normalized confusion matrix (percentages)
            row_sums = model_cm.sum(axis=1, keepdims=True)
            norm_cm = model_cm / row_sums * 100 if row_sums.min() > 0 else model_cm * 0
            
            # Create confusion matrix plot with custom colors - removed colorbar
            ax = axes[i]
            sns.heatmap(model_cm, annot=True, fmt='g', cmap='Blues', 
                        xticklabels=['Negative', 'Positive'],
                        yticklabels=['Negative', 'Positive'], ax=ax,
                        cbar=False)  # Removed colorbar
            
            # Set title and labels
            model_short_name = row['model_name'].split('/')[-1] if '/' in row['model_name'] else row['model_name']
            ax.set_title(f"Project: {NEO4J_PROJECT_NAME}\n{model_short_name}", fontsize=9)
            ax.set_xlabel('Predicted', fontsize=8)
            ax.set_ylabel('Actual', fontsize=8)
        
        # Hide unused subplots
        for i in range(n_models, len(axes)):
            axes[i].axis('off')
            
        plt.tight_layout()
        plt.show()
        
        # 2. Create comparative heatmap of all metrics
        print("\nMetrics Comparison Heatmap:")
        print("-" * 80)
        
        # Select relevant metrics for comparison
        metrics_to_compare = ['accuracy', 'balanced_accuracy', 'precision', 'recall', 
                              'specificity', 'f1_score', 'f2_score', 'matthews_corr']
        
        # Create a new dataframe with model names as index and metrics as columns
        comparison_df = best_thresholds_df.set_index('model_name')[metrics_to_compare]
        
        # Use short names for models for better display
        comparison_df.index = [name.split('/')[-1] if '/' in name else name 
                              for name in comparison_df.index]
        
        # Create heatmap
        plt.figure(figsize=(12, 8))
        sns.heatmap(comparison_df, annot=True, cmap='YlGnBu', fmt='.3f',
                   linewidths=0.5, cbar_kws={'label': 'Score'})
        plt.title(f'Project: {NEO4J_PROJECT_NAME} - Comparison of Models Across Metrics', fontsize=14)
        plt.tight_layout()
        plt.show()
        
        # 3. Create grouped bar chart for key metrics
        print("\nKey Metrics Comparison Chart:")
        print("-" * 80)
        
        # Select key metrics for bar chart
        key_metrics = ['precision', 'recall', 'f1_score', 'f2_score', 'matthews_corr']
        
        # Create plot with rotated x-labels for better readability
        fig, ax = plt.subplots(figsize=(14, 8))
        
        x = np.arange(len(comparison_df.index))
        width = 0.15
        offsets = np.linspace(0, width*4, 5) - width*2
        
        # Custom colors for metrics
        metric_colors = ['#1A85FF', '#FFC61A', '#00C9A7', '#845EC2', '#D41159']
        
        for i, metric in enumerate(key_metrics):
            ax.bar(x + offsets[i], comparison_df[metric], width, 
                   label=metric.replace('_', ' ').title(), color=metric_colors[i])
        
        ax.set_xticks(x)
        ax.set_xticklabels(comparison_df.index, rotation=45, ha='right')
        ax.set_ylabel('Score')
        ax.set_title(f'Project: {NEO4J_PROJECT_NAME} - Comparison of Key Metrics Across Models', fontsize=14)
        ax.legend()
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add project name in the top right corner
        plt.figtext(0.95, 0.95, f"Project: {NEO4J_PROJECT_NAME}", 
                  horizontalalignment='right', fontsize=10, 
                  bbox=dict(facecolor='white', alpha=0.8, edgecolor='black'))
        
        plt.tight_layout()
        plt.show()
else:
    print("\nCannot evaluate model thresholds: missing ground truth or model data")

In [None]:
# Cell [6] - Comprehensive traceability analysis with source-target matrix visualization
# Purpose: Analyze traceability prediction using sentence transformers and visualize results
# Dependencies: pandas, numpy, seaborn, matplotlib
# Breadcrumbs: Analysis -> Traceability Evaluation -> Visualization

def apply_thresholds_and_evaluate(combined_df, model_thresholds=None, default_threshold=None):
    """
    Apply similarity thresholds and evaluate traceability predictions
    
    Parameters:
        combined_df: DataFrame with similarity scores and ground truth
        model_thresholds: Dictionary mapping model names to thresholds (REQUIRED)
        default_threshold: Not used - kept for compatibility but will raise error if model_thresholds is None
    
    Returns:
        DataFrame: Copy of input DataFrame with added prediction and evaluation columns
    
    Raises:
        ValueError: If model_thresholds is None (no fallback to hard-coded values)
    """
    # Create a copy of the combined DataFrame for evaluation
    combined_traced_eval_df = combined_df.copy()
    
    try:
        # STRICT: Require model thresholds to be provided (no hard-coded fallbacks)
        if model_thresholds is None:
            error_msg = (
                "model_thresholds parameter is required and cannot be None. "
                "Please run Cell [5] - Model evaluation and threshold optimization first to calculate optimal thresholds. "
                "This ensures we use F2-optimized thresholds rather than potentially outdated hard-coded values."
            )
            logger.error(error_msg)
            raise ValueError(error_msg)
        else:
            logger.info(f"Using provided model thresholds (calculated from threshold optimization)")
        
        # Make sure there are no None values in similarity scores
        if 'similarity_score' in combined_traced_eval_df.columns:
            null_count = combined_traced_eval_df['similarity_score'].isna().sum()
            if null_count > 0:
                logger.warning(f"Found {null_count} null values in similarity_score column before applying thresholds. Replacing with 0.")
                combined_traced_eval_df['similarity_score'] = combined_traced_eval_df['similarity_score'].fillna(0)
            
            # Ensure similarity_score is numeric
            if combined_traced_eval_df['similarity_score'].dtype == object:
                try:
                    combined_traced_eval_df['similarity_score'] = pd.to_numeric(combined_traced_eval_df['similarity_score'])
                    logger.info("Converted similarity_score to numeric type for threshold application")
                except Exception as e:
                    logger.error(f"Error converting similarity_score to numeric: {str(e)}")
                    logger.error("Using zeros for similarity_score")
                    combined_traced_eval_df['similarity_score'] = 0
        
        # Apply thresholds to determine if a pair is traceable based on similarity score
        def is_traceable(row):
            # Check which column has the model information
            if 'sentence_transformer_model' in row:
                model = row['sentence_transformer_model']
            elif 'model' in row:
                model = row['model']
            else:
                error_msg = f"No model column found for row and no default threshold available (strict mode)"
                logger.error(error_msg)
                raise ValueError(error_msg)
            
            # Check for None values in similarity_score
            similarity = row['similarity_score']
            if similarity is None:
                logger.warning(f"Found None value in similarity_score for {model}, using 0")
                similarity = 0
            
            # Get threshold for this model (STRICT: must be in model_thresholds)
            if model in model_thresholds:
                threshold = model_thresholds[model]
            else:
                error_msg = (
                    f"No threshold found for model '{model}' in provided model_thresholds. "
                    f"Available models: {list(model_thresholds.keys())}. "
                    f"Please ensure all models have calculated thresholds from Cell [5]."
                )
                logger.error(error_msg)
                raise ValueError(error_msg)
                
            # Ensure threshold is not None
            if threshold is None:
                error_msg = f"Threshold for model '{model}' is None. This should not happen with calculated thresholds."
                logger.error(error_msg)
                raise ValueError(error_msg)
                
            return float(similarity) >= float(threshold)
        
        # Add predicted traceable column
        combined_traced_eval_df['predicted_traceable'] = combined_traced_eval_df.apply(is_traceable, axis=1)
        
        # Add column for confusion matrix category (TP, FP, FN, TN)
        def get_confusion_category(row):
            if 'ground_truth_traceable' not in row or pd.isna(row['ground_truth_traceable']):
                return 'Unknown'
            
            if row['ground_truth_traceable'] and row['predicted_traceable']:
                return 'TP'  # True Positive
            elif not row['ground_truth_traceable'] and row['predicted_traceable']:
                return 'FP'  # False Positive
            elif row['ground_truth_traceable'] and not row['predicted_traceable']:
                return 'FN'  # False Negative
            else:  # not ground_truth and not predicted
                return 'TN'  # True Negative
        
        combined_traced_eval_df['confusion_category'] = combined_traced_eval_df.apply(get_confusion_category, axis=1)
        
        return combined_traced_eval_df
    
    except Exception as e:
        logger.error(f"Error in traceability evaluation: {str(e)}")
        logger.error("Exception details:", exc_info=True)
        print(f"Error in traceability evaluation: {str(e)}")
        return combined_traced_eval_df

def display_evaluation_summary(combined_traced_eval_df, optimization_metric='F2'):
    """
    Display summary statistics for traceability evaluation
    
    Parameters:
        combined_traced_eval_df: DataFrame with traceability predictions and evaluations
        optimization_metric: Metric used for optimization (for display purposes)
    """
    # Display summary of ground truth values
    print(f"\nGround Truth Distribution ({optimization_metric} optimized):")
    print("-" * 40)
    gt_counts = combined_traced_eval_df['ground_truth_traceable'].value_counts()
    print(gt_counts)
    print(f"Percentage traceable: {gt_counts.get(True, 0) / len(combined_traced_eval_df) * 100:.2f}%")
    
    # Display summary of predictions
    print("\nPrediction Distribution:")
    print("-" * 40)
    pred_counts = combined_traced_eval_df['predicted_traceable'].value_counts()
    print(pred_counts)
    print(f"Percentage predicted traceable: {pred_counts.get(True, 0) / len(combined_traced_eval_df) * 100:.2f}%")
    
    # Display confusion category distribution
    print("\nConfusion Matrix Categories:")
    print("-" * 40)
    confusion_counts = combined_traced_eval_df['confusion_category'].value_counts()
    print(confusion_counts)
    for category, count in confusion_counts.items():
        print(f"{category}: {count} ({count/len(combined_traced_eval_df)*100:.2f}%)")

def create_source_target_matrix_visualization(combined_traced_eval_df, model_thresholds=None, 
                                             project_name=None, show_visualization=True):
    """
    Create source-target matrix visualizations for traceability analysis
    
    Parameters:
        combined_traced_eval_df: DataFrame with traceability predictions and evaluations
        model_thresholds: Dictionary mapping model names to thresholds
        project_name: Name of the project for visualization titles
        show_visualization: Whether to display visualizations
        
    Returns:
        dict: Dictionary containing visualization data for each model
    """
    if not show_visualization:
        return {}
    
    try:
        # Use global project name if not provided
        if project_name is None:
            project_name = globals().get('NEO4J_PROJECT_NAME', 'Unknown Project')
            
        # STRICT: Require model thresholds to be provided (no hard-coded fallbacks in visualization)
        if model_thresholds is None and 'best_thresholds_df' in globals() and not globals()['best_thresholds_df'].empty:
            # Create a dictionary mapping model names to optimal thresholds
            model_thresholds = dict(zip(
                globals()['best_thresholds_df']['model_name'],
                globals()['best_thresholds_df']['best_threshold']
            ))
        elif model_thresholds is None:
            # STRICT: No hard-coded fallback thresholds for visualization
            error_msg = (
                "No model thresholds available for visualization. "
                "Please run Cell [5] - Model evaluation and threshold optimization first to calculate "
                "optimal thresholds, or provide model_thresholds parameter. "
                "This ensures visualizations use data-driven thresholds rather than hard-coded values."
            )
            logger.error(error_msg)
            raise ValueError(error_msg)
        
        print("\nCreating source-target requirement matrix visualization...")
        
        # Get list of all models
        all_models = combined_traced_eval_df['model'].unique()
        
        # Define intuitive color palette for TP, FP, FN, TN
        # Using color scheme as requested: TP (good), FP (ok), FN (not great), TN (neutral)
        color_palette = {
            'TP': '#1A85FF',  # Good - Blue
            'FP': '#FFC61A',  # Okay - Yellow/Gold
            'FN': '#D41159',  # Not Great - Magenta/Red
            'TN': '#CCCCCC',  # Neutral - Light Gray
            'Unknown': '#FFFFFF'  # White for unknown
        }
        
        # Import needed matplotlib components
        from matplotlib.colors import LinearSegmentedColormap
        from matplotlib.patches import Patch
        
        # Store visualization data for each model
        viz_data = {}
        
        # For each model, create a matrix visualization
        for model_name in all_models:
            model_df = combined_traced_eval_df[combined_traced_eval_df['model'] == model_name].copy()
            
            # Get unique source and target requirements
            source_reqs = sorted(model_df['source_id'].unique())
            target_reqs = sorted(model_df['target_id'].unique())
            
            # Check if the matrix would be too large
            is_large_matrix = len(source_reqs) > 50 or len(target_reqs) > 50
            if is_large_matrix:
                print(f"Warning: Matrix for {model_name} would be too large ({len(source_reqs)}x{len(target_reqs)}). Reducing tick frequency...")
            
            # Create a 2D matrix to hold our visualization data
            # Initialize with zeros (we'll use different numbers for each category)
            matrix = np.zeros((len(source_reqs), len(target_reqs)))
            
            # Define numeric values for each category
            category_to_value = {
                'TP': 0,
                'FP': 1,
                'FN': 2,
                'TN': 3,
                'Unknown': 4
            }
            
            # Fill the matrix with the appropriate values
            for _, row in model_df.iterrows():
                source_idx = source_reqs.index(row['source_id'])
                target_idx = target_reqs.index(row['target_id'])
                category = row['confusion_category']
                matrix[source_idx, target_idx] = category_to_value.get(category, 4)
            
            # Create explicit color map
            colors = [
                color_palette['TP'],     # 0 = TP (Blue)
                color_palette['FP'],     # 1 = FP (Yellow)
                color_palette['FN'],     # 2 = FN (Red)
                color_palette['TN'],     # 3 = TN (Gray)
                color_palette['Unknown'] # 4 = Unknown (White)
            ]
            
            # Create a colormap with exactly 5 colors
            cmap = LinearSegmentedColormap.from_list('confusion_cmap', colors, N=5)
            
            # Determine figure size based on matrix dimensions
            fig_width = min(20, max(10, len(target_reqs) * 0.15))
            fig_height = min(16, max(8, len(source_reqs) * 0.15))
            
            # Create the figure
            fig, ax = plt.subplots(figsize=(fig_width, fig_height))
            
            # Plot the matrix
            im = ax.imshow(matrix, cmap=cmap, vmin=0, vmax=4, aspect='auto')
            
            # Adjust tick labels based on matrix size
            if is_large_matrix:
                # Add tick marks at regular intervals
                x_step = max(1, len(target_reqs) // 10)
                y_step = max(1, len(source_reqs) // 10)
                
                x_ticks = np.arange(0, len(target_reqs), x_step)
                y_ticks = np.arange(0, len(source_reqs), y_step)
                
                ax.set_xticks(x_ticks)
                ax.set_yticks(y_ticks)
                
                ax.set_xticklabels([target_reqs[i] for i in x_ticks], rotation=90)
                ax.set_yticklabels([source_reqs[i] for i in y_ticks])
            else:
                # Show all ticks for small matrices
                ax.set_xticks(np.arange(len(target_reqs)))
                ax.set_yticks(np.arange(len(source_reqs)))
                
                ax.set_xticklabels(target_reqs, rotation=90)
                ax.set_yticklabels(source_reqs)
            
            # Add gridlines to make it easier to read
            ax.set_xticks(np.arange(-.5, len(target_reqs), 1), minor=True)
            ax.set_yticks(np.arange(-.5, len(source_reqs), 1), minor=True)
            
            # Hide minor ticks but show the grid
            ax.tick_params(which='minor', length=0)
            
            # Turn off the frame
            for spine in ax.spines.values():
                spine.set_visible(False)
            
            # Create custom legend
            legend_elements = [
                Patch(facecolor=color_palette['TP'], label='True Positive'),
                Patch(facecolor=color_palette['FP'], label='False Positive'),
                Patch(facecolor=color_palette['FN'], label='False Negative'),
                Patch(facecolor=color_palette['TN'], label='True Negative')
            ]
            ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.15, 1), 
                     fontsize=12, frameon=True, edgecolor='black')
            
            # Set titles and labels
            model_short_name = model_name.split('/')[-1] if '/' in model_name else model_name
            # STRICT: model_name must exist in model_thresholds (no fallback)
            if model_name not in model_thresholds:
                error_msg = f"Model '{model_name}' not found in model_thresholds for visualization title. Available: {list(model_thresholds.keys())}"
                logger.error(error_msg)
                raise ValueError(error_msg)
            plt.title(f'Project: {project_name} - {model_short_name}\nThreshold: {model_thresholds[model_name]:.3f}',
                     fontsize=14, fontweight='bold', pad=20)
            plt.xlabel('Target Requirements', fontsize=12, labelpad=10)
            plt.ylabel('Source Requirements', fontsize=12, labelpad=10)
            
            # Count occurrences of each category
            tp_mask = matrix == category_to_value['TP']
            fp_mask = matrix == category_to_value['FP']
            fn_mask = matrix == category_to_value['FN']
            tn_mask = matrix == category_to_value['TN']
            
            tp_count = np.sum(tp_mask)
            fp_count = np.sum(fp_mask)
            fn_count = np.sum(fn_mask)
            tn_count = np.sum(tn_mask)
            
            # Debug output to verify counts
            print(f"Category counts for {model_short_name}:")
            print(f"TP: {int(tp_count)}, FP: {int(fp_count)}, FN: {int(fn_count)}, TN: {int(tn_count)}")
            
            total = tp_count + fp_count + fn_count + tn_count
            if total > 0:
                accuracy = (tp_count + tn_count) / total
                precision = tp_count / (tp_count + fp_count) if (tp_count + fp_count) > 0 else 0
                recall = tp_count / (tp_count + fn_count) if (tp_count + fn_count) > 0 else 0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
                f2 = 5 * precision * recall / (4 * precision + recall) if (4 * precision + recall) > 0 else 0
                
                stats_text = (f"TP: {int(tp_count)} ({tp_count/total*100:.1f}%) | FP: {int(fp_count)} ({fp_count/total*100:.1f}%)\n"
                             f"FN: {int(fn_count)} ({fn_count/total*100:.1f}%) | TN: {int(tn_count)} ({tn_count/total*100:.1f}%)\n"
                             f"Acc: {accuracy:.3f} | Prec: {precision:.3f} | Rec: {recall:.3f}\n"
                             f"F1: {f1:.3f} | F2: {f2:.3f}")
                
                # Add the text box at the bottom right of the figure
                plt.figtext(0.95, 0.02, stats_text, horizontalalignment='right', 
                          bbox=dict(facecolor='white', alpha=0.8, edgecolor='black'), 
                          fontsize=10)
                
                # Print summary statistics for this model
                print(f"\nConfusion Matrix Statistics for {model_short_name} (Project: {project_name}):")
                print(f"True Positives: {int(tp_count)} ({tp_count/total*100:.2f}%)")
                print(f"False Positives: {int(fp_count)} ({fp_count/total*100:.2f}%)")
                print(f"False Negatives: {int(fn_count)} ({fn_count/total*100:.2f}%)")
                print(f"True Negatives: {int(tn_count)} ({tn_count/total*100:.2f}%)")
                
                print(f"Accuracy: {accuracy:.3f}")
                print(f"Precision: {precision:.3f}")
                print(f"Recall: {recall:.3f}")
                print(f"F1 Score: {f1:.3f}")
                print(f"F2 Score: {f2:.3f}")
                
                # Store visualization data
                viz_data[model_name] = {
                    'matrix': matrix,
                    'source_reqs': source_reqs,
                    'target_reqs': target_reqs,
                    'metrics': {
                        'tp': int(tp_count),
                        'fp': int(fp_count),
                        'fn': int(fn_count),
                        'tn': int(tn_count),
                        'accuracy': accuracy,
                        'precision': precision,
                        'recall': recall,
                        'f1': f1,
                        'f2': f2
                    }
                }
            
            plt.tight_layout()
            if show_visualization:
                plt.show()
            else:
                plt.close(fig)
                
        return viz_data
            
    except Exception as e:
        logger.error(f"Error creating visualizations: {str(e)}")
        logger.error("Exception details:", exc_info=True)
        print(f"Error creating visualizations: {str(e)}")
        return {}

# Get model thresholds from best_thresholds_df if available
if 'best_thresholds_df' in globals() and not best_thresholds_df.empty:
    # Create a dictionary mapping model names to optimal thresholds
    model_thresholds = dict(zip(
        best_thresholds_df['model_name'],
        best_thresholds_df['best_threshold']
    ))
    logger.info(f"Using optimal thresholds from model evaluation (optimized for {OPTIMIZATION_METRIC})")
else:
    # STRICT: No hard-coded fallback thresholds
    error_msg = (
        "No model evaluation results available (best_thresholds_df is empty or not found). "
        "Please run Cell [5] - Model evaluation and threshold optimization first to calculate "
        f"{OPTIMIZATION_METRIC}-optimized thresholds for all models. This ensures we use data-driven "
        "thresholds rather than potentially outdated hard-coded values."
    )
    logger.error(error_msg)
    raise ValueError(error_msg)

# Apply thresholds and evaluate traceability
combined_traced_eval_df = apply_thresholds_and_evaluate(combined_df, model_thresholds)

# Display evaluation summary
display_evaluation_summary(combined_traced_eval_df, OPTIMIZATION_METRIC)

# Create visualizations if enabled
if SHOW_VISUALIZATION:
    viz_data = create_source_target_matrix_visualization(
        combined_traced_eval_df, 
        model_thresholds,
        NEO4J_PROJECT_NAME,
        SHOW_VISUALIZATION
    )

In [None]:
# Cell [7] - Analysis Summary and Conclusions
# Purpose: Summarize findings and draw conclusions about traceability prediction
# Dependencies: pandas
# Breadcrumbs: Analysis -> Summary

def create_analysis_summary(best_thresholds_df=None, optimization_metric=None, project_name=None):
    """
    Create a summary of model evaluation and traceability analysis
    
    Parameters:
        best_thresholds_df: DataFrame containing model evaluation results
        optimization_metric: Metric used for optimization ('F1' or 'F2')
        project_name: Name of the project being analyzed
        
    Returns:
        dict: Dictionary containing summary information
    """
    # Use global variables if parameters not provided
    _best_thresholds_df = best_thresholds_df if best_thresholds_df is not None else globals().get('best_thresholds_df')
    _optimization_metric = optimization_metric if optimization_metric is not None else globals().get('OPTIMIZATION_METRIC', 'F2')
    _project_name = project_name if project_name is not None else globals().get('NEO4J_PROJECT_NAME', 'Unknown Project')
    
    # Create empty result dictionary
    summary = {
        'project_name': _project_name,
        'optimization_metric': _optimization_metric,
        'has_model_evaluation': False,
        'models_evaluated': 0,
        'best_model': None,
        'best_model_threshold': None,
        'best_model_metrics': {},
        'model_family_comparison': {},
        'recommendations': []
    }
    
    # Check if we have model evaluation results available
    if _best_thresholds_df is not None and not _best_thresholds_df.empty:
        summary['has_model_evaluation'] = True
        summary['models_evaluated'] = len(_best_thresholds_df)
        
        # Identify the best performing model
        if _optimization_metric == 'F2':
            best_model_idx = _best_thresholds_df['f2_score'].idxmax()
            sort_metric = 'f2_score'
        else:
            best_model_idx = _best_thresholds_df['f1_score'].idxmax()
            sort_metric = 'f1_score'
            
        best_model = _best_thresholds_df.loc[best_model_idx]
        summary['best_model'] = best_model['model_name']
        summary['best_model_threshold'] = float(best_model['best_threshold'])
        summary['sort_metric'] = sort_metric
        
        # Store best model metrics
        summary['best_model_metrics'] = {
            'threshold': float(best_model['best_threshold']),
            'precision': float(best_model['precision']),
            'recall': float(best_model['recall']),
            'f1_score': float(best_model['f1_score']),
            'f2_score': float(best_model['f2_score']),
            'matthews_corr': float(best_model['matthews_corr']),
            'balanced_accuracy': float(best_model['balanced_accuracy']),
            'confusion_matrix': {
                'tp': int(best_model['true_positives']),
                'fp': int(best_model['false_positives']),
                'fn': int(best_model['false_negatives']),
                'tn': int(best_model['true_negatives'])
            }
        }
        
        # Model family analysis
        mpnet_models = _best_thresholds_df[_best_thresholds_df['model_name'].str.contains('mpnet', case=False, na=False)]
        minilm_models = _best_thresholds_df[_best_thresholds_df['model_name'].str.contains('MiniLM', na=False)]
        bert_models = _best_thresholds_df[_best_thresholds_df['model_name'].str.contains('bert', case=False, na=False)]
        
        model_families = {}
        if not mpnet_models.empty:
            model_families['MPNet'] = float(mpnet_models[sort_metric].mean())
        if not minilm_models.empty:
            model_families['MiniLM'] = float(minilm_models[sort_metric].mean())
        if not bert_models.empty:
            model_families['BERT'] = float(bert_models[sort_metric].mean())
            
        summary['model_family_comparison'] = model_families
        
        if model_families:
            best_family = max(model_families.items(), key=lambda x: x[1])
            summary['best_model_family'] = best_family[0]
            summary['best_model_family_score'] = best_family[1]
        
        # Threshold analysis
        threshold_corr = _best_thresholds_df[['best_threshold', sort_metric]].corr().iloc[0, 1]
        summary['threshold_correlation'] = float(threshold_corr)
        
        # Precision-recall tradeoff analysis
        prec_recall_corr = _best_thresholds_df[['precision', 'recall']].corr().iloc[0, 1]
        summary['precision_recall_correlation'] = float(prec_recall_corr)
        
        # Average model performance
        summary['average_performance'] = {
            'precision': float(_best_thresholds_df['precision'].mean()),
            'precision_std': float(_best_thresholds_df['precision'].std()),
            'recall': float(_best_thresholds_df['recall'].mean()),
            'recall_std': float(_best_thresholds_df['recall'].std()),
            'f1_score': float(_best_thresholds_df['f1_score'].mean()),
            'f1_score_std': float(_best_thresholds_df['f1_score'].std()),
            'f2_score': float(_best_thresholds_df['f2_score'].mean()),
            'f2_score_std': float(_best_thresholds_df['f2_score'].std()),
            'matthews_corr': float(_best_thresholds_df['matthews_corr'].mean()),
            'matthews_corr_std': float(_best_thresholds_df['matthews_corr'].std())
        }
        
        # Top 3 models
        top3 = _best_thresholds_df.sort_values(sort_metric, ascending=False).head(3)
        summary['top_models'] = [
            {
                'model_name': row['model_name'],
                'score': float(row[sort_metric]),
                'threshold': float(row['best_threshold'])
            }
            for _, row in top3.iterrows()
        ]
        
        # Recommendations
        recommendations = [
            f"Use {best_model['model_name']} for requirement traceability prediction",
            f"Apply a similarity threshold of {best_model['best_threshold']:.3f}",
            f"Expected performance: {best_model[sort_metric]:.3f} {sort_metric.replace('_', ' ').title()}"
        ]
        
        # Add precision-recall guidance if needed
        if abs(best_model['precision'] - best_model['recall']) > 0.1:
            if best_model['precision'] > best_model['recall']:
                recommendations.append(f"Note that precision ({best_model['precision']:.3f}) is higher than recall ({best_model['recall']:.3f}), meaning the model misses some true links but has fewer false positives")
            else:
                recommendations.append(f"Note that recall ({best_model['recall']:.3f}) is higher than precision ({best_model['precision']:.3f}), meaning the model captures most true links but generates more false positives")
        
        summary['recommendations'] = recommendations
        
        # Limitations
        summary['limitations'] = [
            "Results are specific to the current dataset and domain",
            "Ground truth may contain imperfections that affect evaluation",
            "Similarity-based approaches alone may miss some semantic connections",
            "Threshold optimization is based on a single metric and may not be optimal for all use cases"
        ]
        
    return summary

def display_analysis_summary(summary=None, show_visualization=False):
    """
    Display summary of traceability analysis results
    
    Parameters:
        summary: Dictionary with analysis summary (from create_analysis_summary)
        show_visualization: Whether to display a visualization of top models
    """
    if summary is None:
        # Create summary if not provided
        summary = create_analysis_summary()
    
    if not summary['has_model_evaluation']:
        print("\nInsufficient evaluation data to draw conclusions.")
        print("Please run model evaluation in Cell 6 to generate performance metrics.")
        return
    
    print(f"\nTRACEABILITY ANALYSIS SUMMARY FOR PROJECT: {summary['project_name']}")
    print("=" * 80)
    print(f"Optimization metric: {summary['optimization_metric']}")
    print(f"Total models evaluated: {summary['models_evaluated']}")
    print(f"\nBest performing model: {summary['best_model']}")
    print(f"  - Threshold: {summary['best_model_metrics']['threshold']:.3f}")
    print(f"  - {summary['sort_metric'].replace('_', ' ').title()}: {summary['best_model_metrics'][summary['sort_metric']]:.3f}")
    print(f"  - Precision: {summary['best_model_metrics']['precision']:.3f}")
    print(f"  - Recall: {summary['best_model_metrics']['recall']:.3f}")
    print(f"  - F1 Score: {summary['best_model_metrics']['f1_score']:.3f}")
    print(f"  - F2 Score: {summary['best_model_metrics']['f2_score']:.3f}")
    print(f"  - Matthews Correlation: {summary['best_model_metrics']['matthews_corr']:.3f}")
    print(f"  - Balanced Accuracy: {summary['best_model_metrics']['balanced_accuracy']:.3f}")
    print(f"  - Confusion Matrix Stats:")
    print(f"    * True Positives: {summary['best_model_metrics']['confusion_matrix']['tp']}")
    print(f"    * False Positives: {summary['best_model_metrics']['confusion_matrix']['fp']}")
    print(f"    * False Negatives: {summary['best_model_metrics']['confusion_matrix']['fn']}")
    print(f"    * True Negatives: {summary['best_model_metrics']['confusion_matrix']['tn']}")
    
    # Average performance across all models
    print("\nAverage model performance:")
    avg = summary['average_performance']
    print(f"  - Precision: {avg['precision']:.3f} (±{avg['precision_std']:.3f})")
    print(f"  - Recall: {avg['recall']:.3f} (±{avg['recall_std']:.3f})")
    print(f"  - F1 Score: {avg['f1_score']:.3f} (±{avg['f1_score_std']:.3f})")
    print(f"  - F2 Score: {avg['f2_score']:.3f} (±{avg['f2_score_std']:.3f})")
    print(f"  - Matthews Correlation: {avg['matthews_corr']:.3f} (±{avg['matthews_corr_std']:.3f})")
    
    # Performance difference between best and worst models
    top_models = summary['top_models']
    if len(top_models) > 0:
        print(f"\nPerformance range ({summary['sort_metric']}):")
        print(f"  - Best: {top_models[0]['score']:.3f} ({top_models[0]['model_name']})")
        print(f"  - Worst: {top_models[-1]['score']:.3f} ({top_models[-1]['model_name']})")
        print(f"  - Difference: {top_models[0]['score'] - top_models[-1]['score']:.3f}")
    
    # Top 3 models
    print("\nTop 3 models:")
    for i, model in enumerate(top_models[:3], 1):
        print(f"  {i}. {model['model_name']}: {model['score']:.3f} {summary['sort_metric'].replace('_', ' ').title()}")
    
    print(f"\nCONCLUSIONS FOR PROJECT {summary['project_name']}:")
    
    # Draw conclusions about model types
    for family, score in summary['model_family_comparison'].items():
        print(f"- {family} models average {summary['sort_metric']}: {score:.3f}")
        
    # Identify if any model family consistently performs better
    if 'best_model_family' in summary:
        print(f"- {summary['best_model_family']} models tend to perform best overall for this dataset")
    
    # Identify optimal threshold range
    print(f"- Correlation between threshold and {summary['sort_metric']}: {summary['threshold_correlation']:.3f}")
    
    if abs(summary['threshold_correlation']) > 0.3:
        if summary['threshold_correlation'] > 0:
            print("- Higher thresholds tend to produce better results")
        else:
            print("- Lower thresholds tend to produce better results")
    else:
        print("- No strong correlation between threshold values and performance")
    
    # Analysis of precision vs recall tradeoff
    print(f"- Precision-Recall tradeoff correlation: {summary['precision_recall_correlation']:.3f}")
    
    if summary['precision_recall_correlation'] < -0.5:
        print("- Strong precision-recall tradeoff observed across models")
    elif summary['precision_recall_correlation'] < -0.3:
        print("- Moderate precision-recall tradeoff observed across models")
    else:
        print("- Minimal precision-recall tradeoff observed across models")
    
    # Final recommendations
    print(f"\nRECOMMENDATIONS FOR PROJECT {summary['project_name']}:")
    for i, rec in enumerate(summary['recommendations'], 1):
        print(f"{i}. {rec}")
    
    # Limitations
    print(f"\nLIMITATIONS FOR PROJECT {summary['project_name']}:")
    for i, lim in enumerate(summary['limitations'], 1):
        print(f"{i}. {lim}")
    
    # Create visualization if enabled
    if show_visualization and 'best_thresholds_df' in globals() and not globals()['best_thresholds_df'].empty:
        try:
            # Create a bar chart comparing top models
            plt.figure(figsize=(12, 6))
            
            # Select top 5 models (or fewer if not enough)
            top_models_df = globals()['best_thresholds_df'].head(min(5, len(globals()['best_thresholds_df'])))
            
            # Prepare data for visualization
            model_names = [m.split('/')[-1] if '/' in m else m for m in top_models_df['model_name']]
            metrics = ['precision', 'recall', 'f1_score', 'f2_score', 'matthews_corr']
            metric_labels = ['Precision', 'Recall', 'F1 Score', 'F2 Score', 'Matthews Corr']
            
            # Define colors
            colors = ['#1A85FF', '#FFC61A', '#00C9A7', '#845EC2', '#D41159']
            
            # Plot each metric as a group of bars
            x = np.arange(len(model_names))
            width = 0.15
            
            for i, metric in enumerate(metrics):
                plt.bar(x + (i - 2) * width, top_models_df[metric], width, label=metric_labels[i], color=colors[i])
            
            # Customize plot
            plt.xlabel('Model', fontsize=12)
            plt.ylabel('Score', fontsize=12)
            plt.title(f'Project: {summary["project_name"]} - Top Models Comparison', fontsize=14)
            plt.xticks(x, model_names, rotation=45, ha='right')
            plt.legend(loc='upper right')
            plt.grid(axis='y', linestyle='--', alpha=0.7)
            
            # Add optimization metric in subtitle
            plt.figtext(0.5, 0.01, f"Optimized for {summary['optimization_metric']} score", 
                       ha='center', fontsize=10, 
                       bbox=dict(facecolor='white', alpha=0.8, edgecolor='black'))
            
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            logger.warning(f"Could not create summary visualization: {str(e)}")
            print(f"Could not create summary visualization: {str(e)}")

# Check if we have model evaluation results available from cell 5
if 'best_thresholds_df' in globals() and not best_thresholds_df.empty:
    # Create and display analysis summary
    summary = create_analysis_summary()
    display_analysis_summary(summary, SHOW_VISUALIZATION)
else:
    print("\nInsufficient evaluation data to draw conclusions.")
    print("Please run model evaluation in Cell 6 to generate performance metrics.")

In [None]:
# Cell [8] - Metrics Comparison Whisker Chart with Heatmap Stats
# Purpose: Create box plots with statistical heatmap for metrics across models
# Dependencies: pandas, matplotlib, seaborn
# Breadcrumbs: Visualization -> Metrics Distribution

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

def create_metrics_whisker_plot(model_metrics_df=None, project_name=None, show_visualization=True):
    """
    Create whisker plots showing the distribution of metrics across models with a statistical heatmap
    
    Parameters:
        model_metrics_df: DataFrame containing model metrics (default: best_thresholds_df)
        project_name: Name of the project for visualization titles
        show_visualization: Whether to display visualizations
        
    Returns:
        dict: Dictionary containing plotting data
    """
    try:
        # Use global variables if parameters not provided
        _model_metrics_df = model_metrics_df if model_metrics_df is not None else globals().get('best_thresholds_df', pd.DataFrame())
        _project_name = project_name if project_name is not None else globals().get('NEO4J_PROJECT_NAME', 'Unknown Project')
        _show_visualization = show_visualization if show_visualization is not None else globals().get('SHOW_VISUALIZATION', True)
        
        if _model_metrics_df.empty:
            print("No model metrics data available. Please run model evaluation in Cell 5 first.")
            return {}
            
        # Select the metrics we want to visualize
        metrics_to_plot = [
            'accuracy', 'balanced_accuracy', 'precision', 'recall', 
            'f1_score', 'f2_score', 'matthews_corr'
        ]
        
        # Define human-readable names for the metrics
        metric_names = {
            'accuracy': 'Accuracy',
            'balanced_accuracy': 'Balanced Accuracy',
            'precision': 'Precision',
            'recall': 'Recall',
            'f1_score': 'F1 Score',
            'f2_score': 'F2 Score',
            'matthews_corr': 'Matthews Correlation'
        }
        
        # Reshape data from wide to long format for easier plotting with seaborn
        plot_data = pd.melt(
            _model_metrics_df, 
            id_vars=['model_name'], 
            value_vars=metrics_to_plot,
            var_name='metric', 
            value_name='score'
        )
        
        # Map metric names to their human-readable versions
        plot_data['metric'] = plot_data['metric'].map(metric_names)
        
        # Calculate metric statistics for the heatmap
        stats_list = ['min', '25%', 'mean', '50%', '75%', 'max', 'std']
        stats_names = ['Min', 'Q1', 'Mean', 'Median', 'Q3', 'Max', 'Std Dev']
        
        # Create stats DataFrames - one for numeric values, one for formatted strings
        stats_values = {}
        for metric in metrics_to_plot:
            metric_display_name = metric_names[metric]
            stats = _model_metrics_df[metric].describe()
            stats_values[metric_display_name] = [stats[stat] for stat in stats_list]
        
        # Create numeric DataFrame for heatmap coloring
        stats_df_numeric = pd.DataFrame(stats_values, index=stats_names)
        
        # Create a new DataFrame with object dtype for formatted strings
        formatted_data = {}
        for col in stats_df_numeric.columns:
            formatted_data[col] = []
            for idx in stats_names:
                value = stats_df_numeric.loc[idx, col]
                if idx == 'Std Dev':
                    # Special handling for std dev to use scientific notation if very small
                    if value < 0.001:
                        formatted_data[col].append(f"{value:.2e}")
                    else:
                        formatted_data[col].append(f"{value:.3f}")
                else:
                    # Format other statistics
                    formatted_data[col].append(f"{value:.3f}")
        
        # Create formatted DataFrame with object dtype
        stats_df_formatted = pd.DataFrame(formatted_data, index=stats_names)
        
        # Create a figure with two subplots (box plot above, heatmap below)
        fig = plt.figure(figsize=(14, 12))
        
        # Create grid for the plots
        gs = plt.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
        
        # Box plot subplot
        ax_box = fig.add_subplot(gs[0])
        
        # Create box plot with consistent color
        sns.boxplot(
            x='metric', 
            y='score', 
            data=plot_data, 
            ax=ax_box,
            color='khaki',  # Light yellow color for all boxes
            width=0.5
        )
        
        # Add individual data points
        sns.stripplot(
            x='metric', 
            y='score', 
            data=plot_data, 
            jitter=True, 
            color='black',
            marker='o', 
            alpha=0.5,
            size=4,
            ax=ax_box
        )
        
        # Customize the box plot
        ax_box.set_title(f'Project: {_project_name} - Distribution of Performance Metrics', fontsize=14)
        ax_box.set_xlabel('')  # Remove x-axis label as it's clear from the plot
        ax_box.set_ylabel('Score', fontsize=12)
        ax_box.grid(axis='y', linestyle='--', alpha=0.7)
        ax_box.set_ylim(0, 1.0)  # Metrics are typically in range [0, 1]
        
        # Add a horizontal line at y=0.5 as a reference
        ax_box.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
        
        # Add legend indicating data source
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor='khaki', edgecolor='black', alpha=0.7, label='Sentence Transformer/TF-IDF')]
        ax_box.legend(handles=legend_elements, loc='upper right')
        
        # Heatmap subplot
        ax_heatmap = fig.add_subplot(gs[1])
        
        # Create a mask for NaN values
        mask = np.isnan(stats_df_numeric.values)
        
        # Create the heatmap
        cmap = sns.light_palette("gold", as_cmap=True)  # Light yellow color palette to match box plots
        sns.heatmap(
            stats_df_numeric,
            annot=stats_df_formatted.values,  # Use formatted strings for annotations
            fmt="",  # No additional formatting needed as we pre-formatted the values
            cmap=cmap,
            linewidths=0.5,
            linecolor='lightgray',
            cbar=False,  # No color bar needed
            ax=ax_heatmap,
            mask=mask,
            annot_kws={"size": 10, "weight": "normal"},
            vmin=0,  # Set minimum value for color scaling
            vmax=1.0  # Set maximum value for color scaling (most metrics are 0-1)
        )
        
        # Customize heatmap appearance
        ax_heatmap.set_title('Statistical Summary', fontsize=12)
        ax_heatmap.set_xticklabels(ax_heatmap.get_xticklabels(), rotation=0, ha='center')
        
        # Custom color for std dev row - we need to identify the cells in the last row
        # and make them a light gray because they're not on the same scale
        cells = ax_heatmap.get_children()
        std_dev_row_idx = len(stats_names) - 1  # Std Dev is the last row
        
        # Filter for rectangle patches (cells) that are in the std dev row
        # This works because cells are ordered row by row from bottom to top
        for i, cell in enumerate(cells):
            if hasattr(cell, 'get_xy'):  # Check if it's a patch with coordinates
                # Get row index from patch position
                row = int(i / len(stats_df_numeric.columns))
                if row == std_dev_row_idx:
                    cell.set_facecolor('#f5f5f5')  # Light gray for std dev row
        
        # Adjust figure layout
        plt.subplots_adjust(hspace=0.3)
        
        if _show_visualization:
            plt.show()
        else:
            plt.close()
            
        return {
            'plot_data': plot_data,
            'stats_df_numeric': stats_df_numeric,
            'stats_df_formatted': stats_df_formatted
        }
    
    except Exception as e:
        print(f"Error creating metrics whisker plot: {str(e)}")
        import traceback
        print(traceback.format_exc())
        return {}

# Create and display metrics whisker plot with statistical heatmap
metrics_viz_data = create_metrics_whisker_plot()

In [None]:
# Cell [9] - Store Whisker Chart Data in Neo4j
# Purpose: Store the metrics statistics data from whisker chart in Neo4j linked to project
# Dependencies: neo4j, pandas, logging
# Breadcrumbs: Data Storage -> Neo4j Persistence

import json
from datetime import datetime  # Added global import for datetime

def store_whisker_chart_data_in_neo4j(metrics_viz_data=None, driver=None, project_name=None):
    """
    Store the metrics statistics data from the whisker chart in Neo4j
    
    Parameters:
        metrics_viz_data: Dictionary containing whisker chart data (from create_metrics_whisker_plot)
        driver: Neo4j driver connection
        project_name: Project name to attach the metrics data to
    
    Returns:
        bool: True if successful, False otherwise
    """
    try:
        # Use global variables if parameters not provided
        _metrics_viz_data = metrics_viz_data if metrics_viz_data is not None else globals().get('metrics_viz_data', {})
        _driver = driver if driver is not None else globals().get('driver')
        _project_name = project_name if project_name is not None else globals().get('NEO4J_PROJECT_NAME', 'Unknown Project')
        
        if not _metrics_viz_data or not _driver:
            logger.error("Missing metrics data or Neo4j driver for storage")
            return False
        
        # Check if we have the stats data
        if 'stats_df_numeric' not in _metrics_viz_data or _metrics_viz_data['stats_df_numeric'].empty:
            logger.error("No statistical metrics data available for storage")
            return False
        
        # Get the stats DataFrames
        stats_df = _metrics_viz_data['stats_df_numeric']
        
        # Prepare data for Neo4j storage - convert stats dataframe to dictionary
        metrics_data = {}
        for column in stats_df.columns:
            # Each column is a metric
            metric_stats = {}
            for idx, value in stats_df[column].items():
                # Each row is a statistic (min, max, etc.)
                # Convert to standard lowercase keys without spaces
                key = idx.lower().replace(' ', '_')
                metric_stats[key] = float(value)
            
            # Add metric stats to overall data
            metrics_data[column.lower().replace(' ', '_')] = metric_stats
        
        # Also add the number of models analyzed
        if 'plot_data' in _metrics_viz_data and not _metrics_viz_data['plot_data'].empty:
            model_count = _metrics_viz_data['plot_data']['model_name'].nunique()
            metrics_data['model_count'] = model_count
        
        # Create model_data dictionary to store the individual data points that make up the whisker chart
        model_data = {}
        
        # Create results_data dictionary to store TP, FP, FN, TN for each model
        results_data = {}
        
        # If we have best_thresholds_df, extract the confusion matrix data (TP, FP, FN, TN)
        if 'best_thresholds_df' in globals() and not globals()['best_thresholds_df'].empty:
            best_df = globals()['best_thresholds_df']
            print(f"\nExtracting confusion matrix data from best_thresholds_df with shape {best_df.shape}")
            
            # For each model, store its confusion matrix data
            for _, row in best_df.iterrows():
                model_name = row['model_name']
                # Extract just the model name without path
                if '/' in model_name:
                    model_key = model_name.split('/')[-1]
                else:
                    model_key = model_name
                
                # Store confusion matrix data for this model
                results_data[model_key] = {
                    'true_positives': int(row['true_positives']) if not pd.isna(row['true_positives']) else 0,
                    'false_positives': int(row['false_positives']) if not pd.isna(row['false_positives']) else 0,
                    'false_negatives': int(row['false_negatives']) if not pd.isna(row['false_negatives']) else 0,
                    'true_negatives': int(row['true_negatives']) if not pd.isna(row['true_negatives']) else 0,
                    'threshold': float(row['best_threshold']) if not pd.isna(row['best_threshold']) else 0.0
                }
            
            print(f"  Added confusion matrix data for {len(results_data)} models")
            # Show a few examples
            for model_name, data in list(results_data.items())[:3]:
                print(f"    {model_name}: TP={data['true_positives']}, FP={data['false_positives']}, " +
                      f"FN={data['false_negatives']}, TN={data['true_negatives']}, threshold={data['threshold']}")
            if len(results_data) > 3:
                print(f"    ... and {len(results_data) - 3} more")
        
        # If we have plot_data from the whisker chart in cell 8, extract the individual data points by metric
        if 'plot_data' in _metrics_viz_data and not _metrics_viz_data['plot_data'].empty:
            plot_df = _metrics_viz_data['plot_data']
            print(f"\nExtracting whisker chart data points from plot_data with shape {plot_df.shape}")
            
            # Group plot data by metric and extract all scores by model
            for metric in plot_df['metric'].unique():
                # Convert metric to standardized key format
                metric_key = metric.lower().replace(' ', '_')
                model_data[metric_key] = {}
                
                # Get data for just this metric
                metric_data = plot_df[plot_df['metric'] == metric]
                
                # Store scores by model name
                for _, row in metric_data.iterrows():
                    model_name = row['model_name']
                    # For model names with paths, just use the last part
                    if '/' in model_name:
                        model_key = model_name.split('/')[-1]
                    else:
                        model_key = model_name
                    
                    model_data[metric_key][model_key] = float(row['score'])
                
                print(f"  Added {len(metric_data)} data points for {metric}")
        
        # If model_data is empty and we have best_thresholds_df, try to get data from there as fallback
        if not model_data and 'best_thresholds_df' in globals() and not globals()['best_thresholds_df'].empty:
            best_df = globals()['best_thresholds_df']
            print(f"\nExtracting data points from best_thresholds_df with shape {best_df.shape}")
            
            # Metrics we want to capture
            metrics = ['accuracy', 'balanced_accuracy', 'precision', 'recall', 
                       'f1_score', 'f2_score', 'matthews_corr']
            
            # For each metric, create an entry with model scores
            for metric in metrics:
                if metric in best_df.columns:
                    metric_key = metric.lower()
                    model_data[metric_key] = {}
                    
                    # For each model, store its score for this metric
                    for _, row in best_df.iterrows():
                        model_name = row['model_name']
                        # Extract just the model name without path
                        if '/' in model_name:
                            model_key = model_name.split('/')[-1]
                        else:
                            model_key = model_name
                        
                        model_data[metric_key][model_key] = float(row[metric])
                    
                    print(f"  Added {len(best_df)} data points for {metric}")
        
        # If we've collected data points, print a summary
        if model_data:
            print("\nModel data summary:")
            for metric, values in model_data.items():
                print(f"  {metric}: {len(values)} data points")
                # Show a few examples
                for model_name, score in list(values.items())[:3]:
                    print(f"    {model_name}: {score}")
                if len(values) > 3:
                    print(f"    ... and {len(values) - 3} more")
        else:
            print("No individual data points were extracted for the model_data field")
        
        # Serialize data to JSON for Neo4j storage
        metrics_json = json.dumps(metrics_data)
        model_data_json = json.dumps(model_data)
        results_json = json.dumps(results_data)
        
        # Current timestamp for the analysis record
        timestamp = datetime.now().isoformat()
        
        # Cypher query to create metrics data connected to project
        # IMPORTANT: Fixed to use stable identifiers in the MERGE pattern
        query = """
        MATCH (p:Project {name: $project_name})
        MERGE (m:MetricsAnalysis {project_name: $project_name, model_type: $model_type, analysis_type: 'whisker_chart'})
        MERGE (p)-[r:HAS_METRICS_ANALYSIS {model_type: $model_type}]->(m)
        SET m.metrics_data = $metrics_data,
            m.model_count = $model_count,
            m.model_data = $model_data,
            m.results = $results_data,
            m.created_at = CASE WHEN m.created_at IS NULL THEN $timestamp ELSE m.created_at END,
            m.last_updated = $timestamp,
            r.timestamp = $timestamp,
            r.last_updated = $timestamp
        RETURN p.name as project_name, m.created_at as created_at
        """
        
        # Execute query to store metrics data
        with _driver.session() as session:
            result = session.run(
                query,
                project_name=_project_name,
                model_type="sentence_transformer_tf_idf",
                timestamp=timestamp,
                metrics_data=metrics_json,
                model_data=model_data_json,
                results_data=results_json,
                model_count=model_count if 'model_count' in locals() else 0
            ).single()
            
            if result:
                logger.info(f"Successfully stored whisker chart metrics data for project: {result['project_name']} at {result['created_at']}")
                print(f"Stored metrics analysis data for {model_count if 'model_count' in locals() else 0} models in Neo4j at {timestamp}")
                return True
            else:
                logger.warning(f"No result returned when storing metrics data for project: {_project_name}")
                print("No result returned when storing metrics data. Check logs for details.")
                return False
                
    except Exception as e:
        logger.error(f"Error storing whisker chart data in Neo4j: {str(e)}")
        logger.error("Exception details:", exc_info=True)
        print(f"Error storing whisker chart data in Neo4j: {str(e)}")
        return False

# Store the metrics data in Neo4j if available
if 'metrics_viz_data' in globals() and metrics_viz_data:
    success = store_whisker_chart_data_in_neo4j()
    if success:
        print("\nMetrics whisker chart data successfully stored in Neo4j")
        print("=" * 80)
        print(f"Project: {NEO4J_PROJECT_NAME}")
        print(f"Model type: sentence_transformer_tf_idf")
        print(f"Timestamp: {datetime.now().isoformat()}")
    else:
        print("\nFailed to store metrics whisker chart data. Check logs for details.")
else:
    print("\nNo metrics whisker chart data available for storage")
    print("Please run cell 8 first to generate the data")