## GENERATING COMP TABLE

In [2]:
import os
import pandas as pd
import numpy as np
import glob
from pathlib import Path
import mir_eval
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

def parse_lab_file(filepath):
    """Parse a .lab file and return intervals and labels."""
    intervals = []
    labels = []
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) >= 3:
                start = float(parts[0])
                end = float(parts[1])
                chord = ' '.join(parts[2:])
                intervals.append([start, end])
                labels.append(chord)
    return np.array(intervals), labels

def calculate_metrics_for_inference(inference_dir, ground_truth_dir):
    """Calculate all metrics for a single inference directory."""
    # Get all matching files
    inference_files = glob.glob(os.path.join(inference_dir, "*.lab"))
    
    if not inference_files:
        return None
    
    all_scores = defaultdict(list)
    
    for inf_file in inference_files:
        track_id = os.path.basename(inf_file)
        gt_file = os.path.join(ground_truth_dir, track_id)
        
        if not os.path.exists(gt_file):
            continue
        
        try:
            # Parse files
            ref_intervals, ref_labels = parse_lab_file(gt_file)
            est_intervals, est_labels = parse_lab_file(inf_file)
            
            if len(ref_intervals) == 0 or len(est_intervals) == 0:
                continue
            
            # Calculate mir_eval metrics
            scores = mir_eval.chord.evaluate(ref_intervals, ref_labels, 
                                            est_intervals, est_labels)
            
            # Store scores
            for metric, value in scores.items():
                all_scores[metric].append(value)
            
            # Calculate additional metrics
            all_scores['num_predictions'].append(len(est_labels))
            all_scores['num_ground_truth'].append(len(ref_labels))
            all_scores['duration_seconds'].append(ref_intervals[-1][1])
            
            # Chord change rate
            est_changes = len(est_labels) / (est_intervals[-1][1] / 60) if est_intervals[-1][1] > 0 else 0
            ref_changes = len(ref_labels) / (ref_intervals[-1][1] / 60) if ref_intervals[-1][1] > 0 else 0
            all_scores['pred_changes_per_min'].append(est_changes)
            all_scores['gt_changes_per_min'].append(ref_changes)
            
        except Exception as e:
            print(f"Error processing {track_id}: {str(e)}")
            continue
    
    # Calculate mean and std metrics
    if not all_scores:
        return None
    
    mean_metrics = {key: np.mean(values) for key, values in all_scores.items()}
    std_metrics = {f"{key}_std": np.std(values) for key, values in all_scores.items()}
    
    # Combine mean and std metrics
    combined_metrics = {**mean_metrics, **std_metrics}
    combined_metrics['num_tracks'] = len(all_scores['root'])
    
    return combined_metrics

def generate_comparison_table():
    """Generate a comparison table for all inference directories."""
    # Base paths
    inferences_base = "/home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences"
    ground_truth_dir = "/home/daniel.melo/datasets/rwc/annotations"
    
    # Get all inference directories
    inference_dirs = [d for d in glob.glob(os.path.join(inferences_base, "inferences_*")) 
                     if os.path.isdir(d)]
    inference_dirs = sorted(inference_dirs)
    
    if not inference_dirs:
        print(f"No inference directories found in {inferences_base}")
        return None
    
    print(f"Found {len(inference_dirs)} inference directories")
    print("Calculating metrics...\n")
    
    # Calculate metrics for each inference directory
    results = {}
    for inf_dir in inference_dirs:
        dir_name = os.path.basename(inf_dir)
        # Extract experiment number from directory name
        exp_num = dir_name.replace('inferences_', '').split('_')[0]
        
        print(f"Processing: {dir_name}")
        metrics = calculate_metrics_for_inference(inf_dir, ground_truth_dir)
        
        if metrics:
            results[f"Exp_{exp_num}"] = metrics
            print(f"  ‚úì Processed {metrics['num_tracks']} tracks")
        else:
            print(f"  ‚úó No valid tracks found")
    
    if not results:
        print("No results to display")
        return None
    
    # Create DataFrame
    df = pd.DataFrame(results).T
    
    # Select and order important metrics (mean values)
    metric_order = [
        'root', 'majmin', 'thirds', 'triads', 'tetrads', 'sevenths',
        'overseg', 'underseg', 'seg',
        'num_tracks', 'duration_seconds',
        'num_predictions', 'num_ground_truth',
        'pred_changes_per_min', 'gt_changes_per_min'
    ]
    
    # Keep only metrics that exist
    available_metrics = [m for m in metric_order if m in df.columns]
    
    # Add corresponding std columns
    std_metrics = [f"{m}_std" for m in available_metrics if f"{m}_std" in df.columns]
    
    # Combine mean and std columns
    all_columns = []
    for metric in available_metrics:
        all_columns.append(metric)
        if f"{metric}_std" in df.columns:
            all_columns.append(f"{metric}_std")
    
    df = df[all_columns]
    
    # Rename columns for better readability
    column_names = {
        'root': 'Root Accuracy (Mean)',
        'root_std': 'Root Accuracy (Std)',
        'majmin': 'Maj/Min Accuracy (Mean)',
        'majmin_std': 'Maj/Min Accuracy (Std)',
        'thirds': 'Thirds Accuracy (Mean)',
        'thirds_std': 'Thirds Accuracy (Std)',
        'triads': 'Triads Accuracy (Mean)',
        'triads_std': 'Triads Accuracy (Std)',
        'tetrads': 'Tetrads Accuracy (Mean)',
        'tetrads_std': 'Tetrads Accuracy (Std)',
        'sevenths': 'Sevenths Accuracy (Mean)',
        'sevenths_std': 'Sevenths Accuracy (Std)',
        'overseg': 'Over-segmentation (Mean)',
        'overseg_std': 'Over-segmentation (Std)',
        'underseg': 'Under-segmentation (Mean)',
        'underseg_std': 'Under-segmentation (Std)',
        'seg': 'Segmentation (Mean)',
        'seg_std': 'Segmentation (Std)',
        'num_tracks': 'Number of Tracks',
        'duration_seconds': 'Avg Duration (s) (Mean)',
        'duration_seconds_std': 'Avg Duration (s) (Std)',
        'num_predictions': 'Avg Predictions (Mean)',
        'num_predictions_std': 'Avg Predictions (Std)',
        'num_ground_truth': 'Avg Ground Truth (Mean)',
        'num_ground_truth_std': 'Avg Ground Truth (Std)',
        'pred_changes_per_min': 'Pred Changes/min (Mean)',
        'pred_changes_per_min_std': 'Pred Changes/min (Std)',
        'gt_changes_per_min': 'GT Changes/min (Mean)',
        'gt_changes_per_min_std': 'GT Changes/min (Std)'
    }
    
    df = df.rename(columns=column_names)
    
    # Format numeric columns
    for col in df.columns:
        if 'Number of Tracks' in col:
            df[col] = df[col].apply(lambda x: f"{x:.0f}")
        elif 'Accuracy' in col or 'segmentation' in col.lower() or 'Segmentation' in col:
            df[col] = df[col].apply(lambda x: f"{x:.4f}")
        elif 'Changes' in col or 'Duration' in col or 'Predictions' in col or 'Ground Truth' in col:
            df[col] = df[col].apply(lambda x: f"{x:.2f}")
    
    return df

# Generate and display the comparison table
comparison_table = generate_comparison_table()

if comparison_table is not None:
    print("\n" + "="*100)
    print("COMPARISON TABLE - All Inference Experiments (Mean ¬± Std)")
    print("="*100)
    display(comparison_table)
    
    # Save to CSV for easy access
    output_path = "/home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences/comparison_metrics.csv"
    comparison_table.to_csv(output_path)
    print(f"\n‚úì Table saved to: {output_path}")


Found 5 inference directories
Calculating metrics...

Processing: inferences_1trainBillJaah_testRwc
  ‚úì Processed 94 tracks
Processing: inferences_2trainBillJaahDjavan_testRwc
  ‚úì Processed 94 tracks
Processing: inferences_3trainBillJaahDjavan_testRwc
  ‚úì Processed 94 tracks
Processing: inferences_4trainJaahDjavan_testRwc
  ‚úì Processed 94 tracks
Processing: inferences_moises
  ‚úó No valid tracks found

COMPARISON TABLE - All Inference Experiments (Mean ¬± Std)


Unnamed: 0,Root Accuracy (Mean),Root Accuracy (Std),Maj/Min Accuracy (Mean),Maj/Min Accuracy (Std),Thirds Accuracy (Mean),Thirds Accuracy (Std),Triads Accuracy (Mean),Triads Accuracy (Std),Tetrads Accuracy (Mean),Tetrads Accuracy (Std),...,Avg Duration (s) (Mean),Avg Duration (s) (Std),Avg Predictions (Mean),Avg Predictions (Std),Avg Ground Truth (Mean),Avg Ground Truth (Std),Pred Changes/min (Mean),Pred Changes/min (Std),GT Changes/min (Mean),GT Changes/min (Std)
Exp_1trainBillJaah,0.3234,0.1934,0.2946,0.1882,0.2853,0.18,0.2709,0.1743,0.1959,0.1568,...,244.32,42.3,171.27,59.44,133.18,35.14,42.64,14.24,32.99,8.26
Exp_2trainBillJaahDjavan,0.7834,0.1796,0.7603,0.189,0.7591,0.1877,0.721,0.188,0.5771,0.1876,...,244.32,42.3,174.81,44.36,133.18,35.14,43.5,10.05,32.99,8.26
Exp_3trainBillJaahDjavan,0.7905,0.1762,0.7627,0.1849,0.7672,0.1838,0.7259,0.1858,0.5856,0.1842,...,244.32,42.3,179.6,45.92,133.18,35.14,44.84,11.15,32.99,8.26
Exp_4trainJaahDjavan,0.575,0.19,0.5488,0.1931,0.5349,0.1914,0.5118,0.1899,0.2558,0.1644,...,244.32,42.3,177.31,58.26,133.18,35.14,43.81,11.98,32.99,8.26



‚úì Table saved to: /home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences/comparison_metrics.csv


## COMPARISON PLOTS

In [20]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

def visualize_comparison_table(df):
    """Create interactive visualizations for the comparison table."""
    if df is None or df.empty:
        print("No data to visualize")
        return
    
    # Convert string values back to numeric for plotting
    df_numeric = df.copy()
    for col in df_numeric.columns:
        try:
            df_numeric[col] = pd.to_numeric(df_numeric[col].astype(str).str.replace(',', ''))
        except:
            pass
    
    # Create subplots
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=(
            'Accuracy Metrics Comparison',
            'Segmentation Metrics',
            'Chord Changes per Minute',
            'Number of Predictions vs Ground Truth'
        ),
        specs=[[{"type": "bar"}, {"type": "bar"}],
               [{"type": "bar"}, {"type": "scatter"}]]
    )
    
    experiments = df_numeric.index.tolist()
    
    # 1. Accuracy Metrics
    accuracy_cols = [col for col in df_numeric.columns if 'Accuracy' in col]
    for col in accuracy_cols:
        fig.add_trace(
            go.Bar(name=col, x=experiments, y=df_numeric[col], 
                   text=df_numeric[col].apply(lambda x: f"{x:.3f}"),
                   textposition='auto'),
            row=1, col=1
        )
    
    # 2. Segmentation Metrics
    seg_cols = [col for col in df_numeric.columns if 'segmentation' in col.lower() or 'Segmentation' in col]
    for col in seg_cols:
        if col in df_numeric.columns:
            fig.add_trace(
                go.Bar(name=col, x=experiments, y=df_numeric[col],
                       text=df_numeric[col].apply(lambda x: f"{x:.3f}"),
                       textposition='auto'),
                row=1, col=2
            )
    
    # 3. Chord Changes per Minute
    if 'Pred Changes/min' in df_numeric.columns:
        fig.add_trace(
            go.Bar(name='Predictions', x=experiments, y=df_numeric['Pred Changes/min'],
                   text=df_numeric['Pred Changes/min'].apply(lambda x: f"{x:.1f}"),
                   textposition='auto'),
            row=2, col=1
        )
    if 'GT Changes/min' in df_numeric.columns:
        fig.add_trace(
            go.Bar(name='Ground Truth', x=experiments, y=df_numeric['GT Changes/min'],
                   text=df_numeric['GT Changes/min'].apply(lambda x: f"{x:.1f}"),
                   textposition='auto'),
            row=2, col=1
        )
    
    # 4. Predictions vs Ground Truth scatter
    if 'Avg Predictions' in df_numeric.columns and 'Avg Ground Truth' in df_numeric.columns:
        fig.add_trace(
            go.Scatter(
                x=df_numeric['Avg Ground Truth'], 
                y=df_numeric['Avg Predictions'],
                mode='markers+text',
                marker=dict(size=15, color=list(range(len(experiments))), 
                          colorscale='Viridis', showscale=True),
                text=experiments,
                textposition='top center',
                name='Experiments'
            ),
            row=2, col=2
        )
        # Add diagonal reference line
        min_val = min(df_numeric['Avg Ground Truth'].min(), df_numeric['Avg Predictions'].min())
        max_val = max(df_numeric['Avg Ground Truth'].max(), df_numeric['Avg Predictions'].max())
        fig.add_trace(
            go.Scatter(x=[min_val, max_val], y=[min_val, max_val],
                      mode='lines', line=dict(dash='dash', color='red'),
                      name='Perfect Match', showlegend=False),
            row=2, col=2
        )
    
    # Update layout
    fig.update_layout(
        height=900,
        title_text="Inference Experiments Comparison Dashboard",
        showlegend=True,
        barmode='group'
    )
    
    fig.update_xaxes(title_text="Experiment", row=1, col=1)
    fig.update_xaxes(title_text="Experiment", row=1, col=2)
    fig.update_xaxes(title_text="Experiment", row=2, col=1)
    fig.update_xaxes(title_text="Avg Ground Truth Segments", row=2, col=2)
    
    fig.update_yaxes(title_text="Score", row=1, col=1)
    fig.update_yaxes(title_text="Score", row=1, col=2)
    fig.update_yaxes(title_text="Changes/min", row=2, col=1)
    fig.update_yaxes(title_text="Avg Predicted Segments", row=2, col=2)
    
    fig.show()
    
    # Create a heatmap of all metrics
    fig2 = go.Figure(data=go.Heatmap(
        z=df_numeric.T.values,
        x=df_numeric.index,
        y=df_numeric.columns,
        colorscale='RdYlGn',
        text=df.T.values,
        texttemplate='%{text}',
        textfont={"size": 10},
        colorbar=dict(title="Value")
    ))
    
    fig2.update_layout(
        title='All Metrics Heatmap',
        xaxis_title='Experiment',
        yaxis_title='Metric',
        height=600,
        width=1000
    )
    
    fig2.show()
    
    # Create radar chart for accuracy metrics
    if len(accuracy_cols) > 2:
        fig3 = go.Figure()
        
        for exp in experiments:
            values = [df_numeric.loc[exp, col] for col in accuracy_cols]
            fig3.add_trace(go.Scatterpolar(
                r=values,
                theta=[col.replace(' Accuracy', '') for col in accuracy_cols],
                fill='toself',
                name=exp
            ))
        
        fig3.update_layout(
            polar=dict(
                radialaxis=dict(
                    visible=True,
                    range=[0, 1]
                )),
            showlegend=True,
            title='Accuracy Metrics - Radar Chart',
            height=600
        )
        
        fig3.show()

# Visualize the comparison table
if comparison_table is not None:
    visualize_comparison_table(comparison_table)



In [21]:
def analyze_best_and_worst_experiments(df):
    """Identify best and worst performing experiments for each metric."""
    if df is None or df.empty:
        print("No data to analyze")
        return
    
    # Convert string values back to numeric
    df_numeric = df.copy()
    for col in df_numeric.columns:
        try:
            df_numeric[col] = pd.to_numeric(df_numeric[col].astype(str).str.replace(',', ''))
        except:
            pass
    
    print("="*100)
    print("BEST AND WORST EXPERIMENTS BY METRIC")
    print("="*100)
    
    # Focus on main accuracy metrics
    accuracy_metrics = [col for col in df_numeric.columns if 'Accuracy' in col]
    
    for metric in accuracy_metrics:
        if metric in df_numeric.columns:
            best_exp = df_numeric[metric].idxmax()
            worst_exp = df_numeric[metric].idxmin()
            best_val = df_numeric.loc[best_exp, metric]
            worst_val = df_numeric.loc[worst_exp, metric]
            improvement = ((best_val - worst_val) / worst_val * 100) if worst_val > 0 else 0
            
            print(f"\n{metric}:")
            print(f"  üèÜ Best:  {best_exp} = {best_val:.4f}")
            print(f"  ‚ö†Ô∏è  Worst: {worst_exp} = {worst_val:.4f}")
            print(f"  üìà Improvement: {improvement:.2f}%")
    
    # Overall ranking based on weighted average of metrics
    print("\n" + "="*100)
    print("OVERALL EXPERIMENT RANKING")
    print("="*100)
    
    # Calculate weighted score (you can adjust weights as needed)
    weights = {
        'Root Accuracy': 0.20,
        'Maj/Min Accuracy': 0.25,
        'Thirds Accuracy': 0.15,
        'Triads Accuracy': 0.15,
        'Tetrads Accuracy': 0.15,
        'Sevenths Accuracy': 0.10
    }
    
    scores = {}
    for exp in df_numeric.index:
        score = 0
        total_weight = 0
        for metric, weight in weights.items():
            if metric in df_numeric.columns:
                score += df_numeric.loc[exp, metric] * weight
                total_weight += weight
        scores[exp] = score / total_weight if total_weight > 0 else 0
    
    ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    
    print("\nRanking (weighted average of accuracy metrics):")
    for i, (exp, score) in enumerate(ranked, 1):
        medal = "ü•á" if i == 1 else "ü•à" if i == 2 else "ü•â" if i == 3 else f"{i}."
        print(f"  {medal} {exp}: {score:.4f}")
    
    return ranked

def compare_specific_experiments(df, exp1, exp2):
    """Compare two specific experiments side by side."""
    if df is None or df.empty:
        print("No data to compare")
        return
    
    if exp1 not in df.index or exp2 not in df.index:
        print(f"Error: One or both experiments not found")
        print(f"Available experiments: {list(df.index)}")
        return
    
    # Convert to numeric
    df_numeric = df.copy()
    for col in df_numeric.columns:
        try:
            df_numeric[col] = pd.to_numeric(df_numeric[col].astype(str).str.replace(',', ''))
        except:
            pass
    
    print("="*100)
    print(f"COMPARISON: {exp1} vs {exp2}")
    print("="*100)
    
    comparison = pd.DataFrame({
        exp1: df_numeric.loc[exp1],
        exp2: df_numeric.loc[exp2],
        'Difference': df_numeric.loc[exp2] - df_numeric.loc[exp1],
        'Improvement (%)': ((df_numeric.loc[exp2] - df_numeric.loc[exp1]) / df_numeric.loc[exp1] * 100).fillna(0)
    })
    
    # Format for display
    comparison_display = comparison.copy()
    for col in [exp1, exp2, 'Difference']:
        comparison_display[col] = comparison_display[col].apply(lambda x: f"{x:.4f}" if abs(x) < 10 else f"{x:.2f}")
    comparison_display['Improvement (%)'] = comparison_display['Improvement (%)'].apply(lambda x: f"{x:+.2f}%")
    
    display(comparison_display)
    
    # Create visualization
    fig = go.Figure()
    
    accuracy_cols = [col for col in df_numeric.columns if 'Accuracy' in col]
    
    fig.add_trace(go.Bar(
        name=exp1,
        x=accuracy_cols,
        y=[df_numeric.loc[exp1, col] for col in accuracy_cols],
        text=[f"{df_numeric.loc[exp1, col]:.3f}" for col in accuracy_cols],
        textposition='auto'
    ))
    
    fig.add_trace(go.Bar(
        name=exp2,
        x=accuracy_cols,
        y=[df_numeric.loc[exp2, col] for col in accuracy_cols],
        text=[f"{df_numeric.loc[exp2, col]:.3f}" for col in accuracy_cols],
        textposition='auto'
    ))
    
    fig.update_layout(
        title=f'Accuracy Comparison: {exp1} vs {exp2}',
        xaxis_title='Metric',
        yaxis_title='Score',
        barmode='group',
        height=500
    )
    
    fig.show()
    
    return comparison

# Run analyses
if comparison_table is not None:
    # Analyze best and worst
    ranking = analyze_best_and_worst_experiments(comparison_table)
    
    # If we have at least 2 experiments, compare the top 2
    if len(comparison_table) >= 2 and ranking:
        print("\n" + "="*100)
        print("DETAILED COMPARISON: Top 2 Experiments")
        print("="*100)
        exp1, exp2 = ranking[0][0], ranking[1][0]
        compare_specific_experiments(comparison_table, exp1, exp2)



BEST AND WORST EXPERIMENTS BY METRIC

Root Accuracy:
  üèÜ Best:  Exp_3trainBillJaahDjavan = 0.7905
  ‚ö†Ô∏è  Worst: Exp_1trainBillJaah = 0.3234
  üìà Improvement: 144.43%

Maj/Min Accuracy:
  üèÜ Best:  Exp_3trainBillJaahDjavan = 0.7627
  ‚ö†Ô∏è  Worst: Exp_1trainBillJaah = 0.2946
  üìà Improvement: 158.89%

Thirds Accuracy:
  üèÜ Best:  Exp_3trainBillJaahDjavan = 0.7672
  ‚ö†Ô∏è  Worst: Exp_1trainBillJaah = 0.2853
  üìà Improvement: 168.91%

Triads Accuracy:
  üèÜ Best:  Exp_3trainBillJaahDjavan = 0.7259
  ‚ö†Ô∏è  Worst: Exp_1trainBillJaah = 0.2709
  üìà Improvement: 167.96%

Tetrads Accuracy:
  üèÜ Best:  Exp_3trainBillJaahDjavan = 0.5856
  ‚ö†Ô∏è  Worst: Exp_1trainBillJaah = 0.1959
  üìà Improvement: 198.93%

Sevenths Accuracy:
  üèÜ Best:  Exp_3trainBillJaahDjavan = 0.6371
  ‚ö†Ô∏è  Worst: Exp_1trainBillJaah = 0.2198
  üìà Improvement: 189.85%

OVERALL EXPERIMENT RANKING

Ranking (weighted average of accuracy metrics):
  ü•á Exp_3trainBillJaahDjavan: 0.7243
  ü•à Exp

Unnamed: 0,Exp_3trainBillJaahDjavan,Exp_2trainBillJaahDjavan,Difference,Improvement (%)
Root Accuracy,0.7905,0.7834,-0.0071,-0.90%
Maj/Min Accuracy,0.7627,0.7603,-0.0024,-0.31%
Thirds Accuracy,0.7672,0.7591,-0.0081,-1.06%
Triads Accuracy,0.7259,0.721,-0.0049,-0.68%
Tetrads Accuracy,0.5856,0.5771,-0.0085,-1.45%
Sevenths Accuracy,0.6371,0.6308,-0.0063,-0.99%
Over-segmentation,0.8427,0.8511,0.0084,+1.00%
Under-segmentation,0.8872,0.889,0.0018,+0.20%
Segmentation,0.8298,0.8381,0.0083,+1.00%
Number of Tracks,94.0,94.0,0.0,+0.00%


In [22]:
# ============================================================================
# UTILITY FUNCTIONS FOR FUTURE USE
# ============================================================================

def quick_refresh():
    """
    Quick function to refresh the comparison table with all current experiments.
    Use this whenever a new inference folder is added.
    """
    global comparison_table
    comparison_table = generate_comparison_table()
    
    if comparison_table is not None:
        print("\n" + "="*100)
        print("UPDATED COMPARISON TABLE")
        print("="*100)
        display(comparison_table)
        
        # Save updated CSV
        output_path = "/home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences/comparison_metrics.csv"
        comparison_table.to_csv(output_path)
        print(f"\n‚úì Updated table saved to: {output_path}")
        
        return comparison_table
    return None

def analyze_single_experiment(experiment_name):
    """
    Analyze a single experiment in detail.
    
    Args:
        experiment_name: Name of the experiment (e.g., 'Exp_1')
    """
    if comparison_table is None:
        print("Please run the comparison table generation first")
        return
    
    if experiment_name not in comparison_table.index:
        print(f"Experiment '{experiment_name}' not found")
        print(f"Available experiments: {list(comparison_table.index)}")
        return
    
    # Convert to numeric
    df_numeric = comparison_table.copy()
    for col in df_numeric.columns:
        try:
            df_numeric[col] = pd.to_numeric(df_numeric[col].astype(str).str.replace(',', ''))
        except:
            pass
    
    print("="*100)
    print(f"DETAILED ANALYSIS: {experiment_name}")
    print("="*100)
    
    exp_data = df_numeric.loc[experiment_name]
    
    # Display metrics
    print("\nüìä Accuracy Metrics:")
    for col in df_numeric.columns:
        if 'Accuracy' in col:
            value = exp_data[col]
            print(f"  {col}: {value:.4f}")
    
    print("\nüìà Segmentation Metrics:")
    for col in df_numeric.columns:
        if 'segmentation' in col.lower() or 'Segmentation' in col:
            value = exp_data[col]
            print(f"  {col}: {value:.4f}")
    
    print("\nüî¢ Statistics:")
    for col in ['Number of Tracks', 'Avg Duration (s)', 'Avg Predictions', 'Avg Ground Truth']:
        if col in df_numeric.columns:
            value = exp_data[col]
            print(f"  {col}: {value:.2f}")
    
    print("\nüîÑ Chord Changes:")
    for col in ['Pred Changes/min', 'GT Changes/min']:
        if col in df_numeric.columns:
            value = exp_data[col]
            print(f"  {col}: {value:.2f}")
    
    # Compare with average
    print("\n" + "="*100)
    print(f"COMPARISON WITH AVERAGE")
    print("="*100)
    
    accuracy_cols = [col for col in df_numeric.columns if 'Accuracy' in col]
    
    comparison_data = []
    for col in accuracy_cols:
        exp_val = exp_data[col]
        avg_val = df_numeric[col].mean()
        diff = exp_val - avg_val
        diff_pct = (diff / avg_val * 100) if avg_val > 0 else 0
        comparison_data.append({
            'Metric': col,
            experiment_name: f"{exp_val:.4f}",
            'Average': f"{avg_val:.4f}",
            'Difference': f"{diff:+.4f}",
            'Diff (%)': f"{diff_pct:+.2f}%"
        })
    
    comp_df = pd.DataFrame(comparison_data)
    display(comp_df)

def export_metrics_summary(filename=None):
    """
    Export a comprehensive summary of all metrics to CSV and Excel formats.
    
    Args:
        filename: Base filename (without extension). If None, uses timestamp.
    """
    if comparison_table is None:
        print("Please run the comparison table generation first")
        return
    
    if filename is None:
        from datetime import datetime
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"metrics_summary_{timestamp}"
    
    base_path = "/home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences"
    
    # Save CSV
    csv_path = os.path.join(base_path, f"{filename}.csv")
    comparison_table.to_csv(csv_path)
    print(f"‚úì CSV saved to: {csv_path}")
    
    # Save Excel with formatting
    try:
        excel_path = os.path.join(base_path, f"{filename}.xlsx")
        
        # Convert to numeric for Excel
        df_numeric = comparison_table.copy()
        for col in df_numeric.columns:
            try:
                df_numeric[col] = pd.to_numeric(df_numeric[col].astype(str).str.replace(',', ''))
            except:
                pass
        
        with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
            df_numeric.to_excel(writer, sheet_name='Metrics')
            
            # Add a summary sheet
            summary_data = {
                'Metric': ['Best Root Accuracy', 'Best Maj/Min Accuracy', 'Average Root Accuracy', 'Average Maj/Min Accuracy'],
                'Value': [
                    df_numeric['Root Accuracy'].max() if 'Root Accuracy' in df_numeric.columns else 0,
                    df_numeric['Maj/Min Accuracy'].max() if 'Maj/Min Accuracy' in df_numeric.columns else 0,
                    df_numeric['Root Accuracy'].mean() if 'Root Accuracy' in df_numeric.columns else 0,
                    df_numeric['Maj/Min Accuracy'].mean() if 'Maj/Min Accuracy' in df_numeric.columns else 0,
                ],
                'Experiment': [
                    df_numeric['Root Accuracy'].idxmax() if 'Root Accuracy' in df_numeric.columns else 'N/A',
                    df_numeric['Maj/Min Accuracy'].idxmax() if 'Maj/Min Accuracy' in df_numeric.columns else 'N/A',
                    'Average',
                    'Average'
                ]
            }
            summary_df = pd.DataFrame(summary_data)
            summary_df.to_excel(writer, sheet_name='Summary', index=False)
        
        print(f"‚úì Excel saved to: {excel_path}")
    except Exception as e:
        print(f"‚ö†Ô∏è  Could not save Excel file: {str(e)}")

# Display usage instructions
print("="*100)
print("UTILITY FUNCTIONS LOADED")
print("="*100)
print("\nüìö Available functions:")
print("  1. quick_refresh() - Refresh table when new experiments are added")
print("  2. analyze_single_experiment('Exp_N') - Detailed analysis of a specific experiment")
print("  3. compare_specific_experiments(df, 'Exp_1', 'Exp_2') - Compare two experiments")
print("  4. export_metrics_summary('filename') - Export results to CSV/Excel")
print("\nüí° Example usage:")
print("  >> quick_refresh()  # Update table with new experiments")
print("  >> analyze_single_experiment('Exp_3')  # Analyze experiment 3")
print("  >> export_metrics_summary('my_results')  # Export to files")
print("="*100)



UTILITY FUNCTIONS LOADED

üìö Available functions:
  1. quick_refresh() - Refresh table when new experiments are added
  2. analyze_single_experiment('Exp_N') - Detailed analysis of a specific experiment
  3. compare_specific_experiments(df, 'Exp_1', 'Exp_2') - Compare two experiments
  4. export_metrics_summary('filename') - Export results to CSV/Excel

üí° Example usage:
  >> quick_refresh()  # Update table with new experiments
  >> analyze_single_experiment('Exp_3')  # Analyze experiment 3
  >> export_metrics_summary('my_results')  # Export to files


In [23]:
# ============================================================================
# ADVANCED METRICS CALCULATION (Optional)
# ============================================================================
# This cell calculates additional advanced metrics beyond mir_eval standards

import scipy.stats as stats
from collections import Counter

def calculate_advanced_metrics(inference_dir, ground_truth_dir):
    """Calculate advanced metrics for a single inference directory."""
    inference_files = glob.glob(os.path.join(inference_dir, "*.lab"))
    
    if not inference_files:
        return None
    
    advanced_metrics = {
        'chord_diversity': [],
        'unique_chords_ratio': [],
        'avg_segment_duration': [],
        'segment_duration_std': [],
        'chord_vocabulary_size': [],
        'no_chord_percentage': [],
        'major_minor_ratio': [],
        'complex_chord_percentage': [],  # 7ths, 9ths, etc.
        'correlation_segment_count': [],
        'mae_segment_count': []
    }
    
    all_predicted_chords = []
    all_gt_chords = []
    
    for inf_file in inference_files:
        track_id = os.path.basename(inf_file)
        gt_file = os.path.join(ground_truth_dir, track_id)
        
        if not os.path.exists(gt_file):
            continue
        
        try:
            ref_intervals, ref_labels = parse_lab_file(gt_file)
            est_intervals, est_labels = parse_lab_file(inf_file)
            
            if len(ref_intervals) == 0 or len(est_intervals) == 0:
                continue
            
            # Chord diversity (entropy)
            chord_counts = Counter(est_labels)
            total = len(est_labels)
            probabilities = [count/total for count in chord_counts.values()]
            entropy = -sum(p * np.log2(p) if p > 0 else 0 for p in probabilities)
            advanced_metrics['chord_diversity'].append(entropy)
            
            # Unique chords ratio
            unique_ratio = len(set(est_labels)) / len(est_labels) if len(est_labels) > 0 else 0
            advanced_metrics['unique_chords_ratio'].append(unique_ratio)
            
            # Average segment duration
            durations = [interval[1] - interval[0] for interval in est_intervals]
            advanced_metrics['avg_segment_duration'].append(np.mean(durations))
            advanced_metrics['segment_duration_std'].append(np.std(durations))
            
            # Chord vocabulary size
            advanced_metrics['chord_vocabulary_size'].append(len(set(est_labels)))
            
            # No chord percentage
            no_chord_count = sum(1 for label in est_labels if label == 'N' or label.lower() == 'n')
            no_chord_pct = no_chord_count / len(est_labels) * 100 if len(est_labels) > 0 else 0
            advanced_metrics['no_chord_percentage'].append(no_chord_pct)
            
            # Major/Minor ratio
            major_count = sum(1 for label in est_labels if ':maj' in label or label.split(':')[0] in ['C', 'D', 'E', 'F', 'G', 'A', 'B'] and ':min' not in label and ':' in label)
            minor_count = sum(1 for label in est_labels if ':min' in label)
            maj_min_ratio = major_count / minor_count if minor_count > 0 else 0
            advanced_metrics['major_minor_ratio'].append(maj_min_ratio)
            
            # Complex chord percentage (7ths, 9ths, etc.)
            complex_count = sum(1 for label in est_labels if any(x in label for x in ['7', '9', '11', '13', 'sus', 'dim', 'aug']))
            complex_pct = complex_count / len(est_labels) * 100 if len(est_labels) > 0 else 0
            advanced_metrics['complex_chord_percentage'].append(complex_pct)
            
            # Correlation between predicted and GT segment counts
            advanced_metrics['correlation_segment_count'].append(len(est_labels))
            advanced_metrics['mae_segment_count'].append(abs(len(est_labels) - len(ref_labels)))
            
            # Collect all chords for global analysis
            all_predicted_chords.extend(est_labels)
            all_gt_chords.extend(ref_labels)
            
        except Exception as e:
            continue
    
    if not advanced_metrics['chord_diversity']:
        return None
    
    # Calculate mean metrics
    mean_metrics = {key: np.mean(values) if values else 0 for key, values in advanced_metrics.items()}
    
    # Add correlation
    if len(advanced_metrics['correlation_segment_count']) > 1:
        pred_counts = advanced_metrics['correlation_segment_count']
        gt_counts = [len(parse_lab_file(os.path.join(ground_truth_dir, os.path.basename(f)))[0]) 
                     for f in inference_files if os.path.exists(os.path.join(ground_truth_dir, os.path.basename(f)))]
        if len(pred_counts) == len(gt_counts) and len(pred_counts) > 1:
            correlation = np.corrcoef(pred_counts[:len(gt_counts)], gt_counts)[0, 1]
            mean_metrics['segment_count_correlation'] = correlation
    
    return mean_metrics

def generate_advanced_comparison():
    """Generate advanced comparison table for all experiments."""
    inferences_base = "/home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences"
    ground_truth_dir = "/home/daniel.melo/datasets/rwc/annotations"
    
    inference_dirs = sorted([d for d in glob.glob(os.path.join(inferences_base, "inferences_*")) 
                             if os.path.isdir(d)])
    
    if not inference_dirs:
        print("No inference directories found")
        return None
    
    print("Calculating advanced metrics...\n")
    
    results = {}
    for inf_dir in inference_dirs:
        dir_name = os.path.basename(inf_dir)
        exp_num = dir_name.replace('inferences_', '').split('_')[0]
        
        print(f"Processing: {dir_name}")
        metrics = calculate_advanced_metrics(inf_dir, ground_truth_dir)
        
        if metrics:
            results[f"Exp_{exp_num}"] = metrics
            print(f"  ‚úì Calculated advanced metrics")
    
    if not results:
        return None
    
    df = pd.DataFrame(results).T
    
    # Rename columns for better readability
    column_names = {
        'chord_diversity': 'Chord Diversity (Entropy)',
        'unique_chords_ratio': 'Unique Chords Ratio',
        'avg_segment_duration': 'Avg Segment Duration (s)',
        'segment_duration_std': 'Segment Duration StdDev',
        'chord_vocabulary_size': 'Vocabulary Size',
        'no_chord_percentage': 'No-Chord %',
        'major_minor_ratio': 'Major/Minor Ratio',
        'complex_chord_percentage': 'Complex Chords %',
        'mae_segment_count': 'MAE Segment Count',
        'segment_count_correlation': 'Segment Count Correlation'
    }
    
    df = df.rename(columns={k: v for k, v in column_names.items() if k in df.columns})
    
    # Format columns
    for col in df.columns:
        if '%' in col:
            df[col] = df[col].apply(lambda x: f"{x:.2f}%")
        elif 'Correlation' in col or 'Ratio' in col:
            df[col] = df[col].apply(lambda x: f"{x:.3f}")
        elif 'Duration' in col or 'MAE' in col or 'Size' in col or 'Entropy' in col:
            df[col] = df[col].apply(lambda x: f"{x:.2f}")
    
    return df

# Generate advanced metrics table
print("="*100)
print("ADVANCED METRICS ANALYSIS")
print("="*100)
print("\nThis may take a few moments...\n")

advanced_table = generate_advanced_comparison()

if advanced_table is not None:
    print("\n" + "="*100)
    print("ADVANCED METRICS TABLE")
    print("="*100)
    display(advanced_table)
    
    # Save to CSV
    output_path = "/home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences/advanced_metrics.csv"
    advanced_table.to_csv(output_path)
    print(f"\n‚úì Advanced metrics saved to: {output_path}")
    
    # Visualize advanced metrics
    df_numeric = advanced_table.copy()
    for col in df_numeric.columns:
        try:
            df_numeric[col] = pd.to_numeric(df_numeric[col].astype(str).str.replace('%', '').str.replace(',', ''))
        except:
            pass
    
    # Create visualization
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=(
            'Chord Diversity & Vocabulary Size',
            'Segment Duration Statistics',
            'Chord Type Distribution',
            'Prediction vs Ground Truth Correlation'
        )
    )
    
    experiments = df_numeric.index.tolist()
    
    # Plot 1: Diversity and Vocabulary
    if 'Chord Diversity (Entropy)' in df_numeric.columns:
        fig.add_trace(
            go.Bar(name='Diversity', x=experiments, y=df_numeric['Chord Diversity (Entropy)'],
                   marker_color='lightblue'),
            row=1, col=1
        )
    
    if 'Vocabulary Size' in df_numeric.columns:
        fig.add_trace(
            go.Bar(name='Vocab Size', x=experiments, y=df_numeric['Vocabulary Size'],
                   marker_color='lightgreen'),
            row=1, col=1
        )
    
    # Plot 2: Segment Duration
    if 'Avg Segment Duration (s)' in df_numeric.columns:
        fig.add_trace(
            go.Bar(name='Avg Duration', x=experiments, y=df_numeric['Avg Segment Duration (s)'],
                   marker_color='orange'),
            row=1, col=2
        )
    
    # Plot 3: Chord Types
    if 'No-Chord %' in df_numeric.columns:
        fig.add_trace(
            go.Bar(name='No-Chord %', x=experiments, y=df_numeric['No-Chord %'],
                   marker_color='red'),
            row=2, col=1
        )
    
    if 'Complex Chords %' in df_numeric.columns:
        fig.add_trace(
            go.Bar(name='Complex %', x=experiments, y=df_numeric['Complex Chords %'],
                   marker_color='purple'),
            row=2, col=1
        )
    
    # Plot 4: Correlation
    if 'Segment Count Correlation' in df_numeric.columns:
        fig.add_trace(
            go.Scatter(x=experiments, y=df_numeric['Segment Count Correlation'],
                      mode='markers+lines', marker=dict(size=10, color='green'),
                      name='Correlation'),
            row=2, col=2
        )
    
    fig.update_layout(height=800, showlegend=True, title_text="Advanced Metrics Dashboard")
    fig.show()
    
    print("\n‚úÖ Advanced analysis complete!")
else:
    print("‚ö†Ô∏è  Could not generate advanced metrics table")



ADVANCED METRICS ANALYSIS

This may take a few moments...

Calculating advanced metrics...

Processing: inferences_1trainBillJaah_testRwc
  ‚úì Calculated advanced metrics
Processing: inferences_2trainBillJaahDjavan_testRwc
  ‚úì Calculated advanced metrics
Processing: inferences_3trainBillJaahDjavan_testRwc
  ‚úì Calculated advanced metrics

ADVANCED METRICS TABLE


Unnamed: 0,Chord Diversity (Entropy),Unique Chords Ratio,Avg Segment Duration (s),Segment Duration StdDev,Vocabulary Size,No-Chord %,Major/Minor Ratio,Complex Chords %,correlation_segment_count,MAE Segment Count,Segment Count Correlation
Exp_1trainBillJaah,3.53,0.138,1.63,2.34,22.76,1.47%,1.462,23.02%,171.265957,58.57,0.344
Exp_2trainBillJaahDjavan,3.73,0.138,1.46,1.25,23.3,1.43%,0.662,35.65%,174.808511,42.33,0.793
Exp_3trainBillJaahDjavan,3.75,0.136,1.43,1.22,23.65,1.36%,0.702,42.79%,179.595745,46.41,0.809



‚úì Advanced metrics saved to: /home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences/advanced_metrics.csv



‚úÖ Advanced analysis complete!


In [24]:
import os
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
from collections import Counter
import glob
from pathlib import Path

# Set paths
INFERENCE_DIR = "/home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences/inferences_3trainBillJaahDjavan_testRwc"
GROUND_TRUTH_DIR = "/home/daniel.melo/datasets/rwc/annotations"

# Note: If the directory doesn't exist, try:
# INFERENCE_DIR = "/home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences/inferences_2trainBillJaahDjavan_testRwc"

print("Libraries imported successfully!")


Libraries imported successfully!


In [25]:
# DIAGNOSTIC: Check what files are actually found
print("=== DIAGNOSTIC: Checking file paths and names ===\n")
print(f"Inference directory: {INFERENCE_DIR}")
print(f"Ground truth directory: {GROUND_TRUTH_DIR}")
print(f"Inference dir exists: {os.path.exists(INFERENCE_DIR)}")
print(f"Ground truth dir exists: {os.path.exists(GROUND_TRUTH_DIR)}\n")

# Check inference files
inference_files = glob.glob(os.path.join(INFERENCE_DIR, "*.lab"))
print(f"Found {len(inference_files)} inference .lab files")
if inference_files:
    print(f"First 5 inference files:")
    for f in inference_files[:5]:
        print(f"  {os.path.basename(f)}")
else:
    # Try alternative patterns
    print("Trying alternative search patterns...")
    alt_patterns = [
        os.path.join(INFERENCE_DIR, "*.lab"),
        os.path.join(INFERENCE_DIR.replace("//", "/"), "*.lab"),
        os.path.join(INFERENCE_DIR.replace("inferences_3", "inferences_2"), "*.lab"),
    ]
    for pattern in alt_patterns:
        files = glob.glob(pattern)
        if files:
            print(f"Found {len(files)} files with pattern: {pattern}")
            print(f"  First 3: {[os.path.basename(f) for f in files[:3]]}")

# Check ground truth files
gt_files = glob.glob(os.path.join(GROUND_TRUTH_DIR, "*.lab"))
print(f"\nFound {len(gt_files)} ground truth .lab files")
if gt_files:
    print(f"First 5 ground truth files:")
    for f in gt_files[:5]:
        print(f"  {os.path.basename(f)}")

# Compare naming patterns
if inference_files and gt_files:
    inference_basenames = {os.path.basename(f) for f in inference_files}
    gt_basenames = {os.path.basename(f) for f in gt_files}
    print(f"\nInference file names (sample): {sorted(list(inference_basenames))[:5]}")
    print(f"Ground truth file names (sample): {sorted(list(gt_basenames))[:5]}")
    print(f"\nCommon files: {len(inference_basenames.intersection(gt_basenames))}")
    print(f"Only in inference: {sorted(list(inference_basenames - gt_basenames))[:5]}")
    print(f"Only in ground truth: {sorted(list(gt_basenames - inference_basenames))[:5]}")


=== DIAGNOSTIC: Checking file paths and names ===

Inference directory: /home/daniel.melo/BTC_ORIGINAL/BTC-ISMIR19/inferences/inferences_3trainBillJaahDjavan_testRwc
Ground truth directory: /home/daniel.melo/datasets/rwc/annotations
Inference dir exists: True
Ground truth dir exists: True

Found 94 inference .lab files
First 5 inference files:
  rwc-pop_17.lab
  rwc-pop_83.lab
  rwc-pop_59.lab
  rwc-pop_16.lab
  rwc-pop_42.lab

Found 100 ground truth .lab files
First 5 ground truth files:
  rwc-pop_17.lab
  rwc-pop_83.lab
  rwc-pop_59.lab
  rwc-pop_16.lab
  rwc-pop_42.lab

Inference file names (sample): ['rwc-pop_00.lab', 'rwc-pop_01.lab', 'rwc-pop_02.lab', 'rwc-pop_03.lab', 'rwc-pop_04.lab']
Ground truth file names (sample): ['rwc-pop_00.lab', 'rwc-pop_01.lab', 'rwc-pop_02.lab', 'rwc-pop_03.lab', 'rwc-pop_04.lab']

Common files: 94
Only in inference: []
Only in ground truth: ['rwc-pop_90.lab', 'rwc-pop_91.lab', 'rwc-pop_92.lab', 'rwc-pop_93.lab', 'rwc-pop_94.lab']


In [26]:
# Fix file names if needed - This cell will rename inference files to match ground truth naming
import shutil

def fix_file_names(inference_dir, ground_truth_dir):
    """Rename inference files to match ground truth file names."""
    # Get all files from both directories
    inference_files = glob.glob(os.path.join(inference_dir, "*.lab"))
    gt_files = glob.glob(os.path.join(ground_truth_dir, "*.lab"))
    
    if not inference_files:
        print(f"No inference files found in {inference_dir}")
        return
    
    if not gt_files:
        print(f"No ground truth files found in {ground_truth_dir}")
        return
    
    # Get base names (without extension)
    gt_basenames = {os.path.basename(f).replace('.lab', '') for f in gt_files}
    
    # Create mapping from inference files to ground truth names
    renamed_count = 0
    skipped_count = 0
    
    print("Renaming inference files to match ground truth names...\n")
    
    for inf_file in inference_files:
        inf_basename = os.path.basename(inf_file).replace('.lab', '')
        
        # Check if it already matches
        if inf_basename in gt_basenames:
            print(f"  ‚úì {os.path.basename(inf_file)} already matches")
            skipped_count += 1
            continue
        
        # Try to find a matching ground truth file
        # Strategy 1: Direct match (already checked)
        # Strategy 2: Try removing/adding prefixes/suffixes
        # Strategy 3: Try case-insensitive match
        matched = False
        
        # Case-insensitive match
        for gt_basename in gt_basenames:
            if inf_basename.lower() == gt_basename.lower():
                new_path = os.path.join(inference_dir, f"{gt_basename}.lab")
                if not os.path.exists(new_path):
                    os.rename(inf_file, new_path)
                    print(f"  ‚úì Renamed: {os.path.basename(inf_file)} -> {gt_basename}.lab (case fix)")
                    renamed_count += 1
                    matched = True
                    break
        
        if not matched:
            # Try pattern matching - remove common prefixes/suffixes
            # For example: "RWC-POP_00" -> "rwc-pop_00"
            normalized_inf = inf_basename.lower().replace('-', '_')
            for gt_basename in gt_basenames:
                normalized_gt = gt_basename.lower().replace('-', '_')
                if normalized_inf == normalized_gt:
                    new_path = os.path.join(inference_dir, f"{gt_basename}.lab")
                    if not os.path.exists(new_path):
                        os.rename(inf_file, new_path)
                        print(f"  ‚úì Renamed: {os.path.basename(inf_file)} -> {gt_basename}.lab (normalized)")
                        renamed_count += 1
                        matched = True
                        break
        
        if not matched:
            # Extract number from inference file and try to match
            import re
            inf_numbers = re.findall(r'\d+', inf_basename)
            if inf_numbers:
                for gt_basename in gt_basenames:
                    gt_numbers = re.findall(r'\d+', gt_basename)
                    if inf_numbers == gt_numbers:
                        new_path = os.path.join(inference_dir, f"{gt_basename}.lab")
                        if not os.path.exists(new_path):
                            os.rename(inf_file, new_path)
                            print(f"  ‚úì Renamed: {os.path.basename(inf_file)} -> {gt_basename}.lab (number match)")
                            renamed_count += 1
                            matched = True
                            break
        
        if not matched:
            print(f"  ‚úó Could not match: {os.path.basename(inf_file)}")
    
    print(f"\nSummary: Renamed {renamed_count} files, {skipped_count} already matched")

# Uncomment the line below to run the file name fix
# fix_file_names(INFERENCE_DIR, GROUND_TRUTH_DIR)


## EXPERIMENT PLOTS

In [27]:
def parse_lab_file(filepath):
    """Parse a .lab file and return a list of (start, end, chord) tuples."""
    chords = []
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) >= 3:
                start = float(parts[0])
                end = float(parts[1])
                chord = ' '.join(parts[2:])  # Handle chords with spaces
                chords.append((start, end, chord))
    return chords

def get_all_track_ids():
    """Get all track IDs that exist in both inference and ground truth."""
    inference_files = glob.glob(os.path.join(INFERENCE_DIR, "*.lab"))
    gt_files = glob.glob(os.path.join(GROUND_TRUTH_DIR, "*.lab"))
    
    inference_ids = {os.path.basename(f).replace('.lab', '') for f in inference_files}
    gt_ids = {os.path.basename(f).replace('.lab', '') for f in gt_files}
    
    common_ids = sorted(inference_ids.intersection(gt_ids))
    return common_ids

# Test the functions
track_ids = get_all_track_ids()
print(f"Found {len(track_ids)} common tracks")
print(f"First 10 tracks: {track_ids[:10]}")


Found 94 common tracks
First 10 tracks: ['rwc-pop_00', 'rwc-pop_01', 'rwc-pop_02', 'rwc-pop_03', 'rwc-pop_04', 'rwc-pop_05', 'rwc-pop_06', 'rwc-pop_07', 'rwc-pop_08', 'rwc-pop_09']


In [28]:
# Load all data
all_data = []

for track_id in track_ids:
    inference_path = os.path.join(INFERENCE_DIR, f"{track_id}.lab")
    gt_path = os.path.join(GROUND_TRUTH_DIR, f"{track_id}.lab")
    
    if os.path.exists(inference_path) and os.path.exists(gt_path):
        inference_chords = parse_lab_file(inference_path)
        gt_chords = parse_lab_file(gt_path)
        
        all_data.append({
            'track_id': track_id,
            'inference': inference_chords,
            'ground_truth': gt_chords,
            'inference_duration': max([c[1] for c in inference_chords]) if inference_chords else 0,
            'gt_duration': max([c[1] for c in gt_chords]) if gt_chords else 0,
        })

print(f"Loaded data for {len(all_data)} tracks")
print(f"\nExample track: {all_data[0]['track_id']}")
print(f"  Inference: {len(all_data[0]['inference'])} chord segments")
print(f"  Ground Truth: {len(all_data[0]['ground_truth'])} chord segments")


Loaded data for 94 tracks

Example track: rwc-pop_00
  Inference: 153 chord segments
  Ground Truth: 134 chord segments


In [29]:
# Analyze chord distributions across all tracks
def extract_chord_roots(chord_label):
    """Extract root note from chord label (e.g., 'A:maj7' -> 'A')."""
    if not chord_label or chord_label == 'N':
        return 'N'
    # Remove everything after colon or slash
    root = chord_label.split(':')[0].split('/')[0]
    return root

def analyze_chord_statistics(all_data):
    """Calculate statistics about chord predictions."""
    stats = []
    
    for track_data in all_data:
        inference_chords = [c[2] for c in track_data['inference']]
        gt_chords = [c[2] for c in track_data['ground_truth']]
        
        inference_roots = [extract_chord_roots(c) for c in inference_chords]
        gt_roots = [extract_chord_roots(c) for c in gt_chords]
        
        # Count unique chords
        inference_unique = len(set(inference_chords))
        gt_unique = len(set(gt_chords))
        
        # Count chord changes
        inference_changes = len(inference_chords)
        gt_changes = len(gt_chords)
        
        # Calculate chord changes per minute
        gt_changes_per_min = (gt_changes / (track_data['gt_duration'] / 60)) if track_data['gt_duration'] > 0 else 0
        inf_changes_per_min = (inference_changes / (track_data['inference_duration'] / 60)) if track_data['inference_duration'] > 0 else 0
        
        stats.append({
            'track_id': track_data['track_id'],
            'inference_unique_chords': inference_unique,
            'gt_unique_chords': gt_unique,
            'inference_chord_changes': inference_changes,
            'gt_chord_changes': gt_changes,
            'inference_duration': track_data['inference_duration'],
            'gt_duration': track_data['gt_duration'],
            'inference_roots': inference_roots,
            'gt_roots': gt_roots,
            'gt_changes_per_min': gt_changes_per_min,
            'inf_changes_per_min': inf_changes_per_min
        })
    
    return pd.DataFrame(stats)

stats_df = analyze_chord_statistics(all_data)
print(stats_df.head())


     track_id  inference_unique_chords  gt_unique_chords  \
0  rwc-pop_00                       27                29   
1  rwc-pop_01                       18                29   
2  rwc-pop_02                       36                44   
3  rwc-pop_03                       26                36   
4  rwc-pop_04                       25                26   

   inference_chord_changes  gt_chord_changes  inference_duration  gt_duration  \
0                      153               134             209.259      209.213   
1                      175               123             222.778      222.707   
2                      230               161             195.370      195.360   
3                      184               114             242.500      242.507   
4                      209               125             228.148      228.107   

                                     inference_roots  \
0  [G#, F#, E, E, B, B, F#, G#, N, G#, E, B, D#, ...   
1  [G#, G, G#, A#, D#, G#, G, G#, D#, G#

In [30]:
# Plot 1: Comparison of number of unique chords per track
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=stats_df['track_id'],
    y=stats_df['gt_unique_chords'],
    mode='lines+markers',
    name='Ground Truth',
    line=dict(color='blue', width=2),
    marker=dict(size=6)
))

fig.add_trace(go.Scatter(
    x=stats_df['track_id'],
    y=stats_df['inference_unique_chords'],
    mode='lines+markers',
    name='Inference',
    line=dict(color='red', width=2),
    marker=dict(size=6)
))

fig.update_layout(
    title='Number of Unique Chords per Track',
    xaxis_title='Track ID',
    yaxis_title='Number of Unique Chords',
    hovermode='x unified',
    height=500
)

fig.show()


In [31]:
# Plot 2: Comparison of chord change frequency
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=stats_df['track_id'],
    y=stats_df['gt_chord_changes'],
    mode='lines+markers',
    name='Ground Truth',
    line=dict(color='blue', width=2),
    marker=dict(size=6)
))

fig.add_trace(go.Scatter(
    x=stats_df['track_id'],
    y=stats_df['inference_chord_changes'],
    mode='lines+markers',
    name='Inference',
    line=dict(color='red', width=2),
    marker=dict(size=6)
))

fig.update_layout(
    title='Number of Chord Changes per Track',
    xaxis_title='Track ID',
    yaxis_title='Number of Chord Changes',
    hovermode='x unified',
    height=500
)

fig.show()


In [32]:
# Plot 3: Chord root distribution comparison
def get_root_distribution(all_data, source='inference'):
    """Get distribution of chord roots."""
    all_roots = []
    for track_data in all_data:
        if source == 'inference':
            chords = [c[2] for c in track_data['inference']]
        else:
            chords = [c[2] for c in track_data['ground_truth']]
        
        roots = [extract_chord_roots(c) for c in chords]
        all_roots.extend(roots)
    
    return Counter(all_roots)

gt_roots = get_root_distribution(all_data, 'ground_truth')
inf_roots = get_root_distribution(all_data, 'inference')

# Get all unique roots
all_roots = sorted(set(list(gt_roots.keys()) + list(inf_roots.keys())))

fig = go.Figure()

fig.add_trace(go.Bar(
    x=all_roots,
    y=[gt_roots.get(root, 0) for root in all_roots],
    name='Ground Truth',
    marker_color='blue',
    opacity=0.7
))

fig.add_trace(go.Bar(
    x=all_roots,
    y=[inf_roots.get(root, 0) for root in all_roots],
    name='Inference',
    marker_color='red',
    opacity=0.7
))

fig.update_layout(
    title='Chord Root Distribution (All Tracks)',
    xaxis_title='Chord Root',
    yaxis_title='Frequency',
    barmode='group',
    height=600,
    xaxis={'categoryorder': 'total descending'}
)

fig.show()


In [33]:
# Plot 4: Duration comparison
fig = go.Figure()


fig.add_trace(go.Scatter(
    x=stats_df['track_id'],
    y=stats_df['gt_duration'],
    mode='lines+markers',
    name='Ground Truth Duration',
    line=dict(color='blue', width=2),
    marker=dict(size=6)
))

fig.add_trace(go.Scatter(
    x=stats_df['track_id'],
    y=stats_df['inference_duration'],
    mode='lines+markers',
    name='Inference Duration',
    line=dict(color='red', width=2),
    marker=dict(size=6)
))

fig.update_layout(
    title='Track Duration Comparison',
    xaxis_title='Track ID',
    yaxis_title='Duration (seconds)',
    hovermode='x unified',
    height=500
)

fig.show()


In [34]:
# Plot 5: Summary statistics heatmap
summary_stats = {
    'Metric': ['Mean Unique Chords', 'Mean Chord Changes', 'Mean Duration (s)', 'Mean Changes/min'],
    'Ground Truth': [
        stats_df['gt_unique_chords'].mean(),
        stats_df['gt_chord_changes'].mean(),
        stats_df['gt_duration'].mean(),
        stats_df['gt_changes_per_min'].mean()
    ],
    'Inference': [
        stats_df['inference_unique_chords'].mean(),
        stats_df['inference_chord_changes'].mean(),
        stats_df['inference_duration'].mean(),
        stats_df['inf_changes_per_min'].mean()
    ]
}

summary_df = pd.DataFrame(summary_stats)

fig = go.Figure(data=go.Heatmap(
    z=summary_df[['Ground Truth', 'Inference']].values.T,
    x=summary_df['Metric'],
    y=['Ground Truth', 'Inference'],
    colorscale='Viridis',
    text=summary_df[['Ground Truth', 'Inference']].values.T,
    texttemplate='%{text:.2f}',
    textfont={"size": 12},
    colorbar=dict(title="Value")
))

fig.update_layout(
    title='Summary Statistics Comparison',
    height=300
)

fig.show()

print("\nSummary Statistics:")
print(summary_df.to_string(index=False))



Summary Statistics:
            Metric  Ground Truth  Inference
Mean Unique Chords     24.957447  23.648936
Mean Chord Changes    133.180851 179.595745
 Mean Duration (s)    244.320819 242.878245
  Mean Changes/min     32.991886  44.838025


In [35]:
# Plot 6: Chord type distribution (major, minor, 7th, etc.)
def extract_chord_type(chord_label):
    """Extract chord type from label (e.g., 'A:maj7' -> 'maj7', 'B:min' -> 'min', 'B' -> 'maj')."""
    if not chord_label or chord_label == 'N':
        return 'N'
    if ':' in chord_label:
        chord_type = chord_label.split(':', 1)[1]
        # Remove bass note if present
        if '/' in chord_type:
            chord_type = chord_type.split('/')[0]
        return chord_type
    # If no colon, it's a major chord (default in chord representation)
    # Check if it's just a root note (contains only letters, #, or b)
    if chord_label.replace('#', '').replace('b', '').isalpha():
        return 'maj'
    return 'unknown'

def get_chord_type_distribution(all_data, source='inference'):
    """Get distribution of chord types."""
    all_types = []
    for track_data in all_data:
        if source == 'inference':
            chords = [c[2] for c in track_data['inference']]
        else:
            chords = [c[2] for c in track_data['ground_truth']]
        
        types = [extract_chord_type(c) for c in chords]
        all_types.extend(types)
    
    return Counter(all_types)

gt_types = get_chord_type_distribution(all_data, 'ground_truth')
inf_types = get_chord_type_distribution(all_data, 'inference')

# Get all unique types and sort by frequency
all_types = sorted(set(list(gt_types.keys()) + list(inf_types.keys())))
# Sort by total frequency
type_freq = {t: gt_types.get(t, 0) + inf_types.get(t, 0) for t in all_types}
all_types = sorted(all_types, key=lambda x: type_freq[x], reverse=True)

# Take top 20 most common types
top_types = all_types[:20]

fig = go.Figure()

fig.add_trace(go.Bar(
    x=top_types,
    y=[gt_types.get(t, 0) for t in top_types],
    name='Ground Truth',
    marker_color='blue',
    opacity=0.7
))

fig.add_trace(go.Bar(
    x=top_types,
    y=[inf_types.get(t, 0) for t in top_types],
    name='Inference',
    marker_color='red',
    opacity=0.7
))

fig.update_layout(
    title='Top 20 Chord Type Distribution (All Tracks)',
    xaxis_title='Chord Type',
    yaxis_title='Frequency',
    barmode='group',
    height=600,
    xaxis={'tickangle': -45}
)

fig.show()


### Optional plots

In [36]:
# Plot 7: Interactive track selector for detailed timeline view
# Option 1: Using ipywidgets (if available)
try:
    import ipywidgets as widgets
    from IPython.display import display
    
    def create_track_selector():
        """Create an interactive widget to select and view tracks."""
        track_options = [(f"{data['track_id']}", i) for i, data in enumerate(all_data)]
        
        track_dropdown = widgets.Dropdown(
            options=track_options,
            value=0,
            description='Track:',
            style={'description_width': 'initial'}
        )
        
        def on_track_change(change):
            track_idx = change['new']
            track_data = all_data[track_idx]
            fig = plot_chord_timeline_comparison(track_data, track_data['track_id'])
            fig.show()
        
        track_dropdown.observe(on_track_change, names='value')
        return track_dropdown
    
    # Create the widget
    track_selector = create_track_selector()
    display(track_selector)
    
    # Show initial track
    if all_data:
        fig = plot_chord_timeline_comparison(all_data[0], all_data[0]['track_id'])
        fig.show()
except ImportError:
    print("ipywidgets not available. Showing first 3 tracks as examples:")
    # Show first 3 tracks as examples
    for i in range(min(3, len(all_data))):
        track_data = all_data[i]
        fig = plot_chord_timeline_comparison(track_data, track_data['track_id'])
        fig.show()


ipywidgets not available. Showing first 3 tracks as examples:


NameError: name 'plot_chord_timeline_comparison' is not defined

In [None]:
# Plot 9: Chord duration distribution
def get_chord_durations(chords):
    """Extract durations from chord list."""
    return [end - start for start, end, _ in chords]

all_gt_durations = []
all_inf_durations = []

for track_data in all_data:
    gt_durations = get_chord_durations(track_data['ground_truth'])
    inf_durations = get_chord_durations(track_data['inference'])
    all_gt_durations.extend(gt_durations)
    all_inf_durations.extend(inf_durations)

fig = go.Figure()

fig.add_trace(go.Histogram(
    x=all_gt_durations,
    name='Ground Truth',
    opacity=0.7,
    nbinsx=50,
    marker_color='blue'
))

fig.add_trace(go.Histogram(
    x=all_inf_durations,
    name='Inference',
    opacity=0.7,
    nbinsx=50,
    marker_color='red'
))

fig.update_layout(
    title='Distribution of Chord Durations (All Tracks)',
    xaxis_title='Chord Duration (seconds)',
    yaxis_title='Frequency',
    barmode='overlay',
    height=500
)

fig.show()

print(f"\nGround Truth - Mean duration: {np.mean(all_gt_durations):.2f}s, Median: {np.median(all_gt_durations):.2f}s")
print(f"Inference - Mean duration: {np.mean(all_inf_durations):.2f}s, Median: {np.median(all_inf_durations):.2f}s")



Ground Truth - Mean duration: 1.83s, Median: 1.62s
Inference - Mean duration: 1.42s, Median: 0.74s


In [None]:
def plot_chord_timeline_comparison(track_data, track_id):
    """Create an interactive timeline comparison plot for a single track."""
    inference = track_data['inference']
    gt = track_data['ground_truth']
    
    fig = make_subplots(
        rows=2, cols=1,
        subplot_titles=('Ground Truth', 'Inference'),
        vertical_spacing=0.15,
        row_heights=[0.5, 0.5]
    )
    
    # Ground truth timeline
    y_pos_gt = 0
    for i, (start, end, chord) in enumerate(gt):
        fig.add_trace(
            go.Scatter(
                x=[start, start, end, end, start],
                y=[y_pos_gt - 0.4, y_pos_gt + 0.4, y_pos_gt + 0.4, y_pos_gt - 0.4, y_pos_gt - 0.4],
                fill='toself',
                mode='lines',
                name=f'GT: {chord}',
                line=dict(width=0),
                fillcolor=f'hsl({(i * 60) % 360}, 70%, 80%)',
                hovertemplate=f'<b>{chord}</b><br>Time: {start:.2f} - {end:.2f}s<br>Duration: {end-start:.2f}s<extra></extra>',
                showlegend=False
            ),
            row=1, col=1
        )
    
    # Inference timeline
    y_pos_inf = 0
    for i, (start, end, chord) in enumerate(inference):
        fig.add_trace(
            go.Scatter(
                x=[start, start, end, end, start],
                y=[y_pos_inf - 0.4, y_pos_inf + 0.4, y_pos_inf + 0.4, y_pos_inf - 0.4, y_pos_inf - 0.4],
                fill='toself',
                mode='lines',
                name=f'Inf: {chord}',
                line=dict(width=0),
                fillcolor=f'hsl({(i * 60 + 30) % 360}, 70%, 80%)',
                hovertemplate=f'<b>{chord}</b><br>Time: {start:.2f} - {end:.2f}s<br>Duration: {end-start:.2f}s<extra></extra>',
                showlegend=False
            ),
            row=2, col=1
        )
    
    # Update layout
    max_duration = max(track_data['inference_duration'], track_data['gt_duration'])
    
    fig.update_xaxes(title_text="Time (seconds)", row=2, col=1)
    fig.update_yaxes(showticklabels=False, row=1, col=1)
    fig.update_yaxes(showticklabels=False, row=2, col=1)
    
    fig.update_layout(
        title=f"Chord Timeline Comparison: {track_id}",
        height=600,
        hovermode='closest'
    )
    
    return fig

# Plot first track as example
if all_data:
    fig = plot_chord_timeline_comparison(all_data[0], all_data[0]['track_id'])
    fig.show()
