In [None]:
from robovast_common.analysis import read_output_files, read_output_csv

# Configure paths
DATA_DIR = ''

# Read all CSV files into a combined dataframe
combined_df = read_output_files(DATA_DIR, lambda test_dir: read_output_csv(test_dir, "out.csv", skiprows=1))

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, Sequence, Tuple

# -----------------------
# Helpers
# -----------------------
def _validate_df(df: pd.DataFrame):
    expected = {'time', 'population', 'test', 'variant', 'growth_rate', 'initial_population'}
    missing = expected - set(df.columns)
    if missing:
        raise ValueError(f"DataFrame missing expected columns: {missing}")
    df = df.copy()
    df['time'] = df['time'].astype(int)
    df['population'] = df['population'].astype(float)
    for cat in ['test', 'variant', 'growth_rate', 'initial_population']:
        df[cat] = df[cat].astype('category')
    return df

def _pivot_variant_to_runs(df: pd.DataFrame, variant: str) -> pd.DataFrame:
    sel = df[df['variant'] == variant]
    if sel.empty:
        return pd.DataFrame()
    pivot = sel.pivot(index='time', columns='test', values='population').sort_index()
    return pivot

def _agg_variant_stats(df: pd.DataFrame):
    variants = sorted(df['variant'].cat.categories)
    rows = []
    for v in variants:
        pivot = _pivot_variant_to_runs(df, v)
        if pivot.empty:
            rows.append({'variant': v, 'n_runs':0, 'final_mean':np.nan, 'final_std':np.nan, 'max_mean':np.nan})
            continue
        final_t = pivot.index.max()
        finals = pivot.loc[final_t].dropna().values
        rows.append({
            'variant': v,
            'n_runs': pivot.shape[1],
            'final_mean': np.nanmean(finals),
            'final_std': np.nanstd(finals, ddof=1),
            'max_mean': np.nanmean(pivot.max(axis=0))
        })
    return pd.DataFrame(rows).set_index('variant')

def compute_ecdf(values):
    """Return x and y for an ECDF plot."""
    values = np.sort(values)
    n = len(values)
    y = np.arange(1, n+1) / n
    return values, y

# -----------------------
# Plotting functions
# -----------------------
def plot_mean_ribbon_variants(df, variants=None, q_lo=0.1, q_hi=0.9, figsize=(10,5), ax=None):
    df = _validate_df(df)
    all_variants = sorted(df['variant'].cat.categories)
    if variants is None:
        variants = all_variants
    variants = [v for v in variants if v in all_variants]
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    cmap = plt.get_cmap('tab10')
    for i, v in enumerate(variants):
        pivot = _pivot_variant_to_runs(df, v)
        if pivot.empty: continue
        mean_traj = pivot.mean(axis=1)
        lo = pivot.quantile(q_lo, axis=1)
        hi = pivot.quantile(q_hi, axis=1)
        color = cmap(i % 10)
        ax.plot(mean_traj.index, mean_traj.values, label=f"{v} (n={pivot.shape[1]})", color=color, linewidth=2)
        ax.fill_between(mean_traj.index, lo, hi, alpha=0.2, color=color)
    ax.set_xlabel('time')
    ax.set_ylabel('population')
    ax.set_title(f"Mean trajectories ± {int(100*q_lo)}-{int(100*q_hi)} percentile")
    ax.legend(loc='best', fontsize='small')
    ax.grid(alpha=0.2)
    return ax

def plot_final_boxplots_by_variant(df, variants=None, figsize=(8,6), ax=None):
    df = _validate_df(df)
    all_variants = sorted(df['variant'].cat.categories)
    if variants is None:
        variants = all_variants
    finals_by_variant = []
    labels = []
    for v in variants:
        pivot = _pivot_variant_to_runs(df, v)
        if pivot.empty: continue
        final_t = pivot.index.max()
        finals = pivot.loc[final_t].dropna().values
        finals_by_variant.append(finals)
        labels.append(v)
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    ax.boxplot(finals_by_variant, tick_labels=labels, showfliers=False)
    ax.set_xlabel('variant'); ax.set_ylabel('final population')
    ax.set_title('Final population distribution by variant')
    ax.grid(alpha=0.2)
    plt.xticks(rotation=45, ha='right')
    return ax

def plot_ecdf_final_by_variant(df, variants=None, figsize=(8,5), ax=None):
    df = _validate_df(df)
    all_variants = sorted(df['variant'].cat.categories)
    if variants is None:
        variants = all_variants
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    cmap = plt.get_cmap('tab10')
    for i, v in enumerate(variants):
        pivot = _pivot_variant_to_runs(df, v)
        if pivot.empty: continue
        finals = pivot.loc[pivot.index.max()].dropna().values
        if finals.size == 0: continue
        x, y = compute_ecdf(finals)
        ax.step(x, y, where='post', label=f"{v} (n={len(finals)})", color=cmap(i%10))
    ax.set_xlabel('final population')
    ax.set_ylabel('ECDF')
    ax.set_title('ECDF of final populations by variant')
    ax.grid(alpha=0.2)
    ax.legend(fontsize='small')
    return ax

def plot_variant_category_scatter(df, variants=None, ax=None):
    """Scatter plot of growth_rate (x) vs initial_population (y) for each variant, labeled by variant name."""
    df = _validate_df(df)
    all_variants = sorted(df['variant'].cat.categories)
    if variants is None:
        variants = all_variants
    if ax is None:
        fig, ax = plt.subplots(figsize=(8,6))
    # Get one row per variant
    variant_rows = [df[df['variant'] == v].iloc[0] for v in variants]
    x = [float(row['growth_rate']) for row in variant_rows]
    y = [float(row['initial_population']) for row in variant_rows]
    names = variants
    ax.scatter(x, y, s=80, c='tab:blue', alpha=0.7)
    for xi, yi, name in zip(x, y, names):
        ax.text(xi, yi, str(name), fontsize=9, ha='left', va='bottom', fontweight='bold', color='tab:gray')
    ax.set_xlabel('growth_rate')
    ax.set_ylabel('initial_population')
    ax.set_title('Variant Parameter Combinations')
    ax.grid(alpha=0.2)
    return ax

# -----------------------
# Main dashboard
# -----------------------
def compare_variants_dashboard(df, variants=None, max_small_multiple=6, q_lo=0.1, q_hi=0.9, heatmap_bins=60, 
                              two_column_text=False, save_path:Optional[str]=None):
    df = _validate_df(df)
    all_variants = sorted(df['variant'].cat.categories)
    if variants is None:
        variants = all_variants
    variants = [v for v in variants if v in all_variants]
    
    summary_df = _agg_variant_stats(df).reindex(variants)
    
    fig = plt.figure(constrained_layout=True, figsize=(16,12))
    gs = fig.add_gridspec(2,2)
    ax_mean = fig.add_subplot(gs[0,0])
    ax_box  = fig.add_subplot(gs[0,1])
    ax_ecdf = fig.add_subplot(gs[1,0])
    ax_scatter = fig.add_subplot(gs[1,1])
    
    plot_mean_ribbon_variants(df, variants=variants, q_lo=q_lo, q_hi=q_hi, ax=ax_mean)
    plot_final_boxplots_by_variant(df, variants=variants, ax=ax_box)
    plot_ecdf_final_by_variant(df, variants=variants, ax=ax_ecdf)
    plot_variant_category_scatter(df, variants=variants, ax=ax_scatter)
    
    fig.suptitle("Multi-variant Population Dashboard", fontsize=16)
    
    if save_path:
        fig.savefig(save_path, dpi=200, bbox_inches='tight')
        print(f"Dashboard saved to {save_path}")
    
    plt.show()
    
    return summary_df


summary_df = compare_variants_dashboard(combined_df,
                                        max_small_multiple=6,  # how many heatmaps to show
                                        q_lo=0.10,             # lower percentile for ribbons
                                        q_hi=0.90,             # upper percentile
                                        heatmap_bins=60,       # bins for density heatmaps
                                        two_column_text=True, # use two-column layout for variants
                                        save_path=None)        # or a filename to save images

# 2️⃣ Display per-variant summary table
display(summary_df.sort_values('final_mean', ascending=False))