In [1]:
import os
import time
import torch
import argparse
import numpy as np
import pandas as pd
from util import * 
from model import *
from transformers import BertTokenizer
from pandas import DataFrame
from sklearn import metrics

os.environ["CUDA_VISIBLE_DEVICES"] = "5"
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--epoches',              type=int,  default=30,  help='')
parser.add_argument('--batch_size',           type=int,  default=16,  help='')
parser.add_argument('--max_length',           type=int,  default=2000, help='')
parser.add_argument('--learning_rate',        type=float, default=1e-4, help="")
parser.add_argument('--model_path',           type=str,  default="../3-new-12w-0", help='')
parser.add_argument('--ind_filename',  type=str,  default="../dataset/enhancer_3-mer_DNABERT_ind.txt", help='')
parser.add_argument('--tra_filename',  type=str,  default="../dataset/enhancer_3-mer_DNABERT_tra.txt", help='')

args = parser.parse_args(args=[]) # 如果不使用"args=[]"，会报错


In [None]:
np.random.seed(8888)
random_list = np.random.randint(10000, 60000, size=(10,))

# average [3,4,5,6]
# 添加L2-正则化 average
mers = [3,4,5,6]
lambds = [1.2e-4, 1.2e-4, 1.5e-4, 1e-4]
learning_rates = [1e-5, 2e-5, 3e-5, 4e-5]

cell_lines = ["HUVEC"]
threshold = 0.50

for cell_line in cell_lines:
    tra_real_labels_list, tra_pre_labels_list = [], []
    ind_real_labels_list, ind_pre_labels_list = [], []
    for i in range(4):
        mer = mers[i]
        seed = random_list[i]
        lambd = lambds[i]
        learning_rate = learning_rates[i]

        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
        np.random.seed(seed)
        torch.backends.cudnn.deterministic = True
        # for _ in range(1):
        args.model_path = "../../DNA-BERT/{}-new-12w-0".format(mer)
        prefix = ""

        args.ind_filename = os.path.join(prefix, "../dataset/{}-test-{}mer.txt".format(cell_line, mer)) 
        args.tra_filename = os.path.join(prefix, "../dataset/{}-train-{}mer.txt".format(cell_line, mer)) 
            
        tra_dataloader = getData_concate(args, validation=False, training=True, shuffle=False)
        ind_dataloader = getData_concate(args, validation=False, training=False, shuffle=False)
        print("tra: {}; ind: {}".format(len(tra_dataloader), len(ind_dataloader)))
        
        args.learning_rate = learning_rate

        model = C_Bert_2FC_concatenate.from_pretrained(args.model_path, num_labels=1).to(device)

        # C_Bert_2FC_concatenate_BN_128
        # C_Bert_2FC_concatenate_BN_256
        # C_Bert_2FC_concatenate_BN_512
      
        model_path = "./well_trained_model/{}/C_Bert_2FC_concatenate_{}_{}mer.pt".format(cell_line, cell_line, mer)
        if not os.path.exists(model_path):
            continue
        model.load_state_dict(torch.load(model_path))

        epoches = args.epoches
        learning_rate = args.learning_rate
        tokenizer = BertTokenizer.from_pretrained(args.model_path)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08,)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.98)    # exponential decay 

        # tra_real_labels, tra_pre_labels = eval_each_samples(tokenizer, model, tra_dataloader, optimizer)
        # tra_real_labels_list.append(tra_real_labels)
        # tra_pre_labels_list.append(tra_pre_labels)
        # tra_acc, tra_mcc, tra_sn, tra_sp = evaluation_criterion(tra_pre_labels, tra_real_labels, threshold=threshold)

        ind_real_labels, ind_pre_labels = eval_each_samples(tokenizer, model, ind_dataloader, optimizer)
        ind_real_labels_list.append(ind_real_labels)
        ind_pre_labels_list.append(ind_pre_labels)
        ind_acc, ind_mcc, ind_sn, ind_sp = evaluation_criterion(ind_pre_labels, ind_real_labels, threshold=threshold)

        ind_fpr, ind_tpr, ind_threshold = metrics.roc_curve(ind_real_labels, ind_pre_labels)
        roc_auc = metrics.auc(ind_fpr, ind_tpr)
        print(mer)
        # content = "{}-{}mer, threshold={}; train-acc: {:.5f}, mcc: {:.5f}, sn: {:.5f}, sp: {:.5f}; ".format(cell_line, mer, threshold, tra_acc, tra_mcc, tra_sn, tra_sp)
        content = "{}-{}mer, threshold={}; ind-acc: {:.5f}, mcc: {:.5f}, sn: {:.5f}, sp: {:.5f}, AUC: {}; ".format(cell_line, mer, threshold, ind_acc, ind_mcc, ind_sn, ind_sp, roc_auc) 
        print(content)
        content += ""

    tra_real_labels_list = np.array(tra_real_labels_list)
    tra_pre_labels_list  = np.array(tra_pre_labels_list)
    ind_real_labels_list = np.array(ind_real_labels_list)
    ind_pre_labels_list = np.array(ind_pre_labels_list)

    tra_real_labels = np.average(tra_real_labels_list, axis=0)
    tra_pre_labels = np.average(tra_pre_labels_list, axis=0)
    ind_real_labels = np.average(ind_real_labels_list, axis=0)
    ind_pre_labels = np.average(ind_pre_labels_list, axis=0)

    print(cell_line, tra_real_labels.shape)

    # tra_acc_int, tra_mcc_int, tra_sn_int, tra_sp_int = evaluation_criterion(tra_pre_labels, tra_real_labels, threshold=threshold)
    ind_acc_int, ind_mcc_int, ind_sn_int, ind_sp_int = evaluation_criterion(ind_pre_labels, ind_real_labels, threshold=threshold)
    ind_fpr, ind_tpr, ind_threshold = metrics.roc_curve(ind_real_labels, ind_pre_labels)
    roc_auc = metrics.auc(ind_fpr, ind_tpr)

    print(ind_acc_int, ind_mcc_int, ind_sn_int, ind_sp_int, roc_auc)



In [None]:
ind_real_labels = np.average(ind_real_labels_list, axis=0)
ind_pre_labels = np.average(ind_pre_labels_list, axis=0)

print(cell_line, tra_real_labels.shape)

tra_acc_int, tra_mcc_int, tra_sn_int, tra_sp_int = evaluation_criterion(tra_pre_labels, tra_real_labels)
ind_acc_int, ind_mcc_int, ind_sn_int, ind_sp_int = evaluation_criterion(ind_pre_labels, ind_real_labels)
ind_fpr, ind_tpr, ind_threshold = metrics.roc_curve(ind_real_labels, ind_pre_labels)
roc_auc = metrics.auc(ind_fpr, ind_tpr)

print(ind_acc_int, ind_mcc_int, ind_sn_int, ind_sp_int, roc_auc)
