# CLIP LoRA Prompting Experiment Notebook

This notebook mirrors `experiments/clip_lora_prompting_experiment.py` while adding short notes so every logical block is easier to follow.

> **Tip:** The notebook resets `sys.argv` before running the main experiment so that the argument parser sees the same defaults you would get when executing the Python script from the command line. Adjust the arguments manually if you want to experiment with different settings.

In [None]:
import argparse
import csv
import json
import os
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional

import torch
from peft import LoraConfig, LoraModel

from experiments.utils import compute_init_ctx, export_family_embeddings

os.environ.setdefault("NUMBA_DISABLE_CACHING", "1")

try:
    NOTEBOOK_PATH = Path(__file__).resolve()
    PROJECT_ROOT = NOTEBOOK_PATH.parent.parent
except NameError:
    PROJECT_ROOT = Path.cwd().resolve()

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from ecogrow.benchmark.ecogrow_benchmark import EcogrowBenchmark
from ecogrow.models.open_clip_wrapper import init_open_clip, FamilyAdaptedClipDetector
from ecogrow.preprocessing.image_segmentator import (
    black_bg_composite,
    crop_to_alpha_bbox,
    segment_plant_rgba,
)
from ecogrow.training.prompt_learners import ClipPromptLearner
from ecogrow.training.trainers import ClipFineTuneEngine
from ecogrow.data.plant_data import PlantData, make_segment_fn, DISEASE_MAPPING
from ecogrow.models.checkpoint_cache import ensure_mobileclip_checkpoint

_last_run_context = None


## Configuration Handling

We keep the same dataclass and CLI parsing logic from the script so the training run behaves identically regardless of whether it is launched from a terminal or this notebook.

In [None]:
@dataclass(frozen=True)
class Config:
    dataset_path: Path
    embeddings_dir: Optional[Path]
    prompts_config: Path
    run_id: str
    exp_dir: Path
    epochs: int
    batch_size: int
    perc_eval: float
    lr: float
    classifier_dropout: float


def _parse_args() -> Config:
    parser = argparse.ArgumentParser(description="EcoGrow CLIP fine-tuning experiment")
    parser.add_argument(
        "--dataset-path",
        default="datasets",
        help="Percorso della directory del dataset (es. data/Indoor-Plant-disease-dataset-1)",
    )
    parser.add_argument(
        "--embeddings-dir",
        default="artifacts/embeddings",
        help="Directory dove salvare gli embedding testuali esportati (per l'inferenza).",
    )
    parser.add_argument(
        "--prompts-config",
        default="experiments/prompts.json",
        help="File JSON con la configurazione delle famiglie e relative classi",
    )
    parser.add_argument(
        "--run-id",
        default=None,
        help="Identificativo della run; se non specificato viene derivato dal file di prompt.",
    )
    parser.add_argument(
        "--exp-dir",
        default="experiments",
        help="Directory principale dove salvare i risultati.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        help="Numero di epoche di fine-tuning per ciascuna famiglia.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Batch size utilizzata durante il fine-tuning.",
    )
    parser.add_argument(
        "--perc-eval",
        type=float,
        default=0.2,
        help="Frazione del training da usare come validation (0 disabilita lo split).",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=5e-4,
        help="Learning rate per l'ottimizzatore AdamW.",
    )
    parser.add_argument(
        "--classifier-dropout",
        type=float,
        default=0.1,
        help="Dropout applicato prima del classificatore lineare.",
    )
    args = parser.parse_args()

    dataset_path = Path(args.dataset_path).expanduser().resolve()
    if not dataset_path.is_dir():
        raise FileNotFoundError(f"Dataset path '{dataset_path}' non esiste o non è una directory.")

    embeddings_dir = None
    if args.embeddings_dir:
        embeddings_dir = Path(args.embeddings_dir).expanduser().resolve()
        embeddings_dir.mkdir(parents=True, exist_ok=True)

    prompts_path = Path(args.prompts_config).expanduser().resolve()
    if not prompts_path.is_file():
        raise FileNotFoundError(f"Prompts config '{prompts_path}' non esiste.")

    exp_dir = Path(args.exp_dir).expanduser().resolve()
    exp_dir.mkdir(parents=True, exist_ok=True)

    run_id = args.run_id or f"clip_lora_finetuning_{prompts_path.stem}"

    return Config(
        dataset_path=dataset_path,
        embeddings_dir=embeddings_dir,
        prompts_config=prompts_path,
        run_id=run_id,
        exp_dir=exp_dir,
        epochs=args.epochs,
        batch_size=args.batch_size,
        perc_eval=max(0.0, float(args.perc_eval)),
        lr=args.lr,
        classifier_dropout=max(0.0, float(args.classifier_dropout)),
    )

## Prompt Helpers

The following utilities manipulate the prompt configuration exactly as in the script so that downstream training uses the same class ordering.

In [None]:
def _canonicalize_label(label: str) -> str:
    normalized = label.replace("-", "_").replace(" ", "_").lower()
    for alias, canonical in DISEASE_MAPPING.items():
        if alias in normalized:
            return canonical
    return normalized


def _collect_class_prompts(prompt_config: Dict[str, object]) -> Dict[str, List[str]]:
    """Extract prompt texts per disease class from the JSON config."""

    per_class: Dict[str, List[str]] = defaultdict(list)

    def _append(label: str, value) -> None:
        if not label:
            return
        canonical = _canonicalize_label(label)

        if isinstance(value, str):
            text = value.strip()
            if text:
                per_class[canonical].append(text)
        elif isinstance(value, (list, tuple, set)):
            for item in value:
                _append(label, item)
        elif isinstance(value, dict):
            for nested in value.values():
                _append(label, nested)

    for family_payload in prompt_config.values():
        if isinstance(family_payload, dict):
            for disease_label, entries in family_payload.items():
                _append(disease_label, entries)

    return per_class


def _default_prompt(label: str) -> str:
    pretty = label.replace("_", " ")
    return f"a close-up photo of a plant showing {pretty}"

## Run the Experiment

The function below is a literal translation of the script's `main` routine, with only notebook-friendly tweaks. Comments highlight why the steps happen.

In [None]:
def run_clip_lora_prompting_experiment() -> Dict[str, Dict[str, object]]:
    global _last_run_context
    config = _parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = "MobileCLIP-S2"
    pretrained_tag = ensure_mobileclip_checkpoint(model_name=model_name)
    clip_model, preprocess, tokenizer, text_encoder = init_open_clip(
        model_name=model_name,
        pretrained_tag=pretrained_tag,
        device=device,
    )

    candidate_targets = [
        "token_mixer.qkv",
        "token_mixer.proj",
        "head.fc"
    ]

    submodule_names = [name for name, _ in clip_model.visual.named_modules()]
    filtered_targets = [t for t in candidate_targets if any(t in n for n in submodule_names)]
    if not filtered_targets:
        fallback = ["attn.qkv", "attn.proj", "mlp.fc1", "mlp.fc2", "qkv", "proj"]
        filtered_targets = [t for t in fallback if any(t in n for n in submodule_names)]

    print(f"[LoRA] target_modules candidates matched: {filtered_targets if filtered_targets else 'NONE'}")

    base_visual = clip_model.visual
    for p in base_visual.parameters():
        p.requires_grad_(False)

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=filtered_targets if filtered_targets else candidate_targets,
        bias="none",
        task_type="FEATURE_EXTRACTION",
    )

    clip_model.visual = LoraModel(base_visual, lora_config, adapter_name="default")

    num_adapters = sum(1 for m in clip_model.visual.modules() if hasattr(m, "lora_A") and hasattr(m, "lora_B"))
    num_lora_params = sum(p.numel() for n, p in clip_model.visual.named_parameters() if p.requires_grad and "lora_" in n)
    total_trainable_visual = sum(p.numel() for p in clip_model.visual.parameters() if p.requires_grad)
    print(f"[LoRA] adapters inserted: {num_adapters}")
    print(f"[LoRA] trainable LoRA params: {num_lora_params}")
    print(f"[LoRA] total trainable params in visual: {total_trainable_visual}")

    benchmark = EcogrowBenchmark(
        run_id=config.run_id,
        exp_dir=str(config.exp_dir),
        data_root=str(config.dataset_path),
    )

    segment_fn = make_segment_fn(
        segment_plant_rgba,
        crop_to_alpha_bbox,
        black_bg_composite,
        pad=12,
    )
    with open(config.prompts_config, "r", encoding="utf-8") as f:
        prompt_config = json.load(f)

    families = tuple(dict.fromkeys(prompt_config.keys()))
    if not families:
        raise ValueError("Prompts config must define at least one family.")

    preview_dataset = PlantData(
        dataset_root=config.dataset_path,
        families=families,
        split="train",
        segment_fn=segment_fn,
        transform=preprocess,
    )
    classnames = preview_dataset.classes

    prompt_texts_map = _collect_class_prompts(prompt_config)
    class_prompt_texts: List[str] = []
    for cls in classnames:
        prompts_for_class = prompt_texts_map.get(cls)
        if prompts_for_class:
            class_prompt_texts.append(prompts_for_class[0])
        else:
            class_prompt_texts.append(_default_prompt(cls))

    clip_model.to(device)

    family_detector = FamilyAdaptedClipDetector(
        name="global",
        classes=classnames,
        clip_model=clip_model,
        preprocess=preprocess,
        device=device,
        feature_dropout=config.classifier_dropout,
        train_backbone=True,
        text_encoder=text_encoder,
    )
    ctx_init = compute_init_ctx(
        n_ctx=16,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        class_prompts=class_prompt_texts,
    )

    prompt_learner = ClipPromptLearner(
        classnames=classnames,
        text_encoder=text_encoder,
        ctx_vectors=ctx_init,
        model_name=model_name,
    ).to(device)

    trainer = ClipFineTuneEngine(
        family_detector=family_detector,
        prompt_learner=prompt_learner
    )

    fit_args = {
        "epochs": config.epochs,
        "batch_size": config.batch_size,
        "lr": config.lr,
        "log_fn": lambda msg: print(f"[GLOBAL] {msg}"),
    }

    result = benchmark.run(
        trainer=trainer,
        segment_fn=segment_fn,
        families=families,
        perc_eval=None,
        fit_predictor_args=fit_args,
    )

    test_metrics = None
    result["test_samples"] = 0
    try:
        test_dataset = PlantData(
            dataset_root=config.dataset_path,
            families=families,
            split="test",
            segment_fn=segment_fn,
            transform=preprocess,
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=config.batch_size,
            shuffle=False,
        )
        result["test_samples"] = len(test_dataset)
        test_epoch = trainer.eval(test_loader)
        test_metrics = {"loss": test_epoch.loss, "f1": test_epoch.f1}
        result["test_metrics"] = test_metrics
    except FileNotFoundError as exc:
        print(f"[WARN] Test split unavailable: {exc}")
        result["test_metrics"] = None

    try:
        lora_dir = Path(benchmark.run_dir) / "lora"
        lora_dir.mkdir(parents=True, exist_ok=True)
        clip_model.visual.save_pretrained(lora_dir)
        print(f"[LoRA] adapter saved to {lora_dir}")
    except Exception as e:
        print(f"[LoRA][WARN] failed to save adapter: {e}")

    embeddings_dir = (
        config.embeddings_dir
        if config.embeddings_dir is not None
        else Path(benchmark.run_dir) / "embeddings"
    )
    embeddings_dir.mkdir(parents=True, exist_ok=True)
    embedding_file = embeddings_dir / f"{config.run_id}_{family_detector.name}.pt"
    export_family_embeddings(
        embedding_file,
        family_detector.name,
        classnames,
        prompt_learner,
        text_encoder,
        temperature=family_detector.temperature,
    )
    print(f"[PROMPTS] embeddings exported to {embedding_file}")

    eval_metrics = result.get("eval_metrics")
    test_metrics = result.get("test_metrics")
    summary_row = {
        "family_id": "global",
        "train_samples": result["train_samples"],
        "eval_samples": result["eval_samples"],
        "test_samples": result["test_samples"],
        "eval_loss": eval_metrics["loss"] if eval_metrics else None,
        "eval_f1": eval_metrics["f1"] if eval_metrics else None,
        "test_loss": test_metrics["loss"] if test_metrics else None,
        "test_f1": test_metrics["f1"] if test_metrics else None,
        "temperature": result.get("temperature"),
    }

    csv_path = Path(benchmark.run_dir) / "results.csv"
    fieldnames = list(summary_row.keys())
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerow(summary_row)
    print(f"Results saved to {csv_path}")

    _last_run_context = {
        "config": config,
        "families": families,
        "segment_fn": segment_fn,
        "preprocess": preprocess,
        "family_detector": family_detector,
        "prompt_learner": prompt_learner,
        "classnames": classnames,
    }
    return result

## Execute with Default Arguments

Reset `sys.argv` (so argparse ignores notebook-specific flags) and launch the experiment. Edit the list if you want to pass custom CLI options.

In [None]:
notebook_args = ['clip_lora_prompting_experiment.py']
if len(sys.argv) > 1:
    sys.argv = notebook_args

result = run_clip_lora_prompting_experiment()
result

## Confusion Matrix

After running the experiment you can visualize how predictions distribute across classes. Change `split_to_plot` to `"test"` if you also want to inspect the test split.

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot_confusion_matrix_for_split(split="val"):
    if not _last_run_context:
        raise RuntimeError("Run the experiment cell first so the context is available.")

    ctx = _last_run_context
    dataset = PlantData(
        dataset_root=ctx["config"].dataset_path,
        families=ctx["families"],
        split=split,
        segment_fn=ctx["segment_fn"],
        transform=ctx["preprocess"],
    )
    if len(dataset) == 0:
        raise ValueError(f"Split '{split}' has no samples.")

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=ctx["config"].batch_size,
        shuffle=False,
    )

    detector = ctx["family_detector"]
    prompt_learner = ctx["prompt_learner"]
    prompts_embeds = None
    tokenized_prompts = None
    if prompt_learner is not None:
        prompt_learner.eval()
        with torch.no_grad():
            prompts_embeds, tokenized_prompts = prompt_learner()

    num_classes = len(ctx["classnames"])
    cm = torch.zeros((num_classes, num_classes), dtype=torch.int64)
    detector.classifier.eval()

    for xb, yb in loader:
        xb = xb.to(detector.device)
        yb = yb.to(detector.device)
        with torch.no_grad():
            logits = detector.logits(
                xb,
                prompts_embeds=prompts_embeds,
                tokenized_prompts=tokenized_prompts,
            )
        preds = logits.argmax(dim=-1)
        idx = (yb.view(-1) * num_classes + preds.view(-1)).to(torch.long).cpu()
        cm += torch.bincount(idx, minlength=num_classes * num_classes).view(num_classes, num_classes)

    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(cm.numpy(), interpolation="nearest", cmap="Blues")
    ax.set_title(f"Confusion matrix ({split} split)")
    ax.set_xlabel("Predicted class")
    ax.set_ylabel("True class")
    tick_marks = np.arange(num_classes)
    ax.set_xticks(tick_marks)
    ax.set_yticks(tick_marks)
    ax.set_xticklabels(ctx["classnames"], rotation=90)
    ax.set_yticklabels(ctx["classnames"])
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    cm_np = cm.numpy()
    max_val = cm_np.max() if cm_np.size else 1
    for i in range(num_classes):
        for j in range(num_classes):
            value = int(cm_np[i, j])
            color = "white" if value > max_val * 0.5 else "black"
            ax.text(j, i, value, ha="center", va="center", color=color, fontsize=8)

    fig.tight_layout()
    plt.show()
    return cm


split_to_plot = "val"  # change to "test" to inspect the held-out split
confusion_matrix = plot_confusion_matrix_for_split(split=split_to_plot)
confusion_matrix