In [None]:
#!/usr/bin/env python3
# complete_compare.py

import psycopg2
from psycopg2.extras import Json
import math
import numpy as np
import base64
import io
import matplotlib.pyplot as plt

###############################################################################
# 1) Minimal Database Class
###############################################################################
class QuantumMusicDB:
    """
    Minimal class to connect to a PostgreSQL database and fetch a single row
    from the 'audio_analysis' table. You can expand as needed.
    """

    def __init__(self,
                 db_name="quantummusic",
                 host="localhost",
                 user="postgres",
                 password="postgres"):
        self.db_name = db_name
        self.host = host
        self.user = user
        self.password = password
        self.conn = None
        self.connect()

    def connect(self):
        try:
            self.conn = psycopg2.connect(
                dbname=self.db_name,
                host=self.host,
                user=self.user,
                password=self.password
            )
            print(f"Connected to database {self.db_name}.")
        except Exception as e:
            print(f"Error connecting to database: {e}")

    def close(self):
        if self.conn:
            self.conn.close()
            print("Database connection closed.")

    def fetch_analysis(self, record_id):
        """
        Expects to find a row (id, file_name, sample_rate, analysis_data)
        in 'audio_analysis'. Returns that row or None if not found.
        """
        with self.conn.cursor() as cur:
            query = """
                SELECT id, file_name, sample_rate, analysis_data
                FROM audio_analysis
                WHERE id = %s
            """
            cur.execute(query, (record_id,))
            row = cur.fetchone()
        return row


###############################################################################
# 2) HELPER FUNCTIONS FOR QUANTUM COMPARISON
###############################################################################
def _counts_to_probability(measure_counts):
    total_shots = sum(measure_counts.values())
    if total_shots == 0:
        return {}
    return {state: c / total_shots for (state, c) in measure_counts.items()}

def _kl_div(p, q):
    div = 0.0
    for key, p_val in p.items():
        if p_val > 0:
            q_val = q.get(key, 0.0)
            if q_val > 0:
                div += p_val * math.log2(p_val / q_val)
    return div

def _js_distance(p, q):
    all_keys = set(p.keys()) | set(q.keys())
    m = {}
    for key in all_keys:
        p_val = p.get(key, 0.0)
        q_val = q.get(key, 0.0)
        m[key] = 0.5 * (p_val + q_val)

    js_div = 0.5 * _kl_div(p, m) + 0.5 * _kl_div(q, m)
    return math.sqrt(js_div)

def _euclidean_distance(v1, v2):
    n = min(len(v1), len(v2))
    dist_sq = 0.0
    for i in range(n):
        diff = (v1[i] - v2[i])
        dist_sq += diff * diff
    return math.sqrt(dist_sq)

def _compare_bitstring_probabilities(p1, p2, top_n=5):
    all_keys = set(p1.keys()) | set(p2.keys())
    diffs = []
    for bs in all_keys:
        val1 = p1.get(bs, 0.0)
        val2 = p2.get(bs, 0.0)
        diff = abs(val1 - val2)
        diffs.append((bs, val1, val2, diff))
    diffs.sort(key=lambda x: x[3], reverse=True)
    return diffs[:top_n]

def _compare_scaled_angles(angles1, angles2, top_n=5):
    n = min(len(angles1), len(angles2))
    diffs = []
    for i in range(n):
        a1 = angles1[i]
        a2 = angles2[i]
        diff = abs(a1 - a2)
        diffs.append((i, a1, a2, diff))
    diffs.sort(key=lambda x: x[3], reverse=True)
    return diffs[:top_n]

def _compare_advanced_stats(adv_stats_1, adv_stats_2):
    all_stats = set(adv_stats_1.keys()) | set(adv_stats_2.keys())
    diffs = []
    for stat in all_stats:
        val1 = adv_stats_1.get(stat, 0.0)
        val2 = adv_stats_2.get(stat, 0.0)
        diff = abs(val1 - val2)
        if diff > 0:
            diffs.append((stat, val1, val2, diff))
    diffs.sort(key=lambda x: x[3], reverse=True)
    return diffs


###############################################################################
# 3) FEEDBACK CONSOLIDATION & ACTIONABLE SUMMARY
###############################################################################
def consolidate_feedback_items(feedback_list, max_time_gap=1.0):
    if not feedback_list:
        return []
    sorted_items = sorted(feedback_list, key=lambda f: f.get("start_time", 0.0))
    consolidated = []
    current_group = [sorted_items[0]]

    for item in sorted_items[1:]:
        prev_item = current_group[-1]
        gap = item["start_time"] - prev_item["end_time"]
        if gap <= max_time_gap:
            current_group.append(item)
        else:
            consolidated.append(_merge_feedback_group(current_group))
            current_group = [item]

    if current_group:
        consolidated.append(_merge_feedback_group(current_group))

    return consolidated

def _merge_feedback_group(group):
    start_times = [f["start_time"] for f in group]
    end_times = [f["end_time"] for f in group]
    min_start = min(start_times)
    max_end = max(end_times)

    issues = []
    for f in group:
        issues.append(f"({f['start_time']:.2f}-{f['end_time']:.2f}s) {f['issue']}")
    combined_issues = " | ".join(issues)

    severity_levels = [f.get("severity", "medium") for f in group]
    final_severity = max(severity_levels)

    return {
        "start_time": min_start,
        "end_time": max_end,
        "issue": combined_issues,
        "severity": final_severity
    }

def generate_actionable_feedback(analysis_dict):
    raw_items = analysis_dict.get("raw_feedback_items", [])
    consolidated = consolidate_feedback_items(raw_items, max_time_gap=1.0)
    consolidated.sort(key=lambda x: x.get("severity", "medium"), reverse=True)

    top_segments = consolidated[:4]
    bullet_points = []
    for seg in top_segments:
        msg = (f"Time {seg['start_time']:.2f}-{seg['end_time']:.2f}s: {seg['issue']}. "
               f"Severity: {seg['severity']}. Focus on consistent performance.")
        bullet_points.append(msg)

    if len(consolidated) > 4:
        bullet_points.append("Additional issues exist but are omitted for brevity.")

    analysis_dict["actionable_feedback"] = bullet_points
    return bullet_points


###############################################################################
# 4) CLASSICAL FEATURE COMPARISON
###############################################################################
def compare_classical_features(analysis_master, analysis_student):
    results_master = analysis_master.get("results", {})
    results_student = analysis_student.get("results", {})

    master_dev = results_master.get("average_dev_cents", 0.0)
    student_dev = results_student.get("average_dev_cents", 0.0)
    tempo_master = results_master.get("estimated_tempo_bpm", 0.0)
    tempo_student = results_student.get("estimated_tempo_bpm", 0.0)
    tempo_diff = abs(tempo_master - tempo_student)

    master_adv = analysis_master.get("quantum_analysis", {}).get("advanced_stats", {})
    student_adv = analysis_student.get("quantum_analysis", {}).get("advanced_stats", {})
    master_jitter = master_adv.get("avg_jitter", 0.0)
    student_jitter = student_adv.get("avg_jitter", 0.0)
    master_shimmer = master_adv.get("avg_shimmer", 0.0)
    student_shimmer = student_adv.get("avg_shimmer", 0.0)

    classical_comp = {
        "pitch_accuracy": {
            "master_dev_cents": master_dev,
            "student_dev_cents": student_dev,
            "comment": (
                "Student pitch is close to master."
                if student_dev <= master_dev + 10.0
                else "Student pitch has noticeably higher deviation."
            )
        },
        "tempo": {
            "master_bpm": tempo_master,
            "student_bpm": tempo_student,
            "tempo_diff": tempo_diff,
            "comment": f"Student is {tempo_diff:.2f} BPM away from master."
        },
        "jitter_shimmer": {
            "master_jitter": master_jitter,
            "student_jitter": student_jitter,
            "master_shimmer": master_shimmer,
            "student_shimmer": student_shimmer,
            "comment": (
                "Student’s jitter/shimmer is significantly higher (could be vibrato or instability)."
                if (student_jitter > master_jitter * 1.2 or student_shimmer > master_shimmer * 1.2)
                else "Student’s jitter/shimmer is comparable to master."
            )
        }
    }

    if "master_comparison" not in analysis_student:
        analysis_student["master_comparison"] = {}
    analysis_student["master_comparison"]["classical"] = classical_comp
    return classical_comp


###############################################################################
# 5) QUANTUM FEATURE COMPARISON
###############################################################################
def compare_quantum_features(analysis_master, analysis_student, top_n=5):
    qa1 = analysis_master.get("quantum_analysis", {})
    qa2 = analysis_student.get("quantum_analysis", {})

    if not qa1 or not qa2:
        if "master_comparison" not in analysis_student:
            analysis_student["master_comparison"] = {}
        analysis_student["master_comparison"]["quantum"] = {
            "quantum_summary": "Missing quantum data in one or both recordings."
        }
        return analysis_student["master_comparison"]["quantum"]

    mc1 = qa1.get("measurement_counts", {})
    mc2 = qa2.get("measurement_counts", {})
    p1 = _counts_to_probability(mc1)
    p2 = _counts_to_probability(mc2)
    js_dist = _js_distance(p1, p2)

    angles1 = qa1.get("scaled_angles", [])
    angles2 = qa2.get("scaled_angles", [])
    angle_dist = _euclidean_distance(angles1, angles2)

    bitstring_diffs = _compare_bitstring_probabilities(p1, p2, top_n=top_n)
    angle_diffs = _compare_scaled_angles(angles1, angles2, top_n=top_n)

    adv_stats_1 = qa1.get("advanced_stats", {})
    adv_stats_2 = qa2.get("advanced_stats", {})
    adv_stat_diffs = _compare_advanced_stats(adv_stats_1, adv_stats_2)

    summary_lines = [
        "=== Quantum Master–Student Comparison ===",
        f"Jensen–Shannon distance: {js_dist:.4f}",
        f"Scaled angles Euclidean distance: {angle_dist:.4f}",
        "Lower => more similar.\n",
        "Top Bitstring Probability Differences:"
    ]
    for i, (bs, val1, val2, diff) in enumerate(bitstring_diffs, 1):
        summary_lines.append(
            f"  {i}) {bs}: master={val1:.4f}, student={val2:.4f}, diff={diff:.4f}"
        )

    summary_lines.append("\nLargest Scaled Angle Differences:")
    for i, (idx, a1, a2, d) in enumerate(angle_diffs, 1):
        summary_lines.append(
            f"  {i}) Angle {idx}: master={a1:.4f}, student={a2:.4f}, diff={d:.4f}"
        )

    if adv_stat_diffs:
        summary_lines.append("\nAdvanced Stats Differences:")
        for stat_name, v1, v2, diff in adv_stat_diffs:
            summary_lines.append(
                f"  {stat_name}: master={v1:.4f}, student={v2:.4f}, diff={diff:.4f}"
            )
    else:
        summary_lines.append("\nNo advanced stats differences found.")

    quantum_results = {
        "quantum_js_distance": js_dist,
        "angle_distance": angle_dist,
        "bitstring_differences": bitstring_diffs,
        "angle_differences": angle_diffs,
        "advanced_stat_differences": adv_stat_diffs,
        "quantum_summary": "\n".join(summary_lines)
    }

    if "master_comparison" not in analysis_student:
        analysis_student["master_comparison"] = {}
    analysis_student["master_comparison"]["quantum"] = quantum_results
    return quantum_results


###############################################################################
# 6) PITCH CONTOUR COMPARISON
###############################################################################
def compare_pitch_contours(analysis_master, analysis_student, pitch_tolerance_cents=50.0):
    master_pc = analysis_master.get("pitch_contour", {})
    student_pc = analysis_student.get("pitch_contour", {})

    # If missing pitch_contour data => the message you see
    if not master_pc or not student_pc:
        if "master_comparison" not in analysis_student:
            analysis_student["master_comparison"] = {}
        analysis_student["master_comparison"]["pitch_contours"] = {
            "message": "No pitch_contour data in master or student."
        }
        return

    times_m = master_pc.get("times", [])
    pitches_m = master_pc.get("pitches", [])
    times_s = student_pc.get("times", [])
    pitches_s = student_pc.get("pitches", [])

    n = min(len(times_m), len(times_s), len(pitches_m), len(pitches_s))
    if n == 0:
        analysis_student["master_comparison"]["pitch_contours"] = {
            "message": "Pitch contour arrays exist but are empty."
        }
        return

    major_deviations = []
    in_deviation = False
    start_dev_idx = 0

    for i in range(n):
        diff = abs(pitches_m[i] - pitches_s[i])
        if diff > pitch_tolerance_cents:
            if not in_deviation:
                in_deviation = True
                start_dev_idx = i
        else:
            if in_deviation:
                end_dev_idx = i - 1
                avg_diff = float(np.mean(np.abs(
                    np.array(pitches_m[start_dev_idx:end_dev_idx+1]) -
                    np.array(pitches_s[start_dev_idx:end_dev_idx+1])
                )))
                major_deviations.append({
                    "start_time": times_m[start_dev_idx],
                    "end_time": times_m[end_dev_idx],
                    "average_diff_cents": avg_diff
                })
                in_deviation = False

    if in_deviation:
        end_dev_idx = n - 1
        avg_diff = float(np.mean(np.abs(
            np.array(pitches_m[start_dev_idx:end_dev_idx+1]) -
            np.array(pitches_s[start_dev_idx:end_dev_idx+1])
        )))
        major_deviations.append({
            "start_time": times_m[start_dev_idx],
            "end_time": times_m[end_dev_idx],
            "average_diff_cents": avg_diff
        })

    pitch_contour_result = {
        "num_deviations": len(major_deviations),
        "pitch_tolerance_cents": pitch_tolerance_cents,
        "deviation_regions": major_deviations
    }

    if "master_comparison" not in analysis_student:
        analysis_student["master_comparison"] = {}
    analysis_student["master_comparison"]["pitch_contours"] = pitch_contour_result


###############################################################################
# 7) PLOTTING MULTIPLE METRICS
###############################################################################
def plot_pitch_jitter_shimmer(analysis_master, analysis_student):
    """
    Plot master vs. student pitch contours; overlay time-series jitter/shimmer if available.
    """
    master_pc = analysis_master.get("pitch_contour", {})
    student_pc = analysis_student.get("pitch_contour", {})

    times_m = master_pc.get("times", [])
    pitch_m = master_pc.get("pitches", [])
    times_s = student_pc.get("times", [])
    pitch_s = student_pc.get("pitches", [])

    adv_m = analysis_master.get("quantum_analysis", {}).get("advanced_stats", {})
    adv_s = analysis_student.get("quantum_analysis", {}).get("advanced_stats", {})
    jit_m = adv_m.get("time_series_jitter", [])
    jit_s = adv_s.get("time_series_jitter", [])
    shim_m = adv_m.get("time_series_shimmer", [])
    shim_s = adv_s.get("time_series_shimmer", [])

    if not times_m or not pitch_m or not times_s or not pitch_s:
        return  # no pitch data to plot

    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(times_m, pitch_m, label="Master Pitch", color="blue")
    ax.plot(times_s, pitch_s, label="Student Pitch", color="orange", alpha=0.8)

    # Overplot jitter & shimmer if length matches times
    if len(jit_m) == len(times_m):
        ax.plot(times_m, jit_m, label="Master Jitter", color="green", linestyle="--")
    if len(jit_s) == len(times_s):
        ax.plot(times_s, jit_s, label="Student Jitter", color="green", linestyle=":")

    if len(shim_m) == len(times_m):
        ax.plot(times_m, shim_m, label="Master Shimmer", color="red", linestyle="--")
    if len(shim_s) == len(times_s):
        ax.plot(times_s, shim_s, label="Student Shimmer", color="red", linestyle=":")

    ax.set_title("Pitch + Jitter/Shimmer Comparison")
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Pitch (Hz or cents)")
    ax.legend()
    ax.grid(True)

    buf = io.BytesIO()
    plt.tight_layout()
    plt.savefig(buf, format="png")
    plt.close(fig)
    buf.seek(0)
    b64_str = base64.b64encode(buf.read()).decode("utf-8")

    if "master_comparison" not in analysis_student:
        analysis_student["master_comparison"] = {}
    if "plots" not in analysis_student["master_comparison"]:
        analysis_student["master_comparison"]["plots"] = {}
    analysis_student["master_comparison"]["plots"]["pitch_jitter_shimmer"] = b64_str

def plot_energy_formants_comparison(analysis_master, analysis_student):
    """
    Plot RMS, LUFS, and formant data (F1, F2).
    Each is stored typically as:
      analysis_dict["dynamics_summary"]["rms_db"]["time_series"] = [...]
      analysis_dict["dynamics_summary"]["lufs"]["time_series"] = [...]
      analysis_dict["quantum_analysis"]["advanced_stats"]["time_series_F1"] = [...]
      analysis_dict["quantum_analysis"]["advanced_stats"]["time_series_F2"] = [...]
    """
    dyn_m = analysis_master.get("dynamics_summary", {})
    dyn_s = analysis_student.get("dynamics_summary", {})
    rms_m = dyn_m.get("rms_db", {}).get("time_series", [])
    rms_s = dyn_s.get("rms_db", {}).get("time_series", [])
    lufs_m = dyn_m.get("lufs", {}).get("time_series", [])
    lufs_s = dyn_s.get("lufs", {}).get("time_series", [])

    adv_m = analysis_master.get("quantum_analysis", {}).get("advanced_stats", {})
    adv_s = analysis_student.get("quantum_analysis", {}).get("advanced_stats", {})
    f1_m = adv_m.get("time_series_F1", [])
    f1_s = adv_s.get("time_series_F1", [])
    f2_m = adv_m.get("time_series_F2", [])
    f2_s = adv_s.get("time_series_F2", [])

    fig, axs = plt.subplots(2, 2, figsize=(12, 8))

    # RMS
    axs[0, 0].plot(rms_m, label="Master RMS (dB)", color="blue")
    axs[0, 0].plot(rms_s, label="Student RMS (dB)", color="orange")
    axs[0, 0].set_title("RMS (dB)")
    axs[0, 0].grid(True)
    axs[0, 0].legend()

    # LUFS
    axs[0, 1].plot(lufs_m, label="Master LUFS", color="blue")
    axs[0, 1].plot(lufs_s, label="Student LUFS", color="orange")
    axs[0, 1].set_title("LUFS")
    axs[0, 1].grid(True)
    axs[0, 1].legend()

    # F1
    axs[1, 0].plot(f1_m, label="Master F1", color="green")
    axs[1, 0].plot(f1_s, label="Student F1", color="red")
    axs[1, 0].set_title("Formant F1")
    axs[1, 0].grid(True)
    axs[1, 0].legend()

    # F2
    axs[1, 1].plot(f2_m, label="Master F2", color="green", linestyle="--")
    axs[1, 1].plot(f2_s, label="Student F2", color="red", linestyle="--")
    axs[1, 1].set_title("Formant F2")
    axs[1, 1].grid(True)
    axs[1, 1].legend()

    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(fig)
    buf.seek(0)
    b64_str = base64.b64encode(buf.read()).decode("utf-8")

    if "master_comparison" not in analysis_student:
        analysis_student["master_comparison"] = {}
    if "plots" not in analysis_student["master_comparison"]:
        analysis_student["master_comparison"]["plots"] = {}
    analysis_student["master_comparison"]["plots"]["energy_formants"] = b64_str


###############################################################################
# 8) MASTER FUNCTION: compare_analysis_results
###############################################################################
def compare_analysis_results(master_id, student_id):
    """
    1) Connect to DB, fetch analysis_master & analysis_student
    2) Compare classical & quantum
    3) Compare pitch contours
    4) Plot pitch/jitter/shimmer, plus RMS/LUFS/F1/F2
    5) Consolidate feedback
    6) Print summary
    7) Return updated student analysis
    """
    # 1) Create DB object & fetch
    db = QuantumMusicDB()
    row_master = db.fetch_analysis(master_id)
    row_student = db.fetch_analysis(student_id)
    db.close()

    if not row_master:
        raise ValueError(f"No record found with ID={master_id}")
    if not row_student:
        raise ValueError(f"No record found with ID={student_id}")

    analysis_master = row_master[3]  # analysis_data
    analysis_student = row_student[3]

    # 2) Compare classical & quantum
    classical_comp = compare_classical_features(analysis_master, analysis_student)
    quantum_comp   = compare_quantum_features(analysis_master, analysis_student)

    # 3) Compare pitch
    compare_pitch_contours(analysis_master, analysis_student, pitch_tolerance_cents=50.0)

    # 4) Plot
    plot_pitch_jitter_shimmer(analysis_master, analysis_student)
    plot_energy_formants_comparison(analysis_master, analysis_student)

    # 5) Merge feedback -> actionable
    final_feedback = generate_actionable_feedback(analysis_student)

    # 6) Print summary
    print("=== MASTER–STUDENT COMPARISON RESULTS ===")

    print("\n--- Classical Features ---")
    for k, v in classical_comp.items():
        print(f"{k}: {v}")

    print("\n--- Quantum Features ---")
    q_summary = analysis_student["master_comparison"]["quantum"].get("quantum_summary", "")
    print(q_summary)

    pc_result = analysis_student["master_comparison"].get("pitch_contours", {})
    print("\n--- Pitch Contour Deviations ---")
    print(pc_result)

    # Show info about saved plots
    plots_dict = analysis_student["master_comparison"].get("plots", {})
    if "pitch_jitter_shimmer" in plots_dict:
        print("\n[INFO] pitch_jitter_shimmer (base64) stored at: "
              "analysis_student['master_comparison']['plots']['pitch_jitter_shimmer']")
    if "energy_formants" in plots_dict:
        print("[INFO] energy_formants (base64) stored at: "
              "analysis_student['master_comparison']['plots']['energy_formants']")

    print("\n--- Actionable Feedback ---")
    for item in final_feedback:
        print(f"* {item}")

    # 7) Return updated student's analysis
    return analysis_student


###############################################################################
# 9) MAIN USAGE EXAMPLE
###############################################################################
if __name__ == "__main__":
    # Adjust these IDs to whatever is valid in your DB
    MASTER_ID = 2000
    STUDENT_ID = 2001

    print("\n[INFO] Running compare_analysis_results in standalone mode...\n")
    student_analysis_updated = compare_analysis_results(MASTER_ID, STUDENT_ID)
    print("\nDone.\n")
    # 'student_analysis_updated' now contains 'master_comparison' results,
    # plus 'actionable_feedback'. You can persist it back to DB or elsewhere.