In [1]:
from Bio.Seq import Seq
from datasets import load_dataset
from gpn.data import Genome
import grelu.resources
from grelu.sequence.format import strings_to_one_hot
import grelu.variant
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import average_precision_score
import tempfile
import torch
from transformers import Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm
2024-08-10 23:05:03.843435: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-10 23:05:03.904652: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
dataset_path = "gonzalobenegas/siraj_gwas_highpip"

In [3]:
V = load_dataset(dataset_path, split="test").to_pandas()
#V = V.head(32)
V.chrom = "chr" + V.chrom
V

Unnamed: 0,chrom,pos,ref,alt,label
0,chr1,3080038,T,C,True
1,chr1,3774964,A,G,True
2,chr1,6616131,C,T,False
3,chr1,7665224,C,A,False
4,chr1,8407293,G,A,False
...,...,...,...,...,...
1778,chr22,47019717,G,T,False
1779,chr22,47990921,C,T,True
1780,chr22,50007172,T,C,False
1781,chr22,50190508,G,A,True


In [4]:
d = load_dataset(dataset_path, split="test")
d

Dataset({
    features: ['chrom', 'pos', 'ref', 'alt', 'label'],
    num_rows: 1783
})

In [5]:
genome = Genome("../../results/genome.fa.gz")

In [6]:
def transform_f(V):
    # we convert from 1-based coordinate (standard in VCF) to
    # 0-based, to use with Genome
    chrom = np.array(V["chrom"])
    n = len(chrom)
    pos = np.array(V["pos"]) - 1
    start = pos - window_size // 2
    end = pos + window_size // 2
    seq_fwd, seq_rev = zip(
        *(genome.get_seq_fwd_rev(chrom[i], start[i], end[i]) for i in range(n))
    )
    seq_fwd = np.array([list(seq.upper()) for seq in seq_fwd], dtype="object")
    seq_rev = np.array([list(seq.upper()) for seq in seq_rev], dtype="object")
    assert seq_fwd.shape[1] == window_size
    assert seq_rev.shape[1] == window_size
    ref_fwd = np.array(V["ref"])
    alt_fwd = np.array(V["alt"])
    ref_rev = np.array([str(Seq(x).reverse_complement()) for x in ref_fwd])
    alt_rev = np.array([str(Seq(x).reverse_complement()) for x in alt_fwd])
    pos_fwd = window_size // 2
    pos_rev = pos_fwd - 1 if window_size % 2 == 0 else pos_fwd

    def prepare_output(seq, pos, ref, alt):
        assert (seq[:, pos] == ref).all(), f"{seq[:, pos]}, {ref}"
        seq_ref = seq
        seq_alt = seq.copy()
        seq_alt[:, pos] = alt
        return (
            strings_to_one_hot(["".join(x) for x in seq_ref]),
            strings_to_one_hot(["".join(x) for x in seq_alt]),
        )

    res = {}
    res["x_ref_fwd"], res["x_alt_fwd"] = prepare_output(seq_fwd, pos_fwd, ref_fwd, alt_fwd)
    res["x_ref_rev"], res["x_alt_rev"] = prepare_output(seq_rev, pos_rev, ref_rev, alt_rev)
    return res

In [7]:
d.set_transform(transform_f)

In [8]:
class VEPModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def get_scores(self, x_ref, x_alt):
        y_ref = self.model(x_ref)
        y_alt = self.model(x_alt)
        lfc = torch.log2(1 + y_alt) - torch.log2(1 + y_ref)
        l2 = torch.linalg.norm(lfc, dim=2)
        return l2

    def forward(
        self,
        x_ref_fwd=None,
        x_alt_fwd=None,
        x_ref_rev=None,
        x_alt_rev=None,
    ):
        fwd = self.get_scores(x_ref_fwd, x_alt_fwd)
        rev = self.get_scores(x_ref_rev, x_alt_rev)
        return (fwd + rev) / 2

In [9]:
#model = grelu.resources.load_model(project="enformer", model_name="human")
model = grelu.resources.load_model(project="borzoi", model_name="human_fold0")

columns = model.data_params['tasks']["name"]
window_size = model.data_params["train_seq_len"]
model = VEPModel(model.model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mgonzalobenegas[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact human_fold0:latest, 711.00MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:2.1


In [10]:
#torch.set_float32_matmul_precision("medium")
training_args = TrainingArguments(
    output_dir=tempfile.TemporaryDirectory().name,
    per_device_eval_batch_size=8,
    dataloader_num_workers=8,
    remove_unused_columns=False,
    #torch_compile=True,
    #fp16=True,
    #fp16_full_eval=True,
)
trainer = Trainer(model=model, args=training_args)
preds = trainer.predict(test_dataset=d).predictions
#preds.shape
preds

## Benchmarking

In [None]:
V.label.mean()

0.5002804262478968

In [None]:
average_precision_score(V.label, np.linalg.norm(preds, axis=1))

0.7106313639491217

In [None]:
df = pd.DataFrame(preds, columns=columns)
df

Unnamed: 0,ENCFF833POA,ENCFF110QGM,ENCFF880MKD,ENCFF463ZLQ,ENCFF890OGQ,ENCFF996AEF,ENCFF660YSU,ENCFF787MSC,ENCFF568LMQ,ENCFF685MZL,...,CNhs14551,CNhs14618,CNhs14226,CNhs14229,CNhs14238,CNhs14239,CNhs14240,CNhs14241,CNhs14244,CNhs14245
0,0.020627,0.027453,0.022877,0.014595,0.011404,0.012971,0.009009,0.014654,0.010176,0.026699,...,0.030669,0.028385,0.014161,0.009556,0.004341,0.008560,0.003383,0.008210,0.017461,0.018274
1,0.034373,0.026412,0.037801,0.026490,0.027976,0.030215,0.028804,0.027649,0.030563,0.028170,...,0.059913,0.032793,0.019731,0.008741,0.034403,0.020556,0.023130,0.022730,0.060444,0.061170
2,0.013653,0.013146,0.022949,0.011773,0.012057,0.014716,0.007150,0.011956,0.008903,0.016390,...,0.044309,0.031237,0.013514,0.054213,0.008707,0.007368,0.008335,0.008033,0.023703,0.020034
3,0.029997,0.034185,0.063735,0.015230,0.026252,0.029426,0.020775,0.033114,0.027260,0.037913,...,0.091137,0.069207,0.024454,0.019247,0.004269,0.005681,0.003253,0.003965,0.014895,0.017959
4,0.020948,0.011623,0.015968,0.010985,0.022360,0.020332,0.020782,0.017494,0.020130,0.011037,...,0.078673,0.040349,0.017659,0.014112,0.008577,0.012185,0.011752,0.010652,0.041655,0.026237
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1778,0.010188,0.015055,0.031079,0.005793,0.014833,0.013747,0.009445,0.014654,0.010966,0.012773,...,0.060567,0.021760,0.007151,0.014033,0.004122,0.001136,0.001977,0.002988,0.010180,0.010185
1779,0.042335,0.041975,0.008428,0.005667,0.005185,0.009152,0.005836,0.005353,0.006187,0.009516,...,0.013600,0.011352,0.005881,0.005740,0.001590,0.001055,0.001519,0.001720,0.005846,0.003889
1780,0.030721,0.039224,0.036249,0.023633,0.029598,0.035346,0.024490,0.033465,0.029582,0.034701,...,0.048981,0.033054,0.013186,0.011193,0.011780,0.009787,0.007663,0.008840,0.021838,0.019903
1781,0.185299,0.264034,0.273924,0.256359,0.342722,0.415615,0.179939,0.262859,0.255552,0.240080,...,0.623208,0.439687,0.121340,0.054223,0.568436,0.201643,0.287033,0.352236,0.761115,0.776491


In [None]:
#df.to_parquet("Enformer.parquet", index=False)
df.to_parquet("Borzoi.parquet", index=False)

  if _pandas_api.is_sparse(col):
