# TSV-driven single-mutant Kermut pipeline

This notebook ingests a TSV file describing variant measurements, filters for single substitutions, prepares the Kermut-compatible data layout, and evaluates Kermut with three-fold cross-validation. Run it inside the Kermut repository after activating the project environment so that Hydra configs and preprocessing utilities are available.

## Notebook outline

1. Configure inputs, objectives, and feature toggles.
2. Parse the TSV, construct mutated sequences, and aggregate multi-objective targets.
3. Save assay tables plus a reference manifest in the layout expected by Kermut.
4. Inspect skipped records and verify that required embeddings/structure features exist.
5. Compose a Hydra config and run 3-fold train/test splits to obtain performance metrics.
6. Export predictions, per-fold metrics, and aggregated summaries for downstream analysis.

In [None]:
import json
import re
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
from IPython.display import display
from omegaconf import DictConfig, OmegaConf

from kermut.data import prepare_GP_inputs, prepare_GP_kwargs, split_inputs, standardize
from kermut.gp import instantiate_gp, optimize_gp, predict

### Configure input paths and runtime options

Update the paths in the next cell to point to your TSV file and any directories where intermediate outputs should be written. The workspace folder will store generated assay CSVs, Hydra reference manifests, cross-validation predictions, and metric summaries.

In [None]:
# --- Update these values before running the pipeline ---

tsv_path = Path("path/to/your_assay.tsv")  # <-- replace with the TSV you want to process

workspace_root = Path("data/notebook_runs/example_run")
assay_output_dir = workspace_root / "assays"
reference_manifest_path = workspace_root / "reference.csv"
cv_predictions_dir = workspace_root / "cv_predictions"

feature_paths: Dict[str, Optional[Path]] = {
    "embeddings": Path("data/embeddings/substitutions_singles/ESM2"),
    "zero_shot": Path("data/zero_shot_fitness_predictions"),
    "conditional_probs": Path("data/conditional_probs/ProteinMPNN"),
    "coords": Path("data/structures/coords"),
}

kernel_config_name = "kermut"  # pick from kermut/hydra_configs/kernel/*.yaml
DATA_STANDARDIZE = True
kernel_overrides: Dict[str, object] = {
    "use_zero_shot": False,
    # Example: "structure_kernel.use_distance_comparison": False,
}

optimization_overrides: Dict[str, object] = {
    "n_steps": 150,
    "lr": 0.1,
    "progress_bar": False,
}

FOLD_COLUMN = "cv_fold_3"
CV_FOLDS = 3
RANDOM_SEED = 2024
USE_GPU = torch.cuda.is_available()

for path in [workspace_root, assay_output_dir, cv_predictions_dir]:
    path.mkdir(parents=True, exist_ok=True)


### Configure optimisation objectives

Describe each target column that should contribute to the aggregate `DMS_score`. Provide the TSV column name, a weight (the code will normalise weights to sum to 1), and whether the objective should be maximised or minimised.

In [None]:
# Example configuration:
# objectives = [
#     {"column": "fitness", "weight": 0.7, "goal": "maximize"},
#     {"column": "stability", "weight": 0.3, "goal": "minimize"},
# ]

objectives: List[Dict[str, object]] = []
OBJECTIVE_NORMALIZATION = "zscore"  # one of {"zscore", "minmax", "none"}
DROP_ROWS_WITH_MISSING_OBJECTIVES = True


### Helper functions

The utilities below handle mutation parsing, multi-objective aggregation, fold assignment, Hydra configuration, and metric summarisation.

In [None]:
SINGLE_MUTATION_RE = re.compile(r"^(?P<wt>[A-Z])(?P<pos>\d+)(?P<mut>[A-Z])$")
NORMALIZATION_METHODS = {"zscore", "minmax", "none"}

def validate_objective_config(obj_list: List[Dict[str, object]]) -> pd.DataFrame:
    if not obj_list:
        raise ValueError(
            "No objectives configured. Update the `objectives` list in the configuration cell."
        )
    df_obj = pd.DataFrame(obj_list)
    required_keys = {"column", "weight", "goal"}
    missing_keys = required_keys - set(df_obj.columns)
    if missing_keys:
        raise ValueError(f"Objective definitions are missing keys: {sorted(missing_keys)}")
    df_obj["column"] = df_obj["column"].astype(str)
    df_obj["weight"] = df_obj["weight"].astype(float)
    if (df_obj["weight"] < 0).any():
        raise ValueError("Objective weights must be non-negative.")
    weight_sum = df_obj["weight"].sum()
    if weight_sum == 0:
        raise ValueError("Objective weights must sum to a positive value.")
    df_obj["weight"] = df_obj["weight"] / weight_sum
    df_obj["goal"] = df_obj["goal"].str.lower()
    allowed_goals = {"maximize", "minimize"}
    invalid_goals = set(df_obj["goal"]) - allowed_goals
    if invalid_goals:
        raise ValueError(
            f"Objective goals must be 'maximize' or 'minimize', got {sorted(invalid_goals)}."
        )
    return df_obj


def ensure_columns(df: pd.DataFrame, required: Iterable[str]) -> None:
    missing = [col for col in required if col not in df.columns]
    if missing:
        raise KeyError(f"Missing required columns in TSV: {missing}")


def parse_mutation_tokens(value) -> List[str]:
    if pd.isna(value):
        return []
    tokens = re.split(r"[;,\s]+", str(value).strip())
    return [token.strip().upper() for token in tokens if token.strip()]


def apply_single_mutation(sequence: str, mutation: str) -> Tuple[str, str, str, int]:
    match = SINGLE_MUTATION_RE.match(mutation)
    if match is None:
        raise ValueError(f"Mutation '{mutation}' is not a single amino-acid substitution.")
    wt = match.group("wt")
    mut = match.group("mut")
    pos = int(match.group("pos"))
    if not isinstance(sequence, str) or not sequence:
        raise ValueError("Missing reference sequence.")
    if pos < 1 or pos > len(sequence):
        raise ValueError(f"Position {pos} out of bounds for sequence of length {len(sequence)}.")
    if sequence[pos - 1] != wt:
        raise ValueError(
            f"Expected residue '{wt}' at position {pos}, but sequence contains '{sequence[pos - 1]}'."
        )
    mutated_sequence = sequence[: pos - 1] + mut + sequence[pos:]
    return mutated_sequence, wt, mut, pos


def extract_single_mutants(
    df_raw: pd.DataFrame, objectives_df: pd.DataFrame
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    records: List[Dict[str, object]] = []
    skipped: List[Dict[str, object]] = []
    objective_cols = objectives_df["column"].tolist()
    for idx, row in df_raw.iterrows():
        reference_key = row["reference_key"]
        mutation_tokens = parse_mutation_tokens(row["mutations"])
        single_tokens = [token for token in mutation_tokens if SINGLE_MUTATION_RE.match(token)]
        if len(single_tokens) != 1:
            reason = "no_single_mutation" if len(single_tokens) == 0 else "multi_mutation"
            skipped.append(
                {
                    "row_index": idx,
                    "reference_key": reference_key,
                    "reason": reason,
                    "raw_mutations": row.get("mutations", ""),
                }
            )
            continue
        mutation = single_tokens[0]
        try:
            mutated_sequence, wt, mut, pos = apply_single_mutation(row["sequence"], mutation)
        except ValueError as exc:
            skipped.append(
                {
                    "row_index": idx,
                    "reference_key": reference_key,
                    "reason": str(exc),
                    "raw_mutations": row.get("mutations", ""),
                }
            )
            continue
        record: Dict[str, object] = {
            "reference_key": reference_key,
            "mutant": mutation,
            "mutation": mutation,
            "wt_aa": wt,
            "mut_aa": mut,
            "position": pos,
            "sequence": row["sequence"],
            "mutated_sequence": mutated_sequence,
            "backbone_id": row.get("backbone_id"),
            "pdb": row.get("pdb"),
        }
        for col in objective_cols:
            record[col] = row.get(col)
        records.append(record)
    df_single = pd.DataFrame(records)
    skip_df = pd.DataFrame(skipped)
    return df_single, skip_df


def drop_rows_with_missing_objectives(
    df: pd.DataFrame, objective_cols: List[str]
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if df.empty:
        return df.copy(), pd.DataFrame(columns=df.columns)
    missing_mask = df[objective_cols].isnull().any(axis=1)
    dropped = df.loc[missing_mask].copy()
    kept = df.loc[~missing_mask].copy()
    return kept.reset_index(drop=True), dropped.reset_index(drop=True)


def normalize_objectives(
    df: pd.DataFrame, objectives_df: pd.DataFrame, method: str
) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]]]:
    method = method.lower()
    if method not in NORMALIZATION_METHODS:
        raise ValueError(
            f"Unknown normalization method '{method}'. Valid options: {sorted(NORMALIZATION_METHODS)}."
        )
    df_norm = df.copy()
    stats: Dict[str, Dict[str, float]] = {}
    for _, obj in objectives_df.iterrows():
        col = obj["column"]
        values = df_norm[col].astype(float)
        if method == "zscore":
            mean = values.mean()
            std = values.std(ddof=0)
            if std == 0:
                normalized = values - mean
            else:
                normalized = (values - mean) / std
            stats[col] = {"mean": float(mean), "std": float(std)}
        elif method == "minmax":
            min_val = values.min()
            max_val = values.max()
            if max_val == min_val:
                normalized = values - min_val
            else:
                normalized = (values - min_val) / (max_val - min_val)
            stats[col] = {"min": float(min_val), "max": float(max_val)}
        else:
            normalized = values
            stats[col] = {}
        df_norm[f"{col}__normalized"] = normalized
    return df_norm, stats


def compute_weighted_scores(df: pd.DataFrame, objectives_df: pd.DataFrame) -> pd.DataFrame:
    df_scores = df.copy()
    contribution_cols: List[str] = []
    for _, obj in objectives_df.iterrows():
        col = obj["column"]
        norm_col = f"{col}__normalized"
        signed_col = f"{col}__signed"
        weight_col = f"{col}__weighted"
        sign = 1.0 if obj["goal"] == "maximize" else -1.0
        df_scores[signed_col] = df_scores[norm_col] * sign
        df_scores[weight_col] = df_scores[signed_col] * obj["weight"]
        contribution_cols.append(weight_col)
    df_scores["DMS_score"] = df_scores[contribution_cols].sum(axis=1)
    return df_scores


def assign_cv_folds(
    df: pd.DataFrame, fold_column: str, n_splits: int, seed: int
) -> pd.DataFrame:
    if n_splits < 2:
        raise ValueError("n_splits must be >= 2.")
    if len(df) < n_splits:
        raise ValueError(
            f"Cannot create {n_splits} folds because the assay only contains {len(df)} variants."
        )
    rng = np.random.default_rng(seed)
    indices = np.arange(len(df))
    rng.shuffle(indices)
    fold_indices = np.array_split(indices, n_splits)
    assignments = np.empty(len(df), dtype=int)
    for fold_idx, fold_idx_values in enumerate(fold_indices):
        assignments[fold_idx_values] = fold_idx
    df_with_folds = df.copy()
    df_with_folds[fold_column] = assignments
    return df_with_folds


def build_reference_table(df: pd.DataFrame) -> pd.DataFrame:
    reference_cols = ["reference_key", "sequence", "backbone_id", "pdb"]
    ensure_columns(df, reference_cols)
    reference_df = (
        df[reference_cols]
        .drop_duplicates(subset=["reference_key"])
        .rename(
            columns={
                "reference_key": "DMS_id",
                "sequence": "target_seq",
                "pdb": "pdb_path",
            }
        )
        .reset_index(drop=True)
    )
    return reference_df


def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    if y_true.size == 0:
        return {"spearman": np.nan, "pearson": np.nan, "rmse": np.nan, "mae": np.nan}
    true_series = pd.Series(y_true)
    pred_series = pd.Series(y_pred)
    metrics = {
        "spearman": float(true_series.corr(pred_series, method="spearman")),
        "pearson": float(true_series.corr(pred_series, method="pearson")),
        "rmse": float(np.sqrt(np.mean((y_true - y_pred) ** 2))),
        "mae": float(np.mean(np.abs(y_true - y_pred))),
    }
    return metrics


def summarize_skip_table(skip_df: pd.DataFrame) -> pd.DataFrame:
    if skip_df.empty:
        return pd.DataFrame(columns=["reason", "count"])
    return (
        skip_df.groupby("reason")
        .size()
        .reset_index(name="count")
        .sort_values("count", ascending=False)
        .reset_index(drop=True)
    )


def apply_overrides(cfg_section: DictConfig, overrides: Dict[str, object]) -> None:
    for key, value in overrides.items():
        OmegaConf.update(cfg_section, key, value, force_add=True)


def inspect_feature_availability(cfg: DictConfig, dataset_ids: Iterable[str]) -> pd.DataFrame:
    records: List[Dict[str, object]] = []
    for dms_id in dataset_ids:
        record: Dict[str, object] = {"DMS_id": dms_id}
        if cfg.kernel.use_sequence_kernel:
            embedding_path = Path(cfg.data.paths.embeddings_singles) / f"{dms_id}.h5"
            record["embeddings_exists"] = embedding_path.exists()
        else:
            record["embeddings_exists"] = True
        if cfg.kernel.use_zero_shot:
            zero_path = (
                Path(cfg.data.paths.zero_shot)
                / cfg.kernel.zero_shot_method
                / f"{dms_id}.csv"
            )
            record["zero_shot_exists"] = zero_path.exists()
        else:
            record["zero_shot_exists"] = True
        if cfg.kernel.use_structure_kernel:
            cond_path = Path(cfg.data.paths.conditional_probs) / f"{dms_id}.npy"
            record["conditional_probs_exists"] = cond_path.exists()
            use_distance = getattr(cfg.kernel.structure_kernel, "use_distance_comparison", False)
            if use_distance:
                coords_path = Path(cfg.data.paths.coords) / f"{dms_id}.npy"
                record["coords_exists"] = coords_path.exists()
            else:
                record["coords_exists"] = True
        records.append(record)
    return pd.DataFrame(records)


def clone_config(cfg: DictConfig) -> DictConfig:
    return OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))


def build_base_cfg(
    kernel_config_name: str,
    kernel_overrides: Dict[str, object],
    optimization_overrides: Dict[str, object],
    feature_paths: Dict[str, Optional[Path]],
    reference_csv: Path,
    assay_dir: Path,
    output_dir: Path,
    fold_column: str,
    use_gpu: bool,
    seed: int,
    standardize: bool,
) -> DictConfig:
    cfg = OmegaConf.load("kermut/hydra_configs/benchmark.yaml")
    kernel_cfg_path = Path("kermut/hydra_configs/kernel") / f"{kernel_config_name}.yaml"
    if not kernel_cfg_path.exists():
        raise FileNotFoundError(f"Kernel config not found: {kernel_cfg_path}")
    cfg.kernel = OmegaConf.load(str(kernel_cfg_path))
    apply_overrides(cfg.kernel, kernel_overrides)
    cfg.dataset = "single"
    cfg.cv_scheme = fold_column
    cfg.single.use_id = True
    cfg.single.id = None
    cfg.use_gpu = bool(use_gpu)
    cfg.overwrite = True
    cfg.seed = int(seed)
    cfg.progress_bar = False
    cfg.data.paths.reference_file = str(reference_csv)
    cfg.data.paths.DMS_input_folder = str(assay_dir)
    cfg.data.paths.output_folder = str(output_dir)
    if feature_paths.get("embeddings") is not None:
        cfg.data.paths.embeddings_singles = str(feature_paths["embeddings"])
    if feature_paths.get("zero_shot") is not None:
        cfg.data.paths.zero_shot = str(feature_paths["zero_shot"])
    if feature_paths.get("conditional_probs") is not None:
        cfg.data.paths.conditional_probs = str(feature_paths["conditional_probs"])
    if feature_paths.get("coords") is not None:
        cfg.data.paths.coords = str(feature_paths["coords"])
    cfg.data.standardize = bool(standardize)
    cfg.data.test_index = -1
    apply_overrides(cfg.optim, optimization_overrides)
    return cfg


def run_kermut_cv(cfg: DictConfig, dms_id: str, target_seq: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    df, y, x_toks, x_embed, x_zero_shot = prepare_GP_inputs(cfg, dms_id)
    gp_inputs = prepare_GP_kwargs(cfg, dms_id, target_seq)
    df_out = df[["mutant"]].copy()
    df_out = df_out.assign(fold=np.nan, y=np.nan, y_pred=np.nan, y_var=np.nan)
    metrics_records: List[Dict[str, object]] = []
    unique_folds = sorted(df[cfg.cv_scheme].unique())
    for test_fold in unique_folds:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)
        test_mask = df[cfg.cv_scheme] == test_fold
        train_mask = ~test_mask
        train_idx = train_mask.tolist()
        test_idx = test_mask.tolist()
        y_train, y_test = split_inputs(train_idx, test_idx, y)
        if cfg.data.standardize:
            y_train, y_test = standardize(y_train, y_test)
        x_toks_train, x_toks_test = split_inputs(train_idx, test_idx, x_toks)
        x_embed_train, x_embed_test = split_inputs(train_idx, test_idx, x_embed)
        x_zero_train, x_zero_test = split_inputs(train_idx, test_idx, x_zero_shot)
        train_inputs = (x_toks_train, x_embed_train, x_zero_train)
        test_inputs = (x_toks_test, x_embed_test, x_zero_test)
        gp, likelihood = instantiate_gp(
            cfg=cfg,
            train_inputs=train_inputs,
            train_targets=y_train,
            gp_inputs=gp_inputs,
        )
        gp, likelihood = optimize_gp(
            gp=gp,
            likelihood=likelihood,
            train_inputs=train_inputs,
            train_targets=y_train,
            lr=cfg.optim.lr,
            n_steps=cfg.optim.n_steps,
            progress_bar=cfg.optim.progress_bar,
        )
        df_out = predict(
            gp=gp,
            likelihood=likelihood,
            test_inputs=test_inputs,
            test_targets=y_test,
            test_fold=int(test_fold),
            test_idx=test_idx,
            df_out=df_out,
        )
        test_count = int(test_mask.sum())
        y_test_np = df_out.loc[test_idx, "y"].to_numpy(dtype=float)
        y_pred_test_np = df_out.loc[test_idx, "y_pred"].to_numpy(dtype=float)
        metrics_test = compute_regression_metrics(y_test_np, y_pred_test_np)
        metrics_test.update({"fold": int(test_fold), "partition": "test", "n": test_count})
        metrics_records.append(metrics_test)
        train_inputs_filtered = tuple(x for x in train_inputs if x is not None)
        if train_inputs_filtered:
            with torch.no_grad():
                train_dist = likelihood(gp(*train_inputs_filtered))
                train_mean = train_dist.mean.detach().cpu().numpy()
            y_train_np = y_train.detach().cpu().numpy()
            metrics_train = compute_regression_metrics(y_train_np, train_mean)
        else:
            metrics_train = {"spearman": np.nan, "pearson": np.nan, "rmse": np.nan, "mae": np.nan}
        train_count = int(train_mask.sum())
        metrics_train.update({"fold": int(test_fold), "partition": "train", "n": train_count})
        metrics_records.append(metrics_train)
    df_out.rename(columns={"fold": "test_fold"}, inplace=True)
    df_out["mutated_sequence"] = df["mutated_sequence"].values
    df_out[cfg.cv_scheme] = df[cfg.cv_scheme].values
    if "DMS_score" in df.columns:
        df_out["DMS_score"] = df["DMS_score"].values
    df_out["DMS_id"] = dms_id
    metrics_df = pd.DataFrame(metrics_records)
    return df_out, metrics_df


def summarize_cv_metrics(metrics_df: pd.DataFrame) -> pd.DataFrame:
    metrics_cols = ["spearman", "pearson", "rmse", "mae"]
    summary = (
        metrics_df.groupby(["DMS_id", "partition"])[metrics_cols]
        .agg(["mean", "std"])
        .sort_index()
    )
    return summary


def aggregate_prediction_metrics(cv_results: Dict[str, Dict[str, pd.DataFrame]]) -> pd.DataFrame:
    records: List[Dict[str, object]] = []
    for dms_id, payload in cv_results.items():
        df_pred = payload["predictions"].dropna(subset=["y_pred"])
        metrics = compute_regression_metrics(
            df_pred["y"].to_numpy(dtype=float), df_pred["y_pred"].to_numpy(dtype=float)
        )
        metrics["DMS_id"] = dms_id
        metrics["n"] = len(df_pred)
        records.append(metrics)
    return pd.DataFrame(records)


def fit_full_model(cfg: DictConfig, dms_id: str, target_seq: str) -> pd.DataFrame:
    df, y, x_toks, x_embed, x_zero_shot = prepare_GP_inputs(cfg, dms_id)
    gp_inputs = prepare_GP_kwargs(cfg, dms_id, target_seq)
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    if cfg.data.standardize:
        y_train, _ = standardize(y, y)
    else:
        y_train = y
    train_inputs = (x_toks, x_embed, x_zero_shot)
    gp, likelihood = instantiate_gp(
        cfg=cfg,
        train_inputs=train_inputs,
        train_targets=y_train,
        gp_inputs=gp_inputs,
    )
    gp, likelihood = optimize_gp(
        gp=gp,
        likelihood=likelihood,
        train_inputs=train_inputs,
        train_targets=y_train,
        lr=cfg.optim.lr,
        n_steps=cfg.optim.n_steps,
        progress_bar=cfg.optim.progress_bar,
    )
    inputs_filtered = tuple(x for x in train_inputs if x is not None)
    with torch.no_grad():
        posterior = likelihood(gp(*inputs_filtered))
        posterior_mean = posterior.mean.detach().cpu().numpy()
        posterior_var = posterior.covariance_matrix.diag().detach().cpu().numpy()
    df_out = df.copy()
    df_out["posterior_mean"] = posterior_mean
    df_out["posterior_variance"] = posterior_var
    df_out["DMS_id"] = dms_id
    return df_out


### Prepare Kermut-ready assay files

The next cell reads the TSV, filters single substitutions, aggregates the configured objectives, assigns three cross-validation folds, and saves one CSV per assay alongside a reference manifest.

In [None]:
if not tsv_path.exists():
    raise FileNotFoundError(
        f"Input TSV not found at {tsv_path}. Update `tsv_path` in the configuration cell."
    )

df_raw = pd.read_csv(tsv_path, sep="\t")
print(f"Loaded {len(df_raw):,} rows from {tsv_path}")

objectives_df = validate_objective_config(objectives)
required_columns = {"reference_key", "sequence", "mutations", "backbone_id", "pdb"}
required_columns |= set(objectives_df["column"])
ensure_columns(df_raw, required_columns)

df_single, skipped_rows = extract_single_mutants(df_raw, objectives_df)
print(f"Identified {len(df_single):,} candidate single mutants (skipped {len(skipped_rows):,} rows).")

if df_single.empty:
    raise ValueError("No single mutants available after filtering. Check the `mutations` column.")

if DROP_ROWS_WITH_MISSING_OBJECTIVES:
    df_single, dropped_missing = drop_rows_with_missing_objectives(df_single, objectives_df["column"].tolist())
    if not dropped_missing.empty:
        print(f"Dropped {len(dropped_missing):,} rows with missing objective values.")
else:
    dropped_missing = pd.DataFrame(columns=df_single.columns)

group_cols = ["reference_key", "mutant", "mutation", "wt_aa", "mut_aa", "position"]
agg_dict = {col: "mean" for col in objectives_df["column"]}
agg_dict.update(
    {
        "sequence": "first",
        "mutated_sequence": "first",
        "backbone_id": "first",
        "pdb": "first",
    }
)
df_single = (
    df_single.groupby(group_cols, dropna=False, as_index=False)
    .agg(agg_dict)
    .reset_index(drop=True)
)

df_normalized, objective_stats = normalize_objectives(df_single, objectives_df, OBJECTIVE_NORMALIZATION)
df_scored = compute_weighted_scores(df_normalized, objectives_df)

assay_tables: Dict[str, pd.DataFrame] = {}
fold_summaries: List[Dict[str, object]] = []
for dms_id, group_df in df_scored.groupby("reference_key"):
    df_with_folds = assign_cv_folds(group_df.reset_index(drop=True), FOLD_COLUMN, CV_FOLDS, RANDOM_SEED)
    assay_tables[dms_id] = df_with_folds
    fold_counts = df_with_folds[FOLD_COLUMN].value_counts().sort_index()
    fold_summary = {"reference_key": dms_id, "n_variants": int(len(df_with_folds))}
    for fold_id, count in fold_counts.items():
        fold_summary[f"fold_{fold_id}"] = int(count)
    fold_summaries.append(fold_summary)

processed_df = pd.concat(assay_tables.values(), ignore_index=True)
reference_df = build_reference_table(df_scored)

objective_config_path = workspace_root / "objectives.json"
objective_config_path.write_text(json.dumps(objectives_df.to_dict(orient="records"), indent=2))
normalization_stats_path = workspace_root / "objective_normalization.json"
normalization_stats_path.write_text(json.dumps(objective_stats, indent=2))

for dms_id, table in assay_tables.items():
    save_cols = ["mutant", "mutated_sequence", "DMS_score", FOLD_COLUMN] + objectives_df["column"].tolist()
    output_path = assay_output_dir / f"{dms_id}.csv"
    table[save_cols].to_csv(output_path, index=False)

reference_df.to_csv(reference_manifest_path, index=False)

fold_summary_df = pd.DataFrame(fold_summaries).sort_values("reference_key").reset_index(drop=True)
skip_summary = summarize_skip_table(skipped_rows)

print(f"Prepared {len(assay_tables)} assays. Files written to {assay_output_dir}.")
print(f"Reference manifest: {reference_manifest_path}")

processed_df.head()


In [None]:
print("Fold counts per assay:")
display(fold_summary_df)

print("\nSkipped or rejected rows:")
display(skip_summary)

if 'dropped_missing' in locals() and not dropped_missing.empty:
    print("\nRows dropped because of missing objective values (showing first 5 rows):")
    display(dropped_missing.head())


### Feature preparation checklist

The default `kermut` kernel expects:

- **ESM2 embeddings** stored under `data/embeddings/substitutions_singles/ESM2/<DMS_id>.h5`.
- **ProteinMPNN conditional probabilities** saved as `.npy` files inside `data/conditional_probs/ProteinMPNN/`.
- **3D coordinates** (`.npy`) matching the same `<DMS_id>` in `data/structures/coords/` when distance comparisons are enabled.
- Optional **zero-shot scores** (`data/zero_shot_fitness_predictions/<model>/<DMS_id>.csv`) if `use_zero_shot` is true.

Generate these artefacts with the CLI utilities in `kermut/cmdline/preprocess_data`, for example:

```bash
python -m kermut.cmdline.preprocess_data.extract_esm2_embeddings data.paths.reference_file=<your_reference.csv> data.paths.DMS_input_folder=<assay_dir>
python -m kermut.cmdline.preprocess_data.extract_ProteinMPNN_probs data.paths.reference_file=<your_reference.csv> data.paths.DMS_input_folder=<assay_dir>
python -m kermut.cmdline.preprocess_data.extract_3d_coords data.paths.reference_file=<your_reference.csv>
```

Adjust the overrides so that the scripts pick up the manifest and assay directory generated above. Skip commands for features that you disable via `kernel_config_name` or `kernel_overrides`.

### Compose Hydra config and verify feature availability

Build a benchmark configuration programmatically so we can pass it to the reusable training utilities without spawning a new Hydra process. The feature availability table helps catch missing embeddings or structure files before training.

In [None]:
if reference_df.empty:
    raise ValueError("Reference manifest is empty. Verify that the TSV produced at least one assay.")

base_cfg = build_base_cfg(
    kernel_config_name=kernel_config_name,
    kernel_overrides=kernel_overrides,
    optimization_overrides=optimization_overrides,
    feature_paths=feature_paths,
    reference_csv=reference_manifest_path,
    assay_dir=assay_output_dir,
    output_dir=workspace_root / "outputs",
    fold_column=FOLD_COLUMN,
    use_gpu=USE_GPU,
    seed=RANDOM_SEED,
    standardize=DATA_STANDARDIZE,
)

print(f"Kernel configuration: {base_cfg.kernel.name}")
feature_status = inspect_feature_availability(base_cfg, reference_df["DMS_id"].tolist())
display(feature_status)


### Run three-fold cross-validation

For each assay the loop below trains Kermut on two folds, evaluates on the held-out fold, and records the metrics. Per-fold predictions and metrics are written to `cv_predictions/` inside the workspace.

In [None]:
cv_results: Dict[str, Dict[str, pd.DataFrame]] = {}
metrics_tables: List[pd.DataFrame] = []

for row in reference_df.itertuples(index=False):
    dms_id = row.DMS_id
    target_seq = row.target_seq
    cfg = clone_config(base_cfg)
    cfg.single.id = dms_id
    print(f"Running 3-fold CV for {dms_id} (length {len(target_seq)})")
    predictions_df, metrics_df = run_kermut_cv(cfg, dms_id, target_seq)
    predictions_path = cv_predictions_dir / f"{dms_id}_cv_predictions.csv"
    metrics_path = cv_predictions_dir / f"{dms_id}_cv_metrics.csv"
    predictions_df.to_csv(predictions_path, index=False)
    metrics_df.to_csv(metrics_path, index=False)
    payload = {"predictions": predictions_df, "metrics": metrics_df}
    cv_results[dms_id] = payload
    metrics_with_id = metrics_df.copy()
    metrics_with_id["DMS_id"] = dms_id
    metrics_tables.append(metrics_with_id)

all_metrics_df = pd.concat(metrics_tables, ignore_index=True)
metrics_summary = summarize_cv_metrics(all_metrics_df)
metrics_summary


### Summarise train/test performance across folds

Flatten the per-fold metrics, save them for downstream reporting, and compute aggregate statistics for the test partitions as well as overall predictions.

In [None]:
metrics_summary_flat = metrics_summary.copy()
metrics_summary_flat.columns = [
    f"{metric}_{stat}" for metric, stat in metrics_summary_flat.columns.to_flat_index()
]
metrics_summary_flat = metrics_summary_flat.reset_index()
display(metrics_summary_flat)

all_metrics_df.to_csv(workspace_root / "cv_metrics_folds.csv", index=False)
metrics_summary_flat.to_csv(workspace_root / "cv_metrics_summary.csv", index=False)

overall_prediction_metrics = aggregate_prediction_metrics(cv_results)
display(overall_prediction_metrics)
overall_prediction_metrics.to_csv(
    workspace_root / "cv_metrics_overall_predictions.csv", index=False
)


### Optional: fit final models on the full dataset

After inspecting cross-validation performance you can refit Kermut on the entire assay to obtain posterior means/variances for every single mutant. Uncomment and run the next cell if you want those predictions.

In [None]:
# final_predictions_dir = workspace_root / "final_model_predictions"
# final_predictions_dir.mkdir(parents=True, exist_ok=True)
#
# for row in reference_df.itertuples(index=False):
#     dms_id = row.DMS_id
#     cfg = clone_config(base_cfg)
#     cfg.single.id = dms_id
#     print(f"Fitting final GP on all data for {dms_id}")
#     full_df = fit_full_model(cfg, dms_id, row.target_seq)
#     full_df.to_csv(final_predictions_dir / f"{dms_id}_posterior.csv", index=False)
# print("Saved posterior predictions for each assay.")
