In [1]:
import numpy as np
import pandas as pd
from collections import defaultdict

import os,sys

import torch 
from torch.utils.data import DataLoader, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM,AutoModel

sys.path.append('DNABERT/')

from src.transformers import DNATokenizer 
from transformers import BertModel, BertConfig

%load_ext autoreload
%autoreload 2

In [2]:
class dotdict(dict):
    '''
    dot.notation access to dictionary attributes
    '''
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
data_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/MLM/'

In [4]:
input_params = dotdict({})

#input_params.fasta = data_dir + 'griesemer/fasta/GRCh38_UTR_variants.fa'
input_params.fasta = data_dir + 'fasta/Homo_sapiens_no_reverse.fa'

input_params.model = 'NT-MS-v2-500M'

input_params.output_dir = data_dir + f'griesemer/embeddings/{input_params.model}/'

input_params.batch_size = 10

input_params.processed_seqs = data_dir + f'/3UTR_embeddings/{input_params.model}/processed_utrs.csv'

input_params.N_folds = 10

input_params.fold=0

In [5]:
MAX_SEQ_LENGTH = 5000

In [6]:
def load_model(model_name):
    
    model_dirs = {'DNABERT':data_dir + 'dnabert/default/6-new-12w-0/',
                  'DNABERT-2':data_dir + 'dnabert2/DNABERT-2-117M/',
                  'NT-MS-v2-500M':data_dir + 'nucleotide-transform/nucleotide-transformer-v2-500m-multi-species'} 

    if model_name == 'DNABERT':
        
        config = BertConfig.from_pretrained('https://raw.githubusercontent.com/jerryji1993/DNABERT/master/src/transformers/dnabert-config/bert-config-6/config.json')
        tokenizer = DNATokenizer.from_pretrained('dna6')
        model = BertModel.from_pretrained(model_dirs[model_name], config=config)

    elif model_name == 'DNABERT-2':

        tokenizer = AutoTokenizer.from_pretrained(model_dirs[model_name],trust_remote_code=True)
        model = AutoModel.from_pretrained(model_dirs[model_name],trust_remote_code=True)

    elif model_name == 'NT-MS-v2-500M':

        # Import the tokenizer and the model
        tokenizer = AutoTokenizer.from_pretrained(model_dirs[model_name],trust_remote_code=True)
        model = AutoModelForMaskedLM.from_pretrained(model_dirs[model_name],trust_remote_code=True)

    return tokenizer, model

In [7]:
#from glob import glob
#import pickle
#
#processed_seqs = []
#for emb_file in glob(data_dir + f'/3UTR_embeddings/{input_params.model}/ENST*.pickle'):
#    with open(emb_file, 'rb') as f:
#        utr_names_batch, _ = pickle.load(f)
#        processed_seqs.extend(utr_names_batch)
#
#pd.Series(processed_seqs).to_csv(data_dir + f'/3UTR_embeddings/{input_params.model}/processed_utrs.csv',index=None,header=none)

In [8]:
class SeqDataset(Dataset):
    
    def __init__(self, fasta_file):
        
        seqs = defaultdict(str)
            
        with open(fasta_file, 'r') as f:
            for line in f:
                if line.startswith('>'):
                    transcript_id = line[1:].split(':')[0].rstrip()
                else:
                    seqs[transcript_id] += line.rstrip().upper()
                    
        seqs = {k:v[:MAX_SEQ_LENGTH] for k,v in seqs.items()}
        #seqs = {k:''.join(np.random.choice(list('ACGT'),size=MAX_LENGTH)) for k,v in seqs.items()}
        seqs = list(seqs.items())

        if input_params.exclude!=None:
            print(f'Excluding sequences from {input_params.processed_seqs}')
            processed_seqs = pd.read_csv(input_params.exclude,names=['seq_name']).seq_name.values
            seqs = [(seq_name,seq) for seq_name,seq in seqs if not seq_name in processed_seqs]
        if input_params.N_folds!=None:
            print(f'Fold {input_params.fold}')
            folds = np.tile(np.arange(input_params.N_folds),len(seqs)//input_params.N_folds+1)[:len(seqs)]
            seqs = [x for idx,x in enumerate(seqs) if folds[idx]==input_params.fold]
            
        self.seqs = seqs
        self.max_length = max([len(seq[1]) for seq in self.seqs])
        
    def __len__(self):
        
        return len(self.seqs)
    
    def __getitem__(self, idx):
        
        return self.seqs[idx]

In [9]:
def kmers_stride1(seq, k=6):
    # splits a sequence into overlapping k-mers
    return [seq[i:i + k] for i in range(0, len(seq)-k+1)] 


def get_batch_embeddings(model_name, sequences):

    if model_name == 'DNABERT':

        outputs = []
 
        for seq in sequences:

            seq_kmer = kmers_stride1(seq)
    
            model_input = tokenizer.encode_plus(seq_kmer, add_special_tokens=True, padding='max_length', max_length=512)["input_ids"]
            model_input = torch.tensor(model_input, dtype=torch.long)
            model_input = model_input.unsqueeze(0)   # to generate a fake batch with batch size one

            output = model(model_input)
            outputs.append(output[1])

        return torch.vstack(outputs)

    elif model_name == 'DNABERT-2':

        inputs = tokenizer(sequences, return_tensors = 'pt', padding="max_length", max_length = dataset.max_length)["input_ids"]
        
        hidden_states = model(inputs)[0] # [1, sequence_length, 768]
        
        # embedding with mean pooling
        mean_sequence_embeddings = torch.mean(hidden_states, dim=1)

        return mean_sequence_embeddings

    elif model_name == 'NT-MS-v2-500M':

        batch_token_ids = tokenizer.batch_encode_plus(sequences, return_tensors="pt", padding="max_length", max_length = dataset.max_length)["input_ids"]

        attention_mask = batch_token_ids != tokenizer.pad_token_id
            
        torch_outs = model(
            batch_token_ids,
            attention_mask=attention_mask,
            encoder_attention_mask=attention_mask,
            output_hidden_states=True)
        
        # Compute sequences embeddings
        embeddings = torch_outs['hidden_states'][-1].detach().numpy()
        #print(f"Embeddings shape: {embeddings.shape}")
        #print(f"Embeddings per token: {embeddings}")
        
        # Add embed dimension axis
        attention_mask = torch.unsqueeze(attention_mask, dim=-1)
        
        # Compute mean embeddings per sequence
        mean_sequence_embeddings = torch.sum(attention_mask*embeddings, axis=-2)/torch.sum(attention_mask, axis=1)
        #print(f"Mean sequence embeddings: {mean_sequence_embeddings}")

        probas = F.softmax(torch_outs['logits'],dim=2).cpu().numpy()
        batch_token_ids = batch_token_ids.cpu().numpy()
        gt_probas = np.take_along_axis(probas, batch_token_ids[...,None], axis=2).squeeze()
        log_probas = np.log(gt_probas)

    return (mean_sequence_embeddings, log_probas)

In [10]:
#sequences = next(iter(dataloader))

In [34]:
tokenizer, model = load_model(input_params.model)

Downloading (…)config-6/config.json: 359B [00:00, 9.92kB/s]                   


<class 'src.transformers.tokenization_dna.DNATokenizer'>


Some weights of the model checkpoint at /lustre/groups/epigenereg01/workspace/projects/vale/MLM/dnabert/default/6-new-12w-0/ were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
dataset = SeqDataset(input_params.fasta)

dataloader = DataLoader(dataset = dataset, 
                        batch_size = input_params.batch_size, 
                        num_workers = 2, collate_fn = None, shuffle = False)

In [None]:
all_emb = []

for seq_idx, (seq_names,sequences) in enumerate(dataloader):

    print(f'generating embeddings for batch {seq_idx}/{len(dataloader)}')

    with torch.no_grad():
        emb = get_batch_embeddings(input_params.model,sequences).cpu().numpy()

    all_emb.append(emb)

all_emb = np.vstack(all_emb)

In [14]:
os.makedirs(input_params.output_dir, exist_ok=True)

with open(input_params.output_dir + 'embeddings.npy', 'wb') as f:
    np.save(f, all_emb)