In [1]:
import os
import torch
import argparse
import numpy as np
import pandas as pd
from util import * 
from model import *
from pandas import DataFrame

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--epoches',              type=int,  default=30,  help='')
parser.add_argument('--batch_size',           type=int,  default=16,  help='')
parser.add_argument('--max_length',           type=int,  default=200, help='')
parser.add_argument('--learning_rate',        type=float, default=1e-4, help="")
parser.add_argument('--model_path',           type=str,  default="../3-new-12w-0", help='')
parser.add_argument('--ind_filename',  type=str,  default="../dataset/enhancer_3-mer_DNABERT_ind.txt", help='')
parser.add_argument('--tra_filename',  type=str,  default="../dataset/enhancer_3-mer_DNABERT_tra.txt", help='')

args = parser.parse_args(args=[]) 


In [None]:
# add L2-normalization + average

mers = [3, 4, 5, 6]
seeds = [5576, 5217, 9653, 5630]    # The aim of fixing these random seeds is to reproduce our work
lambds = [1e-3, 1e-3, 5e-4, 5e-5]
learning_rates = [5e-5, 2e-5, 5e-5, 1e-5]

tra_loss_list, tra_acc_list, val_loss_list, val_acc_list, ind_loss_list, ind_acc_list = [], [], [], [], [], []

for i in range(len(mers)):
    mer = mers[i]
    seed = seeds[i]
    lambd = lambds[i]
    learning_rate = learning_rates[i]
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
    args.model_path = "../../DNA-BERT/{}-new-12w-0".format(mer)
    args.ind_filename = "../dataset/enhancer_{}-mer_DNABERT_ind.txt".format(mer)
    args.tra_filename = "../dataset/enhancer_{}-mer_DNABERT_tra.txt".format(mer)

    tra_dataloader = getData(args, split=False, validation=False, shuffle=True)
    ind_dataloader = getData(args, split=False, validation=True, shuffle=False)
    

    args.learning_rate = learning_rate

    model = C_Bert_2FC_average.from_pretrained(args.model_path, num_labels=1).to(device)
    print("{}-mer; lr: {}, lambd:{}, seed:{}, dropout:0.30, delay: 0.98 768->25->1, L2".format(mer, learning_rate, lambd, seed))

    epoches = args.epoches
    learning_rate = args.learning_rate
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08,)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.98)    # learning rate delay

    for epoch in range(epoches):
        tra_acc, tra_loss = train_finetuning_Norm(model, tra_dataloader, optimizer, args, lambd, 2, kmer=mer)
        scheduler.step()
        ind_acc, ind_mcc, ind_sn, ind_sp, ind_loss = validation_finetuning(model, ind_dataloader)

        tra_loss_list.append(tra_loss)
        tra_acc_list.append(tra_acc)

        ind_loss_list.append(ind_loss)
        ind_acc_list.append(ind_acc)
        print("{}-mer; epoch:{:2d}, tra loss:{:.4f}, acc:{:.4f};  ind loss:{:.4f}, acc:{:.4f}, mcc:{:.4f}, sn:{:.4f}, sp:{:.4f}".format(mer, epoch, tra_loss, tra_acc, ind_loss, ind_acc, ind_mcc, ind_sn, ind_sp))
    
    # torch.save(model.state_dict(), "fine-trained_model/C_Bert_2FC_average_{}-mer_temp.pt".format(mer))
    