In [None]:
# Convergence analysis for trajs ONLY

import os
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime

# Locate project folder and DB
if "__file__" in globals():
    PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
else:
    PROJECT_ROOT = os.getcwd()

DB_PATH = os.path.join(PROJECT_ROOT, "Output", "database", "unified_database.db")

print("PROJECT_ROOT:", PROJECT_ROOT)
print("DB_PATH:", DB_PATH)

if not os.path.exists(DB_PATH):
    raise FileNotFoundError(f"Database not found at: {DB_PATH}")

# Connect to DB
conn = sqlite3.connect(DB_PATH)

# List tables
tables_df = pd.read_sql(
    "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;",
    conn
)
available = set(tables_df["name"].tolist())
print("\nAvailable tables:", available)

# Helper to count rows
def get_rows(table):
    try:
        df = pd.read_sql(f"SELECT COUNT(*) AS n FROM {table}", conn)
        return int(df['n'][0])
    except:
        return None

# Entities for convergence
entities = [
    {
        "name": "trajs",
        "before": "trajs_synthesized",
        "after": "trajs",
    },
    {
        "name": "waypoint",
        "before": None,   # no waypoint_synthesized exists
        "after": "waypoint",
    }
]

rows_list = []

# Get counts
for e in entities:
    name = e["name"]
    before_tbl = e["before"]
    after_tbl = e["after"]

    print(f"\n=== {name.upper()} ===")

    # BEFORE
    if before_tbl and before_tbl in available:
        before_rows = get_rows(before_tbl)
        print(f"  BEFORE  ({before_tbl}): {before_rows}")
        rows_list.append({"entity": name, "stage": "before", "table": before_tbl, "rows": before_rows})
    else:
        rows_list.append({"entity": name, "stage": "before", "table": before_tbl, "rows": None})

    # AFTER
    if after_tbl in available:
        after_rows = get_rows(after_tbl)
        print(f"  AFTER   ({after_tbl}): {after_rows}")
        rows_list.append({"entity": name, "stage": "after", "table": after_tbl, "rows": after_rows})
    else:
        rows_list.append({"entity": name, "stage": "after", "table": after_tbl, "rows": None})

conn.close()

# Build dataframe
conv_df = pd.DataFrame(rows_list)

# Compute summary
summary = []
for ent in conv_df["entity"].unique():
    sub = conv_df[conv_df["entity"] == ent]
    before = sub[sub["stage"] == "before"]["rows"].dropna()
    after = sub[sub["stage"] == "after"]["rows"].dropna()

    before_val = int(before.iloc[0]) if len(before) > 0 else None
    after_val = int(after.iloc[0]) if len(after) > 0 else None

    if before_val and after_val:
        removed = before_val - after_val
        pct = (removed / before_val) * 100
    else:
        removed = None
        pct = None

    summary.append({
        "entity": ent,
        "before_rows": before_val,
        "after_rows": after_val,
        "rows_removed": removed,
        "pct_removed": pct
    })

summary_df = pd.DataFrame(summary)
print("\n=== SUMMARY ===")
print(summary_df)

# Save CSV files
conv_df.to_csv("convergence_counts.csv", index=False)
summary_df.to_csv("convergence_summary.csv", index=False)

# Plot: Trajs before vs after
plot_df = summary_df.dropna(subset=["before_rows", "after_rows"])

if not plot_df.empty:
    plt.figure(figsize=(7,4))
    plt.bar(["Before"], [plot_df.iloc[0]["before_rows"]], label="Before")
    plt.bar(["After"], [plot_df.iloc[0]["after_rows"]], label="After")
    plt.title("Convergence of Trajectories (Before vs After Cleaning)")
    plt.ylabel("Number of Rows")
    plt.legend()
    plt.grid(axis="y", linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig("trajs_convergence_plot.png", dpi=200)
    plt.close()
    print("Saved plot: trajs_convergence_plot.png")

# Markdown report
md = []
md.append("# Convergence Report â€” trajs & waypoint\n")
md.append(f"- Generated on: {datetime.now().isoformat(timespec='seconds')}\n")
md.append("## Raw convergence table\n")

for _, r in conv_df.iterrows():
    md.append(f"- {r['entity']}: stage `{r['stage']}`, rows = {r['rows']}")

md.append("\n## Summary\n")
for _, r in summary_df.iterrows():
    md.append(f"### {r['entity']}")
    md.append(f"- Before: {r['before_rows']}")
    md.append(f"- After: {r['after_rows']}")
    md.append(f"- Rows removed: {r['rows_removed']}")
    md.append(f"- % removed: {r['pct_removed']:.2f}%\n" if r['pct_removed'] else "- % removed: N/A\n")

with open("convergence_report.md", "w", encoding="utf-8") as f:
    f.write("\n".join(md))

print("Generated convergence_report.md, convergence CSVs and the plot.")
