In [None]:
from robovast.common.analysis import read_output_files, read_output_csv

DATA_DIR = ''

# Read all CSV files into a combined dataframe
combined_df = read_output_files(DATA_DIR, lambda test_dir: read_output_csv(test_dir, "out.csv", skiprows=1))

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, Dict, Any, Tuple

# ---------- helpers ----------
def compute_derivative(times: np.ndarray, values: np.ndarray) -> np.ndarray:
    """Numerical derivative d(values)/dt handling non-uniform spacing."""
    return np.gradient(values, times)

def find_time_to_threshold(times: np.ndarray, values: np.ndarray, threshold: float) -> Optional[float]:
    idx = np.where(values >= threshold)[0]
    return float(times[idx[0]]) if idx.size > 0 else None

def fit_log_linear(times: np.ndarray, values: np.ndarray, low_frac=0.05, high_frac=0.5) -> Optional[Tuple[float,float]]:
    """
    Fit linear model to log(values) on the interval where values between
    low_frac*max and high_frac*max (to target exponential growth phase).
    Returns (slope, intercept) where slope ~ exponential growth rate.
    If fit not possible (too few points or non-positive values), return None.
    """
    if len(values) == 0:
        return None
    vmax = np.nanmax(values)
    if vmax <= 0:
        return None
    low = low_frac * vmax
    high = high_frac * vmax
    mask = (values >= low) & (values <= high)
    if np.sum(mask) < 3:
        # fallback: attempt on first 30% of times
        N = max(3, int(0.3 * len(values)))
        mask = np.arange(len(values)) < N
        if np.sum(mask) < 3:
            return None
    xs = times[mask]
    ys = values[mask]
    pos = ys > 0
    if np.sum(pos) < 3:
        return None
    xs = xs[pos]
    ys = ys[pos]
    logy = np.log(ys)
    slope, intercept = np.polyfit(xs, logy, 1)
    return float(slope), float(intercept)

# ---------- main function ----------
def plot_single_run(df: pd.DataFrame,
                               threshold: Optional[float] = None,
                               smoothing_window: int = 5,
                               figsize: Tuple[int,int] = (14,10),
                               save_path: Optional[str] = None) -> Dict[str, Any]:
    """
    Four-panel diagnostics for a single test run (DataFrame for exactly one test):
      TL: population vs time
      TR: instantaneous growth rate (dN/dt)
      BL: phase plot (population vs growth rate)
      BR: histogram of growth rate distribution
    Returns summary dict with AUC, max growth, time_of_max, time_to_threshold, exp_rate (if fitted)
    """
    # Basic checks
    required = {'time','population'}
    if not required.issubset(df.columns):
        raise ValueError(f"DataFrame must contain {required}")
    # If there are multiple tests, require the user filtered already; otherwise pick the first unique test
    if 'test' in df.columns:
        unique_tests = df['test'].unique()
        if len(unique_tests) > 1:
            # auto-select first test but warn
            print(f"Warning: input df has {len(unique_tests)} tests. Using first test: {unique_tests[0]}")
            df = df[df['test'] == unique_tests[0]]
    run_df = df.sort_values('time').reset_index(drop=True)
    times = run_df['time'].to_numpy(dtype=float)
    pop = run_df['population'].to_numpy(dtype=float)

    # smoothing
    if smoothing_window is None or smoothing_window <= 1:
        pop_smooth = pop
    else:
        pop_smooth = pd.Series(pop).rolling(window=smoothing_window, center=True, min_periods=1).mean().to_numpy()

    # derivatives
    dpop_dt = compute_derivative(times, pop_smooth)

    # metrics
    auc = float(np.trapz(pop, x=times))
    idx_max = int(np.nanargmax(dpop_dt))
    max_growth = float(dpop_dt[idx_max])
    time_of_max = float(times[idx_max])
    time_to_thresh = find_time_to_threshold(times, pop, threshold) if threshold is not None else None

    # exponential phase fit (semi-log)
    log_fit = fit_log_linear(times, pop_smooth, low_frac=0.05, high_frac=0.5)
    exp_rate = float(log_fit[0]) if log_fit is not None else None  # slope of log-pop = growth rate per time unit

    # ---- plotting ----
    fig = plt.figure(constrained_layout=True, figsize=figsize)
    gs = fig.add_gridspec(2, 2, height_ratios=[1, 1])

    # Top-left: Population vs time
    ax0 = fig.add_subplot(gs[0,0])
    ax0.plot(times, pop, alpha=0.6, label='raw', lw=1)
    ax0.plot(times, pop_smooth, lw=2, label=f'smooth (w={smoothing_window})')
    ax0.scatter(times[-1], pop[-1], color='k', zorder=5)
    ax0.annotate(f'final={pop[-1]:.2f}', xy=(times[-1], pop[-1]), xytext=(-60,10),
                 textcoords='offset points', arrowprops=dict(arrowstyle='->', alpha=0.5))
    if threshold is not None:
        ax0.axhline(threshold, color='red', linestyle='--', lw=1)
        if time_to_thresh is not None:
            ax0.axvline(time_to_thresh, color='red', linestyle=':', lw=1)
            ax0.annotate(f't={time_to_thresh:.0f}', xy=(time_to_thresh, threshold),
                         xytext=(5,5), textcoords='offset points', color='red')
    ax0.set_title('Population vs Time')
    ax0.set_xlabel('Time'); ax0.set_ylabel('Population')
    ax0.grid(alpha=0.3)
    ax0.legend()

    # Top-right: instantaneous growth rate
    ax1 = fig.add_subplot(gs[0,1])
    ax1.plot(times, dpop_dt, color='tab:blue', lw=1.5, label='dN/dt')
    ax1.axhline(0, color='gray', linestyle='--', lw=0.8)
    ax1.scatter(time_of_max, max_growth, color='tab:orange', zorder=5)
    ax1.annotate(f'max={max_growth:.2f}\nt={time_of_max:.0f}', xy=(time_of_max, max_growth),
                 xytext=(5,10), textcoords='offset points',
                 bbox=dict(boxstyle='round,pad=0.3', fc='yellow', alpha=0.2),
                 arrowprops=dict(arrowstyle='->', alpha=0.5))
    ax1.set_title('Instantaneous Growth Rate')
    ax1.set_xlabel('Time'); ax1.set_ylabel('dN/dt')
    ax1.grid(alpha=0.3)

    # Bottom-left: Phase plot (population vs growth rate)
    ax2 = fig.add_subplot(gs[1,0])
    sc = ax2.scatter(pop_smooth, dpop_dt, c=times, cmap='viridis', s=8)
    ax2.set_xlabel('Population'); ax2.set_ylabel('dN/dt')
    ax2.set_title('Phase plot: population vs growth rate (color=time)')
    cb = fig.colorbar(sc, ax=ax2, label='time')
    ax2.grid(alpha=0.2)
    # Add a small arrow showing temporal direction (start->end)
    if len(times) >= 2:
        mid_idx = max(1, len(times)//10)
        ax2.annotate('', xy=(pop_smooth[mid_idx], dpop_dt[mid_idx]), xytext=(pop_smooth[0], dpop_dt[0]),
                     arrowprops=dict(arrowstyle='->', color='black', alpha=0.4))

    # Bottom-right: Histogram of growth rate distribution
    ax3 = fig.add_subplot(gs[1,1])
    ax3.hist(dpop_dt, bins=30, alpha=0.85, color='tab:blue')
    ax3.set_title('Growth Rate Distribution')
    ax3.set_xlabel('dN/dt')
    ax3.set_ylabel('Frequency')
    ax3.grid(alpha=0.2)

    # Add categories (skip 'variant' and 'test') to suptitle
    cat_cols = [c for c in combined_df.columns if str(combined_df[c].dtype) == 'category' and c not in ('variant', 'test')]
    cat_vals = [f"{col}={combined_df[col].iloc[0]}" for col in cat_cols]
    cat_str = " | ".join(cat_vals) or ""
    
    plt.suptitle(f'Single-run — {cat_str}', fontsize=14)

    if save_path:
        fig.savefig(save_path, dpi=200, bbox_inches='tight')
        print(f"Saved figure to {save_path}")

    plt.show()

    # summary
    summary = {
        'AUC': auc,
        'max_growth_rate': max_growth,
        'time_of_max_growth': time_of_max,
        'time_to_threshold': time_to_thresh,
        'estimated_exp_rate': exp_rate
    }
    # Print concise summary
    print("\nSummary:")
    for k,v in summary.items():
        print(f"{k:>22}: {v}")
    return summary

summary = plot_single_run(combined_df, threshold=1000, smoothing_window=10)