In [1]:
import torch
from transformers import AutoTokenizer, AutoModel, AdamW, get_scheduler, BertForMaskedLM, DataCollatorWithPadding
from transformers.models.bert.configuration_bert import BertConfig
from dnabert_for_token_classification import BertForTokenClassification
import evaluate
from tqdm import tqdm
from data_handling_for_MLM import MutationDetectionDataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


# Data
Tokenizing the normal and mutated data,
Marking what token has been changed.

In [2]:
fasta_m = '/ems/elsc-labs/habib-n/yuval.rom/school/ANLP/final_project/Mutation-Simulator/data/sample_data/data_m.fa'
fasta_t = '/ems/elsc-labs/habib-n/yuval.rom/school/ANLP/final_project/Mutation-Simulator/data/sample_data/data.fa'

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

Using device: cpu




In [4]:
tokenizer.get_vocab()
tokenizer('[PAD]').input_ids

[1, 3, 2]

In [5]:
mutation_dataset = MutationDetectionDataset(fasta_m, fasta_t, tokenizer, verbose=True)


tensor([   1,    4, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
          75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
          83,  180,    4,    4,  588,  126,  545,    2])
tensor([   1,  153, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
          75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
          83,  180, 1719,   54,  588,  126,  545,    2])
True
----------------

tensor([   1,   30,  126,  545,   66,  374, 1602,  283, 1108,  152,  793,  128,
         245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,  283,   65,
         232,  204,   32,    4,   75, 3715,  151,    2])
tensor([   1,   30,  126,  545,   66,  374, 1602,  283, 1108,  152,  793,  128,
         245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,  283,   65,
         232,  204,   32,  289,   75, 3715,  151,    2])
True
----------------

tensor([   1,    5,   23,   78,  479, 3270,   32,    4,    4,    4,  158,   70,
     

## Dataloaders and Padding

In [6]:
from torch.utils.data import Dataset, DataLoader
import torch
from Bio import SeqIO
import random
MAX_LEN = 512
MASK_TOKEN = 4 # [MASK] token
PAD_TOKEN = 3 # [PAD] token
ENDING_TOKEN = 2 # [SEP] token

def insert_single_replacement(seq, mutation_rate):
    new_seq = []
    mutation_added = False
    for c in seq:
        if mutation_added == False and random.random() < mutation_rate:
            choice_str = list({'A', 'C', 'G', 'T'} - {c})
            new_seq.append(random.choice(choice_str))
        else:
            new_seq.append(c)
    return ''.join(new_seq)


def get_onehot_for_first_missmatch(seq1, seq2):
    for index, (a, b) in enumerate(zip(seq1, seq2)):
        if a != b:
            return [1 if i == index else 0 for i in range(len(seq1))]
    return [0] * len(seq1)

def pad_sequences(seq1, seq2, max_len=MAX_LEN):
    max_len = max(len(seq1), len(seq2))
    if seq1.shape[0] < max_len:
        # seq1[-1] = PAD_TOKEN
        pad_vector = torch.ones(max_len - len(seq1), dtype=torch.long) * PAD_TOKEN
        seq1 = torch.cat((seq1, pad_vector))
        # seq1[-1] = ENDING_TOKEN


    elif seq2.shape[0] < max_len:
        # seq2[-1] = PAD_TOKEN
        pad_vector = torch.ones(max_len - len(seq2), dtype=torch.long) * PAD_TOKEN
        seq2 = torch.cat((seq2, pad_vector))
        # seq2[-1] = ENDING_TOKEN
    
    assert seq1.shape == seq2.shape
    return seq1, seq2


def compare_two_sequences(seq1, seq2):
    return seq1 != seq2


def mask_sequence(seq, mask_vector):
    seq[mask_vector] = MASK_TOKEN # [MASK] token
    return seq


class MutationDetectionDataset(Dataset):

    def __init__(self, fasta_m, fasta_t, tokenization_f, replacement_flag=False, mutation_rate=0.01, verbose=False):
 
        zipped_fasta_lines = zip(SeqIO.parse(fasta_m, "fasta"), SeqIO.parse(fasta_t, "fasta"))
        self.sequences = []
        self.tokens_labels = []
        self.mutations = []
        for record_m, record_t in zipped_fasta_lines:
            if replacement_flag :
                x = insert_single_replacement(record_m.seq, mutation_rate=mutation_rate)
            else:
                x = record_m.seq
            tokenized_x = tokenization_f(str(x), padding=False, truncation=True, max_length=MAX_LEN, return_tensors='pt')['input_ids'].squeeze(0)
            print(tokenized_x)
            tokenized_y = tokenization_f(str(record_t.seq), padding=False, truncation=True, max_length=MAX_LEN, return_tensors='pt')['input_ids'].squeeze(0)
            tokenized_x, tokenized_y = pad_sequences(tokenized_x, tokenized_y)
            mask_vector = compare_two_sequences(tokenized_x, tokenized_y)
            tokenized_x = mask_sequence(tokenized_x, mask_vector)
            self.sequences.append(tokenized_x)
            self.tokens_labels.append(tokenized_y)
            if verbose:
                print('x', tokenized_x)
                print('y', tokenized_y)
                # print(tokenized_x.shape == tokenized_y.shape)
                
                # print(tokenized_y.dtype)
                # print('compare', mask_vector)
                print('----------------')

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        assert self.sequences[idx].shape == self.tokens_labels[idx].shape
        return {'input_ids': torch.tensor(self.sequences[idx]), 'labels': torch.tensor(self.tokens_labels[idx])}
    

mutation_dataset = MutationDetectionDataset(fasta_m, fasta_t, tokenizer, verbose=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, max_length=MAX_LEN)
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
    max_len = max([len(t) for t in input_ids])

    input_ids = [torch.cat((t, torch.ones(max_len - len(t), dtype=torch.long) * PAD_TOKEN)) for t in input_ids]
    labels = [torch.cat((t, torch.ones(max_len - len(t), dtype=torch.long) * PAD_TOKEN)) for t in labels]

    return {'input_ids': torch.stack(input_ids), 'labels': torch.stack(labels)}

dataloader = DataLoader(mutation_dataset, batch_size=2, collate_fn=collate_fn, shuffle=False)

for batch in dataloader:
    print('-' * 100)
    print(batch)
    print(batch)
    # print('-' * 100)

tensor([   1,  250, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
          75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
          83,  180,   90, 1898,  588,  126,  545,   66,  374, 1602,  283, 1108,
         152,  645,  215, 3373,  678, 2045,  556, 1176,  727,   97,  173,  448,
        1227,  486,   48,  220,   65,   20,  164,  268,   27,  283,  104, 1184,
          73, 3532,  245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,
         283,   65,  232,  204,   32,  289,   75, 3715,  151,  987, 1435,  226,
          33,  411,  149, 3654,  494,  163, 1321,   53, 2975,  112,  131, 1069,
         776,  937, 3920,  108, 1957,   80,  716, 3833,  177,  100,  118,   48,
         941, 3115,  519, 1624,  107,   89,  586,   31,  181,   51, 1127, 2293,
         448, 3462, 3454,  942,  307,   82, 2491,   50, 1431,  116,   28,  347,
         220,   95,  366,  637,   68, 1886,  257,   53,  769,  210, 1795,   19,
         402,  232,  586, 1889,  403, 36

  return {'input_ids': torch.tensor(self.sequences[idx]), 'labels': torch.tensor(self.tokens_labels[idx])}


In [7]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
dataloader = DataLoader(mutation_dataset, batch_size=2, collate_fn=data_collator, shuffle=False)
# Iterate over the DataLoader
for batch in dataloader:
    print(batch)
    print('-'*100)

  return {'input_ids': torch.tensor(self.sequences[idx]), 'labels': torch.tensor(self.tokens_labels[idx])}
You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`labels` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [None]:
# Iterate over the DataLoader
for batch in dataloader:
    print(batch)
    print('-'*100)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': tensor([[   1,    4, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
           75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
           83,  180,    4,    4,  588,  126,  545,    2],
        [   1,   30,  126,  545,   66,  374, 1602,  283, 1108,  152,  793,  128,
          245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,  283,   65,
          232,  204,   32,    4,   75, 3715,  151,    2]]), 'labels': tensor([[   1,  153, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
           75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
           83,  180, 1719,   54,  588,  126,  545,    2],
        [   1,   30,  126,  545,   66,  374, 1602,  283, 1108,  152,  793,  128,
          245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,  283,   65,
          232,  204,   32,  289,   75, 3715,  151,    2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`labels` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [None]:
mutation_dataset

<data_handling_for_MLM.MutationDetectionDataset at 0x7f3e02406920>

# Model
We are using [DNABERT2](https://github.com/MAGICS-LAB/DNABERT_2/tree/main?tab=readme-ov-file#1-introduction)

In [None]:
model = BertForMaskedLM(config).to(device)

In [None]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(4096, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tru