In [1]:
import math
import itertools
import collections
from collections.abc import Mapping
import numpy as np
import pandas as pd
import tqdm
import os
import torch

from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding

from datasets import Dataset

# Basic Information

## Sequence length constraints

The 3' model expects an input sequence which is 300bp long (stop codon + 297). It will handle shorter sequences (although < 11 cannot be masked) and in theory can even predict up to 512 - but this is out-of-distribution and likely performs very poorly as the positional encodings are not adapted for this.

The 5' model expects an input which is 1003bp long (1000 + start codon). Longer sequences will not work, shorter sequences must be padded (e.g. with a fixed sequence), otherwise the start codon gets the wrong positional encoding which confuses the model.

# Basic Functions

## Utilities

In [2]:
def chunkstring(string, length):
    # chunks a string into segments of length
    return (string[0+i:length+i] for i in range(0, len(string), length))

def kmers(seq, k=6):
    # splits a sequence into non-overlappnig k-mers
    return [seq[i:i + k] for i in range(0, len(seq), k) if i + k <= len(seq)]

def kmers_stride1(seq, k=6):
    # splits a sequence into overlapping k-mers
    return [seq[i:i + k] for i in range(0, len(seq)-k+1)]   

def one_hot_encode(gts, dim=5):
    # one-hot encodes the sequence
    result = []
    nuc_dict = {"A":0,"C":1,"G":2,"T":3}
    for nt in gts:
        vec = np.zeros(dim)
        vec[nuc_dict[nt]] = 1
        result.append(vec)
    return np.stack(result, axis=0)

def class_label_gts(gts):
    # make labels into ground truths
    nuc_dict = {"A":0,"C":1,"G":2,"T":3}
    return np.array([nuc_dict[x] for x in gts])

def tok_func_standard(x, seq_col): return tokenizer(" ".join(kmers_stride1(x[seq_col])))

def tok_func_species(x, species_proxy, seq_col):
    res = tokenizer(species_proxy + " " +  " ".join(kmers_stride1(x[seq_col])))
    return res

In [3]:
def count_special_tokens(tokens, tokenizer, where = "left"):
    count = 0
    if where == "right":
        tokens = tokens[::-1]
    for pos in range(len(tokens)):
        tok = tokens[pos]
        if tok in tokenizer.all_special_ids:
            count += 1
        else:
            break
    return count

## Data Collator

In [4]:
#from transformers import  DataCollatorForLanguageModeling
#data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability = 0.15)
torch.manual_seed(0)

def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of = None):
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
    import torch

    # Tensorize if necessary.
    if isinstance(examples[0], (list, tuple, np.ndarray)):
        examples = [torch.tensor(e, dtype=torch.long) for e in examples]

    length_of_first = examples[0].size(0)

    # Check if padding is necessary.

    are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
    if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
        return torch.stack(examples, dim=0)

    # If yes, check if we have a `pad_token`.
    if tokenizer._pad_token is None:
        raise ValueError(
            "You are attempting to pad samples but the tokenizer you are using"
            f" ({tokenizer.__class__.__name__}) does not have a pad token."
        )

    # Creating the full tensor and filling it with our data.
    max_length = max(x.size(0) for x in examples)
    if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
        max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
    result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
    for i, example in enumerate(examples):
        if tokenizer.padding_side == "right":
            result[i, : example.shape[0]] = example
        else:
            result[i, -example.shape[0] :] = example
    return result

class DataCollatorForLanguageModelingSpan():
    
    def __init__(self, tokenizer, mlm, mlm_probability, span_length):
        self.tokenizer = tokenizer
        self.mlm = mlm
        self.span_length = span_length
        self.mlm_probability= mlm_probability
        self.pad_to_multiple_of = span_length

    def __call__(self, examples):
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], Mapping):
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {
                "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
            }

        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        if self.mlm:
            batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
                batch["input_ids"], special_tokens_mask=special_tokens_mask
            )
        else:
            labels = batch["input_ids"].clone()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            batch["labels"] = labels
        return batch

    def torch_mask_tokens(self, inputs, special_tokens_mask):
        import torch
        
        original_inputs = inputs.clone()
        labels = inputs.clone()
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        # 10% of the time, we leave the input unchanged
        probability_matrix = torch.full(labels.shape, self.mlm_probability*0.1)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool().numpy()
        # to ensure that we create spans, we convolve with a filter length of 6
        # we convert back to bool to account for overlaps (anything bigger than 0 gets masked)
        masked_indices = np.apply_along_axis(lambda m : np.convolve(m, [1] * self.span_length, mode = 'same' ),axis = 1, arr = masked_indices).astype(bool) 
        masked_indices = torch.from_numpy(masked_indices)
        m_save = masked_indices.clone()
        
        # 10% of the time, we replace masked input tokens with random nt
        # create a random offset matrix (randint(1,4))
        offsets = torch.randint(1, 4, labels.shape)
        # multiply with a masking matrix to get masked offsets
        probability_matrix = torch.full(labels.shape, self.mlm_probability*0.1)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        probability_matrix.masked_fill_(m_save, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        masked_offsets = masked_indices*offsets
        # make the modulo matrix which identifies the nucleotide
        modulo_matrix = torch.remainder(inputs - 5, 4) # adjust for special tokens
        # Get the masked shifts matrix
        masked_shifts = torch.remainder(modulo_matrix + masked_offsets, 4) - modulo_matrix
        # we now propagate the change to the next few tokens
        for i in range(self.span_length):
            shifted_shift = masked_shifts[:,:(masked_shifts.shape[1] - i)]
            inputs[:,i:] = (shifted_shift*(4**i)) + inputs[:,i:]
            masked_indices[:,i:] = shifted_shift + masked_indices[:,i:]
        masked_indices = masked_indices.bool()
        m_save = m_save + masked_indices
        
        # 80% of the time, we mask
        probability_matrix = torch.full(labels.shape, self.mlm_probability*0.8) 
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        probability_matrix.masked_fill_(m_save, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool().numpy()
        # to ensure that we create spans, we convolve with a filter length of 6
        masked_indices = np.apply_along_axis(lambda m : np.convolve(m, [1] * self.span_length, mode = 'same' ),axis = 1, arr = masked_indices).astype(bool) 
        masked_indices = torch.from_numpy(masked_indices)
        
        # aggregate all the positions where we want loss
        m_final = masked_indices + m_save 
        labels[~m_final] = -100  # We only compute loss on masked tokens
        # we actually replace with the mask token
        inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        
        # restore the special tokens
        inputs[special_tokens_mask] = original_inputs[special_tokens_mask]
        labels[special_tokens_mask] = -100
        
        return inputs, labels

# Parameters

In [6]:
seq_df_path = "data/Sequences/Annotation/Sequences/saccharomyces_cerevisiae/saccharomyces_cerevisiae_three_prime.parquet"

seq_col = "three_prime_seq" # name of the column in the df that stores the sequences
kmer_size = 6 # size of kmers, always 6
proxy_species = "candida_glabrata" # species token to use
pred_batch_size = 128*3 # batch size for rolling masking
target_layer = (8,) # what hidden layers to use for embedding

# Load Data and Model

## Load the model

In [5]:
from transformers import Trainer
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig  
tokenizer = AutoTokenizer.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")
model = AutoModelForMaskedLM.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")

In [8]:
device = "cuda"

#model.to(torch.bfloat16).to(device)
#model.to(torch.float16).to(device)
model.to(device)
model.eval()

print("Done")

Done


## Prepare the data

In [9]:
dataset = pd.read_parquet(seq_df_path)
dataset[seq_col] = dataset[seq_col].str[:300] # truncate longer sequences
dataset = dataset.loc[dataset[seq_col].str.len() == 300] # throw out too short sequences

In [10]:
tok_func = lambda x: tok_func_species(x, proxy_species, seq_col)

ds = Dataset.from_pandas(dataset[[seq_col]])

tok_ds = ds.map(tok_func, batched=False,  num_proc=2)

rem_tok_ds = tok_ds.remove_columns(seq_col)

data_collator = DataCollatorForLanguageModelingSpan(tokenizer, mlm=False, mlm_probability = 0, span_length = 6)
data_loader = torch.utils.data.DataLoader(rem_tok_ds, batch_size=1, collate_fn=data_collator, shuffle = False)

# Reconstruction Predictions

## Functions

In [10]:
def predict_on_batch_generator(tokenized_data, dataset, seq_idx, 
                               special_token_offset, 
                               kmer_size = kmer_size,
                               seq_col = seq_col,
                               pred_batch_size = pred_batch_size):
    model_input_unaltered = tokenized_data['input_ids'].clone()
    label = dataset.iloc[seq_idx][seq_col]
    label_len = len(label)
    if label_len < kmer_size:
        print("This should not occur")
        return torch.zeros(label_len,label_len,5)
    else:
        diag_matrix = torch.eye(tokenized_data['input_ids'].shape[1]).numpy()
        masked_indices = np.apply_along_axis(lambda m : np.convolve(m, [1] * 6, mode = 'same' ),axis = 1, arr = diag_matrix).astype(bool)
        masked_indices = torch.from_numpy(masked_indices)
        masked_indices = masked_indices[2+special_token_offset:label_len-(kmer_size-1)-3+special_token_offset]
        res = tokenized_data['input_ids'].expand(masked_indices.shape[0],-1).clone()
        res[masked_indices] = 4
        yield res.shape[0] # provide the total size
        for batch_idx in range(math.ceil(res.shape[0]/pred_batch_size)):
            res_batch = res[batch_idx*pred_batch_size:(batch_idx+1)*pred_batch_size]
            res_batch = res_batch.to(device)
            with torch.no_grad():
                computation = model(res_batch)
                logits = computation["logits"].detach()
                #if "logits" in computation:
                #    logits = computation["logits"].detach()
                #else:
                #    logits = computation["prediction_logits"].float().detach()
                fin_calculation = logits
            yield fin_calculation, res

In [11]:
# make a convolutional filter for each nt
# the way this works:
# The kmer ACGTGC
# maps to token 739
# the last nt is C
# this would be the prediction for the masked nucleotide
# from this kmer, if the kmer is the first in masked span
# so the first row of column 739 searches for C
# in other words filter_ijk = 1 for i = 0, j = 739, k = 2
vocab = tokenizer.get_vocab()
kmer_list = ["".join(x) for x in itertools.product("ACGT",repeat=6)]
nt_mapping = {"A":0,"C":1,"G":2,"T":3}
prb_filter = np.zeros((kmer_size, 4**kmer_size, 4))
for kmer in kmer_list:
    token = vocab[kmer] - 5 # there are 5 special tokens
    for idx, nt in enumerate(kmer):
        nt_idx = nt_mapping[nt]
        prb_filter[5-idx, token, nt_idx] = 1
prb_filter = torch.from_numpy(prb_filter)
prb_filter = prb_filter.to(device)

In [12]:
def extract_prbs_from_pred(kmer_prediction, 
                           label_pos,
                           max_pos,
                           prb_filter=prb_filter,
                           kmer_size=kmer_size):   
    # label_pos = position of actual nucleotide in sequence
    nt_preds = kmer_prediction[label_pos:(label_pos+kmer_size),:] # extract the right kmers
    nt_preds = nt_preds.unsqueeze(2).expand((nt_preds.shape[0],nt_preds.shape[1],4)) # repeat along nt dimension
    nt_preds = (nt_preds*prb_filter).sum(axis=1) # filter and add over tokens
    nt_preds = nt_preds.sum(axis=0)
    nt_prbs = nt_preds/nt_preds.sum() # renormalize
    return nt_prbs.cpu().numpy()

## Run Inference

In [17]:
predicted_prbs,gts = [],[]
prev_len = 0

for no_of_index, tokenized_data in tqdm.tqdm(enumerate(data_loader)):
    #if no_of_index > 10:
    #    break
    label = dataset.iloc[no_of_index][seq_col]
    label_len = len(label)
    
    left_special_tokens = count_special_tokens(tokenized_data['input_ids'].numpy()[0], tokenizer, where="left")
    right_special_tokens = count_special_tokens(tokenized_data['input_ids'].numpy()[0], tokenizer, where="right")
    
    # Edge case: for a sequence less then 11 nt
    # we cannot even feed 6 mask tokens
    # so we might as well predict random
    if label_len < 11: 
        #print (no_of_index)
        for i in range(label_len):
            predicted_prbs.append(torch.tensor([0.25,0.25,0.25,0.25]))
            gts.append(label[i])
        added_len = len(predicted_prbs) - prev_len
        prev_len = len(predicted_prbs)
        assert added_len == len(label)
        continue

    # we do a batched predict to process the sequence
    batch_start = 0
    pos = 0
    prediction_generator = predict_on_batch_generator(tokenized_data, dataset, no_of_index, special_token_offset = left_special_tokens)
    max_idx = next(prediction_generator)
    for predictions, res in prediction_generator:
    
        # prepare predictions for processing
        logits = predictions[:,:,5:(5+prb_filter.shape[1])] # remove any non k-mer dims
        kmer_preds = torch.softmax(logits,dim=2)
        # remove special tokens:
        kmer_preds = kmer_preds[:,(left_special_tokens):(kmer_preds.shape[1] - right_special_tokens),:]
        max_pos = kmer_preds.shape[1] - 1
        # pad to predict first 5 and last 5 nt
        padded_tensor = torch.zeros((kmer_preds.shape[0],2*(kmer_size-1) + kmer_preds.shape[1],kmer_preds.shape[2]),device=device)
        padded_tensor[:,kmer_size-1:-(kmer_size-1),:] = kmer_preds
        kmer_preds = padded_tensor
        
        while pos < label_len:
            # get prediction
            theoretical_idx = min(max(pos-5,0),max_idx-1) # idx if we did it all in one batch
            actual_idx = max(theoretical_idx - batch_start,0) 
            if actual_idx >= kmer_preds.shape[0]:
                break
            kmer_prediction = kmer_preds[actual_idx]
            nt_prbs = extract_prbs_from_pred(kmer_prediction=kmer_prediction, 
                                             label_pos=pos,
                                             max_pos=max_pos)
            predicted_prbs.append(nt_prbs)
            # extract ground truth
            gt = label[pos]
            gts.append(gt)
            # update
            pos += 1
        
        batch_start = pos - 5

    added_len = len(predicted_prbs) - prev_len
    prev_len = len(predicted_prbs)
    assert added_len == len(label)

11it [00:14,  1.33s/it]


In [18]:
prbs_arr = np.stack(predicted_prbs).reshape((no_of_index, 300, 4))

In [27]:
torch.load("outputs/bertadn_origtest_convolved_prb_candida_glabrata_scer_downstream_fixedlen_512_withstop/prbs.pt").reshape((len(dataset),300,4))[0]

tensor([[5.3996e-05, 3.1054e-05, 2.2263e-04, 9.9969e-01],
        [6.7587e-01, 9.7943e-05, 3.2381e-01, 2.2240e-04],
        [7.6734e-01, 1.1883e-04, 2.3232e-01, 2.1588e-04],
        ...,
        [3.0795e-01, 2.0922e-01, 1.9542e-01, 2.8741e-01],
        [2.9296e-01, 1.2911e-01, 1.8747e-01, 3.9046e-01],
        [2.2976e-01, 2.7151e-01, 1.7726e-01, 3.2147e-01]], dtype=torch.float64)

# Embedding Sequences

## Functions

In [11]:
def embed_on_batch(tokenized_data, dataset, seq_idx, 
                   special_token_offset,
                   target_layer = target_layer):
    model_input_unaltered = tokenized_data['input_ids'].clone()
    label = dataset.iloc[seq_idx][seq_col]
    label_len = len(label)
    if label_len < 6:
        print("This should not occur")
        return torch.zeros(label_len,label_len,768)
    else:
        res = tokenized_data['input_ids'].clone()
        res = res.to(device)
        with torch.no_grad():
            embedding = model(res, output_hidden_states=True)['hidden_states'] 
    if isinstance(target_layer, int):    
        embedding = embedding[target_layer]
    elif len(target_layer) == 1:
        embedding = torch.stack(embedding[target_layer[0]:],axis=0)
        embedding = torch.mean(embedding, axis=0)
    else:
        embedding = torch.stack(embedding[target_layer[0]:target_layer[1]],axis=0)
        embedding = torch.mean(embedding, axis=0)   
    embedding = embedding.detach().cpu().numpy() 
    return embedding

In [12]:
def extract_embedding_from_pred(hidden_states, batch_pos):   
    pred_pos_min = min(max(pos - 5, 0), hidden_states.shape[1]-1)
    pred_pos_max = min(max(pos, 0), hidden_states.shape[1]-1)
    token_embedding = hidden_states[batch_pos, pred_pos_min:pred_pos_max+1, :]
    token_embedding = token_embedding.mean(axis=0)
    return token_embedding

## Run Inference

In [13]:
k = 6
averaged_embeddings = []
#print (dataset.iloc[0]['seq_chunked'])

for no_of_index, tokenized_data in tqdm.tqdm(enumerate(data_loader)):
    embeddings = []

    label = dataset.iloc[no_of_index][seq_col]
    label_len = len(label)
    
    left_special_tokens = count_special_tokens(tokenized_data['input_ids'].numpy()[0], tokenizer, where="left")
    right_special_tokens = count_special_tokens(tokenized_data['input_ids'].numpy()[0], tokenizer, where="right")

    if label_len < 11: 
        averaged_embeddings.append(np.array([0.0]*768))
        continue

    hidden_states = embed_on_batch(tokenized_data, dataset, no_of_index, special_token_offset = left_special_tokens)
    avg = hidden_states.mean(axis=(0,1))
    
    averaged_embeddings.append(avg)

0it [00:00, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
6594it [01:18, 83.97it/s]


In [14]:
embeddings = np.stack(averaged_embeddings)