## Imports and Setup

In [None]:
import random
import time
from os import path as osp

import fsspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from sklearn.metrics import roc_auc_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from tqdm.auto import tqdm

In [None]:
DIST_TO_TSS = [[0, 30_000], [30_000, 100_000], [100_000, np.infty]]
USE_TISSUE = [True]  # used as another for loop for fitting SVM, whether to use tissue embed or not
Cs = [1, 5, 10]  # for loop in fitting SVM, inverse of L2 penalty (sklearn hyperparam)
PATH_TO_OUTPUTS = "./outputs/downstream/vep_embeddings"

In [None]:
def fsspec_exists(filename: str) -> bool:
    """Check if file exists in manner compatible with fsspec."""
    fs, _ = fsspec.core.url_to_fs(filename)
    return fs.exists(filename)

In [None]:
def dataset_nan_filter(data: dict, data_key: str) -> dict:
    """Filter any items that have NaN in embedding within TSS bucket"""
    mask_out = torch.logical_or(
        torch.any(data[data_key].isnan(), dim=1),
        torch.any(data[f"rc_{data_key}"].isnan(), dim=1)
    )
    
    new_data = dict()
    for data_key in data.keys():
        new_data[data_key] = data[data_key][~mask_out]

    return new_data

def dataset_tss_filter(data: dict, min_distance: int, max_distance: int) -> dict:
    """Filter the data to items that fall within TSS bucket"""
    distance_mask = ((data["distance_to_nearest_tss"] >= min_distance) 
                     & (data["distance_to_nearest_tss"] <= max_distance))
    new_data = dict()
    for data_key in data.keys():
        new_data[data_key] = data[data_key][distance_mask]

    return new_data

## Specify which models to test

In [None]:
# Embeddings to test
model_dict = {
    "HyenaDNA": dict(
        embed_path="hyena_downstream-seqlen=131k",
        rc_aug=False,
        conjoin_train=False,
        conjoin_test=False,
        key="concat_avg_ws",
    ),
    "Caduceus-Ph": dict(
        embed_path="caduceus-ph_downstream-seqlen=131k",
        rc_aug=False,
        conjoin_train=False,
        conjoin_test=True,
        key="concat_avg_ws",
    ),
    "Caduceus w/o Equiv.": dict(
        embed_path="caduceus-ph_downstream-seqlen=131k",
        rc_aug=False,
        conjoin_train=False,
        conjoin_test=False,
        key="concat_avg_ws",
    ),
    "Caduceus-PS": dict(
        embed_path="caduceus-ps_downstream-seqlen=131k",
        rc_aug=False,
        conjoin_train=True,
        conjoin_test=False,
        key="concat_avg_ws",
    ),
    "Enformer": dict(
        embed_path="enformer-seqlen=196k",
        rc_aug=False,
        conjoin_train=False,
        conjoin_test=False,
        key="concat_avg_ws",
    ),
    "NTv2": dict(
        embed_path="NTv2_downstream-seqlen=12k",
        rc_aug=False,
        conjoin_train=False,
        conjoin_test=False,
        key="concat_avg_ws",
    ),
}

## Fit and test SVM

In [None]:
metrics = {
    "model_name": [],
    "bucket_id": [],
    "use_tissue": [],
    "C": [],
    "seed": [],
    "AUROC": [],
}

for model_name, downstream_kwargs in model_dict.items():
    print(f"********** Gathering results for: {model_name} **********")
    embed_path = downstream_kwargs["embed_path"]
    rc_aug = downstream_kwargs["rc_aug"]
    conjoin_train = downstream_kwargs["conjoin_train"]
    conjoin_test = downstream_kwargs["conjoin_test"]
    key = downstream_kwargs["key"]
    
    if "NT" in model_name: assert (rc_aug == False) and (conjoin_train == False) and (conjoin_test == False)
    
    base_embeds_path = PATH_TO_OUTPUTS
    embeds_path = osp.join(base_embeds_path, embed_path)
    
    print(f"Embed Path: {embeds_path}")
    with fsspec.open(osp.join(embeds_path, "train_embeds_combined.pt"), "rb") as f:
        train_val_ds_raw = torch.load(f, map_location="cpu")
        train_val_ds_raw = dataset_nan_filter(train_val_ds_raw, data_key=key)
    with fsspec.open(osp.join(embeds_path, "test_embeds_combined.pt"), "rb") as f:
        test_ds_raw = torch.load(f, map_location="cpu")
        test_ds_raw = dataset_nan_filter(test_ds_raw, data_key=key)
    print(f"Total Train size: {len(train_val_ds_raw[key])},", end=" ")
    print(f"Total Test size: {len(test_ds_raw[key])},", end=" ")
    print(f"Shape: {test_ds_raw[key].shape[1:]}")


    for bucket_id, (min_dist, max_dist) in enumerate(DIST_TO_TSS):
        # Filter data to desired TSS bucket
        train_val_ds_filter = dataset_tss_filter(train_val_ds_raw, min_dist, max_dist)
        test_ds_filter = dataset_tss_filter(test_ds_raw, min_dist, max_dist)
        print(f"- TSS bucket: [{min_dist}, {max_dist}],", end=" ")
        print(f"Train size: {len(train_val_ds_filter[key])},", end=" ")
        print(f"Test size: {len(test_ds_filter[key])}")
    
        for use_tissue in USE_TISSUE:
            for C in Cs:
                for seed in range(1, 6):     
                    # Re-seed for SVM fitting
                    random.seed(seed)
                    np.random.seed(seed)
                    torch.manual_seed(seed)
                    torch.cuda.manual_seed_all(seed)

                    svm_clf = make_pipeline(
                        StandardScaler(),
                        SVC(C=C, random_state=seed),
                    )

                    # Setup Train/Test dataset
                    if conjoin_train:
                        X = np.array(train_val_ds_filter[key])
                        X += np.array(train_val_ds_filter[f"rc_{key}"])
                        X /= 2
                    else:
                        X = np.array(train_val_ds_filter[key])
                    X_with_tissue = np.concatenate(
                        [X, np.array(train_val_ds_filter["tissue_embed"])[..., None]],
                        axis=-1
                    )
                    y = train_val_ds_filter["labels"]
                    if conjoin_train or conjoin_test:
                        X_test = np.array(test_ds_filter[key])
                        X_test += np.array(test_ds_filter[f"rc_{key}"])
                        X_test /= 2
                    else:
                        X_test = np.array(test_ds_filter[key])
                    X_test_with_tissue = np.concatenate(
                        [X_test, np.array(test_ds_filter["tissue_embed"])[..., None]],
                        axis=-1
                    )
                    y_test = test_ds_filter["labels"]

                    print(f"\tFitting SVM ({use_tissue=}, {C=}, {seed=})...", end=" ")
                    
                    mask = np.random.choice(len(X), size=5000, replace= 5000 > len(X) )
                    if use_tissue: 
                        X_train = X_with_tissue[mask]
                        X_test = X_test_with_tissue
                    else: 
                        X_train = X[mask]
                    y_train = y[mask]

                    start = time.time()
                    svm_clf.fit(X_train, y_train)
                    svm_y_pred = svm_clf.predict(X_test)
                    svm_aucroc = roc_auc_score(y_test, svm_y_pred)
                    end = time.time()
                    print(f"Completed! ({end - start:0.3f} s) -", end=" ")
                    print(f"AUROC: {svm_aucroc}")
                     
                    metrics["model_name"] += [model_name]
                    metrics["bucket_id"] += [bucket_id]
                    metrics["use_tissue"] += [use_tissue]
                    metrics["C"] += [C]
                    metrics["seed"] += [seed]
                    metrics["AUROC"] += [svm_aucroc]

In [None]:
df_metrics = pd.DataFrame.from_dict(metrics)
df_metrics.to_csv(osp.join(PATH_TO_OUTPUTS, "SVM_results.csv"))

## Plot results

In [None]:
model_name_replacement = {
    "Caduceus w/o Equiv.": "Caduceus w/o\nEquiv. (7.7M)",
    "Caduceus-Ph": "Caduceus-Ph\n(7.7M)",
    "Caduceus-PS": "Caduceus-PS\n(7.7M)",
    "HyenaDNA": "HyenaDNA\n(6.6M)",
    "NTv2": "NTv2\n(500M)",
    "Enformer": "Enformer\n(252M)",
}

In [None]:
# Formatting changes to df
df = pd.read_csv(osp.join(PATH_TO_OUTPUTS, "SVM_results.csv"), index_col=0)
df_display = df.rename(columns={"bucket_id": "Distance to TSS"})
df_display = df_display.replace({"Distance to TSS": {0: "0 - 30k", 1: "30 - 100k", 2: "100k+"}})
df_display = df_display.replace({"model_name": model_name_replacement})

# Take average over seeds
df_display_selected = df_display.groupby(
    ["model_name", "Distance to TSS", "use_tissue", "C"]
).agg(AUROC=("AUROC", np.mean)).reset_index()

# Select best hyperparam by model/bucket
best_ids = df_display_selected.groupby(["model_name", "Distance to TSS"])["AUROC"].idxmax()
df_display_selected = df_display_selected.loc[best_ids.reset_index()["AUROC"].values]
display(
    df_display_selected.pivot(
        index="model_name", columns="Distance to TSS", values="AUROC"
    )[["0 - 30k", "30 - 100k", "100k+"]]
)
display(df_display_selected[["model_name", "Distance to TSS", "C", "use_tissue"]])

In [None]:
# Filter results to selected hyperparams
df_plot = pd.merge(
    df_display, df_display_selected,
    on=["model_name", "Distance to TSS", "use_tissue", "C"]
).drop(columns=["AUROC_y"]).rename(columns={"AUROC_x": "AUROC"})

# Plot results by distance to TSS
sns.set_style("whitegrid")
g = sns.catplot(
    data=df_plot,
    x="model_name",
    y="AUROC",
    col="Distance to TSS",
    hue="Distance to TSS",
    kind="bar",
    errorbar="sd",
    height=12,
    aspect=1,
    dodge=False,
    order=list(model_name_replacement.values()),
)
g.set_xticklabels(rotation=60, fontsize=30)
g.set(xlabel="")
g.set(ylim=(0.4, 0.7))
g.set_titles(template="Dist. to TSS: {col_name}", fontsize=40)
g.fig.suptitle("Predicting Effects of Variants on Gene Expression", y=1.1, fontsize=40)
g._legend.remove()
# Display bar values
# (See: https://stackoverflow.com/questions/55586912/seaborn-catplot-set-values-over-the-bars)
for ax in tqdm(g.axes.ravel(), leave=False):
    title = ax.title.get_text()
    ax.set_title(title, fontsize=35)
    for c in tqdm(ax.containers, leave=False):
        labels = [f"{v.get_height():0.3f}" for v in c]
        ax.bar_label(c, labels=labels, label_type="center", color="white", weight="bold", fontsize=24)
plt.show()
g.savefig(osp.join(PATH_TO_OUTPUTS, "SVM_results.png"))