In [1]:
import torch
from transformers import AutoTokenizer, AutoModel, AdamW, get_scheduler, BertForMaskedLM, DataCollatorWithPadding
from transformers.models.bert.configuration_bert import BertConfig
from dnabert_for_token_classification import BertForTokenClassification
import evaluate
from tqdm import tqdm
from data_handling_for_MLM import MutationDetectionDataset, collate_fn
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


# Data
Tokenizing the normal and mutated data,
Marking what token has been changed.

In [2]:
fasta_m = '/ems/elsc-labs/habib-n/yuval.rom/school/ANLP/final_project/Mutation-Simulator/data/sample_data/data_m.fa'
fasta_t = '/ems/elsc-labs/habib-n/yuval.rom/school/ANLP/final_project/Mutation-Simulator/data/sample_data/data.fa'

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

Using device: cpu




In [4]:
tokenizer.get_vocab()
tokenizer('[PAD]').input_ids

[1, 3, 2]

## Dataloaders and Padding

In [5]:
mutation_dataset = MutationDetectionDataset(fasta_m, fasta_t, tokenizer, verbose=True)

x tensor([   1,    4, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
          75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
          83,  180,    4,    4,  588,  126,  545,   66,  374, 1602,  283, 1108,
         152,  645,  215,    4,  678, 2045,  556, 1176,  727,   97,  173,  448,
        1227,  486,   48,  220,   65,   20,    4,  268,   27,  283,  104, 1184,
          73, 3532,  245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,
         283,   65,  232,  204,   32,  289,   75, 3715,  151,  987, 1435,  226,
          33,  411,  149, 3654,  494,  163, 1321,   53, 2975,  112,  131, 1069,
           4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
           4,    4,    4,    4,    4,    4,    4,    4,    4,    4, 1127, 2293,
         448, 3462, 3454,  942,  307,   82, 2491,   50, 1431,  116,   28,  347,
         220,   95,  366,  637,    4,    4,    4,    4,    4,    4,    4,    4,
           4,    4,    4,    4,    4, 

In [6]:
dataloader = DataLoader(mutation_dataset, batch_size=2, collate_fn=collate_fn, shuffle=False)

In [7]:
for batch in dataloader:
    print('-' * 100)
    print(batch)
    print(batch)
    # print('-' * 100)

----------------------------------------------------------------------------------------------------
{'input_ids': tensor([[   1,    4, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
           75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
           83,  180,    4,    4,  588,  126,  545,   66,  374, 1602,  283, 1108,
          152,  645,  215,    4,  678, 2045,  556, 1176,  727,   97,  173,  448,
         1227,  486,   48,  220,   65,   20,    4,  268,   27,  283,  104, 1184,
           73, 3532,  245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,
          283,   65,  232,  204,   32,  289,   75, 3715,  151,  987, 1435,  226,
           33,  411,  149, 3654,  494,  163, 1321,   53, 2975,  112,  131, 1069,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4, 1127, 2293,
          448, 3462, 3454,  942,  307,   82, 2491,   50, 1431,  116,   28, 

  return {'input_ids': torch.tensor(self.sequences[idx]), 'labels': torch.tensor(self.tokens_labels[idx])}


In [8]:
mutation_dataset

<data_handling_for_MLM.MutationDetectionDataset at 0x7f26546ca470>

# Model
We are using [DNABERT2](https://github.com/MAGICS-LAB/DNABERT_2/tree/main?tab=readme-ov-file#1-introduction)

In [9]:
model = BertForMaskedLM(config).to(device)

In [10]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(4096, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tru