In [1]:
import torch
from transformers import AutoTokenizer, BertForMaskedLM
from transformers.models.bert.configuration_bert import BertConfig
from tqdm import tqdm
from data_handling_for_MLM import MutationDetectionDataset, collate_fn
from torch.utils.data import DataLoader
import Levenshtein
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


# Data
Tokenizing the normal and mutated data,
Marking what token has been changed.

In [2]:
fasta_m = '/ems/elsc-labs/habib-n/yuval.rom/school/ANLP/final_project/Mutation-Simulator/data/sample_data/data_m.fa'
fasta_t = '/ems/elsc-labs/habib-n/yuval.rom/school/ANLP/final_project/Mutation-Simulator/data/sample_data/data.fa'

### Tokenizer

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

Using device: cpu




In [4]:
tokenizer.get_vocab()
tokenizer('[PAD]').input_ids

[1, 3, 2]

## Dataloaders and Padding

In [5]:
mutation_dataset = MutationDetectionDataset(fasta_m, fasta_t, tokenizer, verbose=True)

x tensor([   1,    4, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
          75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
          83,  180,    4,    4,  588,  126,  545,   66,  374, 1602,  283, 1108,
         152,  645,  215,    4,  678, 2045,  556, 1176,  727,   97,  173,  448,
        1227,  486,   48,  220,   65,   20,    4,  268,   27,  283,  104, 1184,
          73, 3532,  245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,
         283,   65,  232,  204,   32,  289,   75, 3715,  151,  987, 1435,  226,
          33,  411,  149, 3654,  494,  163, 1321,   53, 2975,  112,  131, 1069,
           4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
           4,    4,    4,    4,    4,    4,    4,    4,    4,    4, 1127, 2293,
         448, 3462, 3454,  942,  307,   82, 2491,   50, 1431,  116,   28,  347,
         220,   95,  366,  637,    4,    4,    4,    4,    4,    4,    4,    4,
           4,    4,    4,    4,    4, 

In [6]:
dataloader = DataLoader(mutation_dataset, batch_size=2, collate_fn=collate_fn, shuffle=False)

In [7]:
for batch in dataloader:
    print('-' * 100)
    print(batch)
    # print('-' * 100)

----------------------------------------------------------------------------------------------------
{'input_ids': tensor([[   1,    4, 1004,   67,   36,  726,  528, 1104,  319,  746,  296,   28,
           75, 1507,   55,  362,  123,  130,   82,  443,  184, 2063, 2169,  161,
           83,  180,    4,    4,  588,  126,  545,   66,  374, 1602,  283, 1108,
          152,  645,  215,    4,  678, 2045,  556, 1176,  727,   97,  173,  448,
         1227,  486,   48,  220,   65,   20,    4,  268,   27,  283,  104, 1184,
           73, 3532,  245,   61,  208, 3056,  552,  635,   99,  819,   42,  558,
          283,   65,  232,  204,   32,  289,   75, 3715,  151,  987, 1435,  226,
           33,  411,  149, 3654,  494,  163, 1321,   53, 2975,  112,  131, 1069,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
            4,    4,    4,    4,    4,    4,    4,    4,    4,    4, 1127, 2293,
          448, 3462, 3454,  942,  307,   82, 2491,   50, 1431,  116,   28, 

# Model
We are using [DNABERT2](https://github.com/MAGICS-LAB/DNABERT_2/tree/main?tab=readme-ov-file#1-introduction)

# Train the Model

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
lr = 5e-5
weight_decay = 0.01
batch_size = 2 # TODO: change to 32
num_epochs = 3 # TODO: change to ???


model = BertForMaskedLM(config).to(device)
loss_func = torch.nn.CrossEntropyLoss() 
# metric = evaluate.load('seqeval')
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
train_loader = DataLoader(mutation_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
eval_loader = DataLoader(mutation_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)

Using device: cpu


In [22]:
def compute_median_edit_distance(edit_distances):
    return np.median(edit_distances)


def compute_mean_edit_distance(edit_distances):
    return np.mean(edit_distances)


def compute_normalized_mean_edit_distance(edit_distances, label_texts):
    # Calculate normalized mean edit distance
    normalized_edit_distances = [edit_distance / len(label_text) for edit_distance, label_text in zip(edit_distances, label_texts)]
    return np.mean(normalized_edit_distances)


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    # Convert predictions and labels to lists of token IDs
    predictions = predictions.tolist()
    labels = labels.tolist()
    
    # Decode predictions and labels, filtering out invalid labels
    pred_texts = []
    label_texts = []
    for pred, label in zip(predictions, labels):
        # Filter out -100 (or any invalid token IDs) from labels
        valid_label_ids = [id for id in label if id != -100]
        # Decode only valid token IDs
        pred_texts.append(tokenizer.decode(pred, skip_special_tokens=True))
        label_texts.append(tokenizer.decode(valid_label_ids, skip_special_tokens=True))
    
    # Calculate Levenshtein distance
    edit_distances = [Levenshtein.distance(pred, label) for pred, label in zip(pred_texts, label_texts)]
    
    
    return {"avg_edit_distance": compute_mean_edit_distance(edit_distances),
            "median_edit_distance": compute_median_edit_distance(edit_distances),
            "normalized_avg_edit_distance": compute_normalized_mean_edit_distance(edit_distances, labels)}


In [23]:
torch.cuda.is_available()

False

In [24]:
from transformers import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir="./models/correction/results",  # output directory
    overwrite_output_dir=True,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    save_steps=1,#10_000,
    save_total_limit=2,
    logging_dir='./models/correction/logs',
    evaluation_strategy="epoch",  # Ensure evaluations occur

)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=mutation_dataset,
    eval_dataset=mutation_dataset,
    tokenizer=tokenizer,
    optimizers=(optimizer, None),
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

Epoch,Training Loss,Validation Loss,Avg Edit Distance,Median Edit Distance,Normalized Avg Edit Distance
1,No log,1.412605,169.4,175.0,0.518043
2,No log,1.346656,217.4,275.0,0.664832
3,No log,1.27341,167.2,157.0,0.511315


TrainOutput(global_step=9, training_loss=3.339323255750868, metrics={'train_runtime': 130.745, 'train_samples_per_second': 0.115, 'train_steps_per_second': 0.069, 'total_flos': 1869102832128.0, 'train_loss': 3.339323255750868, 'epoch': 3.0})

In [12]:
def train_model(model,
                loss_func,
                train_dataloader,
                eval_dataloader,
                lr,
                weight_decay,
                batch_size,
                num_epochs,
                device,
                optimizer,
                lora=False):
    # if not lora:
    #     optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch}"):
            torch.cuda.empty_cache()
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = loss_func(outputs.logits, batch['labels'])
            train_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        avg_train_loss = train_loss / len(train_dataloader)
        print(f"Average train loss: {avg_train_loss}")

        # Validation
        model.eval()
        eval_loss = 0
        correct = 0
        with torch.no_grad():
            for eval_batch in eval_dataloader:
                eval_batch = {k: v.to(device) for k, v in eval_batch.items()}
                outputs = model(**eval_batch)
                loss = outputs.loss
                correct += (outputs.logits.argmax(dim=1) == eval_batch['labels']).float().sum()
                eval_loss += loss.item()

            print(f"Average eval loss: {eval_loss / len(eval_dataloader)}")
            accuracy = correct / (len(eval_dataloader) * batch_size)
            print(f"Eval Accuracy: {accuracy}")

    model.save_pretrained(f"models/correction/fine_tuned_model_e{num_epochs}_bc{batch_size}_lr{lr}_wd{weight_decay}")


In [13]:
train_model(model,
            loss_func,
            train_loader,
            eval_loader,
            lr,
            weight_decay,
            batch_size,
            num_epochs,
            device,
            optimizer)

Training Epoch 0:   0%|          | 0/3 [00:00<?, ?it/s]


RuntimeError: Expected target size [2, 4096], got [2, 327]