In [5]:
# file name: R0_ss_sensitive.ipynb

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# robust Pearson (same as I gave earlier)
def pearsonr_np(x, y):
    x = np.asarray(x, float).ravel()
    y = np.asarray(y, float).ravel()
    # drop NaN pairs
    m = ~(np.isnan(x) | np.isnan(y))
    x, y = x[m], y[m]
    if x.size < 2:
        return np.nan, x.size
    x = x - x.mean(); y = y - y.mean()
    sx = np.sqrt(np.dot(x, x)); sy = np.sqrt(np.dot(y, y))
    if sx == 0.0 or sy == 0.0:
        return np.nan, x.size
    r = float(np.dot(x, y) / (sx * sy))
    return max(-1.0, min(1.0, r)), x.size

# --- group by (x1, x2) and correlate x3 vs y1 across the x3 repetitions ---
def corr_one_group(g):
    r, n = pearsonr_np(g["Dimmunity"].values, g["avg_time"].values)
    return pd.Series({"r": r, "n": n})

In [6]:
# --- your CSV ---
df = pd.read_csv("../experimental_data/R0_sigma_Dimmunity_results_v2.csv")   # columns: R0, sigma, Dimmunity

In [24]:
def make_10_pngs_each_5x2(
    csv_path,
    out_dir="../../figures/R0_vs_avg_time_panels_by_sigma",
    A1_col="R0",
    A2_col="sigma",
    A3_col="Dimmunity",
    B1_col="avg_time",
    dpi=180
):
    os.makedirs(out_dir, exist_ok=True)

    df = pd.read_csv(csv_path)

    # Ensure numeric (important if CSV has 1e+00 etc.)
    for c in [A1_col, A2_col, A3_col, B1_col]:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=[A1_col, A2_col, A3_col, B1_col])

    A2_vals = np.sort(df[A2_col].unique())
    A3_vals = np.sort(df[A3_col].unique())
    A1_name = A1_col
    A2_name = A2_col
    A3_name = A3_col 
    B1_name = B1_col

    # Expect 10 each, but don't hard-fail if not
    print(f"Unique A2: {len(A2_vals)} values -> {A2_vals}")
    print(f"Unique A3: {len(A3_vals)} values -> {A3_vals}")

    for a2 in A2_vals:
        fig, axes = plt.subplots(2, 5, figsize=(15, 6), sharex=True, sharey=True)
        axes = axes.ravel()

        for k, a3 in enumerate(A3_vals[:10]):  # take first 10 A3s
            ax = axes[k]
            g = df[(df[A2_col] == a2) & (df[A3_col] == a3)].sort_values(A1_col)
            # print(g)

            if g.empty:
                ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
            else:
                ax.scatter(g[A1_col], g[B1_col], s=20)
                ax.plot(g[A1_col], g[B1_col], linewidth=1)  # optional: connect points

            ax.set_title(f"A3={a3:g}", fontsize=10)

        # If there are fewer than 10 A3 values, hide unused axes
        for k in range(min(len(A3_vals), 10), 10):
            axes[k].axis("off")

        # fig.suptitle(f"A1 vs B1 (fixed A2={a2:g})", fontsize=14)
        fig.suptitle(f"{A1_name} vs {B1_name} (fixed {A2_name}={a2:g})", fontsize=14)

        fig.supxlabel(A1_name)
        fig.supylabel(B1_name)
        fig.tight_layout(rect=[0, 0, 1, 0.97])

        fname = f"{A2_name}_{a2:g}".replace(".", "p")
        fig.savefig(os.path.join(out_dir, fname), dpi=dpi)
        plt.close(fig)

    print(f"Saved 10 panel PNGs to: {out_dir}")
    

In [25]:
make_10_pngs_each_5x2("../experimental_data/R0_sigma_Dimmunity_results_v2.csv")

Unique A2: 10 values -> [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]
Unique A3: 10 values -> [0.05 0.1  0.15 0.2  0.25 0.3  0.35 0.4  0.45 0.5 ]
        R0  sigma  Dimmunity   avg_time  max_time  num_strains  \
0      1.0    0.1       0.05   1.500000       3.0          8.0   
100    2.0    0.1       0.05  14.636364      23.0         11.0   
200    3.0    0.1       0.05  21.100000      23.0         10.0   
300    4.0    0.1       0.05  19.785714      23.0         14.0   
400    5.0    0.1       0.05  19.615385      23.0         13.0   
500    6.0    0.1       0.05  21.470588      23.0         17.0   
600    7.0    0.1       0.05  19.176471      23.0         17.0   
700    8.0    0.1       0.05  21.312500      23.0         16.0   
800    9.0    0.1       0.05  22.500000      23.0         18.0   
900   10.0    0.1       0.05  21.000000      23.0         19.0   
1000  11.0    0.1       0.05  22.526316      23.0         19.0   
1100  12.0    0.1       0.05  21.684211      23.0         19.0   
