## Define the inputs and hyperparameters

In [1]:
# run_ewc_adapt_then_eval.py
import os, sys, json, joblib
from pathlib import Path
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

# ---------------------------------------------------------------------
# Paths / config
# ---------------------------------------------------------------------
PROJECT_DIR = "/gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER"
DEV_DIR     = os.path.join(PROJECT_DIR, "dev/transformer")
sys.path.append(DEV_DIR)

from transformer import MultiomicTransformer
from transformer_dataset import MultiomicTransformerDataset
from transformer_training import prepare_dataloader
import ewc_utils  # your separate module

TRAINED_MODEL_SAMPLE_NAME = "mESC"
EVAL_SAMPLE_NAME          = "mESC_holdout"
CHROM_ID                  = "chr1"

OUTPUT_DIR                = os.path.join(PROJECT_DIR, "output/transformer_testing_output")
TRAINED_MODEL_DIR         = os.path.join(OUTPUT_DIR, "model_0.77_corr")
COMMON_DATA_DIR           = os.path.join(DEV_DIR, "transformer_data", "common")

TRAINED_MODEL_DATASET_DIR = os.path.join(DEV_DIR, f"transformer_data/{TRAINED_MODEL_SAMPLE_NAME}")
EVAL_DATASET_DIR          = os.path.join(DEV_DIR, f"transformer_data/{EVAL_SAMPLE_NAME}")

CAL_SPLIT_FRAC = 0.5
CAL_ALPHAS     = [0.1, 0.3, 1.0, 3.0, 10.0]
BATCH_SIZE_FALLBACK = 64
SEED = 42

## Helper Functions

In [2]:
def set_seed(seed=SEED):
    import random
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed()

def inverse_transform(X, mean, scale):
    X = np.asarray(X, dtype=np.float32)
    if scale is not None: X = X * scale
    if mean  is not None: X = X + mean
    return X

def run_model(model, loader, device, zscore_tf=True):
    """Return (preds, true) where preds are in TRAIN z-space; true in EVAL z-space."""
    preds_all, true_all = [], []
    model.eval()
    with torch.no_grad():
        for atac_wins, tf_tensor, tg_true, bias, tf_ids, tg_ids in loader:
            atac_wins = atac_wins.to(device)
            tf_tensor = tf_tensor.to(device)
            tg_true   = tg_true.to(device)
            bias      = bias.to(device)
            tf_ids    = tf_ids.to(device)
            tg_ids    = tg_ids.to(device)

            if zscore_tf:
                mu = tf_tensor.mean(dim=1, keepdim=True)
                sd = tf_tensor.std(dim=1, keepdim=True).clamp_min(1e-6)
                tf_tensor = (tf_tensor - mu) / sd

            preds = model(atac_wins, tf_tensor, tf_ids=tf_ids, tg_ids=tg_ids, bias=bias)
            preds_all.append(preds.cpu().numpy())
            true_all.append(tg_true.cpu().numpy())
    return np.vstack(preds_all), np.vstack(true_all)

def build_overlap_and_spaces(preds, true, dataset, train_dataset_dir, chrom_id):
    """Align gene order, inverse-transform to raw spaces, and map truth to TRAIN z-space."""
    train_scaler = joblib.load(Path(train_dataset_dir) / f"{chrom_id}/tg_scaler_{chrom_id}.pkl")
    eval_scaler  = dataset.scaler

    with open(Path(train_dataset_dir) / f"{chrom_id}/tg_names_{chrom_id}.json") as f:
        train_tg_names = json.load(f)
    
    # Find the index for each gene name in the training and target datasets
    train_name_to_idx = {g:i for i,g in enumerate(train_tg_names)}
    eval_name_to_idx  = {g:i for i,g in enumerate(dataset.tg_names)}

    # Only use target genes that are present in both datasets
    overlap_genes = [g for g in dataset.tg_names if g in train_name_to_idx]
    mask_eval = np.array([g in train_name_to_idx for g in dataset.tg_names], dtype=bool)
    train_idx = np.array([train_name_to_idx[g] for g in overlap_genes])
    eval_idx  = np.array([eval_name_to_idx[g]  for g in overlap_genes])

    preds_ov = preds[:, mask_eval]  # TRAIN z-space
    true_ov  = true[:,  mask_eval]  # EVAL  z-space

    # inverse transform the predictions and true values using their own scaler for each gene
    preds_raw = inverse_transform(preds_ov, train_scaler.mean_[train_idx], train_scaler.scale_[train_idx])
    true_raw  = inverse_transform(true_ov,  eval_scaler.mean_[eval_idx],   eval_scaler.scale_[eval_idx])

    # Re-standardize the target raw truth using the training scaler
    true_in_train_z = (true_raw - train_scaler.mean_[train_idx]) / train_scaler.scale_[train_idx]

    return {
        "overlap_genes": overlap_genes,
        "preds_ov": preds_ov,
        "true_in_train_z": true_in_train_z,
        "preds_raw": preds_raw,
        "true_raw": true_raw,
        "train_scaler": train_scaler,
        "train_idx": train_idx
    }

def split_for_calibration(X, Y, frac=CAL_SPLIT_FRAC, seed=SEED):
    n = X.shape[0]
    rng = np.random.RandomState(seed)
    idx = rng.permutation(n)
    k  = int(np.floor(frac * n))
    return (X[idx[:k]], Y[idx[:k]]), (X[idx[k:]], Y[idx[k:]])

def fit_ridge_calibrator(X_cal, Y_cal, alphas=CAL_ALPHAS):

    best_alpha, best_r = None, -np.inf
    kf = KFold(n_splits=5, shuffle=True, random_state=SEED)
    for a in alphas:
        rs = []
        for tr, va in kf.split(X_cal):
            rr = Ridge(alpha=a, fit_intercept=True)
            rr.fit(X_cal[tr], Y_cal[tr])
            yhat = rr.predict(X_cal[va])
            rs.append(np.corrcoef(Y_cal[va].ravel(), yhat.ravel())[0,1])
        r_mean = float(np.mean(rs))
        if r_mean > best_r:
            best_r, best_alpha = r_mean, a
    ridge = Ridge(alpha=best_alpha, fit_intercept=True)
    ridge.fit(X_cal, Y_cal)
    return ridge, best_alpha, best_r

def metrics_block(y, yhat):
    r_p = pearsonr(y.ravel(), yhat.ravel())[0]
    r_s = spearmanr(y.ravel(), yhat.ravel()).correlation
    mae = np.mean(np.abs(y - yhat))
    return dict(pearson=float(r_p), spearman=float(r_s), mae=float(mae))

def scatter_plot(y, yhat, title, out_png, max_points=5000, seed=SEED):
    n = min(max_points, y.shape[0])
    idx = np.random.RandomState(seed).choice(y.shape[0], n, replace=False)
    plt.figure(figsize=(6.5,6.5))
    plt.scatter(y[idx].ravel(), yhat[idx].ravel(), alpha=0.25, s=12)
    lims = [min(y.min(), yhat.min()), max(y.max(), yhat.max())]
    plt.plot(lims, lims, 'r--', linewidth=1)
    rp = pearsonr(y.ravel(), yhat.ravel())[0]
    plt.title(f"{title}\nPearson r = {rp:.2f}")
    plt.xlabel("Actual"); plt.ylabel("Predicted")
    plt.tight_layout(); plt.savefig(out_png, dpi=150); plt.close()


## Data Loading Functions

In [3]:
def get_vocab_sizes(common_dir, chrom_id):
    with open(os.path.join(common_dir, "tf_vocab.json")) as f: tf_vocab = json.load(f)
    with open(os.path.join(common_dir, "tg_vocab.json")) as f: tg_vocab = json.load(f)  # single file for all CHRs
    return len(tf_vocab), len(tg_vocab)

def get_loaders(data_dir, chrom_id, common_dir, batch):
    ds = MultiomicTransformerDataset(
        data_dir=data_dir,
        chrom_id=chrom_id,
        tf_vocab_path=os.path.join(common_dir, "tf_vocab.json"),
        tg_vocab_path=os.path.join(common_dir, "tg_vocab.json"),
    )
    return ds, prepare_dataloader(ds, batch_size=batch, world_size=1, rank=0)

def build_model(run_params, tf_vocab_size, tg_vocab_size, ckpt_path, device):
    d_model   = run_params["d_model"]
    num_heads = run_params["Attention Heads"]
    num_layers= run_params["Model Layers"]
    d_ff      = run_params["d_feedforward"]
    dropout   = run_params["Dropout"]

    state_dict = torch.load(ckpt_path, map_location=device)
    use_shortcut = ("shortcut_scale" in state_dict)

    model = MultiomicTransformer(
        d_model=d_model, num_heads=num_heads, num_layers=num_layers,
        d_ff=d_ff, dropout=dropout,
        tf_vocab_size=tf_vocab_size, tg_vocab_size=tg_vocab_size,
        use_shortcut=use_shortcut
    ).to(device)
    model.load_state_dict(state_dict, strict=True)
    return model

## Elastic Weight Consolidation

In [4]:
def run_ewc_adaptation(model, device, source_loader, target_train_loader,
                       include_strong, include_weak, lambda_strong=1000.0, lambda_weak=50.0,
                       epochs=5, lr=1e-4, weight_decay=1e-4):
    # Fisher on source (protect knowledge)
    fisher_diag = ewc_utils.compute_fisher_diag(model, source_loader, device, n_batches=100, loss_fn="mse")
    ref_params  = ewc_utils.clone_params(model)

    opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=weight_decay)

    model.train()
    for epoch in range(1, epochs+1):
        for atac_wins, tf_tensor, tg_true, bias, tf_ids, tg_ids in target_train_loader:
            atac_wins, tf_tensor, tg_true, bias = [x.to(device) for x in (atac_wins, tf_tensor, tg_true, bias)]
            tf_ids, tg_ids = tf_ids.to(device), tg_ids.to(device)

            # per-cell TF z-score
            mu = tf_tensor.mean(dim=1, keepdim=True)
            sd = tf_tensor.std(dim=1, keepdim=True).clamp_min(1e-6)
            tf_norm = (tf_tensor - mu) / sd

            preds = model(atac_wins, tf_norm, tf_ids=tf_ids, tg_ids=tg_ids, bias=bias)
            data_loss = F.mse_loss(preds, tg_true)

            pen_head   = ewc_utils.ewc_penalty(model, fisher_diag, ref_params, lambda_weak,   include=include_weak)
            pen_strong = ewc_utils.ewc_penalty(model, fisher_diag, ref_params, lambda_strong, include=include_strong)

            loss = data_loss + pen_head + pen_strong
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
    return model, fisher_diag, ref_params

## Main Pipeline

### Load training model, dataset, and run parameters

In [16]:
# Run params + vocab sizes
with open(os.path.join(TRAINED_MODEL_DIR, "run_parameters.json")) as f:
    run_params = json.load(f)
batch = run_params.get("Batch Size", BATCH_SIZE_FALLBACK)

tf_vocab_size, tg_vocab_size = get_vocab_sizes(COMMON_DATA_DIR, CHROM_ID)
source_ds, (src_train_loader, _, src_test_loader) = get_loaders(TRAINED_MODEL_DATASET_DIR, CHROM_ID, COMMON_DATA_DIR, batch)

# Build the trained model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ckpt_path = os.path.join(TRAINED_MODEL_DIR, "checkpoint.pt")
model = build_model(run_params, tf_vocab_size, tg_vocab_size, ckpt_path, device)

### Load the target dataset

In [6]:
# Load datasets/loaders
target_ds, (tgt_train_loader, _, tgt_test_loader) = get_loaders(EVAL_DATASET_DIR, CHROM_ID, COMMON_DATA_DIR, batch)


### Baseline evaluation (no scaler alignment)

In [7]:
preds, true = run_model(model, tgt_test_loader, device, zscore_tf=True)
spaces = build_overlap_and_spaces(preds, true, target_ds, TRAINED_MODEL_DATASET_DIR, CHROM_ID)

overlap_genes   = spaces["overlap_genes"]
preds_ov        = spaces["preds_ov"]
true_in_train_z = spaces["true_in_train_z"]
preds_raw       = spaces["preds_raw"]
true_raw        = spaces["true_raw"]
train_scaler    = spaces["train_scaler"]
train_idx       = spaces["train_idx"]

base_train = metrics_block(true_in_train_z, preds_ov)
base_raw   = metrics_block(true_raw,       preds_raw)

### Align scaler and fit ridge calibrator

In [8]:
(X_cal, Y_cal), (X_test, Y_test) = split_for_calibration(preds_ov, true_in_train_z, frac=CAL_SPLIT_FRAC, seed=SEED)
ridge, best_alpha, cv_r = fit_ridge_calibrator(X_cal, Y_cal, alphas=CAL_ALPHAS)

preds_test_cal = ridge.predict(X_test)  # TRAIN z-space
cal_train = metrics_block(Y_test, preds_test_cal)

preds_test_cal_raw = inverse_transform(preds_test_cal, train_scaler.mean_[train_idx], train_scaler.scale_[train_idx])
Y_test_raw         = inverse_transform(Y_test,         train_scaler.mean_[train_idx], train_scaler.scale_[train_idx])
cal_raw = metrics_block(Y_test_raw, preds_test_cal_raw)

out_dir = os.path.join(OUTPUT_DIR, f"infer_{EVAL_SAMPLE_NAME}_{CHROM_ID}")
os.makedirs(out_dir, exist_ok=True)
joblib.dump(ridge, os.path.join(out_dir, f"ridge_calibrator_alpha{best_alpha}.pkl"))

scatter_plot(Y_test_raw, preds_test_cal_raw,
                f"{EVAL_SAMPLE_NAME} {CHROM_ID}: Predicted vs Actual (calibrated, test split, pre-EWC)",
                os.path.join(out_dir, "scatter_calibrated_test_pre_ewc.png"))

# Per-gene corr (pre-EWC, calibrated)
gene_corr = []
for j, g in enumerate(overlap_genes):
    y, yhat = Y_test_raw[:, j], preds_test_cal_raw[:, j]
    r = pearsonr(y, yhat)[0] if np.std(y) > 1e-8 else 0.0
    gene_corr.append((g, r))
pd.DataFrame(gene_corr, columns=["gene","pearson"]).sort_values("pearson", ascending=False)\
    .to_csv(os.path.join(out_dir, "per_gene_pearson_test_pre_ewc.csv"), index=False)

### Run Elastic Weight Consolidation

In [9]:
# ===== EWC adaptation =====
# include_strong is a set of model layers to keep stable
# include_weak is a set of model layers that can change more easily to new data
include_strong = {"encoder", "tf_emb_table", "tg_emb_table", "tg_decoder_table"}
include_weak   = {"out_dense", "shortcut_scale"}

ewc_model, fisher_diag_src, _ = run_ewc_adaptation(
    model, device,
    source_loader=src_train_loader,
    target_train_loader=tgt_train_loader,
    include_strong=include_strong,
    include_weak=include_weak,
    lambda_strong=1000.0,
    lambda_weak=50.0,
    epochs=5, lr=1e-4, weight_decay=1e-4
)

# Save adapted model & EWC bundle for target (optional future chaining)
torch.save(ewc_model.state_dict(), os.path.join(out_dir, f"model_ewc_{EVAL_SAMPLE_NAME}_{CHROM_ID}.ckpt"))
fisher_ds011 = ewc_utils.compute_fisher_diag(ewc_model, tgt_train_loader, device, n_batches=100, loss_fn="mse")
ewc_utils.save_ewc_bundle(os.path.join(out_dir, f"ewc_{EVAL_SAMPLE_NAME}_{CHROM_ID}.pt"), ewc_model, fisher_ds011)

### Evaluate EWC-Trained Network

In [10]:
# ===== Post-EWC evaluation =====
preds_post, true_post = run_model(ewc_model, tgt_test_loader, device, zscore_tf=True)
spaces_post = build_overlap_and_spaces(preds_post, true_post, target_ds, TRAINED_MODEL_DATASET_DIR, CHROM_ID)

preds_ov_post        = spaces_post["preds_ov"]
true_in_train_z_post = spaces_post["true_in_train_z"]
preds_raw_post       = spaces_post["preds_raw"]
true_raw_post        = spaces_post["true_raw"]

### Re-fit ridge calibrator (post-EWC)


In [11]:
(X_cal2, Y_cal2), (X_test2, Y_test2) = split_for_calibration(preds_ov_post, true_in_train_z_post, frac=CAL_SPLIT_FRAC, seed=SEED)
ridge2, best_alpha2, cv_r2 = fit_ridge_calibrator(X_cal2, Y_cal2, alphas=CAL_ALPHAS)
joblib.dump(ridge2, os.path.join(out_dir, f"ridge_calibrator_post_ewc_alpha{best_alpha2}.pkl"))

preds_test_cal2 = ridge2.predict(X_test2)
cal_train_post = metrics_block(Y_test2, preds_test_cal2)

preds_test_cal2_raw = inverse_transform(preds_test_cal2, train_scaler.mean_[train_idx], train_scaler.scale_[train_idx])
Y_test2_raw         = inverse_transform(Y_test2,       train_scaler.mean_[train_idx], train_scaler.scale_[train_idx])
cal_raw_post = metrics_block(Y_test2_raw, preds_test_cal2_raw)

# Save metrics summary
pd.DataFrame({
    "metric":     ["pearson","spearman","mae"],
    "pre_train":  [base_train["pearson"], base_train["spearman"], base_train["mae"]],
    "pre_raw":    [base_raw["pearson"],   base_raw["spearman"],   base_raw["mae"]],
    "post_train": [cal_train_post["pearson"], cal_train_post["spearman"], cal_train_post["mae"]],
    "post_raw":   [cal_raw_post["pearson"],   cal_raw_post["spearman"],   cal_raw_post["mae"]],
}).to_csv(os.path.join(out_dir, "metrics_pre_post_ewc.csv"), index=False)

scatter_plot(Y_test2_raw, preds_test_cal2_raw,
                f"{EVAL_SAMPLE_NAME} {CHROM_ID}: Predicted vs Actual (calibrated, test split, post-EWC)",
                os.path.join(out_dir, "scatter_calibrated_test_post_ewc.png"))

# Per-gene corr (post-EWC, calibrated)
gene_corr2 = []
for j, g in enumerate(spaces_post["overlap_genes"]):
    y, yhat = Y_test2_raw[:, j], preds_test_cal2_raw[:, j]
    r = pearsonr(y, yhat)[0] if np.std(y) > 1e-8 else 0.0
    gene_corr2.append((g, r))
pd.DataFrame(gene_corr2, columns=["gene","pearson"]).sort_values("pearson", ascending=False)\
    .to_csv(os.path.join(out_dir, "per_gene_pearson_test_post_ewc.csv"), index=False)

## Output Summary

In [12]:
# Save matrices for reproducibility
pd.DataFrame(preds_test_cal2_raw, columns=spaces_post["overlap_genes"])\
    .to_csv(os.path.join(out_dir, "predictions_calibrated_test_post_ewc.csv"), index=False)
pd.DataFrame(Y_test2_raw, columns=spaces_post["overlap_genes"])\
    .to_csv(os.path.join(out_dir, "truth_raw_test_post_ewc.csv"), index=False)

# Console summary
print("\n== Summary ==")
print(f"Pre-EWC (train z): r={base_train['pearson']:.3f} | Pre-EWC (raw): r={base_raw['pearson']:.3f}")
print(f"Post-EWC (train z): r={cal_train_post['pearson']:.3f} | Post-EWC (raw): r={cal_raw_post['pearson']:.3f}")
print(f"(Pre ridge α={best_alpha}, CV r={cv_r:.3f})  (Post ridge α={best_alpha2}, CV r={cv_r2:.3f})")
print(f"Artifacts in: {out_dir}")




== Summary ==
Pre-EWC (train z): r=0.346 | Pre-EWC (raw): r=0.340
Post-EWC (train z): r=0.428 | Post-EWC (raw): r=0.441
(Pre ridge α=10.0, CV r=0.514)  (Post ridge α=3.0, CV r=0.488)
Artifacts in: /gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/output/transformer_testing_output/infer_mESC_holdout_chr1
