# import 

In [1]:
import toml
import json
import time
import torch
import pickle
import pandas as pd
import numpy as np
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn import Module,Identity,Sequential
from typing import Dict,Tuple
from ecode.m6Anet.blocks import PoolingFilter
from sklearn.metrics import roc_auc_score

# DataLoader

In [2]:
class m6anetDataset(data.Dataset):
    def __init__(self,use_path):
        self.use_path=use_path
        self.LEN=0
        self.f_dict={}
        self.k5_list=[]
        with open(use_path+'/use_files.txt','r') as f:
            for line in f.readlines():
                f_name=line.strip()
                self.f_dict[f_name]={'len':0,'label':1,'cont':[]}
                if f_name[-4:]=='_neg':
                    self.f_dict[f_name]['label']=0
                with open(use_path+'/'+f_name+'.index') as f2:
                    for line2 in f2.readlines():
                        items2=line2.strip().split('\t')
                        self.f_dict[f_name]['cont'].append((int(items2[1]),int(items2[2])))
                        with open(self.use_path+'/'+f_name+'.json') as f:
                            f.seek(int(items2[1]),0)
                            json_str=f.read(int(items2[2])-int(items2[1]))
                            Ls=json_str.strip().split('\n')
                            k_str=json.loads(Ls[0])
                            for k5 in [k_str[9:14],k_str[10:15],k_str[11:16]]:
                                if k5 not in self.k5_list:
                                    self.k5_list.append(k5)
                self.f_dict[f_name]['len']=len(self.f_dict[f_name]['cont'])
                self.LEN+=self.f_dict[f_name]['len']
        self.k5_list=np.array(self.k5_list)
    def __getitem__(self,index):
        R_dict={'X':[],'kmer':[],'label':0}
        for key in self.f_dict:
            if index>=self.f_dict[key]['len']:
                index-=self.f_dict[key]['len']
                continue
            label=self.f_dict[key]['label']
            R_dict['label']=label
            seekit=self.f_dict[key]['cont'][index]
            with open(self.use_path+'/'+key+'.json') as f:
                f.seek(seekit[0],0)
                json_str=f.read(seekit[1]-seekit[0])
                Ls=json_str.strip().split('\n')
                k_str=json.loads(Ls[0])
                for k5 in [k_str[9:14],k_str[10:15],k_str[11:16]]:
                    R_dict['kmer'].append(np.where(self.k5_list==k5)[0][0])
                for L in Ls[1:]:
                    L_data=json.loads(L)
                    t_feature=[]
                    for each in L_data[11:14]:
                        if each[0]<0:
                            t_feature.extend([0,0,0])
                        else:
                            t_feature.extend(each)
                    R_dict['X'].append(t_feature)
            break
        for key2 in R_dict:
            R_dict[key2]=torch.tensor(R_dict[key2])
        R_dict['kmer']=R_dict['kmer'].repeat(len(R_dict['X']),1)
        return R_dict
    def __len__(self):
        return self.LEN

In [6]:
'''
[[block]]
block_type = "KmerMultipleEmbedding"
input_channel = 65  #<--- change the .toml config file here to length of k5_list
output_channel = 2
num_neighboring_features = 1
'''
#k5 list here is from the middle 3 poses, not only the center
m6A_m6Anet_set=m6anetDataset('./edata/DataSet/m6A')
print(m6A_m6Anet_set.k5_list)
print(len(m6A_m6Anet_set.k5_list))

['TAAAC' 'AAACT' 'AACTG' 'AGAAC' 'GAACA' 'AACAT' 'TGAAC' 'GGGAC' 'GGACT'
 'GACTG' 'GAACC' 'AACCG' 'GGACC' 'GACCC' 'AGGAC' 'GACCT' 'GACCG' 'TGGAC'
 'GACTC' 'GACTA' 'GACTT' 'CGGAC' 'CTAAC' 'TAACT' 'AACTT' 'ATGAC' 'TGACC'
 'CAAAC' 'GAGAC' 'AGACT' 'CAGAC' 'AGACC' 'GGACA' 'GACAG' 'TGACT' 'AAGAC'
 'GAACT' 'AACTC' 'TAGAC' 'AACCC' 'AGACA' 'GACAT' 'GGAAC' 'GACAC' 'AACAG'
 'CGAAC' 'GAAAC' 'AACAA' 'AACTA' 'TTGAC' 'GACAA' 'AACCT' 'CTGAC' 'GTGAC'
 'ATAAC' 'TTAAC' 'GACCA' 'TGACA' 'AAAAC' 'AACAC' 'AACCA' 'AAACA' 'GTAAC'
 'AAACC' 'TAACC']
65


In [4]:
RELOAD=0
if RELOAD==1:
    m6A_m6Anet_set=m6anetDataset('./edata/DataSet/m6A')
    train_size=int(len(m6A_m6Anet_set)*0.8)
    test_size=len(m6A_m6Anet_set)-train_size
    m6A_m6Anet_train_set,m6A_m6Anet_test_set=torch.utils.data.random_split(m6A_m6Anet_set,[train_size,test_size])
    with open('./edata/Save_DataSet/m6A_m6Anet_train_set.pkl','wb') as f:
        pickle.dump(m6A_m6Anet_train_set,f)
    with open('./edata/Save_DataSet/m6A_m6Anet_test_set.pkl','wb') as f:
        pickle.dump(m6A_m6Anet_test_set,f)
    m6A_m6Anet_train_loader=DataLoader(m6A_m6Anet_train_set,batch_size=5,shuffle=True)
    m6A_m6Anet_test_loader=DataLoader(m6A_m6Anet_test_set,batch_size=5,shuffle=True)

else:
    with open('./edata/Save_DataSet/m6A_m6Anet_train_set.pkl','rb') as f:
        m6A_m6Anet_train_set=pickle.load(f)
    with open('./edata/Save_DataSet/m6A_m6Anet_test_set.pkl','rb') as f:
        m6A_m6Anet_test_set=pickle.load(f)
    m6A_m6Anet_train_loader=DataLoader(m6A_m6Anet_train_set,batch_size=5,shuffle=True)
    m6A_m6Anet_test_loader=DataLoader(m6A_m6Anet_test_set,batch_size=5,shuffle=True)

# model

In [5]:
class MILModel(Module):
    def __init__(self,model_config):
        super(MILModel,self).__init__()
        self.model_config=model_config
        self.read_level_encoder=None
        self.pooling_filter=None
        self.decoder=None
        self.build_model()

    def build_model(self):
        blocks=self.model_config['block']
        seq_model=[]
        for block in blocks:
            block_type=block.pop('block_type')
            block_obj=self._build_block(block_type,**block)

            if isinstance(block_obj,PoolingFilter):
                if len(seq_model)>0:
                    self.read_level_encoder=Sequential(*seq_model)
                else:
                    self.read_level_encoder=None

                self.pooling_filter=block_obj
                seq_model=[]
            else:
                seq_model.append(block_obj)

        if (self.read_level_encoder is None) and (self.pooling_filter is None):
            self.read_level_encoder=Sequential(*seq_model)
            self.pooling_filter=Identity()
            self.decoder=Identity()
        else:
            if len(seq_model)==0:
                self.decoder=Identity()
            else:
                self.decoder=Sequential(*seq_model)

    def _build_block(self,block_type,**kwargs):
        from ecode.m6Anet import blocks
        block_obj=getattr(blocks,block_type)
        return block_obj(**kwargs)

    def get_read_representation(self,x):
        if self.read_level_encoder is None:
            return x
        else:
            return self.read_level_encoder(x)

    def get_read_probability(self,x):
        read_representation=self.get_read_representation(x)
        return self.pooling_filter.predict_read_level_prob(read_representation)

    def get_site_representation(self,x):
        return self.pooling_filter(self.get_read_representation(x))

    def get_site_probability(self,x):
        return self.decoder(self.get_site_representation(x))

    def get_read_site_probability(self,x):
        read_representation=self.get_read_representation(x)
        read_level_probability=self.pooling_filter.predict_read_level_prob(read_representation)
        site_level_probability=self.decoder(self.pooling_filter(read_representation))
        return read_level_probability,site_level_probability,read_representation

    def get_attention_weights(self,x):
        if hasattr(self.pooling_filter, "get_attention_weights"):
            return self.pooling_filter.get_attention_weights(self.get_read_representation(x))
        else:
            raise ValueError("Pooling filter does not have attention weights")

    def forward(self,x):
        return self.get_site_probability(x)

# For Train and Test

In [6]:
def test(model,test_loader,device,reads_reduce=0):
    model.eval()
    right_count,all_count=0,0
    prob_all,by_all=[],[]
    with torch.no_grad():
        for _,l_dic in enumerate(test_loader):
            by=l_dic['label'].to(device)
            by=by.to(torch.int64)
            l_dic['X']=l_dic['X'][:,reads_reduce:]
            l_dic['kmer']=l_dic['kmer'][:,reads_reduce:]
            ry=model(l_dic).to(device)
            out_y=ry>0.5
            right_count+=out_y.eq(by).sum()
            all_count+=len(by)
            for each in ry:
                prob_all.append(np.array(each.cpu()))
            for each in by:
                by_all.append(np.array(each.cpu()))
    roauc=roc_auc_score(by_all,prob_all)

    accuracy=100*(right_count/all_count).item()
    print('AUC:{:.4f}   accuracy:{:.4f}%'.format(roauc,accuracy))
    torch.cuda.empty_cache()

def train(model,train_loader,test_loader,device,optimizer,loss_func,epochs,reads_reduce=0):
    torch.cuda.empty_cache()
    for epoch in range(epochs):
        total_loss=0
        model.train()
        for _,l_dic in enumerate(train_loader):
            by=l_dic['label'].to(device)
            l_dic['X']=l_dic['X'][:,reads_reduce:]
            l_dic['kmer']=l_dic['kmer'][:,reads_reduce:]
            ry=model(l_dic).to(device)
            loss=loss_func(ry,by.float())
            total_loss+=loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print('epoch '+str(epoch+1)+' loss:  ',total_loss/len(test_loader))        
        if epoch%10==9:
            print('At epoch '+str(epoch+1),':')
            test(model,test_loader,device,reads_reduce)
            torch.save(model.state_dict(),'./model/model_'+str(epoch+1)+'_'+str(int(time.time()))+'.pkl')

In [7]:
def detailed_test(model,test_loader,device,reads_reduce=0,curve_name=None):
    model.eval()
    right_count,all_count=0,0
    more_dict={0.5:[0,0],0.6:[0,0],0.8:[0,0],0.7:[0,0],0.9:[0,0],0.95:[0,0],0.98:[0,0],\
               0.99:[0,0],0.995:[0,0],0.999:[0,0],0.9995:[0,0],0.9999:[0,0],0.99995:[0,0],\
               0.99999:[0,0],0.999995:[0,0],0.999999:[0,0]}
    prob_all,by_all=[],[]
    motif_found_dict={}
    motif_dict={}
    with torch.no_grad():
        for _,l_dic in enumerate(test_loader):
            by=l_dic['label'].to(device).to(torch.int64)
            l_dic['X']=l_dic['X'][:,reads_reduce:]
            l_dic['kmer']=l_dic['kmer'][:,reads_reduce:]
            ry=model(l_dic).to(device)
            out_y=ry>0.5
            right_count+=out_y.eq(by).sum()
            all_count+=len(by)
            for each in ry:
                prob_all.append(np.array(each.cpu()))
            for each in by:
                by_all.append(np.array(each.cpu()))
            for key in more_dict:
                more_dict[key][0]+=((ry>key)&by).sum()
                more_dict[key][1]+=(ry>key).sum()
    if curve_name:
        save_frame=pd.DataFrame({'label':by_all,'pred':prob_all})
        save_frame.to_csv('./edata/Save_for_drawing/'+curve_name+'_curve.csv',index=False,sep=',')

    print('Im total',all_count,'samples:')
    auc=roc_auc_score(by_all,prob_all)
    accuracy=100*(right_count/all_count).item()
    print('AUC:{:.4f}   accuracy:{:.4f}%'.format(auc,accuracy))
    for key in more_dict:
        if more_dict[key][1]>0:
            print('Precision when positive threshold at {:g} is :{:.4f}% (total:{:d})'.format(key,more_dict[key][0]/more_dict[key][1],more_dict[key][1]))
    torch.cuda.empty_cache()

# Train

In [11]:
config_file_50reads='./ecode/m6Anet/m6Anet_50reads.toml'
model_config_50reads=toml.load(config_file_50reads)
model=MILModel(model_config_50reads)
device=torch.device('cuda:0')
optimizer=optim.Adam(model.parameters(),lr=0.001)
loss_func=nn.BCELoss().to(device)
epochs=500
train(model,m6A_m6Anet_train_loader,m6A_m6Anet_test_loader,device,optimizer,loss_func,epochs)

epoch 1 loss:   tensor(17.9327, device='cuda:0', grad_fn=<DivBackward0>)
epoch 2 loss:   tensor(13.4488, device='cuda:0', grad_fn=<DivBackward0>)
epoch 3 loss:   tensor(10.1011, device='cuda:0', grad_fn=<DivBackward0>)
epoch 4 loss:   tensor(6.1791, device='cuda:0', grad_fn=<DivBackward0>)
epoch 5 loss:   tensor(3.4221, device='cuda:0', grad_fn=<DivBackward0>)
epoch 6 loss:   tensor(2.6744, device='cuda:0', grad_fn=<DivBackward0>)
epoch 7 loss:   tensor(2.5841, device='cuda:0', grad_fn=<DivBackward0>)
epoch 8 loss:   tensor(2.5063, device='cuda:0', grad_fn=<DivBackward0>)
epoch 9 loss:   tensor(2.4206, device='cuda:0', grad_fn=<DivBackward0>)
epoch 10 loss:   tensor(2.3800, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 10 :
AUC:0.7348   accuracy:67.5603%
epoch 11 loss:   tensor(2.3694, device='cuda:0', grad_fn=<DivBackward0>)
epoch 12 loss:   tensor(2.3575, device='cuda:0', grad_fn=<DivBackward0>)
epoch 13 loss:   tensor(2.3401, device='cuda:0', grad_fn=<DivBackward0>)
epoch 14 los

epoch 108 loss:   tensor(1.8285, device='cuda:0', grad_fn=<DivBackward0>)
epoch 109 loss:   tensor(1.8281, device='cuda:0', grad_fn=<DivBackward0>)
epoch 110 loss:   tensor(1.8213, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 110 :
AUC:0.8239   accuracy:74.2627%
epoch 111 loss:   tensor(1.8072, device='cuda:0', grad_fn=<DivBackward0>)
epoch 112 loss:   tensor(1.8255, device='cuda:0', grad_fn=<DivBackward0>)
epoch 113 loss:   tensor(1.8126, device='cuda:0', grad_fn=<DivBackward0>)
epoch 114 loss:   tensor(1.8116, device='cuda:0', grad_fn=<DivBackward0>)
epoch 115 loss:   tensor(1.8070, device='cuda:0', grad_fn=<DivBackward0>)
epoch 116 loss:   tensor(1.8052, device='cuda:0', grad_fn=<DivBackward0>)
epoch 117 loss:   tensor(1.8098, device='cuda:0', grad_fn=<DivBackward0>)
epoch 118 loss:   tensor(1.8061, device='cuda:0', grad_fn=<DivBackward0>)
epoch 119 loss:   tensor(1.8106, device='cuda:0', grad_fn=<DivBackward0>)
epoch 120 loss:   tensor(1.7918, device='cuda:0', grad_fn=<DivBack

epoch 212 loss:   tensor(1.6654, device='cuda:0', grad_fn=<DivBackward0>)
epoch 213 loss:   tensor(1.6403, device='cuda:0', grad_fn=<DivBackward0>)
epoch 214 loss:   tensor(1.6512, device='cuda:0', grad_fn=<DivBackward0>)
epoch 215 loss:   tensor(1.6679, device='cuda:0', grad_fn=<DivBackward0>)
epoch 216 loss:   tensor(1.6394, device='cuda:0', grad_fn=<DivBackward0>)
epoch 217 loss:   tensor(1.6705, device='cuda:0', grad_fn=<DivBackward0>)
epoch 218 loss:   tensor(1.6813, device='cuda:0', grad_fn=<DivBackward0>)
epoch 219 loss:   tensor(1.6461, device='cuda:0', grad_fn=<DivBackward0>)
epoch 220 loss:   tensor(1.6425, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 220 :
AUC:0.8276   accuracy:74.5979%
epoch 221 loss:   tensor(1.6311, device='cuda:0', grad_fn=<DivBackward0>)
epoch 222 loss:   tensor(1.6478, device='cuda:0', grad_fn=<DivBackward0>)
epoch 223 loss:   tensor(1.6345, device='cuda:0', grad_fn=<DivBackward0>)
epoch 224 loss:   tensor(1.6316, device='cuda:0', grad_fn=<DivBack

epoch 317 loss:   tensor(1.5342, device='cuda:0', grad_fn=<DivBackward0>)
epoch 318 loss:   tensor(1.5398, device='cuda:0', grad_fn=<DivBackward0>)
epoch 319 loss:   tensor(1.5326, device='cuda:0', grad_fn=<DivBackward0>)
epoch 320 loss:   tensor(1.5261, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 320 :
AUC:0.8312   accuracy:75.0000%
epoch 321 loss:   tensor(1.5276, device='cuda:0', grad_fn=<DivBackward0>)
epoch 322 loss:   tensor(1.5375, device='cuda:0', grad_fn=<DivBackward0>)
epoch 323 loss:   tensor(1.5159, device='cuda:0', grad_fn=<DivBackward0>)
epoch 324 loss:   tensor(1.5054, device='cuda:0', grad_fn=<DivBackward0>)
epoch 325 loss:   tensor(1.5094, device='cuda:0', grad_fn=<DivBackward0>)
epoch 326 loss:   tensor(1.5397, device='cuda:0', grad_fn=<DivBackward0>)
epoch 327 loss:   tensor(1.5066, device='cuda:0', grad_fn=<DivBackward0>)
epoch 328 loss:   tensor(1.5136, device='cuda:0', grad_fn=<DivBackward0>)
epoch 329 loss:   tensor(1.5201, device='cuda:0', grad_fn=<DivBack

epoch 421 loss:   tensor(1.3981, device='cuda:0', grad_fn=<DivBackward0>)
epoch 422 loss:   tensor(1.4055, device='cuda:0', grad_fn=<DivBackward0>)
epoch 423 loss:   tensor(1.4112, device='cuda:0', grad_fn=<DivBackward0>)
epoch 424 loss:   tensor(1.3623, device='cuda:0', grad_fn=<DivBackward0>)
epoch 425 loss:   tensor(1.3877, device='cuda:0', grad_fn=<DivBackward0>)
epoch 426 loss:   tensor(1.4109, device='cuda:0', grad_fn=<DivBackward0>)
epoch 427 loss:   tensor(1.4218, device='cuda:0', grad_fn=<DivBackward0>)
epoch 428 loss:   tensor(1.3966, device='cuda:0', grad_fn=<DivBackward0>)
epoch 429 loss:   tensor(1.3858, device='cuda:0', grad_fn=<DivBackward0>)
epoch 430 loss:   tensor(1.3747, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 430 :
AUC:0.8012   accuracy:72.9893%
epoch 431 loss:   tensor(1.3739, device='cuda:0', grad_fn=<DivBackward0>)
epoch 432 loss:   tensor(1.4305, device='cuda:0', grad_fn=<DivBackward0>)
epoch 433 loss:   tensor(1.3907, device='cuda:0', grad_fn=<DivBack

In [9]:
config_file_20reads='./ecode/m6Anet/m6Anet_20reads.toml'
model_config_20reads=toml.load(config_file_20reads)
model=MILModel(model_config_20reads)
device=torch.device('cuda:0')
optimizer=optim.Adam(model.parameters(),lr=0.001)
loss_func=nn.BCELoss().to(device)
epochs=500
train(model,m6A_m6Anet_train_loader,m6A_m6Anet_test_loader,device,optimizer,loss_func,epochs,30)

epoch 1 loss:   tensor(2.7765, device='cuda:0', grad_fn=<DivBackward0>)
epoch 2 loss:   tensor(2.5129, device='cuda:0', grad_fn=<DivBackward0>)
epoch 3 loss:   tensor(2.4634, device='cuda:0', grad_fn=<DivBackward0>)
epoch 4 loss:   tensor(2.4397, device='cuda:0', grad_fn=<DivBackward0>)
epoch 5 loss:   tensor(2.3754, device='cuda:0', grad_fn=<DivBackward0>)
epoch 6 loss:   tensor(2.3415, device='cuda:0', grad_fn=<DivBackward0>)
epoch 7 loss:   tensor(2.3202, device='cuda:0', grad_fn=<DivBackward0>)
epoch 8 loss:   tensor(2.2623, device='cuda:0', grad_fn=<DivBackward0>)
epoch 9 loss:   tensor(2.2331, device='cuda:0', grad_fn=<DivBackward0>)
epoch 10 loss:   tensor(2.2047, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 10 :
AUC:0.8085   accuracy:73.1233%
epoch 11 loss:   tensor(2.1865, device='cuda:0', grad_fn=<DivBackward0>)
epoch 12 loss:   tensor(2.1734, device='cuda:0', grad_fn=<DivBackward0>)
epoch 13 loss:   tensor(2.1426, device='cuda:0', grad_fn=<DivBackward0>)
epoch 14 loss: 

epoch 108 loss:   tensor(1.8391, device='cuda:0', grad_fn=<DivBackward0>)
epoch 109 loss:   tensor(1.8516, device='cuda:0', grad_fn=<DivBackward0>)
epoch 110 loss:   tensor(1.8505, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 110 :
AUC:0.8499   accuracy:74.8660%
epoch 111 loss:   tensor(1.8346, device='cuda:0', grad_fn=<DivBackward0>)
epoch 112 loss:   tensor(1.8527, device='cuda:0', grad_fn=<DivBackward0>)
epoch 113 loss:   tensor(1.8509, device='cuda:0', grad_fn=<DivBackward0>)
epoch 114 loss:   tensor(1.8320, device='cuda:0', grad_fn=<DivBackward0>)
epoch 115 loss:   tensor(1.8371, device='cuda:0', grad_fn=<DivBackward0>)
epoch 116 loss:   tensor(1.8363, device='cuda:0', grad_fn=<DivBackward0>)
epoch 117 loss:   tensor(1.8284, device='cuda:0', grad_fn=<DivBackward0>)
epoch 118 loss:   tensor(1.8467, device='cuda:0', grad_fn=<DivBackward0>)
epoch 119 loss:   tensor(1.8168, device='cuda:0', grad_fn=<DivBackward0>)
epoch 120 loss:   tensor(1.8349, device='cuda:0', grad_fn=<DivBack

epoch 212 loss:   tensor(1.7436, device='cuda:0', grad_fn=<DivBackward0>)
epoch 213 loss:   tensor(1.7137, device='cuda:0', grad_fn=<DivBackward0>)
epoch 214 loss:   tensor(1.7312, device='cuda:0', grad_fn=<DivBackward0>)
epoch 215 loss:   tensor(1.7254, device='cuda:0', grad_fn=<DivBackward0>)
epoch 216 loss:   tensor(1.7350, device='cuda:0', grad_fn=<DivBackward0>)
epoch 217 loss:   tensor(1.7170, device='cuda:0', grad_fn=<DivBackward0>)
epoch 218 loss:   tensor(1.7258, device='cuda:0', grad_fn=<DivBackward0>)
epoch 219 loss:   tensor(1.7341, device='cuda:0', grad_fn=<DivBackward0>)
epoch 220 loss:   tensor(1.7230, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 220 :
AUC:0.8335   accuracy:74.6649%
epoch 221 loss:   tensor(1.6980, device='cuda:0', grad_fn=<DivBackward0>)
epoch 222 loss:   tensor(1.7162, device='cuda:0', grad_fn=<DivBackward0>)
epoch 223 loss:   tensor(1.7273, device='cuda:0', grad_fn=<DivBackward0>)
epoch 224 loss:   tensor(1.7166, device='cuda:0', grad_fn=<DivBack

epoch 317 loss:   tensor(1.6074, device='cuda:0', grad_fn=<DivBackward0>)
epoch 318 loss:   tensor(1.6241, device='cuda:0', grad_fn=<DivBackward0>)
epoch 319 loss:   tensor(1.6070, device='cuda:0', grad_fn=<DivBackward0>)
epoch 320 loss:   tensor(1.6077, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 320 :
AUC:0.8476   accuracy:77.0107%
epoch 321 loss:   tensor(1.6225, device='cuda:0', grad_fn=<DivBackward0>)
epoch 322 loss:   tensor(1.6255, device='cuda:0', grad_fn=<DivBackward0>)
epoch 323 loss:   tensor(1.6045, device='cuda:0', grad_fn=<DivBackward0>)
epoch 324 loss:   tensor(1.6100, device='cuda:0', grad_fn=<DivBackward0>)
epoch 325 loss:   tensor(1.6320, device='cuda:0', grad_fn=<DivBackward0>)
epoch 326 loss:   tensor(1.6038, device='cuda:0', grad_fn=<DivBackward0>)
epoch 327 loss:   tensor(1.6175, device='cuda:0', grad_fn=<DivBackward0>)
epoch 328 loss:   tensor(1.6047, device='cuda:0', grad_fn=<DivBackward0>)
epoch 329 loss:   tensor(1.5943, device='cuda:0', grad_fn=<DivBack

epoch 421 loss:   tensor(1.5203, device='cuda:0', grad_fn=<DivBackward0>)
epoch 422 loss:   tensor(1.5408, device='cuda:0', grad_fn=<DivBackward0>)
epoch 423 loss:   tensor(1.5257, device='cuda:0', grad_fn=<DivBackward0>)
epoch 424 loss:   tensor(1.5306, device='cuda:0', grad_fn=<DivBackward0>)
epoch 425 loss:   tensor(1.5583, device='cuda:0', grad_fn=<DivBackward0>)
epoch 426 loss:   tensor(1.5253, device='cuda:0', grad_fn=<DivBackward0>)
epoch 427 loss:   tensor(1.5259, device='cuda:0', grad_fn=<DivBackward0>)
epoch 428 loss:   tensor(1.5266, device='cuda:0', grad_fn=<DivBackward0>)
epoch 429 loss:   tensor(1.5208, device='cuda:0', grad_fn=<DivBackward0>)
epoch 430 loss:   tensor(1.5400, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 430 :
AUC:0.8360   accuracy:76.3405%
epoch 431 loss:   tensor(1.5215, device='cuda:0', grad_fn=<DivBackward0>)
epoch 432 loss:   tensor(1.5119, device='cuda:0', grad_fn=<DivBackward0>)
epoch 433 loss:   tensor(1.5003, device='cuda:0', grad_fn=<DivBack

# Test

In [14]:
config_file_50reads='./ecode/m6Anet/m6Anet_50reads.toml'
model_config_50reads=toml.load(config_file_50reads)
model=MILModel(model_config_50reads)
device=torch.device('cuda:0')
#model.load_state_dict(torch.load('./model/m6A_m6Anet_keep_model_260.pkl'))
model.load_state_dict(torch.load('./model/m6A_m6Anet_keep_model_120.pkl'))
detailed_test(model,m6A_m6Anet_test_loader,device,0,'m6Anet,50reads_3sites')

Im total 1492 samples:
AUC:0.8411   accuracy:76.2064%
Precision when positive threshold at 0.5 is :0.7544% (total:733)
Precision when positive threshold at 0.6 is :0.7934% (total:610)
Precision when positive threshold at 0.8 is :0.8673% (total:339)
Precision when positive threshold at 0.7 is :0.8470% (total:477)
Precision when positive threshold at 0.9 is :0.9091% (total:187)
Precision when positive threshold at 0.95 is :0.9158% (total:95)
Precision when positive threshold at 0.98 is :0.9412% (total:51)
Precision when positive threshold at 0.99 is :0.9412% (total:34)
Precision when positive threshold at 0.995 is :0.9259% (total:27)
Precision when positive threshold at 0.999 is :1.0000% (total:18)
Precision when positive threshold at 0.9995 is :1.0000% (total:13)
Precision when positive threshold at 0.9999 is :1.0000% (total:9)
Precision when positive threshold at 0.99995 is :1.0000% (total:8)
Precision when positive threshold at 0.99999 is :1.0000% (total:6)
Precision when positive thr

In [10]:
config_file_20reads='./ecode/m6Anet/m6Anet_20reads.toml'
model_config_20reads=toml.load(config_file_20reads)
model=MILModel(model_config_20reads)
device=torch.device('cuda:0')
#model.load_state_dict(torch.load('./model/m6A_m6Anet_keep_model_160_20reads.pkl'))
model.load_state_dict(torch.load('./model/m6A_m6Anet_keep_model_80_20reads.pkl'))
detailed_test(model,m6A_m6Anet_test_loader,device,30,'m6Anet,20reads_3sites')

Im total 1492 samples:
AUC:0.8567   accuracy:75.0670%
Precision when positive threshold at 0.5 is :0.6952% (total:912)
Precision when positive threshold at 0.6 is :0.7465% (total:789)
Precision when positive threshold at 0.8 is :0.8528% (total:496)
Precision when positive threshold at 0.7 is :0.8009% (total:643)
Precision when positive threshold at 0.9 is :0.9020% (total:306)
Precision when positive threshold at 0.95 is :0.9304% (total:158)
Precision when positive threshold at 0.98 is :0.9516% (total:62)
Precision when positive threshold at 0.99 is :1.0000% (total:30)
Precision when positive threshold at 0.995 is :1.0000% (total:12)
Precision when positive threshold at 0.999 is :1.0000% (total:4)
Precision when positive threshold at 0.9995 is :1.0000% (total:2)
Precision when positive threshold at 0.9999 is :1.0000% (total:1)
Precision when positive threshold at 0.99995 is :1.0000% (total:1)
Precision when positive threshold at 0.99999 is :1.0000% (total:1)
Precision when positive thre