In [None]:
import psycopg2
from psycopg2.extras import Json

# ====== Quantum Comparison Helper Functions ======

import math

def _counts_to_probability(measure_counts):
    """
    Convert measurement_counts dict -> probability dict,
    e.g. {'0010': 5, '1110': 10} -> {'0010': 0.3333, '1110': 0.6667}.
    """
    total_shots = sum(measure_counts.values())
    if total_shots == 0:
        return {}
    return {state: c / total_shots for (state, c) in measure_counts.items()}

def _kl_div(p, q):
    """
    Kullback–Leibler divergence: sum( p(x) log2 [p(x)/q(x)] ), ignoring zero terms.
    """
    div = 0.0
    for key, p_val in p.items():
        if p_val > 0:
            q_val = q.get(key, 0.0)
            if q_val > 0:
                div += p_val * math.log(p_val / q_val, 2)
    return div

def _js_distance(p, q):
    """
    Jensen–Shannon distance = sqrt(JS divergence).
    JS divergence = 0.5*KL(p||m) + 0.5*KL(q||m), where m = 0.5*(p+q).
    """
    all_keys = set(p.keys()) | set(q.keys())
    m = {}
    for key in all_keys:
        p_val = p.get(key, 0.0)
        q_val = q.get(key, 0.0)
        m[key] = 0.5*(p_val + q_val)
    
    js_div = 0.5*_kl_div(p, m) + 0.5*_kl_div(q, m)
    return math.sqrt(js_div)

def _euclidean_distance(v1, v2):
    """
    Simple L2 distance between two lists. If lengths differ, compare up to min(len).
    """
    n = min(len(v1), len(v2))
    dist_sq = 0.0
    for i in range(n):
        diff = (v1[i] - v2[i])
        dist_sq += diff*diff
    return math.sqrt(dist_sq)

# ----- Deeper Helper Subroutines -----

def _compare_bitstring_probabilities(p1, p2, top_n=5):
    """
    Return a list of the top-N bitstrings with the largest absolute difference in probability.
    Each item is (bitstring, pval1, pval2, abs_diff).
    """
    all_keys = set(p1.keys()) | set(p2.keys())
    diffs = []
    for bs in all_keys:
        val1 = p1.get(bs, 0.0)
        val2 = p2.get(bs, 0.0)
        diff = abs(val1 - val2)
        diffs.append((bs, val1, val2, diff))
    # Sort by absolute difference descending
    diffs.sort(key=lambda x: x[3], reverse=True)
    return diffs[:top_n]

def _compare_scaled_angles(angles1, angles2, top_n=5):
    """
    Return the top-N angles with the largest absolute difference.
    Each item is (angle_index, a1, a2, abs_diff).
    """
    n = min(len(angles1), len(angles2))
    diffs = []
    for i in range(n):
        a1 = angles1[i]
        a2 = angles2[i]
        diff = abs(a1 - a2)
        diffs.append((i, a1, a2, diff))
    diffs.sort(key=lambda x: x[3], reverse=True)
    return diffs[:top_n]

def _compare_advanced_stats(adv_stats_1, adv_stats_2):
    """
    Compare advanced_stats dictionary. Return a list of (stat_name, val1, val2, abs_diff).
    For example: "avg_jitter", "std_jitter", etc. 
    """
    all_stats = set(adv_stats_1.keys()) | set(adv_stats_2.keys())
    diffs = []
    for stat in all_stats:
        val1 = adv_stats_1.get(stat, 0.0)
        val2 = adv_stats_2.get(stat, 0.0)
        diff = abs(val1 - val2)
        if diff > 0:
            diffs.append((stat, val1, val2, diff))
    # Sort descending by absolute difference
    diffs.sort(key=lambda x: x[3], reverse=True)
    return diffs


# --- Database Retrieval ---

def get_analysis_from_db(record_id):
    """
    Fetch the analysis record for a given record_id from your PostgreSQL database,
    then return the analysis_data (a dictionary stored as JSONB).
    """
    # Instantiate your database connection.
    db = QuantumMusicDB()  # Assumes your QuantumMusicDB class is defined and imported.
    
    row = db.fetch_analysis(record_id)
    db.close()
    
    if row:
        # Expected row format: (id, file_name, sample_rate, analysis_data)
        analysis_data = row[3]
        return analysis_data
    else:
        raise ValueError(f"Record with ID {record_id} not found.")


# --- Comparison Module ---

def compare_recordings(analysis1, analysis2):
    """
    Compare two analysis dictionaries on several metrics and produce a detailed report.
    The comparison includes:
      - Raga detection: Are the detected ragas the same or different?
      - Aggregated feedback: Which recording has more (or fewer) aggregated feedback items?
      - High-level feedback: A comparison of the number of high-level (actionable) issues.
      - Pitch accuracy: If available, a comparison of the average pitch accuracy.
    
    Returns a dictionary with a detailed report and individual comparisons.
    """
    comparison = {}
    
    # 1. Raga Detection Comparison
    raga_info1 = analysis1.get("raga_info", {})
    raga_info2 = analysis2.get("raga_info", {})
    
    if raga_info1 and raga_info2:
        best1 = raga_info1.get("best_raga", "Unknown")
        best2 = raga_info2.get("best_raga", "Unknown")
        if best1 == best2:
            raga_comp = f"Both recordings were detected as '{best1}'."
        else:
            raga_comp = (f"Recording 1 was detected as '{best1}', while Recording 2 was detected as '{best2}'.")
    else:
        raga_comp = "Insufficient raga detection data for one or both recordings."
    comparison["raga_comparison"] = raga_comp
    
    # 2. Basic Aggregated Feedback Comparison
    agg_feedback1 = analysis1.get("aggregate_feedback", [])
    agg_feedback2 = analysis2.get("aggregate_feedback", [])
    count1 = len(agg_feedback1)
    count2 = len(agg_feedback2)
    
    if count1 < count2:
        agg_comp = (f"Recording 1 appears to have fewer performance issues "
                    f"({count1} aggregated feedback messages) compared to Recording 2 ({count2}).")
    elif count1 > count2:
        agg_comp = (f"Recording 2 appears to have fewer performance issues "
                    f"({count2} aggregated feedback messages) compared to Recording 1 ({count1}).")
    else:
        agg_comp = f"Both recordings have a similar number of aggregated feedback items ({count1})."
    comparison["aggregated_feedback_comparison"] = agg_comp
    
    # 3. High-Level Feedback Comparison (actionable issues)
    high_feedback1 = analysis1.get("high_level_feedback", [])
    high_feedback2 = analysis2.get("high_level_feedback", [])
    high_count1 = len(high_feedback1)
    high_count2 = len(high_feedback2)
    
    if high_count1 < high_count2:
        high_comp = (f"Recording 1 has fewer sustained performance issues "
                     f"({high_count1} high-level issues) compared to Recording 2 ({high_count2}).")
    elif high_count1 > high_count2:
        high_comp = (f"Recording 2 has fewer sustained performance issues "
                     f"({high_count2} high-level issues) compared to Recording 1 ({high_count1}).")
    else:
        high_comp = f"Both recordings have a similar number of high-level issues ({high_count1})."
    comparison["high_level_feedback_comparison"] = high_comp
    
    # 4. Pitch Accuracy Comparison (if available)
    avg_pitch1 = analysis1.get("average_pitch_accuracy")
    avg_pitch2 = analysis2.get("average_pitch_accuracy")
    if avg_pitch1 is not None and avg_pitch2 is not None:
        if avg_pitch1 > avg_pitch2:
            pitch_comp = (f"Recording 1 has a higher average pitch accuracy ({avg_pitch1:.2f}) "
                          f"than Recording 2 ({avg_pitch2:.2f}).")
        elif avg_pitch1 < avg_pitch2:
            pitch_comp = (f"Recording 2 has a higher average pitch accuracy ({avg_pitch2:.2f}) "
                          f"than Recording 1 ({avg_pitch1:.2f}).")
        else:
            pitch_comp = f"Both recordings have similar average pitch accuracy ({avg_pitch1:.2f})."
    else:
        pitch_comp = "Pitch accuracy data is missing for one or both recordings."
    comparison["pitch_comparison"] = pitch_comp
    
    # Combine the individual comparisons into a detailed report
    detailed_report = (
        "=== Detailed Comparison Report ===\n\n"
        f"Raga Detection:\n{comparison['raga_comparison']}\n\n"
        f"Aggregated Feedback:\n{comparison['aggregated_feedback_comparison']}\n\n"
        f"High-Level Feedback:\n{comparison['high_level_feedback_comparison']}\n\n"
        f"Pitch Accuracy:\n{comparison['pitch_comparison']}\n"
    )
    comparison["detailed_report"] = detailed_report
    return comparison


def compare_quantum_recordings(analysis1, analysis2, top_n=5):
    """
    Compare two recordings' quantum analyses with more detail:
      1) Convert measurement_counts -> probability, compute JS distance.
      2) Compare scaled angles with Euclidean distance.
      3) Identify the top-N bitstrings that differ most in probability.
      4) Identify which scaled angles differ most.
      5) Compare advanced stats differences, if present.
    
    Returns a dict with distances, top differences, and a summary.
    """
    qa1 = analysis1.get("quantum_analysis", {})
    qa2 = analysis2.get("quantum_analysis", {})

    # If either is missing quantum data, bail out
    if not qa1 or not qa2:
        return {
            "quantum_summary": "One or both analyses have no quantum_analysis data.",
            "quantum_js_distance": None,
            "angle_distance": None,
            "bitstring_differences": [],
            "angle_differences": [],
            "advanced_stat_differences": []
        }

    # 1) Probability distributions
    mc1 = qa1.get("measurement_counts", {})
    mc2 = qa2.get("measurement_counts", {})
    p1 = _counts_to_probability(mc1)
    p2 = _counts_to_probability(mc2)

    # JS distance
    js_dist = _js_distance(p1, p2)

    # 2) Compare scaled angles
    angles1 = qa1.get("scaled_angles", [])
    angles2 = qa2.get("scaled_angles", [])
    angle_dist = _euclidean_distance(angles1, angles2)

    # 3) Identify top-N bitstring differences
    bitstring_diffs = _compare_bitstring_probabilities(p1, p2, top_n=top_n)

    # 4) Identify largest angle differences
    angle_diffs = _compare_scaled_angles(angles1, angles2, top_n=top_n)

    # 5) Compare advanced stats if present
    adv_stats_1 = qa1.get("advanced_stats", {})
    adv_stats_2 = qa2.get("advanced_stats", {})
    adv_stat_diffs = _compare_advanced_stats(adv_stats_1, adv_stats_2)

    # Build a textual summary
    summary = (
        f"=== Quantum Comparison ===\n"
        f"Jensen–Shannon distance: {js_dist:.4f}\n"
        f"Euclidean distance of scaled angles: {angle_dist:.4f}\n"
        "Lower is more similar.\n\n"
        "Top Bitstring Probability Differences:\n"
    )
    for i, (bs, pval1, pval2, diff) in enumerate(bitstring_diffs, 1):
        summary += f"  {i}) '{bs}': p1={pval1:.4f}, p2={pval2:.4f}, abs diff={diff:.4f}\n"

    summary += "\nLargest Angle Differences:\n"
    for i, (idx, a1, a2, d) in enumerate(angle_diffs, 1):
        summary += f"  Angle {idx}: rec1={a1:.4f}, rec2={a2:.4f}, diff={d:.4f}\n"

    if adv_stat_diffs:
        summary += "\nAdvanced Stats Differences:\n"
        for stat_name, val1, val2, diff in adv_stat_diffs:
            summary += f"  {stat_name}: rec1={val1:.4f}, rec2={val2:.4f}, diff={diff:.4f}\n"
    else:
        summary += "\nNo advanced stats differences found or advanced_stats missing.\n"

    return {
        "quantum_js_distance": js_dist,
        "angle_distance": angle_dist,
        "bitstring_differences": bitstring_diffs,
        "angle_differences": angle_diffs,
        "advanced_stat_differences": adv_stat_diffs,
        "quantum_summary": summary
    }


def compare_analysis_results(record_id1, record_id2):
    """
    Given two recording IDs, fetch their analysis results from the database and compare them.
    Returns a dictionary containing the detailed comparison report.
    """
    analysis1 = get_analysis_from_db(record_id1)
    analysis2 = get_analysis_from_db(record_id2)
    
    comparison = compare_recordings(analysis1, analysis2)

        # Also do a quantum analysis comparison if quantum data is present
    quantum_comp = compare_quantum_recordings(analysis1, analysis2)
    
    # Merge quantum comparison into the final 'comparison' dict
    comparison["quantum_comparison"] = quantum_comp
    
    # Optionally, build a final combined or 'detailed_report' that includes quantum summary
    if quantum_comp["quantum_summary"]:
        comparison["detailed_report"] += f"\n\nQuantum Analysis:\n{quantum_comp['quantum_summary']}\n"
    
    return comparison


# --- Example Usage ---

# Manually set the recording IDs here (replace these with your actual record IDs).
# record_id1 = 81  # e.g., 1
# record_id2 = 82  # e.g., 2

# Fetch the analysis from the database and compare the recordings.
# comparison_result = compare_analysis_results(record_id1, record_id2)
# print(comparison_result["detailed_report"])