In [1]:
import torch 
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np
import pandas as pd
import pickle
from collections import defaultdict
from torch.utils.data import DataLoader, IterableDataset
from torch.nn.utils.rnn import pad_sequence

import itertools

import os
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
N_FOLDS = 10
MAX_TOK_LEN = 1024
BATCH_SIZE = 64

In [3]:
MASKING = True
DECODE=None#'reference-aware'
CENTRAL_WINDOW=None
reverse_seq_neg_strand = False

In [4]:
data_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [5]:
model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/whole_genome/nucleotide-transformer-v2-100m-multi-species'

#model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/zoonomia-3utr/ntrans-v2-100m-3utr-2e/checkpoints/chkpt_600/'

In [6]:
device = torch.device("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_dir,trust_remote_code=True).to(device);

In [7]:
def get_chunks(seq_tokens):
    '''
    Chunk the given token sequence into chunks of MAX_TOK_LEN
    The input sequences shouldn't contain special tokens
    The last chunk is padded with the previous chunk if it's shorter than MAX_TOK_LEN
    '''
    if tokenizer.eos_token_id:
        #in the original InstaDeep models, the cls token wasn't present
        chunk_len = MAX_TOK_LEN-2 #2 special tokens to be added 
    else:
        chunk_len = MAX_TOK_LEN-1 #only cls token
    chunks = [seq_tokens[start:start+chunk_len] for start in range(0,len(seq_tokens),chunk_len)]
    assert [x for y in chunks for x in y]==seq_tokens
    if len(chunks)>1:
        left_shift = min(chunk_len-len(chunks[-1]), len(chunks[-2])) #overlap length for the last chunk and the previous one
        if left_shift>0:
            pad_seq = chunks[-2][-left_shift:]
            chunks[-1] = pad_seq + chunks[-1]
    else:
        left_shift = 0
    if tokenizer.eos_token_id:
        chunks = [[tokenizer.cls_token_id, *chunk, tokenizer.eos_token_id] for chunk in chunks]
        assert [x for y in chunks[:-1] for x in y[1:-1]]+[x for x in  chunks[-1][1+left_shift:-1]]==seq_tokens
    else:
        chunks = [[tokenizer.cls_token_id, *chunk] for chunk in chunks]
        assert [x for y in chunks[:-1] for x in y[1:]]+[x for x in  chunks[-1][1+left_shift:]]==seq_tokens
    #left_shift only makes sense for the last chunk, for the other chunks it's 0
    res = [(chunk,0) if chunk_idx!=len(chunks)-1 else (chunk,left_shift) for chunk_idx, chunk in enumerate(chunks)]
    return res

In [8]:
def mask_sequence(seq_tokens, mask_crop_left=0,mask_crop_right=None):
    '''
    Consecutively mask tokens in the sequence and yield each masked position
    Mask tokens between mask_crop_left and mask_crop_right
    Don't mask special tokens
    '''    
    if not mask_crop_right:
        mask_crop_right = len(seq_tokens)-1
    for mask_pos in range(1+mask_crop_left,1+mask_crop_right):
        if seq_tokens[mask_pos] in (tokenizer.eos_token_id,tokenizer.pad_token_id):
            break
        masked_seq = seq_tokens.clone()
        masked_seq[mask_pos] = tokenizer.mask_token_id
        yield mask_pos, masked_seq

In [9]:
class SeqDataset(IterableDataset):
    
    def __init__(self, seq_df, masking=True):
        
        self.seq_df = seq_df
        self.start = 0
        self.end = len(self.seq_df)
        self.masking = masking
        
    def __iter__(self):
        
        for seq_idx in range(self.start, self.end):
            
            seq_info = self.seq_df.iloc[seq_idx]
            chunk = seq_info.tokens
            
            gt_tokens = torch.LongTensor(chunk)

            mask_crop_left = seq_info.crop_mask_left
            mask_crop_right = seq_info.crop_mask_right
                
            if self.masking:
                for masked_pos, masked_tokens in mask_sequence(gt_tokens, mask_crop_left, mask_crop_right):
                    #consecutively mask each token in the sequence
                    assert masked_tokens[masked_pos] == tokenizer.mask_token_id
                    yield seq_info.name, gt_tokens, masked_pos, masked_tokens
            else:
                yield seq_info.name, gt_tokens, -1, gt_tokens

def worker_init_fn(worker_id):
     worker_info = torch.utils.data.get_worker_info()
     dataset = worker_info.dataset  # the dataset copy in this worker process
     overall_start = dataset.start
     overall_end = dataset.end
     # configure the dataset to only process the split workload
     per_worker = int(np.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
     worker_id = worker_info.id
     dataset.start = overall_start + worker_id * per_worker
     dataset.end = min(dataset.start + per_worker, overall_end)

In [10]:
def predict_on_batch(masked_tokens_batch):

    targets_masked = masked_tokens_batch.clone()
    targets_masked[targets_masked!=tokenizer.mask_token_id] = -100
    attention_mask = masked_tokens_batch!= tokenizer.pad_token_id   
    
    with torch.no_grad():
        torch_outs = model(
        masked_tokens_batch.to(device),
        labels = targets_masked.to(device),
        attention_mask=attention_mask.to(device),
        encoder_attention_mask=attention_mask.to(device),
        output_hidden_states=False)
    
    logits = torch_outs["logits"] #max_tokenized_length x (max_tokenized_length+1) x N_tokens
    
    probas_batch = F.softmax(logits, dim=-1).cpu().numpy()
    
    loss = torch_outs["loss"].item()
    
    return probas_batch, loss

In [11]:
def reverse_complement(seq):
    '''
    Take sequence reverse complement
    '''
    compl_dict = {'A':'T', 'C':'G', 'G':'C', 'T':'A'}
    compl_seq = ''.join([compl_dict.get(x,x) for x in seq])
    rev_seq = compl_seq[::-1]
    return rev_seq

In [12]:
def crop_seq(seq,tokens):

    L = len(seq)
    central_pos = L//2

    decoded_tokens = tokenizer.decode(tokens).split()
    
    nt_idx = []
    for token_idx,token in enumerate(decoded_tokens[1:]):
        if not token.startswith('<'):
            nt_idx.extend([token_idx]*len(token))
    nt_idx = np.array(nt_idx)
    crop_mask_left, crop_mask_right = nt_idx[central_pos],nt_idx[central_pos+CENTRAL_WINDOW]+1
    seq_idx = np.where((nt_idx>=crop_mask_left) & (nt_idx<crop_mask_right))[0]
    pos_left,pos_right = seq_idx[0], seq_idx[-1]+1
    seq_cropped = seq[pos_left:pos_right]
    assert seq_cropped.startswith(seq[pos_left:pos_left+CENTRAL_WINDOW])
    return (seq, tokens, L, seq_cropped,pos_left, pos_right, crop_mask_left, crop_mask_right)

In [13]:
tokendict_list = [{"A": [], "G": [], "T": [],"C": []} for x in range(6)]

for tpl in itertools.product("ACGT",repeat=6):
    encoding = tokenizer.encode("".join(tpl))
    for idx, nuc in enumerate(tpl):
        tokendict_list[idx][nuc].append(encoding[1]) #token indices for idx position in 6-mer and letter nuc

In [28]:
fasta = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/variants/selected/variants_rna.fa'
#fasta = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/fasta/Homo_sapiens_dna_fwd.fa'

seq_df = defaultdict(str)

with open(fasta, 'r') as f:
    for line in f:
        if line.startswith('>'):
            seq_name = line[1:].rstrip()
        else:
            seq_df[seq_name] += line.rstrip()#.upper()
            
seq_df = pd.DataFrame(list(seq_df.items()), columns=['seq_name','seq']).set_index('seq_name')

In [29]:
fold = 0

folds = np.arange(N_FOLDS).repeat(len(seq_df)//N_FOLDS+1)[:len(seq_df)] #split into folds 

seq_df = seq_df.loc[folds==fold] #get required fold

print(f'Fold {fold}: {len(seq_df)} sequences')

original_seqs = seq_df.seq #sequences before tokenization

Fold 0: 10046 sequences


In [30]:
PREDICT_ONLY_LOWERCASE = True

In [31]:
def crop_seq(seq,tokens):

    L = len(seq)
    
    if CENTRAL_WINDOW:
        assert PREDICT_ONLY_LOWERCASE is None
        left = L//2-CENTRAL_WINDOW//2
        right = left+CENTRAL_WINDOW
    elif PREDICT_ONLY_LOWERCASE:
        lower_idx = np.array([idx for idx, c in enumerate(seq) if c.islower()])
        left = lower_idx.min()
        right = lower_idx.max()

    decoded_tokens = tokenizer.decode(tokens).split()
    
    nt_idx = []
    for token_idx,token in enumerate(decoded_tokens[1:]):
        if not token.startswith('<'):
            nt_idx.extend([token_idx]*len(token))
            
    nt_idx = np.array(nt_idx)
    
    crop_mask_left, crop_mask_right = nt_idx[left],nt_idx[right]+1
    seq_idx = np.where((nt_idx>=crop_mask_left) & (nt_idx<crop_mask_right))[0]
    
    pos_left,pos_right = seq_idx[0], seq_idx[-1]+1
    
    seq_cropped = seq[pos_left:pos_right]
    
    assert seq_cropped.startswith(seq[pos_left:pos_left+right-left])
    
    return (seq, tokens, L, seq_cropped,pos_left, pos_right, crop_mask_left, crop_mask_right)

In [32]:
#reverse complement on the negative strand if reverse_seq_neg_strand=True
#strand_info = pd.read_csv(data_dir + 'UTR_coords/GRCh38_3_prime_UTR_clean-sorted.bed', sep='\t', header = None, names=['seq_name','strand'], usecols=[3,5]).set_index('seq_name').squeeze()
#seq_df.seq = seq_df.apply(lambda x: reverse_complement(x.seq) if strand_info.loc[x.seq_name]=='-' 
#                          and reverse_seq_neg_strand else x.seq, axis=1) #undo reverse complement


if CENTRAL_WINDOW is not None or PREDICT_ONLY_LOWERCASE is not None:
    seq_df['tokens'] = seq_df.seq.apply(lambda seq:tokenizer(seq.upper(),add_special_tokens=True)['input_ids'])
    seq_df = pd.DataFrame([crop_seq(seq,tokens) for seq,tokens in seq_df.values], index=seq_df.index, columns=['seq','tokens','seq_length','seq_cropped','pos_left','pos_right','crop_mask_left','crop_mask_right'])
else:
    tokens = [(seq_name,chunk[0],chunk[1]) for seq_name,seq in seq_df.seq.items() for chunk in get_chunks(tokenizer(seq.upper(),add_special_tokens=False)['input_ids'])
             ]
    seq_df = pd.DataFrame(tokens,columns=['seq_name','tokens','crop_mask_left']).set_index('seq_name')
    seq_df['crop_mask_right'] = seq_df.tokens.apply(len)-1

In [34]:
def collate_fn(batch):
    '''
    Collate tokenized sequences based on the maximal sequence length in the batch
    '''
    seq_names_batch, gt_tokens_batch, masked_pos_batch, masked_tokens_batch = zip(*batch)
    masked_tokens_batch = pad_sequence(masked_tokens_batch, batch_first=True, padding_value=tokenizer.pad_token_id)
    gt_tokens_batch = pad_sequence(gt_tokens_batch, batch_first=True, padding_value=tokenizer.pad_token_id)
    return seq_names_batch, gt_tokens_batch, masked_pos_batch, masked_tokens_batch

In [35]:
dataloader = DataLoader(SeqDataset(seq_df,masking=MASKING), batch_size=BATCH_SIZE, 
                        shuffle=False, collate_fn=collate_fn, 
                        num_workers=1, worker_init_fn=worker_init_fn)

In [36]:
def predict_probas_token(seq_token_probas,token_pos,gt_token):
    '''
    Predict probabilities of each bp for a given token
    '''
    seq_probas = []
    for idx in range(len(gt_token)):
        #loop over all positions of the masked token
        position_probas = [] #probabilities for all bases at given position
        for nuc in 'ACGT':
            if DECODE=='reference-aware':
                token_idx = tokenizer.token_to_id(gt_token[:idx]+nuc+gt_token[idx+1:]) #single token 
            else:
                token_idx = tokendict_list[idx][nuc] #all tokens that have given base nuc at given position idx
            position_probas.append(seq_token_probas[token_pos][token_idx].sum()) 
        seq_probas.append(position_probas)
    return seq_probas

In [None]:
nuc_dict = {"A":0,"C":1,"G":2,"T":3} #for accuracy

all_probas = defaultdict(list) #probas for all masked tokens 
verif_seqs = defaultdict(str) #reconstruct sequences from mask tokens and make sure that they match the original sequences

all_losses, is_correct = [], [] #to compute loss and accuracy
prev_seq_name = None #name of the previous sequence

pbar = tqdm(total=len(original_seqs))

for seq_names_batch, gt_tokens_batch, masked_pos_batch, masked_tokens_batch in dataloader:

    probas_batch, loss_batch = predict_on_batch(masked_tokens_batch)
    #probas_batch, loss_batch = np.zeros((len(seq_names_batch),1024,4108)), 0 #placeholder for testing
    
    all_losses.append(loss_batch)
    
    for seq_name, gt_tokens, masked_pos, seq_probas in zip(seq_names_batch, gt_tokens_batch, masked_pos_batch, probas_batch):
        gt_tokens = gt_tokens.cpu().tolist()
        if MASKING:
            gt_token = tokenizer.id_to_token(gt_tokens[masked_pos]) #ground truth masked token
            all_probas[seq_name].extend(predict_probas_token(seq_probas,masked_pos,gt_token))
            verif_seqs[seq_name] += gt_token
        else:
            for token_idx, gt_token in enumerate(gt_tokens):
                gt_token = tokenizer.id_to_token(gt_token) #ground truth token
                if not gt_token.startswith('<'):
                    all_probas[seq_name].extend(predict_probas_token(seq_probas,token_idx,gt_token))
                    verif_seqs[seq_name] += gt_token
        if seq_name!=prev_seq_name:
            #processing of prev_seq_name is completed
            if len(verif_seqs[prev_seq_name])>0:
                is_correct.extend([nuc_dict.get(base,4)==gt_idx for base, gt_idx in zip(verif_seqs[prev_seq_name],np.argmax(all_probas[prev_seq_name],axis=1))])
                print(f'Sequence {prev_seq_name} processed ({len(verif_seqs)-1}/{len(original_seqs)}), loss: {np.mean(all_losses):.3}, acc:{np.mean(is_correct):.3}')
                if CENTRAL_WINDOW or PREDICT_ONLY_LOWERCASE is not None:
                    assert verif_seqs[prev_seq_name]==seq_df.loc[prev_seq_name]['seq_cropped'].upper() #compare reconstruction from the masked token with the original sequence
                else:
                    assert verif_seqs[prev_seq_name]==original_seqs.loc[prev_seq_name].upper() #compare reconstruction from the masked token with the original sequence
                pbar.update(1)
            prev_seq_name = seq_name
              

assert verif_seqs[seq_name]==original_seqs.loc[seq_name]


  2%|▏         | 191/10046 [55:18<47:33:59, 17.38s/it]


In [None]:
seq_names = list(all_probas.keys())
probs = [np.array(x) for x in all_probas.values()]
seqs = original_seqs.loc[seq_names].values.tolist()

In [None]:
if CENTRAL_WINDOW or PREDICT_ONLY_LOWERCASE is not None:
    for seq_idx,seq_name in enumerate(seq_names):
        pad_left = np.ones((seq_df.loc[seq_name]['pos_left'],4))*0.25
        pad_right = np.ones((seq_df.loc[seq_name]['seq_length']-seq_df.loc[seq_name]['pos_right'],4))*0.25
        probs[seq_idx] = np.vstack((pad_left,probs[seq_idx],pad_right))

In [None]:
seq_names = list(all_probas.keys())
probs = [np.array(x) for x in all_probas.values()]
seqs = original_seqs.loc[seq_names].values.tolist()

if reverse_seq_neg_strand:
    probs = [x[::-1,[3,2,1,0]] if strand_info.loc[seq_name]=='-' else x for x, seq_name in zip(probs,seq_names)]
    seqs = [reverse_complement(x) if strand_info.loc[seq_name]=='-' else x for x, seq_name in zip(seqs,seq_names)]

In [None]:
#with open(data_dir + f'motif_predictions/split_75_25/ntrans/NT-MS-v2-500M_{fold}.pickle', 'wb') as f:
#    pickle.dump({'seq_names':seq_names,'seqs':seqs, 'probs':probs, 'fasta':fasta},f)

print('Done')