In [1]:
from utils.eval_utils import load_evaluation_results_from_dir, evaluation_results_to_dataframe

save_path = "../results"

evaluation_results = load_evaluation_results_from_dir( save_path=save_path)
evaluation_results_dummy = load_evaluation_results_from_dir( save_path=save_path,dataset="dummy_data")
evaluation_results_train = load_evaluation_results_from_dir( save_path=save_path,dataset="txt2sql_alerce_train_v4_0")
evaluation_results_test = load_evaluation_results_from_dir( save_path=save_path,dataset="txt2sql_alerce_test_v4_0")
# evaluation_results_df = evaluation_results_to_dataframe(evaluation_results)

Skipping gpt-4.1-mini-2025-04-14/alerce_train_direct_v8 as config.json does not exist.
Evaluation results file not found for Qwen2.5-1.5B-Instruct/alerce_dummy_direct_v8. Skipping...
Skipping gpt-4.1-mini-2025-04-14/alerce_train_direct_v8 as config.json does not exist.
Evaluation results file not found for Qwen2.5-1.5B-Instruct/alerce_dummy_direct_v8. Skipping...
Skipping gpt-4.1-mini-2025-04-14/alerce_train_direct_v8 as config.json does not exist.
Skipping gpt-4.1-mini-2025-04-14/alerce_train_direct_v8 as config.json does not exist.


In [2]:
import pandas as pd
# load csv
db_train = pd.read_csv('data/txt2sql_alerce_train_v4_0.csv')
db_test = pd.read_csv('data/txt2sql_alerce_test_v4_0.csv')

In [3]:
# The expected structure of evaluation_results is:
#         {
#             'model_name': {
#                 'experiment_name': {
#                     'self_corrected': {
#                         'detailed_results': [
#                             {  "req_id": str,           # Unique request ID
#                                "n_exp": str or int,     # Experiment number
#                                "comparison": { ... }  # Comparison details
#                              ... 
#                             },
#                         'aggregate_metrics': {
#                             'oids': {'perfect_match_rate': float},
#                             'columns': {'perfect_match_rate': float},
#                             'difficulty': str  # e.g., 'easy', 'medium', 'hard'
#                         }
#                     }
#                 }
#             }
#         }
# Extract model and experiment names from the results dictionary
models = list(evaluation_results.keys())

self_corr = True
std_dev = True
# Initialize data collection lists
experiment_labels = {}
difficulties = []
# Self-corrected results are expected
if self_corr: self_corr_key = 'self_corrected'
else: self_corr_key = 'corrected'

from utils.eval_utils import metrics_aggregation

for model in models:
    experiments = list(evaluation_results[model].keys())
    for experiment in experiments:
        try:
            exp_label = f"{model}-{experiment}"
            oid_match_rates = {}
            column_match_rates = {}
            if std_dev:
                oid_match_rates_std = {}
                column_match_rates_std = {}
                
            results = evaluation_results[model][experiment][self_corr_key]['detailed_results'] # list of dictionaries
            aggregate_metrics = metrics_aggregation(results=results)
        except KeyError as e:
            print(f"Warning: Missing data for {model} - {experiment}: {e}")




In [4]:
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import pygments
from pygments.lexers import SqlLexer
from pygments.formatters import HtmlFormatter
import json
import difflib
import ast
from datetime import datetime
import pandas as pd

class SQLComparisonWidget:
    """
    An interactive widget for visualizing and comparing SQL query predictions
    made by language models in a Text-to-SQL task.
    """
    
    def __init__(self, results_dict, df_train=None, df_test=None):
        """
        Initialize the widget with a dictionary of result dictionaries.
        
        Parameters:
        -----------
        results_dict : dict
            Dictionary where keys are 'model_name-experiment_name' and values are lists of result dictionaries
        df_train : pandas.DataFrame, optional
            DataFrame containing training data with requests and knowledge
        df_test : pandas.DataFrame, optional
            DataFrame containing test data with requests and knowledge
        """
        self.results_dict = results_dict
        self.df_train = df_train
        self.df_test = df_test
        
        # Extract model and experiment data from the keys
        self.model_experiment_map = {}
        for key in results_dict.keys():
            if '-' in key:
                model, experiment = key.split('-', 1)
                if model not in self.model_experiment_map:
                    self.model_experiment_map[model] = []
                self.model_experiment_map[model].append(experiment)
        
        # Sort the experiments for each model
        for model in self.model_experiment_map:
            self.model_experiment_map[model] = sorted(self.model_experiment_map[model])
        
        # Flatten results for initial processing
        self.all_results = []
        for result_list in results_dict.values():
            self.all_results.extend(result_list)
        
        # Extract unique request IDs
        self.req_ids = sorted(list(set([r['req_id'] for r in self.all_results])))
        
        # Extract unique models and experiments
        self.models = sorted(list(self.model_experiment_map.keys()))
        if not self.models:
            self.models = ['All']
        
        # Extract unique difficulties
        self.difficulties = sorted(list(set([r.get('difficulty', 'unknown') for r in self.all_results if 'difficulty' in r])))
        self.difficulties = ['All'] + self.difficulties
        
        # Set up CSS styles for the UI
        self.setup_styles()
        
        # Create widgets
        self.create_widgets()
        
        # Connect event handlers
        self.setup_event_handlers()
        
        # Display the widget
        self.display_widget()
    
    def setup_styles(self):
        """Define CSS styles for the UI components."""
        self.style_html = """
        <style>
            .sql-container {
                display: flex;
                flex-direction: row;
                width: 100%;
                margin-bottom: 10px;
            }
            .sql-box {
                flex: 1;
                margin: 5px;
                border: 1px solid #ccc;
                border-radius: 5px;
                padding: 10px;
                overflow: auto;
                max-height: 400px;
            }
            .sql-title {
                font-weight: bold;
                margin-bottom: 5px;
                padding: 5px;
                background-color: #f0f0f0;
                border-radius: 3px;
            }
            .sql-query {
                font-family: monospace;
                white-space: pre-wrap;
                padding: 10px;
                background-color: #f8f8f8;
                border-radius: 3px;
            }
            .metrics-container {
                display: flex;
                flex-direction: row;
                flex-wrap: wrap;
                width: 100%;
            }
            .metrics-box {
                flex: 1;
                min-width: 300px;
                margin: 5px;
                padding: 10px;
                border: 1px solid #ddd;
                border-radius: 5px;
            }
            .metrics-title {
                font-weight: bold;
                margin-bottom: 10px;
                padding: 5px;
                background-color: #f0f0f0;
                border-radius: 3px;
            }
            .metrics-table {
                width: 100%;
                border-collapse: collapse;
            }
            .metrics-table th, .metrics-table td {
                border: 1px solid #ddd;
                padding: 8px;
                text-align: left;
            }
            .metrics-table th {
                background-color: #f2f2f2;
            }
            .metrics-table tr:nth-child(even) {
                background-color: #f9f9f9;
            }
            .error-container {
                background-color: #ffebee;
                color: #c62828;
                padding: 10px;
                border-radius: 5px;
                margin: 10px 0;
                border-left: 4px solid #c62828;
            }
            .success-container {
                background-color: #e8f5e9;
                color: #2e7d32;
                padding: 10px;
                border-radius: 5px;
                margin: 10px 0;
                border-left: 4px solid #2e7d32;
            }
            .metadata-container {
                background-color: #e3f2fd;
                padding: 10px;
                border-radius: 5px;
                margin: 10px 0;
            }
            .request-container {
                background-color: #e8eaf6;
                padding: 10px;
                border-radius: 5px;
                margin: 10px 0;
                border-left: 4px solid #3f51b5;
            }
            .knowledge-container {
                background-color: #f3e5f5;
                padding: 10px;
                border-radius: 5px;
                margin: 10px 0;
                border-left: 4px solid #9c27b0;
            }
            .summary-container {
                background-color: #fff8e1;
                padding: 10px;
                border-radius: 5px;
                margin: 10px 0;
                border: 1px solid #ffe082;
            }
            .diff-same {
                background-color: transparent;
            }
            .diff-add {
                background-color: #e6ffed !important;
                color: #22863a !important;
                border-left: 3px solid #22863a !important;
                padding-left: 2px !important;
                margin: 2px 0 !important;
            }
            .diff-remove {
                background-color: #ffeef0 !important;
                color: #cb2431 !important;
                border-left: 3px solid #cb2431 !important;
                padding-left: 2px !important;
                margin: 2px 0 !important;
            }
            .diff-change {
                background-color: #fff5b1 !important;
                color: #735c0f !important;
                border-left: 3px solid #735c0f !important;
                padding-left: 2px !important;
                margin: 2px 0 !important;
            }
            .sql-with-diffs {
                white-space: pre-wrap;
                font-family: monospace;
            }
            .sql-with-diffs div {
                margin: 0;
                padding: 0 2px;
            }
            .compare-table {
                width: 100%;
                border-collapse: collapse;
                margin: 10px 0;
            }
            .compare-table th, .compare-table td {
                border: 1px solid #ddd;
                padding: 8px;
                text-align: left;
            }
            .compare-table th {
                background-color: #f2f2f2;
                font-weight: bold;
            }
            .match {
                background-color: #e6ffed;
            }
            .mismatch {
                background-color: #ffeef0;
            }
        </style>
        """
        display(HTML(self.style_html))
    
    def create_widgets(self):
        """Create the widgets for the UI."""
        # Model selector
        self.model_dropdown = widgets.Dropdown(
            options=self.models,
            value=self.models[0] if self.models else None,
            description='Model:',
            disabled=False if self.models and self.models[0] != 'All' else True,
            layout=widgets.Layout(width='300px')
        )
        
        # Experiment selector
        self.experiment_dropdown = widgets.Dropdown(
            description='Experiment:',
            disabled=False if self.model_experiment_map else True,
            layout=widgets.Layout(width='300px')
        )
        
        # Update experiment options based on model
        if self.model_experiment_map and self.models and self.models[0] != 'All':
            self.experiment_dropdown.options = ['All'] + self.model_experiment_map.get(self.models[0], [])
            self.experiment_dropdown.value = 'All'
        
        # Difficulty selector
        self.difficulty_dropdown = widgets.Dropdown(
            options=self.difficulties,
            value='All',
            description='Difficulty:',
            disabled=False,
            layout=widgets.Layout(width='200px')
        )
        
        # Request ID selector
        self.req_id_dropdown = widgets.Dropdown(
            options=self.req_ids,
            description='Request ID:',
            disabled=False,
            layout=widgets.Layout(width='300px')
        )
        
        # Experiment number selector (will be populated based on req_id)
        self.exp_dropdown = widgets.Dropdown(
            description='Experiment #:',
            disabled=False,
            layout=widgets.Layout(width='200px')
        )
        
        # View mode selector
        self.view_mode = widgets.RadioButtons(
            options=['Side by Side', 'Diff View'],
            value='Side by Side',
            description='View Mode:',
            layout=widgets.Layout(width='250px')
        )
        
        # Display options
        self.highlight_diffs = widgets.Checkbox(
            value=True,
            description='Highlight differences',
            disabled=False
        )
        
        # Toggle button for highlighting
        self.highlight_toggle = widgets.ToggleButton(
            value=True,
            description='Highlighting',
            disabled=False,
            button_style='success',
            tooltip='Toggle highlighting on/off',
            icon='check'
        )
        
        self.format_sql = widgets.Checkbox(
            value=True,
            description='Format SQL',
            disabled=False
        )
        
        self.show_metrics = widgets.Checkbox(
            value=True,
            description='Show Metrics',
            disabled=False
        )
        
        self.show_request = widgets.Checkbox(
            value=True,
            description='Show Request & Knowledge',
            disabled=False
        )
        
        # Search box
        self.search_box = widgets.Text(
            value='',
            placeholder='Search in queries...',
            description='Search:',
            disabled=False,
            layout=widgets.Layout(width='300px')
        )
        
        # Output area
        self.output = widgets.Output()
        
        # Summary stats
        self.summary_output = widgets.Output()
    
    def setup_event_handlers(self):
        """Set up event handlers for the widgets."""
        self.model_dropdown.observe(self.on_model_change, names='value')
        self.experiment_dropdown.observe(self.on_experiment_change, names='value')
        self.difficulty_dropdown.observe(self.on_difficulty_change, names='value')
        self.req_id_dropdown.observe(self.on_req_id_change, names='value')
        self.exp_dropdown.observe(self.on_exp_change, names='value')
        self.view_mode.observe(self.update_display, names='value')
        self.highlight_diffs.observe(self.update_display, names='value')
        self.highlight_toggle.observe(self.on_highlight_toggle, names='value')
        self.format_sql.observe(self.update_display, names='value')
        self.show_metrics.observe(self.update_display, names='value')
        self.show_request.observe(self.update_display, names='value')
        self.search_box.observe(self.on_search, names='value')
    
    def display_widget(self):
        """Display the widget."""
        # Create the layout
        model_exp_box = widgets.HBox([
            widgets.VBox([self.model_dropdown, self.experiment_dropdown]),
            widgets.VBox([self.difficulty_dropdown])
        ])
        
        req_exp_box = widgets.HBox([
            widgets.VBox([self.req_id_dropdown, self.exp_dropdown])
        ])
        
        view_options_box = widgets.HBox([
            widgets.VBox([self.view_mode]),
            widgets.VBox([self.highlight_toggle, self.format_sql]),
            widgets.VBox([self.show_metrics, self.show_request])
        ])
        
        search_box = widgets.HBox([self.search_box])
        
        # Update the exp_dropdown based on the initial req_id
        self.update_exp_dropdown()
        
        # Update the filtered requests based on initial selections
        self.filter_requests()
        
        # Display summary stats
        self.update_summary_stats()
        
        # Create the main layout
        main_layout = widgets.VBox([
            self.summary_output,
            widgets.HBox([model_exp_box, req_exp_box]),
            widgets.HBox([view_options_box, search_box]),
            self.output
        ])
        
        display(main_layout)
        
        # Update the display
        self.update_display()
    
    def on_model_change(self, change):
        """Event handler for model dropdown change."""
        model = change.new
        
        # Update experiment options based on selected model
        if self.model_experiment_map and model != 'All':
            self.experiment_dropdown.options = ['All'] + self.model_experiment_map.get(model, [])
            self.experiment_dropdown.value = 'All'
        
        # Filter requests based on new model selection
        self.filter_requests()
    
    def on_experiment_change(self, change):
        """Event handler for experiment dropdown change."""
        # Filter requests based on new experiment selection
        self.filter_requests()
    
    def on_difficulty_change(self, change):
        """Event handler for difficulty dropdown change."""
        # Filter requests based on new difficulty selection
        self.filter_requests()
    
    def get_filtered_results(self):
        """Get results filtered by current model, experiment, and difficulty selections."""
        model = self.model_dropdown.value
        experiment = self.experiment_dropdown.value if hasattr(self.experiment_dropdown, 'value') else None
        difficulty = self.difficulty_dropdown.value
        
        # Start with all results
        if model == 'All':
            filtered_results = self.all_results
        else:
            # Filter by model
            if experiment == 'All':
                # Get all experiments for this model
                filtered_results = []
                for exp in self.model_experiment_map.get(model, []):
                    key = f"{model}-{exp}"
                    if key in self.results_dict:
                        filtered_results.extend(self.results_dict[key])
            else:
                # Get specific model-experiment combination
                key = f"{model}-{experiment}"
                filtered_results = self.results_dict.get(key, [])
        
        # Apply difficulty filter if not 'All'
        if difficulty != 'All':
            filtered_results = [r for r in filtered_results if r.get('difficulty', 'unknown') == difficulty]
        
        return filtered_results
    
    def filter_requests(self):
        """Filter request IDs based on model, experiment, and difficulty selections."""
        # Get filtered results
        filtered_results = self.get_filtered_results()
        
        # Update req_ids dropdown
        filtered_req_ids = sorted(list(set([r['req_id'] for r in filtered_results])))
        
        if filtered_req_ids:
            current_req_id = self.req_id_dropdown.value
            self.req_id_dropdown.options = filtered_req_ids
            
            # Try to keep the current selection if possible
            if current_req_id in filtered_req_ids:
                self.req_id_dropdown.value = current_req_id
            else:
                self.req_id_dropdown.value = filtered_req_ids[0]
        else:
            self.req_id_dropdown.options = []
        
        # Update experiment number dropdown
        self.update_exp_dropdown()
        
        # Update summary stats
        self.update_summary_stats(filtered_results)
    
    def update_exp_dropdown(self):
        """Update the experiment dropdown based on the selected request ID."""
        req_id = self.req_id_dropdown.value if self.req_id_dropdown.options else None
        if not req_id:
            self.exp_dropdown.options = []
            return
        
        # Get filtered results
        filtered_results = self.get_filtered_results()
        
        # Filter by request ID
        req_results = [r for r in filtered_results if r['req_id'] == req_id]
        
        # Extract experiment numbers
        exp_nums = sorted(list(set([r['n_exp'] for r in req_results])))
        
        self.exp_dropdown.options = exp_nums
        
        # If there are experiment numbers, select the first one
        if exp_nums:
            self.exp_dropdown.value = exp_nums[0]
    
    def on_req_id_change(self, change):
        """Event handler for request ID dropdown change."""
        self.update_exp_dropdown()
    
    def on_exp_change(self, change):
        """Event handler for experiment number dropdown change."""
        self.update_display()
    
    def on_highlight_toggle(self, change):
        """Event handler for highlight toggle button."""
        if change.new:  # Highlighting on
            self.highlight_toggle.button_style = 'success'
            self.highlight_toggle.icon = 'check'
            self.highlight_diffs.value = True
        else:  # Highlighting off
            self.highlight_toggle.button_style = 'warning'
            self.highlight_toggle.icon = 'remove'
            self.highlight_diffs.value = False
    
    def on_search(self, change):
        """Event handler for search box change."""
        search_term = self.search_box.value.lower()
        
        if not search_term:
            # Reset filters
            self.filter_requests()
            return
        
        # Get filtered results based on current model, experiment, and difficulty
        filtered_results = self.get_filtered_results()
        
        # Filter by search term
        search_filtered = []
        for result in filtered_results:
            if (search_term in result.get('req_id', '').lower() or 
                search_term in str(result.get('n_exp', '')).lower() or 
                search_term in result.get('comparison', {}).get('sql_gold', '').lower() or 
                search_term in result.get('comparison', {}).get('sql_pred', '').lower()):
                search_filtered.append(result)
        
        # Update req_ids dropdown
        filtered_req_ids = sorted(list(set([r['req_id'] for r in search_filtered])))
        
        if filtered_req_ids:
            current_req_id = self.req_id_dropdown.value
            self.req_id_dropdown.options = filtered_req_ids
            
            # Try to keep the current selection if possible
            if current_req_id in filtered_req_ids:
                self.req_id_dropdown.value = current_req_id
            else:
                self.req_id_dropdown.value = filtered_req_ids[0]
        else:
            self.req_id_dropdown.options = []
        
        # Update experiment number dropdown
        self.update_exp_dropdown()
        
        # Update summary stats
        self.update_summary_stats(search_filtered)
    
    def get_current_result(self):
        """Get the current result based on selected request ID and experiment number."""
        req_id = self.req_id_dropdown.value if self.req_id_dropdown.options else None
        exp_num = self.exp_dropdown.value if self.exp_dropdown.options else None
        
        if not req_id or not exp_num:
            return None
        
        # Get filtered results
        filtered_results = self.get_filtered_results()
        
        # Find the specific result
        for result in filtered_results:
            if result['req_id'] == req_id and result['n_exp'] == exp_num:
                return result
        
        return None
    
    def get_request_knowledge(self, req_id):
        """Get request and knowledge information from train or test dataframes."""
        if not self.df_train is None and not self.df_test is None:
            # Look in train dataframe
            train_row = self.df_train[self.df_train['req_id'] == int(req_id)]
            if not train_row.empty:
                return {
                    'request': train_row['request'].values[0],
                    'external_knowledge': train_row['external_knowledge'].values[0] if 'external_knowledge' in train_row.columns else '',
                    'domain_knowledge': train_row['domain_knowledge'].values[0] if 'domain_knowledge' in train_row.columns else ''
                }
            
            # Look in test dataframe
            test_row = self.df_test[self.df_test['req_id'] == int(req_id)]
            if not test_row.empty:
                return {
                    'request': test_row['request'].values[0],
                    'external_knowledge': test_row['external_knowledge'].values[0] if 'external_knowledge' in test_row.columns else '',
                    'domain_knowledge': test_row['domain_knowledge'].values[0] if 'domain_knowledge' in test_row.columns else ''
                }
        
        return None
    
    def format_sql_query(self, query, format_enabled=True):
        """Format a SQL query for better readability."""
        if not query:
            return ""
        
        if format_enabled:
            try:
                import sqlparse
                format_query = sqlparse.format(query,
                                               strip_comments=True,
                                               reindent=True,
                                               keyword_case='upper',
                                               )
                # Simple SQL formatting (basic indentation)
                # For more advanced formatting, consider using sqlparse or similar libraries
                formatted_lines = []
                keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING', 'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN', 'OUTER JOIN', 'LIMIT', 'OFFSET']
                
                lines = format_query.split('\n')
                for line in lines:
                    line = line.strip()
                    for keyword in keywords:
                        if line.upper().startswith(keyword):
                            line = '\n' + line
                            break
                    formatted_lines.append(line)
                
                return ' '.join(formatted_lines)
            except ImportError:
                # Fallback if sqlparse is not available
                formatted_lines = []
                keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING', 'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN', 'OUTER JOIN', 'LIMIT', 'OFFSET']
                
                lines = query.split('\n')
                for line in lines:
                    line = line.strip()
                    for keyword in keywords:
                        if line.upper().startswith(keyword):
                            line = '\n' + line
                            break
                    formatted_lines.append(line)
                
                return ' '.join(formatted_lines)
        
        return query
    
    def highlight_sql(self, query):
        """Apply syntax highlighting to SQL code."""
        return pygments.highlight(query, SqlLexer(), HtmlFormatter())
    
    def generate_diff_html(self, gold_query, pred_query, format_enabled=True):
        """Generate HTML displaying the differences between two SQL queries."""
        gold_formatted = self.format_sql_query(gold_query, format_enabled)
        pred_formatted = self.format_sql_query(pred_query, format_enabled)
        
        gold_lines = gold_formatted.splitlines() or [""]
        pred_lines = pred_formatted.splitlines() or [""]
        
        differ = difflib.Differ()
        diff = list(differ.compare(gold_lines, pred_lines))
        
        html = "<div class='diff-view'>"
        
        for line in diff:
            if line.startswith('+ '):
                html += f"<div class='diff-add'>{line[2:]}</div>"
            elif line.startswith('- '):
                html += f"<div class='diff-remove'>{line[2:]}</div>"
            elif line.startswith('? '):
                continue  # Skip the indicator line
            else:
                html += f"<div class='diff-same'>{line[2:]}</div>"
        
        html += "</div>"
        return html
    
    def generate_highlighted_diffs(self, gold_query, pred_query, highlight_diffs=True, format_enabled=True):
        """Generate HTML for side-by-side comparison with highlighted differences."""
        # If highlighting is disabled, just apply syntax highlighting
        if not highlight_diffs:
            gold_formatted = self.format_sql_query(gold_query, format_enabled)
            pred_formatted = self.format_sql_query(pred_query, format_enabled)
            gold_html = self.highlight_sql(gold_formatted)
            pred_html = self.highlight_sql(pred_formatted)
            return gold_html, pred_html
        
        # Format queries for better readability
        gold_formatted = self.format_sql_query(gold_query, format_enabled)
        pred_formatted = self.format_sql_query(pred_query, format_enabled)
        
        # Break into lines for line-by-line comparison
        gold_lines = gold_formatted.splitlines() if gold_formatted else [""]
        pred_lines = pred_formatted.splitlines() if pred_formatted else [""]
        
        # Use SequenceMatcher to find differences between lines
        matcher = difflib.SequenceMatcher(None, gold_lines, pred_lines)
        
        # Create line-based HTML with proper diff highlighting classes
        gold_html_with_classes = []
        pred_html_with_classes = []
        
        # Process differences using the opcodes from the matcher
        for tag, i1, i2, j1, j2 in matcher.get_opcodes():
            if tag == 'equal':
                # No differences - keep as is
                for i in range(i1, i2):
                    line = gold_lines[i]
                    gold_html_with_classes.append(f'<div class="line">{line}</div>')
                
                for j in range(j1, j2):
                    line = pred_lines[j]
                    pred_html_with_classes.append(f'<div class="line">{line}</div>')
                    
            elif tag == 'replace':
                # Lines were changed
                for i in range(i1, i2):
                    line = gold_lines[i]
                    gold_html_with_classes.append(f'<div class="line diff-change">{line}</div>')
                
                for j in range(j1, j2):
                    line = pred_lines[j]
                    pred_html_with_classes.append(f'<div class="line diff-change">{line}</div>')
                    
            elif tag == 'delete':
                # Lines in gold not in pred
                for i in range(i1, i2):
                    line = gold_lines[i]
                    gold_html_with_classes.append(f'<div class="line diff-remove">{line}</div>')
                    
            elif tag == 'insert':
                # Lines in pred not in gold
                for j in range(j1, j2):
                    line = pred_lines[j]
                    pred_html_with_classes.append(f'<div class="line diff-add">{line}</div>')
        
        # Join the lines into a single string
        gold_with_diff_markers = '\n'.join(gold_html_with_classes)
        pred_with_diff_markers = '\n'.join(pred_html_with_classes)
        
        # Apply syntax highlighting to the full content
        # First, create plain text versions without the HTML markers
        gold_plain = '\n'.join(gold_lines)
        pred_plain = '\n'.join(pred_lines)
        
        # Apply syntax highlighting
        gold_highlighted = self.highlight_sql(gold_plain)
        pred_highlighted = self.highlight_sql(pred_plain)
        
        # Now extract the highlighted content (which will be inside a <div class="highlight">)
        gold_content = gold_highlighted
        pred_content = pred_highlighted
        
        # Create combined HTML with both syntax highlighting and diff classes
        gold_final = '<div class="sql-with-diffs">'
        pred_final = '<div class="sql-with-diffs">'
        
        # Create line-by-line HTML with proper class assignments
        gold_highlighted_lines = gold_content.splitlines()
        pred_highlighted_lines = pred_content.splitlines()
        
        # Match up the highlighted lines with our diff-marked lines
        for i, line in enumerate(gold_html_with_classes):
            # Extract the class name for the diff highlighting
            diff_class = ""
            if "diff-change" in line:
                diff_class = "diff-change"
            elif "diff-remove" in line:
                diff_class = "diff-remove"
            elif "diff-add" in line:
                diff_class = "diff-add"
            
            # Check if we have a corresponding highlighted line
            if i < len(gold_highlighted_lines):
                highlighted_line = gold_highlighted_lines[i]
                if diff_class:
                    gold_final += f'<div class="{diff_class}">{highlighted_line}</div>'
                else:
                    gold_final += f'<div>{highlighted_line}</div>'
            else:
                # If no highlighted line, use the original with diff class
                content = line.replace('<div class="line diff-change">', '').replace('<div class="line diff-remove">', '').replace('<div class="line diff-add">', '').replace('<div class="line">', '').replace('</div>', '')
                if diff_class:
                    gold_final += f'<div class="{diff_class}">{content}</div>'
                else:
                    gold_final += f'<div>{content}</div>'
        
        for i, line in enumerate(pred_html_with_classes):
            # Extract the class name for the diff highlighting
            diff_class = ""
            if "diff-change" in line:
                diff_class = "diff-change"
            elif "diff-remove" in line:
                diff_class = "diff-remove"
            elif "diff-add" in line:
                diff_class = "diff-add"
            
            # Check if we have a corresponding highlighted line
            if i < len(pred_highlighted_lines):
                highlighted_line = pred_highlighted_lines[i]
                if diff_class:
                    pred_final += f'<div class="{diff_class}">{highlighted_line}</div>'
                else:
                    pred_final += f'<div>{highlighted_line}</div>'
            else:
                # If no highlighted line, use the original with diff class
                content = line.replace('<div class="line diff-change">', '').replace('<div class="line diff-remove">', '').replace('<div class="line diff-add">', '').replace('<div class="line">', '').replace('</div>', '')
                if diff_class:
                    pred_final += f'<div class="{diff_class}">{content}</div>'
                else:
                    pred_final += f'<div>{content}</div>'
        
        gold_final += '</div>'
        pred_final += '</div>'
        
        return gold_final, pred_final
    
    def format_metrics_table(self, metrics_dict):
        """Format metrics as an HTML table."""
        if isinstance(metrics_dict, str):
            # This is an error message
            return f"<div class='error-container'>üõë {metrics_dict}</div>"
        
        html = "<table class='metrics-table'>"
        html += "<tr><th>Metric</th><th>Value</th></tr>"
        
        for key, value in metrics_dict.items():
            html += f"<tr><td>{key}</td><td>{value}</td></tr>"
        
        html += "</table>"
        return html
    
    def format_tables_comparison(self, gold_tables, pred_tables):
        """Format table comparison as HTML."""
        if (not gold_tables or gold_tables == "[]" or gold_tables == "None") and \
           (not pred_tables or pred_tables == "[]" or pred_tables == "None"):
            return "<p>No table information available</p>"
        
        # Debug information to help troubleshoot
        debug_info = f"""
        <div style="display: none;">
            <p>Gold tables type: {type(gold_tables)}</p>
            <p>Gold tables raw: {gold_tables}</p>
            <p>Pred tables type: {type(pred_tables)}</p>
            <p>Pred tables raw: {pred_tables}</p>
        </div>
        """
        
        # Function to normalize table names
        def normalize_table_name(name):
            if not name:
                return ""
            # Convert to lowercase for case-insensitive comparison
            name = str(name).lower().strip()
            # Remove common prefixes/suffixes
            for prefix in ['table_', 'tbl_', 't_']:
                if name.startswith(prefix):
                    name = name[len(prefix):]
            # Remove quotes and backticks
            name = name.strip('"\'`')
            return name
        
        # Process gold tables
        processed_gold_tables = []
        if gold_tables:
            # Try to parse if they're string representations of lists or dictionaries
            if isinstance(gold_tables, str):
                # Handle common string formats
                gold_tables = gold_tables.strip()
                if gold_tables in ["None", "[]", "{}", ""]:
                    processed_gold_tables = []
                else:
                    try:
                        # Try different formats: list, dict, or comma-separated string
                        if (gold_tables.startswith('[') and gold_tables.endswith(']')) or \
                           (gold_tables.startswith('{') and gold_tables.endswith('}')):
                            try:
                                parsed = ast.literal_eval(gold_tables)
                                if isinstance(parsed, list):
                                    processed_gold_tables = parsed
                                elif isinstance(parsed, dict):
                                    processed_gold_tables = list(parsed.keys())
                                else:
                                    processed_gold_tables = [parsed]
                            except Exception:
                                # If parsing fails, try comma splitting
                                if ',' in gold_tables:
                                    # Strip brackets if present
                                    clean_str = gold_tables.strip('[]{}')
                                    processed_gold_tables = [t.strip(' \'"') for t in clean_str.split(',')]
                                else:
                                    # Just one table name with brackets
                                    processed_gold_tables = [gold_tables.strip('[]{}')]
                        elif ',' in gold_tables:
                            processed_gold_tables = [t.strip() for t in gold_tables.split(',')]
                        else:
                            # Single table name
                            processed_gold_tables = [gold_tables]
                    except Exception:
                        # If all parsing fails, treat as a single table name
                        processed_gold_tables = [gold_tables]
            elif isinstance(gold_tables, list):
                processed_gold_tables = gold_tables
            elif isinstance(gold_tables, dict):
                processed_gold_tables = list(gold_tables.keys())
            else:
                # Any other type, convert to string
                processed_gold_tables = [str(gold_tables)]
        
        # Process predicted tables
        processed_pred_tables = []
        if pred_tables:
            # Try to parse if they're string representations of lists or dictionaries
            if isinstance(pred_tables, str):
                # Handle common string formats
                pred_tables = pred_tables.strip()
                if pred_tables in ["None", "[]", "{}", ""]:
                    processed_pred_tables = []
                else:
                    try:
                        # Try different formats: list, dict, or comma-separated string
                        if (pred_tables.startswith('[') and pred_tables.endswith(']')) or \
                           (pred_tables.startswith('{') and pred_tables.endswith('}')):
                            try:
                                parsed = ast.literal_eval(pred_tables)
                                if isinstance(parsed, list):
                                    processed_pred_tables = parsed
                                elif isinstance(parsed, dict):
                                    processed_pred_tables = list(parsed.keys())
                                else:
                                    processed_pred_tables = [parsed]
                            except Exception:
                                # If parsing fails, try comma splitting
                                if ',' in pred_tables:
                                    # Strip brackets if present
                                    clean_str = pred_tables.strip('[]{}')
                                    processed_pred_tables = [t.strip(' \'"') for t in clean_str.split(',')]
                                else:
                                    # Just one table name with brackets
                                    processed_pred_tables = [pred_tables.strip('[]{}')]
                        elif ',' in pred_tables:
                            processed_pred_tables = [t.strip() for t in pred_tables.split(',')]
                        else:
                            # Single table name
                            processed_pred_tables = [pred_tables]
                    except Exception:
                        # If all parsing fails, treat as a single table name
                        processed_pred_tables = [pred_tables]
            elif isinstance(pred_tables, list):
                processed_pred_tables = pred_tables
            elif isinstance(pred_tables, dict):
                processed_pred_tables = list(pred_tables.keys())
            else:
                # Any other type, convert to string
                processed_pred_tables = [str(pred_tables)]
        
        # Normalize and clean table names
        clean_gold_tables = []
        for table in processed_gold_tables:
            if table:
                if isinstance(table, str):
                    clean_gold_tables.append(normalize_table_name(table))
                else:
                    clean_gold_tables.append(normalize_table_name(str(table)))
        
        clean_pred_tables = []
        for table in processed_pred_tables:
            if table:
                if isinstance(table, str):
                    clean_pred_tables.append(normalize_table_name(table))
                else:
                    clean_pred_tables.append(normalize_table_name(str(table)))
        
        # Remove duplicates
        clean_gold_tables = list(set(clean_gold_tables))
        clean_pred_tables = list(set(clean_pred_tables))
        
        # Remove empty entries
        clean_gold_tables = [t for t in clean_gold_tables if t]
        clean_pred_tables = [t for t in clean_pred_tables if t]
        
        # Create comparison table
        html = debug_info + "<table class='compare-table'>"
        html += "<tr><th>Gold Tables</th><th>Predicted Tables</th><th>Status</th></tr>"
        
        # Find all unique tables
        all_tables = sorted(set(clean_gold_tables) | set(clean_pred_tables))
        
        if not all_tables:
            html += "<tr><td colspan='3'>No tables identified</td></tr>"
        else:
            for table in all_tables:
                in_gold = table in clean_gold_tables
                in_pred = table in clean_pred_tables
                status = "match" if in_gold and in_pred else "mismatch"
                status_text = "‚úì Match" if in_gold and in_pred else "‚ùå Mismatch"
                
                html += f"<tr class='{status}'>"
                html += f"<td>{'‚úì' if in_gold else '‚ùå'} {table}</td>"
                html += f"<td>{'‚úì' if in_pred else '‚ùå'} {table}</td>"
                html += f"<td>{status_text}</td>"
                html += "</tr>"
        
        html += "</table>"
        return html
    
    def format_difficulty_comparison(self, gold_difficulty, pred_difficulty):
        """Format difficulty comparison as HTML."""
        if not gold_difficulty and not pred_difficulty:
            return ""
        
        status = "match" if gold_difficulty == pred_difficulty else "mismatch"
        status_text = "‚úì Match" if gold_difficulty == pred_difficulty else "‚ùå Mismatch"
        
        html = "<table class='compare-table'>"
        html += "<tr><th>Gold Difficulty</th><th>Predicted Difficulty</th><th>Status</th></tr>"
        html += f"<tr class='{status}'>"
        html += f"<td>{gold_difficulty}</td>"
        html += f"<td>{pred_difficulty}</td>"
        html += f"<td>{status_text}</td>"
        html += "</tr>"
        html += "</table>"
        
        return html
    
    def update_display(self, change=None):
        """Update the display based on the current selections."""
        with self.output:
            clear_output()
            
            result = self.get_current_result()
            if not result:
                print("No result found for the selected criteria.")
                return
            
            # Get display options
            view_mode = self.view_mode.value
            highlight_diffs = self.highlight_diffs.value
            format_sql = self.format_sql.value
            show_metrics = self.show_metrics.value
            show_request = self.show_request.value
            
            # Extract data from the result
            comparison = result.get('comparison', {})
            gold_query = comparison.get('sql_gold', '')
            pred_query = comparison.get('sql_pred', '')
            exec_time_gold = comparison.get('execution_time_gold')
            exec_time_pred = comparison.get('execution_time_pred')
            error_gold = comparison.get('error_gold')
            error_pred = comparison.get('error_pred')
            oids = comparison.get('oids', {})
            columns = comparison.get('columns', {})
            columns_formatted = comparison.get('columns_formatted', {})
            
            # Metadata
            difficulty = result.get('difficulty', '')
            gold_tables = result.get('gold_tables', [])
            
            # Check if there's an execution error in the predicted query
            error_pred = comparison.get('error_pred')
            
            # Initialize variables
            pred_tables = []
            pred_difficulty = ''
            
            # Only look for predicted tables and difficulty if there's no execution error
            if not error_pred:
                # Look for predicted tables in multiple possible locations
                pred_tables = None
                # Try common locations where pred_tables might be stored
                for key_path in [
                    ['comparison', 'pred_tables'],
                    ['comparison', 'tables'],
                    ['pred_tables'],
                    ['tables'],
                    ['table_info'],
                    ['table_schema']
                ]:
                    # Navigate through the nested dictionary
                    current = result
                    found = True
                    for key in key_path:
                        if isinstance(current, dict) and key in current:
                            current = current[key]
                        else:
                            found = False
                            break
                    
                    if found and current:
                        pred_tables = current
                        break
                
                # If still not found, look in experiment data
                if not pred_tables and 'n_exp' in result and 'req_id' in result:
                    # Try to find in the model-experiment dictionary
                    req_id = result['req_id']
                    exp_num = result['n_exp']
                    
                    # Look through all experiment results for this req_id and n_exp
                    for model in self.model_experiment_map:
                        for experiment in self.model_experiment_map[model]:
                            key = f"{model}-{experiment}"
                            if key in self.results_dict:
                                for r in self.results_dict[key]:
                                    if r.get('req_id') == req_id and r.get('n_exp') == exp_num:
                                        if 'pred_tables' in r:
                                            pred_tables = r['pred_tables']
                                            break
                                        elif 'tables' in r:
                                            pred_tables = r['tables']
                                            break
                                        elif 'table_info' in r:
                                            pred_tables = r['table_info']
                                            break
                                        elif 'table_schema' in r:
                                            pred_tables = r['table_schema']
                                            break
                                        elif 'comparison' in r and 'pred_tables' in r['comparison']:
                                            pred_tables = r['comparison']['pred_tables']
                                            break
                
                # If nothing found, default to empty list
                if pred_tables is None:
                    pred_tables = []
                
                # Get predicted difficulty if available
                pred_difficulty = comparison.get('pred_diff', result.get('pred_difficulty', ''))
            
            task_type = result.get('type', '')
            
            # Get request and knowledge info if available
            request_info = None
            if show_request and self.df_train is not None and self.df_test is not None:
                request_info = self.get_request_knowledge(result['req_id'])
            
            # Display metadata
            metadata_html = f"""
            <div class='metadata-container'>
                <h3>Metadata</h3>
                <p><strong>Request ID:</strong> {result['req_id']} | <strong>Experiment:</strong> {result['n_exp']}</p>
                <p><strong>Difficulty:</strong> {difficulty} | <strong>Type:</strong> {task_type}</p>
            </div>
            """
            display(HTML(metadata_html))
            
            # Display request and knowledge if available
            if request_info:
                request_html = f"""
                <div class='request-container'>
                    <h3>User Request</h3>
                    <p>{request_info['request']}</p>
                </div>
                """
                display(HTML(request_html))
                
                # Display knowledge if available
                if request_info['external_knowledge'] or request_info['domain_knowledge']:
                    knowledge_html = f"""
                    <div class='knowledge-container'>
                        <h3>Knowledge</h3>
                        <p><strong>External Knowledge:</strong></p>
                        <p>{request_info['external_knowledge'] or 'None'}</p>
                        <p><strong>Domain Knowledge:</strong></p>
                        <p>{request_info['domain_knowledge'] or 'None'}</p>
                    </div>
                    """
                    display(HTML(knowledge_html))
            
            # Display tables comparison only if there's no execution error in the predicted query
            if not error_pred:
                tables_html = f"""
                <div class='metadata-container'>
                    <h3>Tables Comparison</h3>
                    {self.format_tables_comparison(gold_tables, pred_tables)}
                </div>
                """
                display(HTML(tables_html))
                
                # Display difficulty comparison only if there's no execution error and pred_difficulty is available
                if pred_difficulty:
                    difficulty_html = f"""
                    <div class='metadata-container'>
                        <h3>Difficulty Comparison</h3>
                        {self.format_difficulty_comparison(difficulty, pred_difficulty)}
                    </div>
                    """
                    display(HTML(difficulty_html))
            else:
                # Show a message instead when there's an execution error
                error_message_html = f"""
                <div class='metadata-container'>
                    <h3>Tables Comparison</h3>
                    <div class='error-container'>
                        <p>Table comparison unavailable due to execution error in predicted query:</p>
                        <p>{error_pred}</p>
                    </div>
                </div>
                """
                display(HTML(error_message_html))
            
            # Display SQL queries
            if view_mode == 'Side by Side':
                gold_html, pred_html = self.generate_highlighted_diffs(gold_query, pred_query, highlight_diffs, format_sql)
                
                queries_html = f"""
                <div class='sql-container'>
                    <div class='sql-box'>
                        <div class='sql-title'>Gold Query</div>
                        <div class='sql-query'>{gold_html}</div>
                        {'<div class="success-container">‚úÖ Execution time: ' + str(exec_time_gold) + ' seconds</div>' if exec_time_gold else ''}
                        {'<div class="error-container">üõë Error: ' + error_gold + '</div>' if error_gold else ''}
                    </div>
                    <div class='sql-box'>
                        <div class='sql-title'>Predicted Query</div>
                        <div class='sql-query'>{pred_html}</div>
                        {'<div class="success-container">‚úÖ Execution time: ' + str(exec_time_pred) + ' seconds</div>' if exec_time_pred else ''}
                        {'<div class="error-container">üõë Error: ' + error_pred + '</div>' if error_pred else ''}
                    </div>
                </div>
                """
                display(HTML(queries_html))
            else:  # Diff View
                diff_html = self.generate_diff_html(gold_query, pred_query, format_sql)
                
                queries_html = f"""
                <div class='sql-box'>
                    <div class='sql-title'>Diff View (- Gold, + Predicted)</div>
                    <div class='sql-query'>{diff_html}</div>
                </div>
                <div class='sql-container'>
                    <div class='sql-box'>
                        <div class='sql-title'>Gold Query Execution</div>
                        {'<div class="success-container">‚úÖ Execution time: ' + str(exec_time_gold) + ' seconds</div>' if exec_time_gold else ''}
                        {'<div class="error-container">üõë Error: ' + error_gold + '</div>' if error_gold else ''}
                    </div>
                    <div class='sql-box'>
                        <div class='sql-title'>Predicted Query Execution</div>
                        {'<div class="success-container">‚úÖ Execution time: ' + str(exec_time_pred) + ' seconds</div>' if exec_time_pred else ''}
                        {'<div class="error-container">üõë Error: ' + error_pred + '</div>' if error_pred else ''}
                    </div>
                </div>
                """
                display(HTML(queries_html))
            
            # Display metrics
            if show_metrics:
                metrics_html = f"""
                <div class='metrics-container'>
                    <div class='metrics-box'>
                        <div class='metrics-title'>OIDs Metrics</div>
                        {self.format_metrics_table(oids)}
                    </div>
                    <div class='metrics-box'>
                        <div class='metrics-title'>Columns Metrics</div>
                        {self.format_metrics_table(columns)}
                    </div>
                    <div class='metrics-box'>
                        <div class='metrics-title'>Columns Formatted Metrics</div>
                        {self.format_metrics_table(columns_formatted)}
                    </div>
                </div>
                """
                display(HTML(metrics_html))
    
    def update_summary_stats(self, filtered_results=None):
        """Update the summary statistics display."""
        with self.summary_output:
            clear_output()
            
            # Use filtered results if provided, otherwise use all results
            results = filtered_results if filtered_results is not None else self.all_results
            
            total_results = len(results)
            if total_results == 0:
                display(HTML("<div class='summary-container'><h3>No results match the current filters</h3></div>"))
                return
            
            perfect_matches_oids = sum(1 for r in results if r.get('comparison', {}).get('oids', {}).get('perfect_match', 0) == 1)
            perfect_matches_cols = sum(1 for r in results if r.get('comparison', {}).get('columns', {}).get('perfect_match', 0) == 1)
            
            # Calculate average F1 scores
            f1_scores_oids = [r.get('comparison', {}).get('oids', {}).get('f1_score', 0) for r in results 
                             if isinstance(r.get('comparison', {}).get('oids', {}), dict) and 'f1_score' in r.get('comparison', {}).get('oids', {})]
            f1_scores_cols = [r.get('comparison', {}).get('columns', {}).get('f1_score', 0) for r in results 
                             if isinstance(r.get('comparison', {}).get('columns', {}), dict) and 'f1_score' in r.get('comparison', {}).get('columns', {})]
            
            avg_f1_oids = sum(f1_scores_oids) / len(f1_scores_oids) if f1_scores_oids else 0
            avg_f1_cols = sum(f1_scores_cols) / len(f1_scores_cols) if f1_scores_cols else 0
            
            # Count by difficulty
            difficulties = {}
            for r in results:
                diff = r.get('difficulty', 'unknown')
                difficulties[diff] = difficulties.get(diff, 0) + 1
            
            # Count execution errors
            execution_errors = sum(1 for r in results if r.get('comparison', {}).get('error_pred') is not None)
            
            # Create summary HTML
            summary_html = f"""
            <div class='summary-container'>
                <h3>Summary Statistics</h3>
                <p><strong>Total Results:</strong> {total_results}</p>
                <p><strong>Perfect Matches (OIDs):</strong> {perfect_matches_oids} ({perfect_matches_oids/total_results*100:.1f}%)</p>
                <p><strong>Perfect Matches (Columns):</strong> {perfect_matches_cols} ({perfect_matches_cols/total_results*100:.1f}%)</p>
                <p><strong>Average F1 Score (OIDs):</strong> {avg_f1_oids:.3f}</p>
                <p><strong>Average F1 Score (Columns):</strong> {avg_f1_cols:.3f}</p>
                <p><strong>Execution Errors:</strong> {execution_errors} ({execution_errors/total_results*100:.1f}%)</p>
                <p><strong>By Difficulty:</strong> {', '.join([f"{k}: {v}" for k, v in difficulties.items()])}</p>
            </div>
            """
            display(HTML(summary_html))


def visualize_sql_comparisons(results_dict, df_train=None, df_test=None):
    """
    Create and display an interactive widget for visualizing SQL query comparisons.
    
    Parameters:
    -----------
    results_dict : dict
        Dictionary where keys are 'model_name-experiment_name' and values are lists of result dictionaries
    df_train : pandas.DataFrame, optional
        DataFrame containing training data with requests and knowledge
    df_test : pandas.DataFrame, optional
        DataFrame containing test data with requests and knowledge
        
    Returns:
    --------
    SQLComparisonWidget
        The widget instance
    """
    widget = SQLComparisonWidget(results_dict, df_train, df_test)
    return widget

# SQL Query Comparison Widget for Text-to-SQL Tasks

This notebook provides an interactive visualization for comparing SQL query predictions made by language models in Text-to-SQL tasks. The widget allows for:

1. Selecting specific models, experiments, and queries by their attributes
2. Viewing gold and predicted SQL queries side by side or in diff view
3. Analyzing execution times and errors
4. Exploring evaluation metrics (OIDs, columns, columns_formatted)
5. Viewing the original user request and associated knowledge
6. Comparing table usage and difficulty predictions
7. Filtering and searching through the results

## Features

- **Selection Controls**: 
  - Select models, experiments, and difficulty levels
  - Choose specific request IDs and experiment numbers
- **Display Options**: 
  - Side-by-side or diff view 
  - Toggle highlighting of differences on/off
  - Format SQL for better readability
- **Execution Info**: Shows execution times and any errors that occurred
- **Metadata Display**: Shows difficulty, task type, and gold tables used
- **Request Context**: Displays the original user request and associated knowledge
- **Table Comparison**: Compares gold tables with predicted tables
- **Metrics Visualization**: Displays evaluation metrics in an organized tabular format
- **Summary Statistics**: Provides aggregate statistics about all results

## Usage Example

To use the widget with your data:

```python
# Create a dictionary where keys are "model_name-experiment_name" and values are lists of results
results_dict = {
    "model1-experiment1": results_list_1,
    "model1-experiment2": results_list_2,
    "model2-experiment1": results_list_3
}

# Create the widget
widget = visualize_sql_comparisons(
    results_dict=results_dict,
    df_train=your_train_dataframe,  # Optional: for request & knowledge display
    df_test=your_test_dataframe     # Optional: for request & knowledge display
)
```

Where each list in `results_dict` contains dictionaries representing prediction results for that model-experiment combination.

In [None]:
evaluation_results_ = evaluation_results_test
models = list(evaluation_results_.keys())

self_corr = True
std_dev = True
# Initialize data collection lists
experiment_labels = {}
difficulties = []
# Self-corrected results are expected
if self_corr: self_corr_key = 'self_corrected'
else: self_corr_key = 'corrected'

from utils.eval_utils import metrics_aggregation

# Create a dictionary to store results by model-experiment
results_dict = {}

for model, experiments in evaluation_results_.items():
    # Initialize the model-experiment key in the results dictionary
    for experiment, exp_results in experiments.items():
        try:
            exp_label = f"{model}-{experiment}"
            oid_match_rates = {}
            column_match_rates = {}
            if std_dev:
                oid_match_rates_std = {}
                column_match_rates_std = {}
                
            results = exp_results[self_corr_key]['detailed_results'] # list of dictionaries
            # Store results in the dictionary with model-experiment as key
            results_dict[exp_label] = results
            
            aggregate_metrics = metrics_aggregation(results=results)
        except KeyError as e:
            print(f"Warning: Missing data for {model} - {experiment}: {e}")

# Create the visualization widget with the results dictionary
widget = visualize_sql_comparisons(
    results_dict=results_dict,
    df_train=db_train,
    df_test=db_test
)



VBox(children=(Output(), HBox(children=(HBox(children=(VBox(children=(Dropdown(description='Model:', layout=La‚Ä¶