# 2. Predicitions

In this step, we take our 524 kb one-hot inputs from FlashZoi preparation and run them through the pretrained Borzoi model to predict CAGE signal in **skeletal muscle**. We:

1. Load our FlashZoi metadata (`flashzoi_meta.tsv`), which tells us where each `.npy` lives and its measured GTEx TPM.
2. Pull down Borzoi’s config/checkpoint so that we can ask it to return **all** output tracks (not just the central bin).
3.Find the CAGE skeletal-muscle track index list based on the FANTOM5 target list: https://github.com/johahi/borzoi-pytorch/blob/main/borzoi_pytorch/precomputed/targets.txt OR https://github.com/calico/borzoi/blob/main/examples/targets_gtex.txt
4. For each gene:
   - Load its one-hot array.
   - Run a forward pass.
   - Slice out that one skeletal-muscle track (length 16 352 bins).
   - Save both the full per-bin prediction and its mean.
5. Write out a combined `borzoi_preds_meta.tsv` so we can join predictions back to expression.

In [9]:
#!/usr/bin/env python3
import argparse, os, time
from pathlib import Path
from typing import Sequence

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from borzoi_pytorch import Borzoi
from flashzoi_helpers import load_flashzoi_models

BASE2ROW = {b: i for i, b in enumerate("ACGT")}


# ─── Flashzoi Forward ──────────────────────────────────────────────────────
def fwd(models: Sequence[Borzoi], x: torch.Tensor) -> torch.Tensor:
    y_sum = None
    with torch.autocast(x.device.type, dtype=torch.float16, enabled=x.device.type == "cuda"):
        for m in models:
            y = m(x)
            y = y["human"] if isinstance(y, dict) else y
            c = y.shape[-1] // 2
            sig = y[:, 1, c - 16:c + 16].mean(1)
            y_sum = sig if y_sum is None else y_sum + sig
    return y_sum / len(models)


# ─── Per-Gene Scoring ──────────────────────────────────────────────────────
def score_gene(gene_dir: Path, onehot_dir: Path, pred_dir: Path,
               out_dir: Path, models: Sequence[Borzoi],
               batch: int, device: torch.device):
    gene = gene_dir.name
    start_time = time.time()

    in_tsv = gene_dir / f"{gene}_variants.tsv"
    out_gene_dir = out_dir / gene
    out_gene_dir.mkdir(parents=True, exist_ok=True)
    tsv = out_gene_dir / f"{gene}_variants.tsv"

    if not in_tsv.exists():
        print(f"✘ {gene}: missing input TSV"); return

    df = pd.read_csv(in_tsv, sep="\t")
    df.rename(columns={c: c.upper() for c in df.columns}, inplace=True)

    if "REF" not in df.columns or "ALT" not in df.columns or "POS0" not in df.columns:
        print(f"✘ {gene}: missing required columns"); return

    snp_mask = (df["REF"].str.len() == 1) & (df["ALT"].str.len() == 1)
    df = df[snp_mask].reset_index(drop=True)

    if not df.index.is_unique:
        df.reset_index(drop=True, inplace=True)

    if df.empty:
        print(f"✘ {gene}: no SNP records"); return

    df["DELTA"] = pd.to_numeric(df.get("DELTA", pd.Series(np.nan, index=df.index)), errors="coerce")
    df["VAR_I"] = pd.to_numeric(df.get("VAR_I", pd.Series(np.nan, index=df.index)), errors="coerce")

    pending = df["DELTA"].isna().to_numpy()
    todo = np.flatnonzero(pending).tolist()
    todo = pd.Index(todo).drop_duplicates().tolist()

    if not todo:
        print(f"✔ {gene}: already done"); return

    try:
        ref_np = np.load(onehot_dir / f"{gene}.npy").astype(np.float16)
        ref_t = torch.from_numpy(ref_np).to(device)
        ref_avg = float(np.load(pred_dir / gene / "mean_ref.npy"))
    except Exception as e:
        print(f"✘ {gene}: missing onehot or mean_ref — {e}")
        return

    print(f"▶ {gene}: {len(todo)} variants to process")

    for s in tqdm(range(0, len(todo), batch), leave=False, desc=f"{gene} chunks"):
        idx = todo[s:s + batch]
        sub = df.loc[idx]

        N = len(sub)
        pos0 = torch.as_tensor(sub["POS0"].values, dtype=torch.long, device=device)
        alt = torch.as_tensor([BASE2ROW.get(a.upper(), -1) for a in sub["ALT"]],
                              dtype=torch.long, device=device)
        ref = torch.as_tensor([BASE2ROW[r] for r in sub["REF"]],
                              dtype=torch.long, device=device)

        xb = ref_t.expand(N, -1, -1).clone()
        xb[torch.arange(N, device=device), ref, pos0] = 0.
        xb[torch.arange(N, device=device), alt, pos0] = 1.

        bad = alt == -1
        valid_mask = ~bad.cpu().numpy()
        valid_idx = [idx[i] for i in range(len(idx)) if valid_mask[i]]
        invalid_idx = [idx[i] for i in range(len(idx)) if not valid_mask[i]]

        if valid_idx:
            xb_valid = xb[valid_mask]
            with torch.no_grad():
                delta_valid = (fwd(models, xb_valid).cpu().numpy() - ref_avg).astype(np.float32)
            df.loc[valid_idx, "DELTA"] = delta_valid

            if "AF" in sub.columns:
                af = pd.to_numeric(sub["AF"], errors="coerce").to_numpy(dtype=float)
            else:
                af = np.full(len(sub), np.nan, dtype=float)

            var_i_valid = (delta_valid ** 2) * (2 * af[valid_mask] * (1 - af[valid_mask]))
            df.loc[valid_idx, "VAR_I"] = var_i_valid

        df.loc[invalid_idx, "DELTA"] = np.nan
        df.loc[invalid_idx, "VAR_I"] = np.nan

        tmp = out_gene_dir / f"{gene}_variants.working"
        df.to_csv(tmp, sep="\t", index=False)
        os.replace(tmp, tsv)
        del xb
        torch.cuda.empty_cache()

    elapsed = time.time() - start_time
    vps = len(todo) / elapsed if elapsed > 0 else 0
    print(f"✓ {gene}: done in {elapsed:.2f}s | {len(todo)} variants | {vps:.1f} var/sec")

    log_path = Path("logs/variant_benchmark.tsv")
    log_path.parent.mkdir(exist_ok=True)
    with open(log_path, "a") as f:
        f.write(f"{gene}\t{len(todo)}\t{elapsed:.2f}\t{vps:.2f}\n")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--variants-dir", required=True, type=Path)
    ap.add_argument("--onehot-dir", required=True, type=Path)
    ap.add_argument("--pred-dir", required=True, type=Path)
    ap.add_argument("--out-dir", required=True, type=Path)
    ap.add_argument("--folds", default=4, type=int)
    ap.add_argument("--batch", default=64, type=int)
    ap.add_argument("--device", default="cuda")
    ap.add_argument("--gene", help="Optional single gene to run")

    args = ap.parse_args()
    device = torch.device(args.device if args.device != "cuda" or torch.cuda.is_available()
                          else "cpu")
    models = load_flashzoi_models(args.folds, device)

    dirs = [args.variants_dir / args.gene] if args.gene else \
           sorted([d for d in args.variants_dir.iterdir() if d.is_dir()])

    for d in dirs:
        score_gene(d, args.onehot_dir, args.pred_dir,
                   args.out_dir, models, args.batch, device)


if __name__ == "__main__":
    main()

# 2. Predict Variant ΔS and Compute Genetic Variance

For each common gnomAD SNP in our ±262 kb windows, we:

1. Load its reference one-hot and the precomputed **mean_ref** prediction.
2. Mutate the one-hot at the SNP position (zero out the ref base, set the alt base).
3. Batch up a handful of these ALT one-hots and run them through Borzoi.
4. Compute Δ = (mean signal on ALT) – **mean_ref**, then
   \ [
      $\mathrm{var}_i \;=\; \Delta^2 \times 2\,\mathrm{AF}(1-\mathrm{AF})$
   \]
   which is the per-variant contribution to expected genetic variance.
5. Write `delta` and `var_i` back into each gene’s `*_variants.tsv`.

2.1 Iterate genes & predict deltas

## 3. Inspect data

## 3.1 Reference vs Expression

In [None]:
PRED_META = "../data/output/dataset1/borzoi_preds/borzoi_preds_meta.tsv"
VAR_DIR   = "../data/intermediate/dataset1/flashzoi_inputs/variants"

In [None]:
preds = pd.read_csv(PRED_META, sep="\t")
plt.figure(figsize=(5,5))
plt.scatter(preds["expr_value"], preds["mean_ref"], alpha=0.7)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("GTEx muscle TPM")
plt.ylabel("Borzoi mean CAGE signal")
pearson_r, pearson_p = pearsonr(preds["expr_value"], preds["mean_ref"])
spearman_r, spearman_p = spearmanr(preds["expr_value"], preds["mean_ref"])
plt.title(f"rₚ={pearson_r:.2f}, rₛ={spearman_r:.2f}")
plt.tight_layout()
plt.show()

## 3.2 Δ distribution across all variants

In [None]:
# gather every 'delta' column
all_deltas = []
for tsv in glob.glob(os.path.join(VAR_DIR, "*", "*_variants.tsv")):
    df = pd.read_csv(tsv, sep="\t")
    all_deltas.append(df["delta"].dropna().astype(float))
all_deltas = pd.concat(all_deltas, ignore_index=True)

plt.figure(figsize=(5,4))
plt.hist(all_deltas, bins=100)
plt.xlabel("Δ (Borzoi ALT − REF)")
plt.ylabel("Count")
plt.title("Distribution of predicted Δ across all variants")
plt.tight_layout()
plt.show()

## 3.3 per‐gene expected genetic variance

In [None]:
gene_var = {}
for tsv in glob.glob(os.path.join(VAR_DIR, "*", "*_variants.tsv")):
    gene = os.path.basename(os.path.dirname(tsv))
    df = pd.read_csv(tsv, sep="\t")
    gene_var[gene] = df["var_i"].sum()

var_df = pd.DataFrame.from_dict(gene_var, orient="index", columns=["exp_genetic_var"])
var_df.index.name = "gene"
var_df = var_df.reset_index()

# merge with expression
merged = preds.merge(var_df, on="gene")
merged.head()

## 3.4 scatter expected genetic variance vs. expression

In [None]:
plt.figure(figsize=(5,5))
plt.scatter(merged["expr_value"], merged["exp_genetic_var"], alpha=0.7)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("GTEx muscle TPM")
plt.ylabel("Expected genetic variance ∑var_i")
plt.title("Genetic variance vs expression")
plt.tight_layout()
plt.show()

## 3.5 Inspecting one variant file

In [None]:
example = glob.glob(os.path.join(VAR_DIR, "*", "*_variants.tsv"))[0]
print("Example file:", example)
pd.read_csv(example, sep="\t").head()