In [1]:
from Bio.Seq import Seq
from datasets import load_dataset
from gpn.data import Genome
import grelu.resources
from grelu.sequence.format import strings_to_one_hot
import grelu.variant
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import average_precision_score
import tempfile
import torch
from transformers import Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm
2024-08-10 23:16:38.838590: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-10 23:16:38.885784: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
dataset_path = "gonzalobenegas/siraj_gwas_highpip"

In [3]:
V = load_dataset(dataset_path, split="test").to_pandas()
#V = V.head(32)
V.chrom = "chr" + V.chrom
V

Unnamed: 0,chrom,pos,ref,alt,label
0,chr1,3080038,T,C,True
1,chr1,3774964,A,G,True
2,chr1,6616131,C,T,False
3,chr1,7665224,C,A,False
4,chr1,8407293,G,A,False
...,...,...,...,...,...
1778,chr22,47019717,G,T,False
1779,chr22,47990921,C,T,True
1780,chr22,50007172,T,C,False
1781,chr22,50190508,G,A,True


In [4]:
d = load_dataset(dataset_path, split="test")
d

Dataset({
    features: ['chrom', 'pos', 'ref', 'alt', 'label'],
    num_rows: 1783
})

In [5]:
genome = Genome("../../results/genome.fa.gz")

In [6]:
def transform_f(V):
    # we convert from 1-based coordinate (standard in VCF) to
    # 0-based, to use with Genome
    chrom = np.array(V["chrom"])
    n = len(chrom)
    pos = np.array(V["pos"]) - 1
    start = pos - window_size // 2
    end = pos + window_size // 2
    seq_fwd, seq_rev = zip(
        *(genome.get_seq_fwd_rev(chrom[i], start[i], end[i]) for i in range(n))
    )
    seq_fwd = np.array([list(seq.upper()) for seq in seq_fwd], dtype="object")
    seq_rev = np.array([list(seq.upper()) for seq in seq_rev], dtype="object")
    assert seq_fwd.shape[1] == window_size
    assert seq_rev.shape[1] == window_size
    ref_fwd = np.array(V["ref"])
    alt_fwd = np.array(V["alt"])
    ref_rev = np.array([str(Seq(x).reverse_complement()) for x in ref_fwd])
    alt_rev = np.array([str(Seq(x).reverse_complement()) for x in alt_fwd])
    pos_fwd = window_size // 2
    pos_rev = pos_fwd - 1 if window_size % 2 == 0 else pos_fwd

    def prepare_output(seq, pos, ref, alt):
        assert (seq[:, pos] == ref).all(), f"{seq[:, pos]}, {ref}"
        seq_ref = seq
        seq_alt = seq.copy()
        seq_alt[:, pos] = alt
        return (
            strings_to_one_hot(["".join(x) for x in seq_ref]),
            strings_to_one_hot(["".join(x) for x in seq_alt]),
        )

    res = {}
    res["x_ref_fwd"], res["x_alt_fwd"] = prepare_output(seq_fwd, pos_fwd, ref_fwd, alt_fwd)
    res["x_ref_rev"], res["x_alt_rev"] = prepare_output(seq_rev, pos_rev, ref_rev, alt_rev)
    return res

In [7]:
d.set_transform(transform_f)

In [8]:
class VEPModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def get_scores(self, x_ref, x_alt):
        y_ref = self.model(x_ref)
        y_alt = self.model(x_alt)
        lfc = torch.log2(1 + y_alt) - torch.log2(1 + y_ref)
        l2 = torch.linalg.norm(lfc, dim=2)
        return l2

    def forward(
        self,
        x_ref_fwd=None,
        x_alt_fwd=None,
        x_ref_rev=None,
        x_alt_rev=None,
    ):
        fwd = self.get_scores(x_ref_fwd, x_alt_fwd)
        rev = self.get_scores(x_ref_rev, x_alt_rev)
        return (fwd + rev) / 2

In [9]:
#model = grelu.resources.load_model(project="enformer", model_name="human")
model = grelu.resources.load_model(project="borzoi", model_name="human_fold0")

columns = model.data_params['tasks']["name"]
window_size = model.data_params["train_seq_len"]
model = VEPModel(model.model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgonzalobenegas[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact human_fold0:latest, 711.00MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


In [10]:
#torch.set_float32_matmul_precision("medium")
training_args = TrainingArguments(
    output_dir=tempfile.TemporaryDirectory().name,
    per_device_eval_batch_size=8,
    dataloader_num_workers=8,
    remove_unused_columns=False,
    #torch_compile=True,
    #fp16=True,
    #fp16_full_eval=True,
)
trainer = Trainer(model=model, args=training_args)
preds = trainer.predict(test_dataset=d).predictions
#preds.shape
preds

## Benchmarking

In [16]:
V.label.mean()

0.5002804262478968

In [17]:
average_precision_score(V.label, np.linalg.norm(preds, axis=1))

0.7148561444772359

In [None]:
df = pd.DataFrame(preds, columns=columns)
df

In [None]:
#df.to_parquet("Enformer.parquet", index=False)
#df.to_parquet("Borzoi.parquet", index=False)

In [18]:
np.isnan(preds).sum()

0