In [None]:
import matplotlib.pyplot as plt
import matplotlib
import scienceplots

import numpy as np
import pandas as pd

import os
import shutil

os.environ["PATH"] += os.pathsep + "/Library/TeX/texbin"
print("LaTeX found:", shutil.which('latex'))  # Should return a valid path

# Enable LaTeX
matplotlib.rcParams['text.usetex'] = True

# Set the style
plt.style.use(['science','ieee', 'std-colors'])
# plt.rcParams.update({'figure.dpi': '800'})
# plt.rcParams.update(plt.rcParamsDefault)

In [None]:
import os
import json
import pandas as pd
from pathlib import Path

# Base directory
base_dir = Path('/Data/phi2FM_n_shot')

records = []

# Walk through modes, models, tasks
for mode in ['finetuning', 'lp']:
    mode_dir = base_dir / mode
    if not mode_dir.exists():
        continue
    for model_dir in mode_dir.iterdir():
        if not model_dir.is_dir():
            continue
        for task_dir in model_dir.iterdir():
            if not task_dir.is_dir():
                continue
            # nested task repetition
            nested = task_dir / task_dir.name
            if not nested.exists():
                continue
            for run_dir in nested.iterdir():
                if not run_dir.is_dir():
                    continue
                artifacts_path = run_dir / 'artifacts.json'
                if not artifacts_path.exists():
                    continue
                # parse parameters from path
                # folder name ends with _{shots}
                shots = run_dir.name.split('_')[-1]
                try:
                    n_shots = int(shots)
                except ValueError:
                    n_shots = None
                # load json
                data = json.load(open(artifacts_path))
                metrics = data.get('test_metrics', {})
                record = {
                    'mode': mode,
                    'model': model_dir.name,
                    'task': task_dir.name,
                    'n_shots': n_shots,
                    'precision_micro': metrics.get('precision_micro'),
                    'precision_macro': metrics.get('precision_macro'),
                    "recall_micro": metrics.get('recall_micro'),
                    "recall_macro": metrics.get("recall_macro"),
                    "f1_micro": metrics.get("f1_micro"), 
                    "f1_macro": metrics.get("f1_macro"),
                    'accuracy': metrics.get('acc')
                }
                # optionally include class-wise
                for i, p in enumerate(metrics.get('precision_per_class', [])):
                    record[f'precision_class_{i}'] = p
                records.append(record)

# Create DataFrame
df = pd.DataFrame(records)



In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import LogLocator, FormatStrFormatter, LogFormatterMathtext

# ─────────────────────────────────────────────────────────────────────────────
# 1.  Read & reshape the CSV
# ─────────────────────────────────────────────────────────────────────────────

FILENAME  = "data_plot_v2.csv"
FEW_SHOT  = [50, 100, 500, 1000, 5000]

#df = pd.read_csv(FILENAME, dtype={"model": str, "task": str})
#df.columns = (
#    ["model", "task", "ignore"]
#    + [f"lp_{s}" for s in FEW_SHOT]
#    + [f"f_{s}" for s in FEW_SHOT]
#)

model_names = {
    "CaCO": "phi2_caco",
    "DINO": "phi2_dino", 
    "GASSL": "phi2_gassl", 
    "GeoAware": "phi2_GeoAware", 
    "MoCo": "phi2_moco", 
    "PhisatNet": "phisatnet_downstream", 
    "Prithvi 1.0": "phi2_prithvi", 
    "SatMAE": "phi2_SatMAE", 
    "SeCo": "phi2_seasonal_contrast", 
    "UniPhi": "phi2_phileo_precursor", 
}

# data = {}
# for _, row in df.iterrows():
#     m, t = row["model"], row["task"]
#     if pd.notna(m) and pd.notna(t):
#         data.setdefault(m, {})[t] = {
#             "lp": [row[f"lp_{s}"] if pd.notna(row[f"lp_{s}"]) else None for s in FEW_SHOT],
#             "f":  [row[f"f_{s}"]  if pd.notna(row[f"f_{s}"])  else None for s in FEW_SHOT],
#         }

# ─────────────────────────────────────────────────────────────────────────────
# 2.  Task meta-data
# ─────────────────────────────────────────────────────────────────────────────
TASKS = [
    ("Fire Classification \n (Image-Level)",          "fire", "Micro F1"),
    ("Burned Area Segmentation \n (Pixel-Level)",    "burned_area", "Micro F1"),
    #("Land Cover Segmentation \n (Pixel-Level)",     "lc",  "Micro F1"),
    #("Land Cover Classification \n (Image-Level)",   "lcc", "Micro F1"),
    #("Building Density Regression \n (Pixel-Level)", "blg", "MSE"),
    #("Road Density Regression \n (Pixel-Level)",     "rds", "MSE"),
]

GROUPS = {
    "U-Net":  ["PhisatNet", "GeoAware", "UniPhi"],
    "ResNet": ["MoCo", "DINO", "SeCo", "CaCO", "GASSL"],
    "ViT":    ["SatMAE", "Prithvi 1.0"],
}
MODEL_ORDER = [m for models in GROUPS.values() for m in models]

# ─────────────────────────────────────────────────────────────────────────────
# 3.  Explicit color map + markers + linestyles
# ─────────────────────────────────────────────────────────────────────────────
COLOR_MAP = {
    #"$\Phi$satNet":   "#800000",
    "PhisatNet":   "#800000",
    "GeoAware":    "#ff3300",
    "UniPhi":      "#ff9933",
    "MoCo":        "#33cc33",
    "DINO":        "#00720D",
    "SeCo":        "#00BFFF",
    "CaCO":        "#3030C1",
    "GASSL":       "#C26FFF",
    "Prithvi 1.0": "#000000",
    "SatMAE":      "#7C7C7C",
}

LINESTYLES = {"U-Net": "-", "ResNet": "--", "ViT": ":"}
MARKERS    = {
    "U-Net":  ["o", "s", "^"],
    "ResNet": ["v", ">", "<", "D", "P"],
    "ViT":    ["X", "*"],
}

model_styles = {}
for group, models in GROUPS.items():
    for idx, model in enumerate(models):
        model_styles[model] = dict(
            color     = COLOR_MAP[model],
            marker    = MARKERS[group][idx],
            linestyle = LINESTYLES[group],
            lw        = 3.0 if model == "$\Phi$satNet" else 1.6,
            ms        = 6   if model == "$\Phi$satNet" else 4,
            zorder    = 4   if model == "$\Phi$satNet" else 2,
        )

labels = [f"{chr(ord('a') + i)})" for i in range(8)]

# ─────────────────────────────────────────────────────────────────────────────
# 4.  Build figure
# ─────────────────────────────────────────────────────────────────────────────
fig, axes = plt.subplots(2, 4, figsize=(11, 8), dpi=350)

for col, (title, key, ylabel) in enumerate(TASKS):
    ax_lp, ax_ft = axes[0, col], axes[1, col]

    # draw curves in fixed order
    for model in MODEL_ORDER:
        # if key in data.get(model, {}):
        #     st = model_styles[model]
        #     lp_data = data[model][key]["lp"]
        #     ft_data = data[model][key]["f"]
        #     if key in {"blg", "rds"}:
        #         lp_data = [np.power(v, 1) if v is not None else None for v in lp_data]
        #         ft_data = [np.power(v, 1) if v is not None else None for v in ft_data]

        #     ax_lp.plot(FEW_SHOT, lp_data, label=model, **st)
        #     ax_ft.plot(FEW_SHOT, ft_data,  label=model, **st)
            sub_lp = df[(df["task"] == key) & (df["mode"] == "lp") & (df["model"] == model_names[model])]
            if not sub_lp.empty:
                # sort by n_shots
                sub_lp = sub_lp.sort_values("n_shots")
                xs = sub_lp["n_shots"].values
                ys = sub_lp["f1_micro"].values
                st = model_styles.get(model, {})
                ax_lp.plot(xs, ys, label=model, **st)

            # FT subplot
            sub_ft = df[(df["task"] == key) & (df["mode"] == "finetuning") & (df["model"] == model_names[model])]
            if not sub_ft.empty:
                sub_ft = sub_ft.sort_values("n_shots")
                xs = sub_ft["n_shots"].values
                ys = sub_ft["f1_micro"].values
                st = model_styles.get(model, {})
                ax_ft.plot(xs, ys, label=model, **st)
        

    # common cosmetics
    for ax in (ax_lp, ax_ft):
        ax.set_xscale("log")
        ax.set_xticks(FEW_SHOT)
        ax.set_xticklabels([str(s) for s in FEW_SHOT], fontsize=9)
        ax.set_xlabel("n-shot", fontsize=9.5)
        ax.margins(x=0.05)
        ax.grid(True, which="major", lw=0.4, alpha=0.35)

    # set y-label closer to axis
    ax_lp.set_ylabel(ylabel, labelpad=2, fontsize=9)
    ax_ft.set_ylabel(ylabel, labelpad=2, fontsize=9)

    ax_lp.set_title(f"{labels[col]} {title} - Probing", fontsize=12)
    ax_ft.set_title(f"{labels[4 + col]} {title} - Fine-Tuning", fontsize=12)

    # enhanced log–log for regression
    if key in {"blg", "rds"}:
        for ax in (ax_lp, ax_ft):
            ax.set_yscale("log")
            ax.yaxis.set_major_locator(LogLocator(base=10))
            ax.yaxis.set_minor_locator(LogLocator(base=10, subs=(2,5)))
            ax.yaxis.set_major_formatter(LogFormatterMathtext())
            ax.grid(True, which="minor", lw=0.25, alpha=0.2)

# ─────────────────────────────────────────────────────────────────────────────
# 5.  Unified legend
# ─────────────────────────────────────────────────────────────────────────────
handles = [mpl.lines.Line2D([], [], **model_styles[m], label=m)
           for m in MODEL_ORDER]
legend = fig.legend(
    handles=handles,
    labels=MODEL_ORDER,
    loc="upper center",
    bbox_to_anchor=(0.5, 0.03),
    ncol=len(MODEL_ORDER),
    frameon=True,
    fontsize=9.5,
)
for txt in legend.get_texts():
    if txt.get_text() == "$\Phi$satNet":
        txt.set_weight("bold")

# ─────────────────────────────────────────────────────────────────────────────
# 6.  Save & show
# ─────────────────────────────────────────────────────────────────────────────
plt.tight_layout(rect=[0, 0.02, 1, 1])
plt.savefig("few_shot_experiments_grid.pdf", bbox_inches="tight")
plt.savefig("few_shot_experiments_grid.png", dpi=300, bbox_inches="tight")
plt.show()
