# Fine-tuning script

*Anusha, Joyce, Nina*

#### Imports

In [3]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig
import numpy as np
from datasets import load_dataset

#### Load model and tokenizer from Hugging Face

In [None]:
# use gpu
device = 'cuda:0'

config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config=config).to(device)

# don't need to move the tokenizer to gpu b/c it's light
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
tokenizer.pad_token = "X"

#### Freeze gradients

In [None]:
# unfreeze the last layer in the encoder block
for name, param in model.named_parameters():
    if "encoder.layer.11" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

#### Load data from Hugging Face

In [None]:
class custom_data_load(torch.utils.data.Dataset):
    def __init__(self, dataframe, tokenizer, shuffle=True):
        if shuffle:
            self.dataframe = dataframe.sample(frac=1).reset_index(drop=True)  # shuffle the dataframe
        else:
            self.dataframe = dataframe
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        sequence = self.dataframe.iloc[idx]['sequence']
        label = self.dataframe.iloc[idx]['label']

        # tokenize the sequence
        # tokenizer automatically generates attention masks
        inputs = self.tokenizer(sequence, padding='max_length', max_length=101, truncation=True, return_tensors='pt')
        
        # move inputs to gpu
        inputs = {key: value.to(device) for key, value in inputs.items()}

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [8]:
def fasta_to_csv(fasta_file, csv_file):
    with open(fasta_file, 'r') as fasta_fh, open(csv_file, 'w', newline='') as csv_fh:
        writer = csv.writer(csv_fh)
        writer.writerow(['sequence', 'label'])
        for record in SeqIO.parse(fasta_fh, 'fasta'):
            label = record.id.split('|')[-1]
            sequence = str(record.seq)
            writer.writerow([sequence, label])

In [9]:
fasta_file = '/scratch/gpfs/aa8417/QCB557_project/data/H3K4me3/train.fna'
csv_file = '/scratch/gpfs/aa8417/QCB557_project/data/train.csv'
fasta_to_csv(fasta_file, csv_file)

In [10]:
fasta_file = '/scratch/gpfs/aa8417/QCB557_project/data/H3K4me3/test.fna'
csv_file = '/scratch/gpfs/aa8417/QCB557_project/data/test.csv'
fasta_to_csv(fasta_file, csv_file)