In [6]:
from datasets import load_dataset, disable_caching
disable_caching()
from gpn.data import GenomeMSA, Tokenizer, ReverseComplementer
import gpn.model
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from scipy.stats import pearsonr, spearmanr
import seaborn as sns
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.pipeline import Pipeline
import tempfile
from transformers import AutoModel, Trainer, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import yaml

In [2]:
with open("../../config/config.yaml", 'r') as stream:
    config = yaml.safe_load(stream)

In [3]:
dataset = load_dataset("gonzalobenegas/gwas", split="test")
V = dataset.to_pandas()
V

Unnamed: 0,chrom,pos,ref,alt,pip,maf,trait,label,consequence,tss_dist,exon_dist,match_group
0,1,930312,C,G,0.001963,0.000683,,False,missense_variant,0,0,missense_variant_5
1,1,976215,A,G,0.005418,0.199367,,False,missense_variant,2703,0,missense_variant_39
2,1,1203822,T,C,0.006379,0.059558,,False,synonymous_variant,1856,0,synonymous_variant_0
3,1,1224844,C,T,0.001119,0.096291,,False,intron_variant,3750,876,intron_variant_11
4,1,1291417,C,G,0.003143,0.003690,,False,missense_variant,4308,0,missense_variant_40
...,...,...,...,...,...,...,...,...,...,...,...,...
4589,22,50190508,G,A,1.000000,0.073316,Plt,True,intron_variant,88,88,intron_variant_709
4590,22,50309599,G,A,0.002612,0.005661,,False,3_prime_UTR_variant,1953,0,3_prime_UTR_variant_61
4591,22,50395752,A,G,0.001928,0.083293,,False,intron_variant,23680,1617,intron_variant_692
4592,22,50516206,G,A,0.978163,0.095820,MCH,True,intron_variant,6317,239,intron_variant_710


In [68]:
# TODO: should have train, validation and test splits
# (and datasets)
# do a train run and then a predict run
# although maybe should compute metrics ROC_AUC for the validation while training
split_chroms = {
    "train": [str(i) for i in range(3, 23)],
    "validation": ["1"],
    "test": ["2"],
}

In [4]:
genome_msa = GenomeMSA(
    config["gpn_msa"]["msa_path"], subset_chroms=dataset.unique("chrom"), in_memory=False
)
genome_msa

Loading MSA...
Loading MSA... Done


<gpn.data.GenomeMSA at 0x7effac8e3b90>

In [69]:
datasets = {
    split: dataset.filter(lambda v: v["chrom"] in chroms).with_transform(tokenize_function)
    for split, chroms in split_chroms.items()
}

Filter: 100%|██████████| 4594/4594 [00:00<00:00, 135323.88 examples/s]
Filter: 100%|██████████| 4594/4594 [00:00<00:00, 148198.59 examples/s]
Filter: 100%|██████████| 4594/4594 [00:00<00:00, 149237.36 examples/s]


In [84]:
class VEP(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_path)
        #self.cls = nn.Linear(self.model.config.hidden_size*2, 1)
        self.cls = nn.Linear(self.model.config.hidden_size, 1)
        #self.cls = nn.Linear(1, 1)

    def get_embedding(self, input_ids, aux_features):
        return self.model(
            input_ids=input_ids, aux_features=aux_features,
        ).last_hidden_state.mean(dim=1)

    def forward(
        self,
        input_ids_ref=None,
        aux_features_ref=None,
        input_ids_alt=None,
        aux_features_alt=None,
        labels=None,
    ):
        embedding_ref = self.get_embedding(input_ids_ref, aux_features_ref)
        embedding_alt = self.get_embedding(input_ids_alt, aux_features_alt)
        #features = torch.hstack((embedding_ref, embedding_alt))
        #features = (embedding_ref-embedding_alt).abs()
        features = embedding_ref*embedding_alt
        #features = F.cosine_similarity(embedding_ref, embedding_alt).unsqueeze(-1)
        logits = self.cls(features).squeeze()
        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels)
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
        )


def tokenize_function(V):
    # we convert from 1-based coordinate (standard in VCF) to
    # 0-based, to use with GenomeMSA
    chrom = np.array(V["chrom"])
    pos = np.array(V["pos"]) - 1
    start = pos - window_size // 2
    end = pos + window_size // 2
    strand = ["+"] * len(chrom)

    msa = genome_msa.get_msa_batch(
        chrom,
        start,
        end,
        strand,
        tokenize=True,
    )
    pos = window_size // 2

    ref = np.array(
        [np.frombuffer(x.encode("ascii"), dtype="S1") for x in V["ref"]]
    )
    alt = np.array(
        [np.frombuffer(x.encode("ascii"), dtype="S1") for x in V["alt"]]
    )

    ref, alt = tokenizer(ref.flatten()), tokenizer(alt.flatten())
    input_ids, aux_features = msa[:, :, 0], msa[:, :, 1:]
    assert (
        input_ids[:, pos] == ref
    ).all(), f"{input_ids[:, pos].tolist()}, {ref.tolist()}"
    input_ids_alt = input_ids.copy()
    input_ids_alt[:, pos] = alt
    input_ids = input_ids.astype(np.int64)
    input_ids_alt = input_ids_alt.astype(np.int64)
    return {
        "input_ids_ref": input_ids,
        "aux_features_ref": aux_features,
        "input_ids_alt": input_ids_alt,
        "aux_features_alt": aux_features,
        "labels": np.array(V["label"]).astype(float),
    }

In [101]:
model_path = config["gpn_msa"]["model_path"]
window_size = config["gpn_msa"]["window_size"]
model = VEP(model_path)
tokenizer = Tokenizer()
reverse_complementer = ReverseComplementer()

training_args = TrainingArguments(
    output_dir=tempfile.TemporaryDirectory().name,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    dataloader_num_workers=8,
    remove_unused_columns=False,
    torch_compile=True,
    fp16=True,
    prediction_loss_only=False,#True,
    weight_decay=0.01,
    optim="adamw_torch",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    warmup_ratio=0.1,
    num_train_epochs=1,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    load_best_model_at_end=True,
    save_total_limit=1,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=datasets["train"],
    eval_dataset=datasets["validation"],
)
#res = inference.postprocess(pred)
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.6858,0.691978


TrainOutput(global_step=60, training_loss=0.6857836405436198, metrics={'train_runtime': 23.214, 'train_samples_per_second': 164.986, 'train_steps_per_second': 2.585, 'total_flos': 0.0, 'train_loss': 0.6857836405436198, 'epoch': 1.0})

In [102]:
pred = trainer.predict(test_dataset=datasets["test"]).predictions
pred.shape

(346,)

In [103]:
roc_auc_score(V[V.chrom.isin(split_chroms["test"])].label, pred)

0.6702529319389221