# Gecho: Gemma Echo -- Automated Echocardiogram Reporting

A RAG system that extracts key frames from echocardiogram videos, retrieves similar known cases via **MedSigLIP** embeddings, and generates clinical reports with **MedGemma**.

**Models**: MedSigLIP-448 (retrieval) + MedGemma 1.5-4B-IT (generation)

**Dataset**: [EchoNet-Dynamic](https://echonet.github.io/dynamic/) (Stanford)

In [None]:
# Environment setup -- must run before any Keras imports
import os
os.environ["KERAS_BACKEND"] = "jax"

!pip install -q transformers keras keras-hub jax[cuda12] opencv-python faiss-cpu gradio Pillow huggingface_hub

## `config`

In [None]:
"""Configuration for Gecho pipeline."""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path

def _hf_login() -> None:
    """Authenticate with HuggingFace using Kaggle secrets or env var.

    On Kaggle: stores HF token as a Kaggle Secret named "HF_TOKEN".
    Locally: set the HF_TOKEN environment variable.
    """
    token = os.environ.get("HF_TOKEN")

    # Try Kaggle Secrets API (available inside Kaggle notebooks)
    if token is None:
        try:
            from kaggle_secrets import UserSecretsClient
            token = UserSecretsClient().get_secret("HF_TOKEN")
        except Exception:
            pass

    if token:
        try:
            from huggingface_hub import login
            login(token=token, add_to_git_credential=False)
            print("Authenticated with HuggingFace.")
        except Exception as e:
            print(f"[WARN] HF login failed: {e}")
    else:
        print(
            "[WARN] No HF_TOKEN found. Gated models will fail unless "
            "attached locally via Kaggle 'Add Model'."
        )

def _find_local_model(model_name: str) -> str | None:
    """Search common Kaggle input paths for a locally-attached model.

    Kaggle's "Add Model" feature places models under /kaggle/input/.
    Users typically attach them with slugs like 'medsiglip-448' or
    'medgemma-1-5-4b-it'.  We search for directories whose name
    contains the key part of the model identifier.
    """
    kaggle_input = Path("/kaggle/input")
    if not kaggle_input.exists():
        return None

    # Build search terms from the model name
    # "google/medsiglip-448" -> "medsiglip"
    # "google/medgemma-1.5-4b-it" -> "medgemma"
    search_term = model_name.split("/")[-1].split("-")[0].lower()

    for entry in sorted(kaggle_input.iterdir()):
        if not entry.is_dir():
            continue
        if search_term in entry.name.lower():
            # Kaggle model dirs can be nested: slug/framework/variant/version
            # Walk down to find a directory with config.json or similar
            for root, dirs, files in os.walk(entry):
                if "config.json" in files or "tokenizer.json" in files:
                    local_path = str(Path(root))
                    print(f"Found local model for '{model_name}': {local_path}")
                    return local_path
            # If no config.json found, return the top-level match
            local_path = str(entry)
            print(f"Found local model dir for '{model_name}': {local_path}")
            return local_path

    return None

@dataclass
class GechoConfig:
    """Central configuration for the Gecho pipeline."""

    # --- Dataset paths (Kaggle defaults) ---
    dataset_root: Path = Path("/kaggle/input/datasets/syxlicheng/heartdatabase/EchoNet-Dynamic")
    output_dir: Path = Path("/kaggle/working/gecho_output")

    # --- Model identifiers (HuggingFace Hub IDs) ---
    medsiglip_model_id: str = "google/medsiglip-448"
    medgemma_kerashub_preset: str = "medgemma_1.5_instruct_4b"
    medgemma_hf_model_id: str = "google/medgemma-1.5-4b-it"

    # --- Frame processing ---
    siglip_frame_size: int = 448
    medgemma_frame_size: int = 896

    # --- FAISS / retrieval ---
    faiss_top_k: int = 5
    embedding_dim: int = 768  # verified at runtime from first embedding

    # --- Generation parameters ---
    max_new_tokens: int = 1024
    sequence_length: int = 4096
    dtype: str = "bfloat16"

    # --- Derived paths (set in __post_init__) ---
    videos_dir: Path = field(init=False)
    file_list_path: Path = field(init=False)
    volume_tracings_path: Path = field(init=False)
    faiss_index_path: Path = field(init=False)
    faiss_metadata_path: Path = field(init=False)

    def __post_init__(self) -> None:
        # Authenticate with HuggingFace for gated models
        _hf_login()

        self.videos_dir = self.dataset_root / "Videos"
        self.file_list_path = self.dataset_root / "FileList.csv"
        self.volume_tracings_path = self.dataset_root / "VolumeTracings.csv"
        self.faiss_index_path = self.output_dir / "gecho_faiss.index"
        self.faiss_metadata_path = self.output_dir / "gecho_faiss_meta.pkl"

        os.makedirs(self.output_dir, exist_ok=True)

        # Resolve model paths: prefer locally-attached Kaggle models,
        # fall back to HuggingFace Hub IDs (requires authentication).
        self.medsiglip_model_id = (
            _find_local_model(self.medsiglip_model_id)
            or self.medsiglip_model_id
        )
        self.medgemma_hf_model_id = (
            _find_local_model(self.medgemma_hf_model_id)
            or self.medgemma_hf_model_id
        )

## `video_processor`

In [None]:
"""Video processing pipeline for echocardiogram frame extraction."""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import cv2
import numpy as np
import pandas as pd

# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------

@dataclass
class EchoFrames:
    """Extracted frames and metadata for a single echo video."""
    filename: str
    ed_frame: np.ndarray        # RGB uint8, resized to siglip_frame_size
    es_frame: np.ndarray        # RGB uint8, resized to siglip_frame_size
    ed_frame_idx: int
    es_frame_idx: int
    ef: float | None = None     # Ejection fraction (ground truth)
    esv: float | None = None    # End-systolic volume
    edv: float | None = None    # End-diastolic volume
    ef_category: str | None = None

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def classify_ef(ef: float) -> str:
    """Classify ejection fraction into clinical categories."""
    if ef >= 55:
        return "Normal"
    elif ef >= 45:
        return "Mild Dysfunction"
    elif ef >= 30:
        return "Moderate Dysfunction"
    else:
        return "Severe Dysfunction"

def load_file_list(config: GechoConfig, split: str = "TRAIN") -> pd.DataFrame:
    """Load EchoNet FileList.csv and filter by split.

    Standard columns: FileName, EF, ESV, EDV, FrameHeight, FrameWidth,
                      FPS, NumberOfFrames, Split
    Some versions also include EDFrame, ESFrame directly.
    """
    df = pd.read_csv(config.file_list_path)
    df = df[df["Split"].str.upper() == split.upper()].reset_index(drop=True)
    return df

def load_frame_indices(config: GechoConfig) -> dict[str, tuple[int, int]]:
    """Derive ED and ES frame indices from VolumeTracings.csv.

    VolumeTracings.csv has one row per tracing point with columns:
      FileName, X1, Y1, X2, Y2, Frame

    Each video has tracings at exactly two frames (ED and ES).
    The frame with the larger traced volume is ED (most dilated);
    the smaller is ES (most contracted).

    Returns dict mapping FileName -> (ed_frame_idx, es_frame_idx).
    """
    if not config.volume_tracings_path.exists():
        return {}

    vt = pd.read_csv(config.volume_tracings_path)

    # Each row is one tracing point.  Group by (FileName, Frame) to get
    # a rough volume proxy: count of tracing points per frame, or use
    # the traced coordinates to estimate area.  The simpler approach:
    # the two unique frame numbers per video, the larger volume (more
    # area enclosed) corresponds to ED.
    frame_indices: dict[str, tuple[int, int]] = {}

    for fname, group in vt.groupby("FileName"):
        frames = sorted(group["Frame"].unique())
        if len(frames) < 2:
            # Fallback: only one traced frame
            frame_indices[str(fname)] = (frames[0], frames[0])
            continue

        # Estimate enclosed area for each frame using the Shoelace formula
        # on the tracing points (X1,Y1 -> X2,Y2 are inner/outer wall).
        # Simpler proxy: sum of X2-X1 per frame ≈ cavity diameter sum.
        areas: dict[int, float] = {}
        for frame_num, fgroup in group.groupby("Frame"):
            # X1 = inner wall, X2 = outer wall (or vice versa).
            # Cavity width at each tracing line ≈ |X1 - X2|.
            areas[int(frame_num)] = float((fgroup["X2"] - fgroup["X1"]).abs().sum())

        # ED = largest cavity (most dilated), ES = smallest (most contracted)
        sorted_frames = sorted(areas.keys(), key=lambda f: areas[f], reverse=True)
        ed_idx = sorted_frames[0]
        es_idx = sorted_frames[-1]
        frame_indices[str(fname)] = (ed_idx, es_idx)

    return frame_indices

def extract_frame(video_path: str | Path, frame_idx: int) -> np.ndarray:
    """Extract a single frame from a video file, returned as RGB uint8."""
    cap = cv2.VideoCapture(str(video_path))
    try:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if not ret:
            raise RuntimeError(
                f"Could not read frame {frame_idx} from {video_path}"
            )
        return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    finally:
        cap.release()

def resize_frame(frame: np.ndarray, size: int) -> np.ndarray:
    """Resize frame to (size, size) using bilinear interpolation."""
    return cv2.resize(frame, (size, size), interpolation=cv2.INTER_LINEAR)

# ---------------------------------------------------------------------------
# Main extraction
# ---------------------------------------------------------------------------

def extract_echo_frames(
    video_path: str | Path,
    row: pd.Series,
    config: GechoConfig,
    ed_idx: int | None = None,
    es_idx: int | None = None,
) -> EchoFrames:
    """Extract ED + ES frames for one video.

    Frame indices can come from:
      1. Explicit ed_idx/es_idx arguments (from VolumeTracings)
      2. row["EDFrame"] / row["ESFrame"] columns (some CSV versions)
      3. Heuristic fallback (frame 0 and frame at ~33%)
    """
    # Resolve frame indices
    if ed_idx is None:
        if "EDFrame" in row.index:
            ed_idx = int(row["EDFrame"])
        else:
            ed_idx = 0

    if es_idx is None:
        if "ESFrame" in row.index:
            es_idx = int(row["ESFrame"])
        else:
            # Fallback: read total frame count and pick ~33%
            cap = cv2.VideoCapture(str(video_path))
            total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            cap.release()
            es_idx = max(1, int(total * 0.33))

    ed_raw = extract_frame(video_path, ed_idx)
    es_raw = extract_frame(video_path, es_idx)

    ed = resize_frame(ed_raw, config.siglip_frame_size)
    es = resize_frame(es_raw, config.siglip_frame_size)

    ef = float(row["EF"]) if "EF" in row.index else None

    return EchoFrames(
        filename=row["FileName"],
        ed_frame=ed,
        es_frame=es,
        ed_frame_idx=ed_idx,
        es_frame_idx=es_idx,
        ef=ef,
        esv=float(row["ESV"]) if "ESV" in row.index and pd.notna(row.get("ESV")) else None,
        edv=float(row["EDV"]) if "EDV" in row.index and pd.notna(row.get("EDV")) else None,
        ef_category=classify_ef(ef) if ef is not None else None,
    )

def extract_frames_from_upload(
    video_path: str | Path,
    config: GechoConfig,
) -> EchoFrames:
    """Heuristic frame extraction for user-uploaded videos.

    Without CSV metadata we use:
      - ED = frame 0  (heart typically most dilated at start of clip)
      - ES = frame at ~33% of total frames
    """
    cap = cv2.VideoCapture(str(video_path))
    try:
        total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total < 2:
            raise RuntimeError(f"Video too short ({total} frames): {video_path}")

        ed_idx = 0
        es_idx = max(1, int(total * 0.33))

        cap.set(cv2.CAP_PROP_POS_FRAMES, ed_idx)
        ret, ed_raw = cap.read()
        if not ret:
            raise RuntimeError(f"Cannot read ED frame from {video_path}")
        ed_raw = cv2.cvtColor(ed_raw, cv2.COLOR_BGR2RGB)

        cap.set(cv2.CAP_PROP_POS_FRAMES, es_idx)
        ret, es_raw = cap.read()
        if not ret:
            raise RuntimeError(f"Cannot read ES frame from {video_path}")
        es_raw = cv2.cvtColor(es_raw, cv2.COLOR_BGR2RGB)
    finally:
        cap.release()

    ed = resize_frame(ed_raw, config.siglip_frame_size)
    es = resize_frame(es_raw, config.siglip_frame_size)

    return EchoFrames(
        filename=Path(video_path).name,
        ed_frame=ed,
        es_frame=es,
        ed_frame_idx=ed_idx,
        es_frame_idx=es_idx,
    )

# ---------------------------------------------------------------------------
# Batch processing
# ---------------------------------------------------------------------------

def process_dataset(
    config: GechoConfig,
    split: str = "TRAIN",
    max_videos: int | None = None,
) -> list[EchoFrames]:
    """Process the EchoNet dataset split, returning extracted frames."""
    df = load_file_list(config, split)
    if max_videos is not None:
        df = df.head(max_videos)

    # Try to load ED/ES frame indices from VolumeTracings.csv
    has_frame_cols = "EDFrame" in df.columns and "ESFrame" in df.columns
    if has_frame_cols:
        frame_map: dict[str, tuple[int, int]] = {}
        print("Using EDFrame/ESFrame columns from FileList.csv")
    else:
        print("EDFrame/ESFrame not in FileList.csv, loading VolumeTracings.csv ...")
        frame_map = load_frame_indices(config)
        if frame_map:
            print(f"Loaded frame indices for {len(frame_map)} videos from VolumeTracings.csv")
        else:
            print("[WARN] No VolumeTracings.csv found, using heuristic frame selection")

    results: list[EchoFrames] = []
    for _, row in df.iterrows():
        fname = row["FileName"]
        # EchoNet filenames may or may not have .avi extension
        video_name = fname if fname.endswith(".avi") else f"{fname}.avi"
        video_path = config.videos_dir / video_name

        if not video_path.exists():
            print(f"[WARN] Video not found, skipping: {video_path}")
            continue

        # Look up frame indices
        ed_idx, es_idx = None, None
        if not has_frame_cols and frame_map:
            # VolumeTracings keys may or may not have .avi
            indices = frame_map.get(fname) or frame_map.get(video_name)
            if indices:
                ed_idx, es_idx = indices

        try:
            frames = extract_echo_frames(
                video_path, row, config, ed_idx=ed_idx, es_idx=es_idx
            )
            results.append(frames)
        except Exception as e:
            print(f"[WARN] Failed to process {fname}: {e}")
            continue

    print(f"Processed {len(results)}/{len(df)} videos from {split} split.")
    return results

## `embedding_engine`

In [None]:
"""MedSigLIP embedding engine with FAISS retrieval."""

from __future__ import annotations

import pickle
from collections import Counter
from dataclasses import dataclass, field
from pathlib import Path

import faiss
import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------

@dataclass
class RetrievedCase:
    """A single retrieved case from the FAISS index."""
    filename: str
    ef: float | None
    ef_category: str | None
    similarity_score: float
    frame_type: str  # "ED" or "ES"

@dataclass
class RetrievalResult:
    """Aggregated result from a FAISS query."""
    cases: list[RetrievedCase]
    mean_ef: float | None = None
    ef_std: float | None = None
    consensus_category: str | None = None

# ---------------------------------------------------------------------------
# Index metadata
# ---------------------------------------------------------------------------

@dataclass
class IndexEntry:
    """Metadata stored alongside each FAISS vector."""
    filename: str
    frame_type: str  # "ED" or "ES"
    ef: float | None = None
    ef_category: str | None = None

# ---------------------------------------------------------------------------
# Embedding Engine
# ---------------------------------------------------------------------------

class EmbeddingEngine:
    """MedSigLIP-based embedding engine with FAISS retrieval."""

    def __init__(self, config: GechoConfig) -> None:
        self.config = config
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        print(f"Loading MedSigLIP from {config.medsiglip_model_id} ...")
        self.processor = AutoProcessor.from_pretrained(config.medsiglip_model_id)
        self.model = AutoModel.from_pretrained(
            config.medsiglip_model_id,
            torch_dtype=torch.float16,
        ).to(self.device).eval()
        print(f"MedSigLIP loaded on {self.device}.")

        self.index: faiss.IndexFlatIP | None = None
        self.metadata: list[IndexEntry] = []

    # --- Encoding ---------------------------------------------------------

    @torch.no_grad()
    def encode_image(self, image: np.ndarray | Image.Image) -> np.ndarray:
        """Encode a single image to an L2-normalized embedding vector."""
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        inputs = self.processor(images=image, return_tensors="pt").to(self.device)
        emb = self.model.get_image_features(**inputs)
        emb = emb / emb.norm(dim=-1, keepdim=True)
        return emb.cpu().numpy().astype(np.float32).squeeze()

    @torch.no_grad()
    def encode_batch(
        self, images: list[np.ndarray | Image.Image], batch_size: int = 32
    ) -> np.ndarray:
        """Encode a batch of images to L2-normalized embeddings."""
        all_embs: list[np.ndarray] = []
        pil_images = [
            Image.fromarray(img) if isinstance(img, np.ndarray) else img
            for img in images
        ]

        for i in range(0, len(pil_images), batch_size):
            batch = pil_images[i : i + batch_size]
            inputs = self.processor(images=batch, return_tensors="pt").to(self.device)
            emb = self.model.get_image_features(**inputs)
            emb = emb / emb.norm(dim=-1, keepdim=True)
            all_embs.append(emb.cpu().numpy().astype(np.float32))

        return np.concatenate(all_embs, axis=0)

    @torch.no_grad()
    def encode_text(self, text: str) -> np.ndarray:
        """Encode text to an L2-normalized embedding vector."""
        inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(
            self.device
        )
        emb = self.model.get_text_features(**inputs)
        emb = emb / emb.norm(dim=-1, keepdim=True)
        return emb.cpu().numpy().astype(np.float32).squeeze()

    # --- Index building ---------------------------------------------------

    def build_index(self, echo_frames_list: list[EchoFrames]) -> None:
        """Build a FAISS inner-product index from ED+ES frames."""
        images: list[np.ndarray] = []
        entries: list[IndexEntry] = []

        for ef in echo_frames_list:
            for frame, ftype in [(ef.ed_frame, "ED"), (ef.es_frame, "ES")]:
                images.append(frame)
                entries.append(IndexEntry(
                    filename=ef.filename,
                    frame_type=ftype,
                    ef=ef.ef,
                    ef_category=ef.ef_category,
                ))

        print(f"Encoding {len(images)} frames ...")
        embeddings = self.encode_batch(images)

        # Update embedding_dim from actual data
        dim = embeddings.shape[1]
        if dim != self.config.embedding_dim:
            print(f"Updating embedding_dim: {self.config.embedding_dim} -> {dim}")
            self.config.embedding_dim = dim

        self.index = faiss.IndexFlatIP(dim)
        self.index.add(embeddings)
        self.metadata = entries
        print(f"FAISS index built: {self.index.ntotal} vectors, dim={dim}.")

    # --- Persistence ------------------------------------------------------

    def save_index(self, config: GechoConfig | None = None) -> None:
        """Save FAISS index and metadata to disk."""
        cfg = config or self.config
        if self.index is None:
            raise RuntimeError("No index to save. Call build_index first.")

        faiss.write_index(self.index, str(cfg.faiss_index_path))
        with open(cfg.faiss_metadata_path, "wb") as f:
            pickle.dump(self.metadata, f)
        print(f"Index saved to {cfg.faiss_index_path}")

    def load_index(self, config: GechoConfig | None = None) -> None:
        """Load FAISS index and metadata from disk."""
        cfg = config or self.config
        self.index = faiss.read_index(str(cfg.faiss_index_path))
        with open(cfg.faiss_metadata_path, "rb") as f:
            self.metadata = pickle.load(f)
        print(f"Index loaded: {self.index.ntotal} vectors.")

    # --- Querying ---------------------------------------------------------

    def query(
        self,
        image: np.ndarray | Image.Image,
        frame_type: str = "ED",
        top_k: int | None = None,
    ) -> RetrievalResult:
        """Find the most similar cases in the FAISS index."""
        if self.index is None:
            raise RuntimeError("No index loaded. Call build_index or load_index.")

        k = top_k or self.config.faiss_top_k
        emb = self.encode_image(image).reshape(1, -1)
        scores, indices = self.index.search(emb, k * 2)  # over-fetch to filter

        cases: list[RetrievedCase] = []
        seen: set[str] = set()
        for score, idx in zip(scores[0], indices[0]):
            if idx < 0:
                continue
            entry = self.metadata[idx]
            # Optionally filter by frame type; deduplicate by filename
            if entry.filename in seen:
                continue
            seen.add(entry.filename)
            cases.append(RetrievedCase(
                filename=entry.filename,
                ef=entry.ef,
                ef_category=entry.ef_category,
                similarity_score=float(score),
                frame_type=entry.frame_type,
            ))
            if len(cases) >= (top_k or self.config.faiss_top_k):
                break

        # Aggregate statistics
        efs = [c.ef for c in cases if c.ef is not None]
        mean_ef = float(np.mean(efs)) if efs else None
        ef_std = float(np.std(efs)) if efs else None

        # Consensus category from most-common category
        cats = [c.ef_category for c in cases if c.ef_category]
        consensus = Counter(cats).most_common(1)[0][0] if cats else None

        return RetrievalResult(
            cases=cases,
            mean_ef=mean_ef,
            ef_std=ef_std,
            consensus_category=consensus,
        )

    # --- Zero-shot classification -----------------------------------------

    def zero_shot_classify(
        self,
        image: np.ndarray | Image.Image,
        labels: list[str] | None = None,
    ) -> dict[str, float]:
        """Zero-shot classification using image-text similarity.

        Returns dict mapping label -> probability (sums to 1).
        """
        if labels is None:
            labels = [
                "echocardiogram showing normal cardiac function",
                "echocardiogram showing mild left ventricular dysfunction",
                "echocardiogram showing moderate left ventricular dysfunction",
                "echocardiogram showing severe left ventricular dysfunction",
            ]

        img_emb = self.encode_image(image)
        text_embs = np.array([self.encode_text(lbl) for lbl in labels])

        # Cosine similarities (already L2-normalized)
        sims = text_embs @ img_emb
        # Softmax
        exp_sims = np.exp(sims - sims.max())
        probs = exp_sims / exp_sims.sum()

        # Use short labels for display
        short_labels = ["Normal", "Mild Dysfunction", "Moderate Dysfunction", "Severe Dysfunction"]
        if len(short_labels) == len(labels):
            return dict(zip(short_labels, probs.tolist()))
        return dict(zip(labels, probs.tolist()))

    # --- Cleanup ----------------------------------------------------------

    def unload(self) -> None:
        """Free GPU memory by deleting the model."""
        del self.model
        del self.processor
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("MedSigLIP unloaded from GPU.")

## `report_generator`

In [None]:
"""MedGemma report generation with KerasHub (primary) and transformers (fallback)."""

from __future__ import annotations

import os
from dataclasses import dataclass

import numpy as np
from PIL import Image

# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------

@dataclass
class ClinicalReport:
    """Generated clinical report for an echocardiogram analysis."""
    summary: str
    ef_assessment: str
    retrieval_context: str
    full_report: str
    confidence_note: str

# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = (
    "You are an expert echocardiography AI assistant helping clinicians "
    "interpret echocardiogram images. You provide structured, evidence-based "
    "assessments. Always include a disclaimer that findings require "
    "verification by a qualified cardiologist."
)

def _build_retrieval_context(result: RetrievalResult) -> str:
    """Format retrieval results into a context string for the prompt."""
    if not result.cases:
        return "No similar cases were found in the reference database."

    lines = [
        f"Retrieved {len(result.cases)} similar cases from EchoNet-Dynamic:"
    ]
    for i, c in enumerate(result.cases, 1):
        ef_str = f"EF={c.ef:.1f}%" if c.ef is not None else "EF=N/A"
        lines.append(
            f"  {i}. {c.filename} ({c.frame_type}) - {ef_str} "
            f"({c.ef_category}), similarity={c.similarity_score:.3f}"
        )

    if result.mean_ef is not None:
        lines.append(
            f"Mean EF of similar cases: {result.mean_ef:.1f}% "
            f"(SD: {result.ef_std:.1f}%)"
        )
    if result.consensus_category:
        lines.append(f"Consensus category: {result.consensus_category}")

    return "\n".join(lines)

def _build_zeroshot_context(scores: dict[str, float]) -> str:
    """Format zero-shot scores into a context string."""
    if not scores:
        return ""
    lines = ["MedSigLIP zero-shot classification:"]
    for label, prob in sorted(scores.items(), key=lambda x: -x[1]):
        lines.append(f"  - {label}: {prob:.1%}")
    return "\n".join(lines)

def _build_single_frame_prompt(
    frame_type: str,
    retrieval_ctx: str,
    zeroshot_ctx: str,
) -> str:
    """Build a prompt for single-frame analysis."""
    return (
        f"{SYSTEM_PROMPT}\n\n"
        f"Analyze this {frame_type} (End-Diastole = ED, End-Systole = ES) "
        f"echocardiogram frame.\n\n"
        f"### Context from Similar Cases\n{retrieval_ctx}\n\n"
        f"### Zero-Shot Classification\n{zeroshot_ctx}\n\n"
        f"### Requested Output\n"
        f"Provide a structured report with the following sections:\n"
        f"1. **Visual Assessment**: Describe left ventricle size, wall motion, "
        f"and any visible abnormalities.\n"
        f"2. **EF Estimate**: Based on the visual features and similar-case "
        f"context, estimate the ejection fraction range.\n"
        f"3. **Clinical Impression**: Summarize the key findings and their "
        f"clinical significance.\n"
        f"4. **Limitations**: Note any caveats about this automated analysis."
    )

def _build_comparison_prompt(
    ed_retrieval_ctx: str,
    es_retrieval_ctx: str,
) -> str:
    """Build a prompt for ED vs ES comparison analysis."""
    return (
        f"{SYSTEM_PROMPT}\n\n"
        f"You are given two echocardiogram frames from the same patient:\n"
        f"- **Image 1**: End-Diastole (ED) frame — the heart is maximally dilated.\n"
        f"- **Image 2**: End-Systole (ES) frame — the heart is maximally contracted.\n\n"
        f"### ED Similar Cases\n{ed_retrieval_ctx}\n\n"
        f"### ES Similar Cases\n{es_retrieval_ctx}\n\n"
        f"### Requested Output\n"
        f"Provide a structured report with the following sections:\n"
        f"1. **Visual Assessment**: Compare LV size between ED and ES. "
        f"Describe wall motion and contractility.\n"
        f"2. **EF Estimate**: Based on the visual change between ED and ES "
        f"plus similar-case context, estimate the ejection fraction range.\n"
        f"3. **Clinical Impression**: Key findings and clinical significance.\n"
        f"4. **Limitations**: Caveats about this automated analysis."
    )

# ---------------------------------------------------------------------------
# Report Generator
# ---------------------------------------------------------------------------

class ReportGenerator:
    """MedGemma-based clinical report generator."""

    def __init__(self, config: GechoConfig) -> None:
        self.config = config
        self._backend = self._load_model()

    def _load_model(self) -> str:
        """Load MedGemma via KerasHub (preferred) or transformers fallback."""
        # Try KerasHub first
        try:
            return self._load_kerashub()
        except Exception as e:
            print(f"[INFO] KerasHub loading failed ({e}), trying transformers...")
            return self._load_transformers()

    def _load_kerashub(self) -> str:
        """Load via keras_hub."""
        os.environ.setdefault("KERAS_BACKEND", "jax")
        import keras_hub  # noqa: E402

        print(f"Loading MedGemma via KerasHub: {self.config.medgemma_kerashub_preset}")
        self.keras_model = keras_hub.models.Gemma3CausalLM.from_preset(
            self.config.medgemma_kerashub_preset,
            dtype=self.config.dtype,
        )
        print("MedGemma loaded via KerasHub (JAX).")
        return "kerashub"

    def _load_transformers(self) -> str:
        """Load via HuggingFace transformers."""
        import torch
        from transformers import AutoModelForImageTextToText, AutoProcessor

        print(f"Loading MedGemma via transformers: {self.config.medgemma_hf_model_id}")

        dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
        torch_dtype = dtype_map.get(self.config.dtype, torch.bfloat16)

        self.hf_processor = AutoProcessor.from_pretrained(
            self.config.medgemma_hf_model_id
        )
        self.hf_model = AutoModelForImageTextToText.from_pretrained(
            self.config.medgemma_hf_model_id,
            torch_dtype=torch_dtype,
            device_map="auto",
        )
        self.hf_model.eval()
        print("MedGemma loaded via transformers (PyTorch).")
        return "transformers"

    # --- Generation backends ----------------------------------------------

    def _generate_kerashub(
        self,
        prompt: str,
        images: list[np.ndarray],
    ) -> str:
        """Generate text using KerasHub backend."""
        # KerasHub Gemma3 expects images passed alongside the prompt
        # The prompt should contain <start_of_image> tokens for each image
        image_tokens = "\n".join(["<start_of_image>"] * len(images))
        full_prompt = f"{image_tokens}\n{prompt}"

        response = self.keras_model.generate(
            {
                "prompts": full_prompt,
                "images": [
                    img.astype("float32") / 255.0 if img.dtype == np.uint8 else img
                    for img in images
                ],
            },
            max_length=self.config.sequence_length,
        )
        # KerasHub returns the full sequence; strip the prompt portion
        if isinstance(response, str):
            return response
        return str(response)

    def _generate_transformers(
        self,
        prompt: str,
        images: list[np.ndarray],
    ) -> str:
        """Generate text using transformers backend."""
        import torch

        pil_images = [
            Image.fromarray(img) if isinstance(img, np.ndarray) else img
            for img in images
        ]

        # Build chat messages with images
        content: list[dict] = []
        for img in pil_images:
            content.append({"type": "image", "image": img})
        content.append({"type": "text", "text": prompt})

        messages = [{"role": "user", "content": content}]

        inputs = self.hf_processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.hf_model.device)

        with torch.no_grad():
            output_ids = self.hf_model.generate(
                **inputs,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=False,
            )

        # Decode only the new tokens
        input_len = inputs["input_ids"].shape[1]
        generated = output_ids[0][input_len:]
        return self.hf_processor.decode(generated, skip_special_tokens=True)

    def _generate(self, prompt: str, images: list[np.ndarray]) -> str:
        """Route to the active backend."""
        if self._backend == "kerashub":
            return self._generate_kerashub(prompt, images)
        return self._generate_transformers(prompt, images)

    # --- Public API -------------------------------------------------------

    def generate_single_frame_report(
        self,
        frame: np.ndarray,
        frame_type: str,
        retrieval_result: RetrievalResult,
        zeroshot_scores: dict[str, float] | None = None,
    ) -> ClinicalReport:
        """Generate a clinical report for a single echo frame."""
        retrieval_ctx = _build_retrieval_context(retrieval_result)
        zeroshot_ctx = _build_zeroshot_context(zeroshot_scores or {})
        prompt = _build_single_frame_prompt(frame_type, retrieval_ctx, zeroshot_ctx)

        raw = self._generate(prompt, [frame])
        return self._parse_report(raw, retrieval_ctx)

    def generate_comparison_report(
        self,
        ed_frame: np.ndarray,
        es_frame: np.ndarray,
        ed_retrieval: RetrievalResult,
        es_retrieval: RetrievalResult,
    ) -> ClinicalReport:
        """Generate a clinical report comparing ED and ES frames."""
        ed_ctx = _build_retrieval_context(ed_retrieval)
        es_ctx = _build_retrieval_context(es_retrieval)
        prompt = _build_comparison_prompt(ed_ctx, es_ctx)

        raw = self._generate(prompt, [ed_frame, es_frame])
        combined_ctx = f"--- ED ---\n{ed_ctx}\n\n--- ES ---\n{es_ctx}"
        return self._parse_report(raw, combined_ctx)

    # --- Parsing ----------------------------------------------------------

    @staticmethod
    def _parse_report(raw_text: str, retrieval_ctx: str) -> ClinicalReport:
        """Parse raw model output into a structured ClinicalReport."""
        # Try to extract sections; fallback to raw text
        sections = {
            "Visual Assessment": "",
            "EF Estimate": "",
            "Clinical Impression": "",
            "Limitations": "",
        }

        current_section = None
        lines: list[str] = []

        for line in raw_text.split("\n"):
            matched = False
            for key in sections:
                if key.lower() in line.lower():
                    if current_section and lines:
                        sections[current_section] = "\n".join(lines).strip()
                    current_section = key
                    lines = []
                    matched = True
                    break
            if not matched:
                lines.append(line)

        if current_section and lines:
            sections[current_section] = "\n".join(lines).strip()

        # Build summary from first non-empty section
        summary = (
            sections["Clinical Impression"]
            or sections["Visual Assessment"]
            or raw_text[:500]
        )

        return ClinicalReport(
            summary=summary,
            ef_assessment=sections["EF Estimate"] or "See full report.",
            retrieval_context=retrieval_ctx,
            full_report=raw_text,
            confidence_note=(
                sections["Limitations"]
                or "This is an AI-generated analysis and must be reviewed "
                "by a qualified cardiologist before clinical use."
            ),
        )

## `ui`

In [None]:
"""Gradio dashboard for Gecho echocardiogram analysis."""

from __future__ import annotations

from pathlib import Path

import gradio as gr
import numpy as np
from PIL import Image

class GechoDashboard:
    """Three-column Gradio dashboard for echocardiogram analysis."""

    def __init__(
        self,
        config: GechoConfig,
        embedding_engine: EmbeddingEngine,
        report_generator: ReportGenerator,
    ) -> None:
        self.config = config
        self.engine = embedding_engine
        self.generator = report_generator

    # --- Analysis pipeline ------------------------------------------------

    def analyze_video(
        self,
        video_file,
        mode: str,
        progress: gr.Progress = gr.Progress(),
    ) -> tuple:
        """Main pipeline: extract -> embed -> retrieve -> classify -> generate.

        Returns tuple of outputs matching the Gradio component order.
        """
        if not video_file:
            raise gr.Error("Please upload a video file first.")

        # gr.File returns a filepath string
        video_path = video_file if isinstance(video_file, str) else video_file.name

        # Step 1: Extract frames
        progress(0.1, desc="Extracting frames...")
        frames: EchoFrames = extract_frames_from_upload(video_path, self.config)
        ed_img = Image.fromarray(frames.ed_frame)
        es_img = Image.fromarray(frames.es_frame)

        # Step 2: Zero-shot classification on the primary frame
        progress(0.3, desc="Running zero-shot classification...")
        primary_frame = frames.ed_frame if "ED" in mode else frames.es_frame
        zeroshot_scores = self.engine.zero_shot_classify(primary_frame)

        # Step 3: FAISS retrieval
        progress(0.5, desc="Retrieving similar cases...")
        if mode == "Comparison (ED vs ES)":
            ed_result = self.engine.query(frames.ed_frame, frame_type="ED")
            es_result = self.engine.query(frames.es_frame, frame_type="ES")
        elif "ES" in mode:
            es_result = self.engine.query(frames.es_frame, frame_type="ES")
            ed_result = None
        else:
            ed_result = self.engine.query(frames.ed_frame, frame_type="ED")
            es_result = None

        active_result = ed_result or es_result

        # Step 4: Build retrieval table
        retrieval_table = self._format_retrieval_table(active_result)

        # Step 5: Generate report
        progress(0.7, desc="Generating clinical report...")
        if mode == "Comparison (ED vs ES)" and ed_result and es_result:
            report = self.generator.generate_comparison_report(
                frames.ed_frame, frames.es_frame, ed_result, es_result
            )
        else:
            frame = frames.ed_frame if ed_result else frames.es_frame
            ftype = "ED" if ed_result else "ES"
            report = self.generator.generate_single_frame_report(
                frame, ftype, active_result, zeroshot_scores
            )

        progress(1.0, desc="Done!")

        # Format outputs
        report_md = self._format_report_markdown(report)
        zeroshot_label = {k: float(v) for k, v in zeroshot_scores.items()}

        return (
            ed_img,                 # ED frame display
            es_img,                 # ES frame display
            zeroshot_label,         # gr.Label (zero-shot)
            retrieval_table,        # gr.Dataframe
            report_md,              # gr.Markdown (report)
        )

    # --- Formatting helpers -----------------------------------------------

    @staticmethod
    def _format_retrieval_table(
        result: RetrievalResult | None,
    ) -> list[list[str]]:
        """Format retrieval results as a table for gr.Dataframe."""
        if result is None or not result.cases:
            return [["No results", "", "", ""]]
        rows = []
        for c in result.cases:
            rows.append([
                c.filename,
                f"{c.ef:.1f}%" if c.ef is not None else "N/A",
                c.ef_category or "N/A",
                f"{c.similarity_score:.3f}",
            ])
        return rows

    @staticmethod
    def _format_report_markdown(report: ClinicalReport) -> str:
        """Format a ClinicalReport as Markdown for display."""
        return (
            f"## Clinical Report\n\n"
            f"{report.full_report}\n\n"
            f"---\n\n"
            f"### EF Assessment\n{report.ef_assessment}\n\n"
            f"### Retrieval Context\n"
            f"```\n{report.retrieval_context}\n```\n\n"
            f"### Confidence Note\n"
            f"> {report.confidence_note}"
        )

    # --- Gradio app -------------------------------------------------------

    def build(self) -> gr.Blocks:
        """Build and return the Gradio Blocks app."""
        with gr.Blocks(
            title="Gecho - Automated Echocardiogram Analysis",
            theme=gr.themes.Soft(),
        ) as app:
            gr.Markdown(
                "# Gecho: Gemma Echo\n"
                "Automated echocardiogram interpretation powered by "
                "MedSigLIP retrieval and MedGemma report generation."
            )

            with gr.Row():
                # --- Left column: Input ---
                with gr.Column(scale=1):
                    gr.Markdown("### Input")
                    video_input = gr.File(
                        label="Upload Echo Video (.avi / .mp4)",
                        file_types=[".avi", ".mp4", ".mov", ".mkv"],
                    )
                    mode_selector = gr.Radio(
                        choices=[
                            "Single Frame (ED)",
                            "Single Frame (ES)",
                            "Comparison (ED vs ES)",
                        ],
                        value="Comparison (ED vs ES)",
                        label="Analysis Mode",
                    )
                    analyze_btn = gr.Button(
                        "Analyze", variant="primary", size="lg"
                    )

                    gr.Markdown("### Extracted Frames")
                    ed_display = gr.Image(label="End-Diastole (ED)", type="pil")
                    es_display = gr.Image(label="End-Systole (ES)", type="pil")

                # --- Middle column: Retrieval ---
                with gr.Column(scale=1):
                    gr.Markdown("### MedSigLIP Classification")
                    zeroshot_label = gr.Label(
                        label="Zero-Shot Cardiac Function",
                        num_top_classes=4,
                    )

                    gr.Markdown("### Similar Cases (FAISS Retrieval)")
                    retrieval_table = gr.Dataframe(
                        headers=["Filename", "EF", "Category", "Similarity"],
                        label="Top Retrieved Cases",
                        interactive=False,
                    )

                # --- Right column: Report ---
                with gr.Column(scale=1):
                    gr.Markdown("### Generated Report")
                    report_output = gr.Markdown(
                        value="*Upload a video and click Analyze to begin.*"
                    )

                    gr.Markdown("### Human-in-the-Loop")
                    with gr.Row():
                        approve_btn = gr.Button("Approve Report", variant="secondary")
                        edit_btn = gr.Button("Edit Report", variant="secondary")
                    status_text = gr.Textbox(
                        label="Status",
                        value="Awaiting analysis...",
                        interactive=False,
                    )

            # --- Event handlers ---
            analyze_btn.click(
                fn=self.analyze_video,
                inputs=[video_input, mode_selector],
                outputs=[
                    ed_display,
                    es_display,
                    zeroshot_label,
                    retrieval_table,
                    report_output,
                ],
            )

            approve_btn.click(
                fn=lambda: "Report approved by clinician.",
                outputs=[status_text],
            )
            edit_btn.click(
                fn=lambda: "Report flagged for clinician editing.",
                outputs=[status_text],
            )

        return app

def launch_dashboard(
    config: GechoConfig,
    embedding_engine: EmbeddingEngine,
    report_generator: ReportGenerator,
    share: bool = False,
) -> None:
    """Convenience function to build and launch the dashboard."""
    dashboard = GechoDashboard(config, embedding_engine, report_generator)
    app = dashboard.build()
    app.launch(share=share)

## Main Pipeline

In [None]:
# ---- Main pipeline: build index, load models, launch UI ----

from pathlib import Path

# Initialize configuration (paths auto-adjust for Kaggle)
config = GechoConfig()

# Step 1: Load or build FAISS index
engine = EmbeddingEngine(config)

if config.faiss_index_path.exists() and config.faiss_metadata_path.exists():
    print("Found existing FAISS index, loading...")
    engine.load_index()
else:
    print("No existing index found. Processing EchoNet-Dynamic training set...")
    echo_frames = process_dataset(config, split="TRAIN")
    print("Building embedding index...")
    engine.build_index(echo_frames)
    engine.save_index()
    del echo_frames

# Step 2: Load report generator (MedGemma)
print("Loading MedGemma report generator...")
generator = ReportGenerator(config)

# Step 3: Launch Gradio dashboard
print("Launching Gecho dashboard...")
launch_dashboard(config, engine, generator, share=True)


## Technical Notes & Citations

### Architecture
- **Retrieval**: MedSigLIP-448 encodes echo frames into 768-dim embeddings. FAISS IndexFlatIP performs cosine-similarity search against ~15K training vectors.
- **Generation**: MedGemma 1.5-4B-IT receives the query frame + RAG context (similar cases, zero-shot scores) and produces a structured clinical report.
- **VRAM Strategy**: MedSigLIP (float16, ~1.6GB) builds the index first, then MedGemma (bfloat16, ~8GB) is loaded for generation. Total < 16GB P100.

### Citations
- Ouyang et al. *Video-based AI for beat-to-beat assessment of cardiac function.* Nature, 2020. (EchoNet-Dynamic)
- Yang et al. *Advancing Multimodal Medical Capabilities of Gemini.* arXiv:2405.03162, 2024. (MedGemma)
- Radford et al. *Learning Transferable Visual Models From Natural Language Supervision.* ICML, 2021. (SigLIP heritage)
