In [26]:
import os
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from bend.utils import embedders, Annotation
from Bio import SeqIO
from scipy.spatial.distance import cosine
from sklearn.metrics import roc_auc_score

DATA_DIR = '../data/'

VAR_EXP_FILE = os.path.join(DATA_DIR, 'variant_effects/variant_effects_expression.bed')
VAR_DIS_FILE = os.path.join(DATA_DIR, 'variant_effects/variant_effects_disease.bed')

REF_GENOME_FILE = os.path.join(DATA_DIR, 'genomes/GRCh38.primary_assembly.genome.fa')

OUT_DIR = '../results/variant_effects/'
os.makedirs(OUT_DIR, exist_ok=True)
FILENAME = 'convnet_expression.csv'

len_context = extra_context_left = extra_context_right = 256
embedding_idx = 256
kwargs = {'disable_tqdm': True}
EMBEDDER_DIR = '../pretrained_models/convnet/'

In [None]:
annotation = pd.read_csv(VAR_EXP_FILE, sep = '\t')
genome_dict = SeqIO.to_dict(SeqIO.parse(REF_GENOME_FILE, "fasta"))

In [3]:
from torch.utils.data import Dataset, DataLoader

class AnnotationData(Dataset):

    def __init__(self, annotation: pd.DataFrame, genome_dict: dict, extra_context_left: int = 0, extra_context_right: int = 0):
        super().__init__()

        if extra_context_left != extra_context_right and extra_context_left != 0 and extra_context_right != 0:
            raise ValueError('Left and right context must be equal or one of them must be 0')
        
        if not {'chromosome', 'start', 'end', 'alt'}.issubset(annotation.columns):
            raise ValueError('Annotation dataframe must contain columns: chromosome, start, end, alt')


        self.annotation  = annotation.copy()
        self.genome_dict = genome_dict

        # SNP annotation has start position equal to end position -> need to include context
        if extra_context_left == extra_context_right and extra_context_left == 0:
            # avoid having empty sequence in case of no extra context
            extra_context_right += 1
            
        self.annotation.loc[:, 'start'] = self.annotation.loc[:, 'start'] - extra_context_left
        self.annotation.loc[:, 'end'] = self.annotation.loc[:, 'end'] + extra_context_right

        self.idx_alt = extra_context_left

    def extend_context(self, extra_context_left: int = 0, extra_context_right: int = 0):
        # extend the context of the annotation dataframe

        self.extra_context_left = extra_context_left
        self.extra_context_right = extra_context_right

        self.annotation.loc[:, 'start'] = self.annotation.loc[:, 'start'] - extra_context_left
        self.annotation.loc[:, 'end'] = self.annotation.loc[:, 'end'] + extra_context_right

    def __len__(self):
        return self.annotation.shape[0]
    
    def __getitem__(self, idx):
        # return the data and label for the given index

        item = self.annotation.iloc[idx]
        dna_seq = str(self.genome_dict[item['chromosome']].seq[item['start']:item['end']])

        alt_dna_seq = [n for n in dna_seq]
        alt_dna_seq[self.idx_alt] = item['alt']
        alt_dna_seq = ''.join(alt_dna_seq)

        return (dna_seq, alt_dna_seq)
        

In [37]:
from bend.utils.embedders import BaseEmbedder
import torch
import numpy as np
from typing import List, Iterable
import os

from bend.models.dilated_cnn import ConvNetModel
from bend.utils.download import download_model, download_model_zenodo

from tqdm.auto import tqdm
from transformers import logging, AutoTokenizer
logging.set_verbosity_error()

if torch.backends.mps.is_available():
    device = torch.device("mps")


class ConvNetEmbedder(BaseEmbedder):
    """
    Embed using the GPN-inspired ConvNet baseline LM trained in BEND.
    """
    def load_model(self, model_path, **kwargs):
        """
        Load the GPN-inspired ConvNet baseline LM trained in BEND.

        Parameters
        ----------
        model_path : str
            The path to the model directory.
            If the model path does not exist, it will be downloaded from https://sid.erda.dk/cgi-sid/ls.py?share_id=dbQM0pgSlM&current_dir=pretrained_models&flags=f
        """

        logging.set_verbosity_error()
        if not os.path.exists(model_path):
            print(f'Path {model_path} does not exists, model is downloaded from https://sid.erda.dk/cgi-sid/ls.py?share_id=dbQM0pgSlM&current_dir=pretrained_models&flags=f')
            download_model(model = 'convnet',
                           destination_dir = model_path)
        # load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        # load model        
        self.model = ConvNetModel.from_pretrained(model_path).to(device).eval()
    
    def embed(self, sequences: List[str], disable_tqdm: bool = False, upsample_embeddings: bool = False):
        """
        Embed sequences using the GPN-inspired ConvNet baseline LM trained in BEND.

        Parameters
        ----------
        sequences : List[str]
            List of sequences to embed.
        disable_tqdm : bool, optional
            Whether to disable the tqdm progress bar. Defaults to False.
        upsample_embeddings : bool, optional
            Whether to upsample the embeddings to the length of the input sequence. Defaults to False.
            Only provided for compatibility with other embedders. GPN embeddings are already the same length as the input sequence.

        Returns
        -------
        List[np.ndarray]
            List of embeddings.
        """

        with torch.no_grad():
            
            input_ids = self.tokenizer(sequences, return_tensors="pt", return_attention_mask=False, return_token_type_ids=False)["input_ids"]
            input_ids = input_ids.to(device)

            embeddings = self.model(input_ids=input_ids).last_hidden_state
            embeddings = embeddings.detach().cpu().numpy()
            
            return embeddings

In [38]:
annotation = pd.read_csv(VAR_EXP_FILE, sep = '\t')

extra_context_left = extra_context_right = 256
dataset = AnnotationData(annotation, genome_dict, extra_context_left=extra_context_left, extra_context_right=extra_context_right)
dataloader = DataLoader(dataset, batch_size=256, shuffle=False)

tokenizer = AutoTokenizer.from_pretrained(EMBEDDER_DIR)
embedder = ConvNetEmbedder(EMBEDDER_DIR)

model = ConvNetModel.from_pretrained(EMBEDDER_DIR).to(device).eval()

cosine_dinstances = []

for batch_ref, batch_alt in tqdm(dataloader):
    
    ref_embeddings = embedder.embed(batch_ref)[:,embedding_idx,:]
    alt_embeddings = embedder.embed(batch_alt)[:,embedding_idx,:]


    for ref_emb, alt_emb in zip(ref_embeddings, alt_embeddings):
        cosine_dinstances.append(cosine(ref_emb, alt_emb))


dataset.annotation['distance'] = cosine_dinstances

dataset.annotation.to_csv(os.path.join(OUT_DIR, FILENAME), index=False)

  0%|          | 0/412 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
roc_auc_score(dataset.annotation['label'], dataset.annotation['distance'])