# import

In [1]:
import math
import time
import json
import torch
import random
import pickle
import torch.nn as nn
import numpy as np
import pandas as pd
from functools import partial
import torch.optim as optim
from torch.utils import data
from torch.utils.data import DataLoader
from functools import partialmethod
from scipy.stats import truncnorm
from sklearn.metrics import roc_auc_score

# DataLoader

In [2]:
Seq_Coding={'A':[1.,0.,0.,0.],'T':[0.,1.,0.,0.],'C':[0.,0.,1.,0.],'G':[0.,0.,0.,1.],'N':[0.25,0.25,0.25,0.25]}
class NanoDataset(data.Dataset):
    def __init__(self,use_path):
        self.use_path=use_path
        self.LEN=0
        self.f_dict={}
        with open(use_path+'/use_files.txt','r') as f:
            for line in f.readlines():
                f_name=line.strip()
                self.f_dict[f_name]={'len':0,'label':1,'cont':[]}
                if f_name[-4:]=='_neg':
                    self.f_dict[f_name]['label']=0
                with open(use_path+'/'+f_name+'.index') as f2:
                    for line2 in f2.readlines():
                        items2=line2.strip().split('\t')
                        self.f_dict[f_name]['cont'].append((int(items2[1]),int(items2[2])))
                self.f_dict[f_name]['len']=len(self.f_dict[f_name]['cont'])
                self.LEN+=self.f_dict[f_name]['len']
    def __getitem__(self,index):
        R_dict={'seq_feature':[],'seq_mask':[],'nano_feature':[],'nano_mask':[],'label':0}
        for key in self.f_dict:
            if index>=self.f_dict[key]['len']:
                index-=self.f_dict[key]['len']
                continue
            label=self.f_dict[key]['label']
            R_dict['label']=label
            seekit=self.f_dict[key]['cont'][index]
            with open(self.use_path+'/'+key+'.json') as f:
                f.seek(seekit[0],0)
                json_str=f.read(seekit[1]-seekit[0])
                Ls=json_str.strip().split('\n')
                for each in json.loads(Ls[0]):
                    R_dict['seq_feature'].append(Seq_Coding[each])
                    if each=='N':
                        R_dict['seq_mask'].append(0)
                    else:
                        R_dict['seq_mask'].append(1)

                for L in Ls[1:]:
                    L_data=json.loads(L)
                    t_feature=[]
                    t_mask=[]
                    for each in L_data:
                        if each[0]<0:
                            t_feature.append([0,0,0])
                            t_mask.append(0)
                        else:
                            t_feature.append(each)
                            t_mask.append(1)
                    R_dict['nano_mask'].append(t_mask)
                    R_dict['nano_feature'].append(t_feature)
            break
        for key2 in R_dict:
            R_dict[key2]=torch.tensor(R_dict[key2])
        return R_dict
    def __len__(self):
        return self.LEN

In [3]:
m5C_Nano_set=NanoDataset('./edata/DataSet/m5C')
len(m5C_Nano_set)

5722

In [4]:
RELOAD=0
if RELOAD==1:
    m5C_Nano_set=NanoDataset('./edata/DataSet/m5C')
    train_size=int(len(m5C_Nano_set)*0.8)
    test_size=len(m5C_Nano_set)-train_size
    m5C_Nano_train_set,m5C_Nano_test_set=torch.utils.data.random_split(m5C_Nano_set,[train_size,test_size])
    with open('./edata/Save_DataSet/m5C_Nano_train_set.pkl','wb') as f:
        pickle.dump(m5C_Nano_train_set,f)
    with open('./edata/Save_DataSet/m5C_Nano_test_set.pkl','wb') as f:
        pickle.dump(m5C_Nano_test_set,f)
    m5C_Nano_train_loader=DataLoader(m5C_Nano_train_set,batch_size=5,shuffle=True)
    m5C_Nano_test_loader=DataLoader(m5C_Nano_test_set,batch_size=5,shuffle=True)

else:
    with open('./edata/Save_DataSet/m5C_Nano_train_set.pkl','rb') as f:
        m5C_Nano_train_set=pickle.load(f)
    with open('./edata/Save_DataSet/m5C_Nano_test_set.pkl','rb') as f:
        m5C_Nano_test_set=pickle.load(f)
    m5C_Nano_train_loader=DataLoader(m5C_Nano_train_set,batch_size=5,shuffle=True)
    m5C_Nano_test_loader=DataLoader(m5C_Nano_test_set,batch_size=5,shuffle=True)

# Tools for Model

In [5]:
def glorot_uniform_init_(weights):
    nn.init.xavier_uniform_(weights,gain=1)
def zero_init_(weights):
    with torch.no_grad():
        weights.fill_(0.0)
def permute_final_dims(tensor,inds):
    zero_index=-1*len(inds)
    first_inds=list(range(len(tensor.shape[:zero_index])))
    return tensor.permute(first_inds+[zero_index+i for i in inds])
def flatten_final_dims(t,no_dims):
    return t.reshape(t.shape[:-no_dims]+(-1,))
def relu_init_(weights,scale=2.0):
    shape=weights.shape
    _,f=shape
    scale=scale/max(1,f)
    a=-2
    b=2
    std=math.sqrt(scale)/truncnorm.std(a=a,b=b,loc=0,scale=1)
    size=1
    for n in shape:
        size=size*n
    samples=truncnorm.rvs(a=a,b=b,loc=0,scale=std,size=size)
    samples=np.reshape(samples,shape)
    with torch.no_grad():
        weights.copy_(torch.tensor(samples,device=weights.device))

class Dropout(nn.Module):
    def __init__(self,r,batch_dim):
        super(Dropout,self).__init__()
        self.r=r
        if type(batch_dim)==int:
            batch_dim=[batch_dim]
        self.batch_dim=batch_dim
        self.dropout=nn.Dropout(r)
    def forward(self,x):
        shape=list(x.shape)
        if self.batch_dim is not None:
            for bd in self.batch_dim:
                shape[bd]=1
        mask=x.new_ones(shape)
        mask=self.dropout(mask)
        x*=mask
        return x
class DropoutRowwise(Dropout):
    __init__=partialmethod(Dropout.__init__,batch_dim=-3)
class DropoutColwise(Dropout):
    __init__=partialmethod(Dropout.__init__,batch_dim=-2)

In [6]:
class Linear(nn.Linear):
    def __init__(self,in_dim,out_dim,bias=True,init="zero"):
        super(Linear, self).__init__(in_dim, out_dim, bias=bias)
        if bias:
            with torch.no_grad():
                self.bias.fill_(0)
        with torch.no_grad():
            if init=="zero":
                zero_init_(self.weight)
            elif init=="glorot":
                glorot_uniform_init_(self.weight)
            elif init=="gating":
                zero_init_(self.weight)
                if bias:
                    self.bias.fill_(1.0)
class LayerNorm(nn.Module):
    def __init__(self,c_in,eps=1e-5):
        super(LayerNorm, self).__init__()
        self.c_in=(c_in,)
        self.eps=eps
        self.weight=nn.Parameter(torch.ones(c_in))
        self.bias=nn.Parameter(torch.zeros(c_in))
    def forward(self,x): 
        out=nn.functional.layer_norm(x,self.c_in,self.weight,self.bias,self.eps)
        return out

In [7]:
class LinearEmbedder(nn.Module):
    def __init__(self,c_in,c_out):
        super(LinearEmbedder,self).__init__()
        self.c_in=c_in
        self.c_out=c_out
        self.linear_1=nn.Linear(c_in,c_out)
        self.relu=nn.ReLU()
        self.linear_2=nn.Linear(c_out,c_out)
    def forward(self,x):
        x=self.linear_1(x)
        x=self.relu(x)
        x=self.linear_2(x)
        return x

# Model Component

In [8]:
MAX_SEQ_LEN=50
def precompute_freqs_cis(dim,seq_len,theta=10000.0):
    freqs=1.0/(theta**(torch.arange(0,dim,2)[:(dim//2)].float()/dim))
    t=torch.arange(seq_len,device=freqs.device)
    freqs=torch.outer(t,freqs).float()
    freqs_cis=torch.polar(torch.ones_like(freqs),freqs)
    return freqs_cis

def apply_rotary_emb(q,k,freqs_cis,same=True):
    _q=q.float().reshape(*q.shape[:-1],-1,2)
    _k=k.float().reshape(*k.shape[:-1],-1,2)
    _q=torch.view_as_complex(_q)
    _k=torch.view_as_complex(_k)
    
    if same==False:
        if _k.shape[-2]%2!=0:
            q_out=torch.view_as_real(_q*freqs_cis[int((_k.shape[-2]-1)/2)].to(q.device)).flatten(-2)
        else:
            q_out=torch.view_as_real(_q*freqs_cis[_k.shape[-2]/2].to(q.device)).flatten(-2)
    else:
        q_out=torch.view_as_real(_q*freqs_cis[:_q.shape[-2]].to(q.device)).flatten(-2)
    k_out=torch.view_as_real(_k*freqs_cis[:_k.shape[-2]].to(k.device)).flatten(-2)
    return q_out.type_as(q),k_out.type_as(k)

In [9]:
class Attention(nn.Module):
    def __init__(self,c_q,c_k,c_v,c_hidden,no_heads,gating=True,use_rel_pos=False):
        super(Attention, self).__init__()
        self.c_q=c_q
        self.c_k=c_k
        self.c_v=c_v
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.gating=gating
        self.use_rel_pos=use_rel_pos

        self.linear_q=Linear(c_q,c_hidden*no_heads,bias=False,init="glorot")
        self.linear_k=Linear(c_k,c_hidden*no_heads,bias=False,init="glorot")
        self.linear_v=Linear(c_v,c_hidden*no_heads,bias=False,init="glorot")
        self.linear_o=Linear(c_hidden*no_heads,c_q,init="zero")
        if self.gating:
            self.linear_g=Linear(c_q,c_hidden*no_heads,init="gating")
        self.sigmoid=nn.Sigmoid()

        self.freqs_cis=None
        if self.use_rel_pos:
            self.freqs_cis=precompute_freqs_cis(c_hidden,MAX_SEQ_LEN)

    def forward(self,q_x,kv_x,biases=None):
        if(biases is None):
            biases=[]
        q=self.linear_q(q_x)
        k=self.linear_k(kv_x)
        v=self.linear_v(kv_x)
        q=q.view(q.shape[:-1]+(self.no_heads,-1))
        k=k.view(k.shape[:-1]+(self.no_heads,-1))
        v=v.view(v.shape[:-1]+(self.no_heads,-1))

        q=q.transpose(-2,-3)#r,H,s,h
        k=k.transpose(-2,-3)
        v=v.transpose(-2,-3)
        
        if self.use_rel_pos:
            q,k=apply_rotary_emb(q,k,freqs_cis=self.freqs_cis,same=True)
        k=permute_final_dims(k,(1,0))
        a=torch.matmul(q,k)/math.sqrt(self.c_hidden)#r,H,s,h * r,H,h,s = r,H,s,s
        for b in biases:
            a+=b
        a=torch.nn.functional.softmax(a,dim=-1)
        o=torch.matmul(a,v)#r,H,s,s * r,H,s,h = r,H,s,h
        o=o.transpose(-2,-3)#r,s,H,h

        if self.gating:
            g=self.sigmoid(self.linear_g(q_x))
            g=g.view(g.shape[:-1]+(self.no_heads,-1))
            o=o*g
        o=flatten_final_dims(o,2)#r,s,H*h
        o=self.linear_o(o)#r,s,o
        return o

In [10]:
class NanoAttention(nn.Module):
    def __init__(self,c_in,c_hidden,no_heads,inf=1e9,use_rel_pos=False):
        super(NanoAttention,self).__init__()
        self.c_in=c_in
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.inf=inf
        self.use_rel_pos=use_rel_pos
        self.layer_norm_x=LayerNorm(c_in)
        self.mha=Attention(c_in,c_in,c_in,c_hidden,no_heads,True,use_rel_pos)

    def forward(self,x,mask=None):
        n_seq,n_pos=x.shape[-3:-1]
        if mask is None:
            mask=x.new_ones(x.shape[:-3]+(n_seq,n_pos))
        mask_bias=(self.inf*(mask-1))[...,:,None,None,:]
        biases=[mask_bias]

        x=self.layer_norm_x(x)
        x=self.mha(x,x,biases)
        return x

class Trans_NanoAttention(nn.Module):
    def __init__(self,c_in,c_hidden,no_heads,inf=1e9,use_rel_pos=False):
        super(Trans_NanoAttention,self).__init__()
        self.c_in=c_in
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.inf=inf
        self.use_rel_pos=use_rel_pos
        self._NanoAttention=NanoAttention(c_in,c_hidden,no_heads,inf,use_rel_pos)

    def forward(self,x,mask=None):
        x=x.transpose(-2,-3)
        if mask is not None:
            mask=mask.transpose(-1,-2)
        x=self._NanoAttention(x,mask=mask)

        x=x.transpose(-2,-3)
        if mask is not None:
            mask=mask.transpose(-1,-2)
        return x

In [11]:
class GlobalAttention(nn.Module):
    def __init__(self,c_in,c_hidden,no_heads,inf=1e5,eps=1e-8,use_rel_pos=False):
        super(GlobalAttention,self).__init__()
        self.c_in=c_in
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.inf=inf
        self.eps=eps
        self.use_rel_pos=use_rel_pos
        
        self.linear_q=Linear(c_in,c_hidden*no_heads,bias=False,init="glorot")
        self.linear_k=Linear(c_in,c_hidden,bias=False,init="glorot")
        self.linear_v=Linear(c_in,c_hidden,bias=False,init="glorot")
        self.linear_g=Linear(c_in,c_hidden*no_heads,init="gating")
        self.linear_o=Linear(c_hidden*no_heads,c_in,init="zero")
        self.sigmoid=nn.Sigmoid()
        self.freqs_cis=None
        if self.use_rel_pos:
            self.freqs_cis=precompute_freqs_cis(c_hidden,MAX_SEQ_LEN)
    def forward(self,m,mask):
        q=torch.sum(m*mask.unsqueeze(-1),dim=-2)/(torch.sum(mask,dim=-1)[...,None]+self.eps)
        q=self.linear_q(q)
        k=self.linear_k(m)#r,s,h
        v=self.linear_v(m)#r,s,h
        q=q.view(q.shape[:-1]+(self.no_heads,-1))#r,H,h
        if self.use_rel_pos:
            q,k=apply_rotary_emb(q,k,freqs_cis=self.freqs_cis)
        
        bias=(self.inf*(mask-1))[...,:,None,:]
        a=torch.matmul(q,k.transpose(-1,-2))/math.sqrt(self.c_hidden)#r,H,h * r,h,s = r,H,s
        a+=bias
        a=torch.nn.functional.softmax(a,dim=-1)
        
        o=torch.matmul(a,v)#r,H,s * r,s,h = r,H,h
        g=self.sigmoid(self.linear_g(m))
        g=g.view(g.shape[:-1]+(self.no_heads,-1))
        o=o.unsqueeze(-3)*g#r,1,H,h * r,s,H,h = r,s,H,h
        o=o.reshape(o.shape[:-2]+(-1,))
        
        m=self.linear_o(o)#r,s,H*h->r,s,c_in
        return m

In [12]:
class GlobalNanoAttention(nn.Module):
    def __init__(self,c_in,c_hidden,no_heads,inf=1e9,eps=1e-8,use_rel_pos=False):
        super(GlobalNanoAttention,self).__init__()
        self.c_in=c_in
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.inf=inf
        self.use_rel_pos=use_rel_pos
        self.layer_norm_x=LayerNorm(c_in)
        self.gmha=GlobalAttention(c_in,c_hidden,no_heads,inf,eps,use_rel_pos)

    def forward(self,x,mask=None):
        n_seq,n_pos=x.shape[-3:-1]
        if mask is None:
            mask=x.new_ones(x.shape[:-3]+(n_seq,n_pos))
        x=self.layer_norm_x(x)
        x=self.gmha(x,mask)
        return x

class Trans_GlobalNanoAttention(nn.Module):
    def __init__(self,c_in,c_hidden,no_heads,inf=1e9,eps=1e-8,use_rel_pos=False):
        super(Trans_GlobalNanoAttention,self).__init__()
        self.c_in=c_in
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.inf=inf
        self.use_rel_pos=use_rel_pos
        self._GlobalNanoAttention=GlobalNanoAttention(c_in,c_hidden,no_heads,inf,eps,use_rel_pos)

    def forward(self,x,mask=None):
        x=x.transpose(-2,-3)
        if mask is not None:
            mask=mask.transpose(-1,-2)
        x=self._GlobalNanoAttention(x,mask=mask)
        x=x.transpose(-2,-3)
        if mask is not None:
            mask=mask.transpose(-1,-2)
        return x

In [13]:
class LineAttention(nn.Module):
    def __init__(self,c_in,c_hidden,no_heads,inf=1e5,eps=1e-8,use_rel_pos=False):
        super(LineAttention,self).__init__()
        self.c_in=c_in
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.inf=inf
        self.eps=eps
        self.use_rel_pos=use_rel_pos
        
        self.linear_q0=Linear(c_in,c_hidden*no_heads,bias=False,init="glorot")
        self.linear_k0=Linear(c_in,c_hidden,bias=False,init="glorot")
        self.linear_v0=Linear(c_in,c_hidden,bias=False,init="glorot")
        self.linear_q1=Linear(c_hidden,c_hidden,bias=False,init="glorot")
        self.linear_k1=Linear(c_hidden,c_hidden,bias=False,init="glorot")
        self.linear_v1=Linear(c_hidden,c_hidden,bias=False,init="glorot")
        self.linear_g=Linear(c_in,c_hidden*no_heads,init="gating")
        self.linear_q2=Linear(c_in,c_hidden*no_heads,bias=False,init="glorot")
        self.linear_k2=Linear(c_in,c_hidden*no_heads,bias=False,init="glorot")
        self.linear_v2=Linear(c_in,c_hidden*no_heads,bias=False,init="glorot")
        self.linear_o=Linear(c_hidden*no_heads,c_in,init="zero")
        self.sigmoid=nn.Sigmoid()
        self.freqs_cis=precompute_freqs_cis(c_hidden,MAX_SEQ_LEN)
    def forward(self,m,mask):
        l_sum=torch.sum(m*mask.unsqueeze(-1),dim=-2)/(torch.sum(mask,dim=-1)[...,None]+self.eps)
        q0=self.linear_q0(l_sum)
        k0=self.linear_k0(m)#r,s,h
        v0=self.linear_v0(m)#r,s,h
        q0=q0.view(q0.shape[:-1]+(self.no_heads,-1))#r,H,h
        if self.use_rel_pos:
            q0,k0=apply_rotary_emb(q0,k0,freqs_cis=self.freqs_cis,same=False)#r,H,h;r,s,h
        bias=(self.inf*(mask-1))[...,:,None,:]
        a0=torch.matmul(q0,k0.transpose(-1,-2))/math.sqrt(self.c_hidden)#r,H,h * r,h,s = r,H,s
        a0+=bias
        a0=torch.nn.functional.softmax(a0,dim=-1)
        r0=torch.matmul(a0,v0)#r,H,s * r,s,h = r,H,h
        
        q1=self.linear_q1(r0)
        k1=self.linear_q1(r0)
        v1=self.linear_q1(r0)
        q1=q1.transpose(-2,-3)
        k1=k1.transpose(-2,-3)
        v1=v1.transpose(-2,-3)
        if not self.use_rel_pos:
            q1,k1=apply_rotary_emb(q1,k1,freqs_cis=self.freqs_cis,same=True)#H,r,h;H,r,h
        a1=torch.matmul(q1,k1.transpose(-1,-2))/math.sqrt(self.c_hidden)#H,r,h * H,h,r = H,r,r
        a1=torch.nn.functional.softmax(a1,dim=-1)
        r1=torch.matmul(a1,v1)#H,r,r * H,r,h = H,r,h
        
        q2=self.linear_q2(m)
        k2=self.linear_k2(m)
        v2=self.linear_v2(m)
        q2=q2.view(q2.shape[:-1]+(self.no_heads,-1))
        k2=k2.view(k2.shape[:-1]+(self.no_heads,-1))
        v2=v2.view(v2.shape[:-1]+(self.no_heads,-1))
        q2=q2.transpose(-2,-3)#r,H,s,h
        k2=k2.transpose(-2,-3)
        v2=v2.transpose(-2,-3)
        if self.use_rel_pos:
            q2,k2=apply_rotary_emb(q2,k2,freqs_cis=self.freqs_cis,same=True)
        
        bias2=(self.inf*(mask-1))[...,:,None,None,:]
        a2=torch.matmul(q2,k2.transpose(-1,-2))/math.sqrt(self.c_hidden)#r,H,s,h * r,H,h,s = r,H,s,s
        a2+=bias2
        a2=torch.nn.functional.softmax(a2,dim=-1)
        r2=torch.matmul(a2,v2)#r,H,s,s * r,H,s,h = r,H,s,h
        r2=r2.transpose(-3,-4)#H,r,s,h
        
        g=self.sigmoid(self.linear_g(m))
        g=g.view(g.shape[:-1]+(self.no_heads,-1))
        g=g.transpose(-2,-3)
        g=g.transpose(-3,-4)
        
        if self.use_rel_pos:
            r1,_=apply_rotary_emb(r1,r2,freqs_cis=self.freqs_cis,same=False)#r,H,h;r,s,h
        r=(r1.unsqueeze(-2)+r2)*g#(H,r,1,h+H,r,s,h)*H,r,s,h=H,r,s,h
        r=r.transpose(-3,-4)
        r=r.transpose(-2,-3)
        r=r.reshape(r.shape[:-2]+(-1,))
        m=self.linear_o(r)#r,s,H*h->r,s,c_in
        return m

In [14]:
class LineNanoAttention(nn.Module):
    def __init__(self,c_in,c_hidden,no_heads,inf=1e9,eps=1e-8,use_rel_pos=False):
        super(LineNanoAttention,self).__init__()
        self.c_in=c_in
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.inf=inf
        self.use_rel_pos=use_rel_pos
        self.layer_norm_x=LayerNorm(c_in)
        self.lmha=LineAttention(c_in,c_hidden,no_heads,inf,eps,use_rel_pos)

    def forward(self,x,mask=None):
        n_seq,n_pos=x.shape[-3:-1]
        if mask is None:
            mask=x.new_ones(x.shape[:-3]+(n_seq,n_pos))
        x=self.layer_norm_x(x)
        x=self.lmha(x,mask)
        return x

class Trans_LineNanoAttention(nn.Module):
    def __init__(self,c_in,c_hidden,no_heads,inf=1e9,eps=1e-8,use_rel_pos=False):
        super(Trans_LineNanoAttention,self).__init__()
        self.c_in=c_in
        self.c_hidden=c_hidden
        self.no_heads=no_heads
        self.inf=inf
        self.use_rel_pos=use_rel_pos
        self._LineNanoAttention=LineNanoAttention(c_in,c_hidden,no_heads,inf,eps,use_rel_pos)

    def forward(self,x,mask=None):
        x=x.transpose(-2,-3)
        if mask is not None:
            mask=mask.transpose(-1,-2)
        x=self._LineNanoAttention(x,mask=mask)
        x=x.transpose(-2,-3)
        if mask is not None:
            mask=mask.transpose(-1,-2)
        return x

# Build Model

In [15]:
class NanoBlock(nn.Module):
    def __init__(self,c_x,c_hidden_att,no_heads,row_dropout,col_dropout,transition_n,inf,eps):
        super(NanoBlock,self).__init__()
        self.att_row=NanoAttention(c_x,c_hidden_att,no_heads,inf,use_rel_pos=True)
        self.att_col=Trans_NanoAttention(c_x,c_hidden_att,no_heads,inf,use_rel_pos=False)
        self.row_dropout_layer=DropoutRowwise(row_dropout)
        self.col_dropout_layer=DropoutColwise(col_dropout)

        self.layer_norm=LayerNorm(c_x)
        self.linear_1=Linear(c_x,transition_n*c_x,init="relu")
        self.relu=nn.ReLU()
        self.linear_2=Linear(transition_n*c_x,c_x,init="zero")
    def _transition(self,x):
        x=self.layer_norm(x)
        x=self.linear_1(x)
        x=self.relu(x)
        x=self.linear_2(x)
        return x

    def forward(self,x,x_mask):
        x=x+self.row_dropout_layer(self.att_row(x,x_mask))
        x=x+self.col_dropout_layer(self.att_col(x,x_mask))
        x=x+self._transition(x)
        return x

In [16]:
class NanoGlobalBlock(nn.Module):
    def __init__(self,c_x,c_hidden_att,no_heads,row_dropout,col_dropout,transition_n,inf,eps):
        super(NanoGlobalBlock,self).__init__()
        self.gatt_row=GlobalNanoAttention(c_x,c_hidden_att,no_heads,inf,eps,use_rel_pos=True)
        self.gatt_col=Trans_GlobalNanoAttention(c_x,c_hidden_att,no_heads,inf,eps,use_rel_pos=False)
        self.row_dropout_layer=DropoutRowwise(row_dropout)
        self.col_dropout_layer=DropoutColwise(col_dropout)

        self.layer_norm=LayerNorm(c_x)
        self.linear_1=Linear(c_x,transition_n*c_x,init="relu")
        self.relu=nn.ReLU()
        self.linear_2=Linear(transition_n*c_x,c_x,init="zero")
    def _transition(self,x):
        x=self.layer_norm(x)
        x=self.linear_1(x)
        x=self.relu(x)
        x=self.linear_2(x)
        return x

    def forward(self,x,x_mask):
        x=x+self.row_dropout_layer(self.gatt_row(x,x_mask))
        x=x+self.col_dropout_layer(self.gatt_col(x,x_mask))
        x=x+self._transition(x)
        return x

In [17]:
class NanoLineBlock(nn.Module):
    def __init__(self,c_x,c_hidden_att,no_heads,row_dropout,col_dropout,transition_n,inf,eps):
        super(NanoLineBlock,self).__init__()
        self.latt_row=LineNanoAttention(c_x,c_hidden_att,no_heads,inf,eps,use_rel_pos=True)
        self.latt_col=Trans_LineNanoAttention(c_x,c_hidden_att,no_heads,inf,eps,use_rel_pos=False)
        self.row_dropout_layer=DropoutRowwise(row_dropout)
        self.col_dropout_layer=DropoutColwise(col_dropout)

        self.layer_norm=LayerNorm(c_x)
        self.linear_1=Linear(c_x,transition_n*c_x,init="relu")
        self.relu=nn.ReLU()
        self.linear_2=Linear(transition_n*c_x,c_x,init="zero")
    def _transition(self,x):
        x=self.layer_norm(x)
        x=self.linear_1(x)
        x=self.relu(x)
        x=self.linear_2(x)
        return x

    def forward(self,x,x_mask):
        x=x+self.row_dropout_layer(self.latt_row(x,x_mask))
        x=x+self.col_dropout_layer(self.latt_col(x,x_mask))
        x=x+self._transition(x)
        return x

In [18]:
class NanoStack(nn.Module):
    def __init__(self,c_x,c_hidden_att,no_heads,blocks_lis,
        row_dropout,col_dropout,transition_n,
        inf,eps,clear_cache_between_blocks=False):
        super(NanoStack,self).__init__()
        self.clear_cache_between_blocks=clear_cache_between_blocks
        self.blocks=nn.ModuleList()
        for block_type in blocks_lis:
            if block_type==0:
                block=NanoBlock(c_x,c_hidden_att,no_heads,row_dropout,col_dropout,transition_n,inf,eps)
            elif block_type==1:
                block=NanoGlobalBlock(c_x,c_hidden_att,no_heads,row_dropout,col_dropout,transition_n,inf,eps)
            elif block_type==2:
                block=NanoLineBlock(c_x,c_hidden_att,no_heads,row_dropout,col_dropout,transition_n,inf,eps)
            self.blocks.append(block)

    def _prep_blocks(self,x_mask):
        blocks=[partial(b,x_mask=x_mask)for b in self.blocks]
        if(self.clear_cache_between_blocks):
            def block_with_cache_clear(block,*args,**kwargs):
                torch.cuda.empty_cache()
                return block(*args,**kwargs)
            blocks=[partial(block_with_cache_clear,b) for b in blocks]
        return blocks

    def forward(self,x,x_mask):
        blocks=self._prep_blocks(x_mask)
        for block in blocks:
            x=block(x)
        return x

In [19]:
class Nano(nn.Module):
    def __init__(self,c_s,c_x,c_emb,c_f,c_hidden_att,c_o,no_heads,blocks_lis,
                row_dropout,col_dropout,transition_n,inf=1e9,eps=1e-8,clear_cache_between_blocks=False):
        super(Nano,self).__init__()
        self.s_embedder=LinearEmbedder(c_s,c_emb)
        self.x_embedder=LinearEmbedder(c_x,c_emb)
        self.stack=NanoStack(c_emb,c_hidden_att,no_heads,blocks_lis,
                     row_dropout,col_dropout,transition_n,inf,eps,clear_cache_between_blocks)
        self.linear_f=Linear(c_emb,c_f)
        self.classifier=nn.Sequential(
            nn.Linear(c_f,int(c_f/2)),
            nn.ReLU(),
            nn.Linear(int(c_f/2),c_o),
            nn.Sigmoid()
        )
    def forward(self,s,x,s_mask,x_mask):
        s=self.s_embedder(s)
        x=self.x_embedder(x)
        
        x=torch.cat([s.unsqueeze(-3),x],dim=-3)
        x_mask=torch.cat([s_mask.unsqueeze(-2),x_mask],dim=-2)

        x=self.stack(x,x_mask)
        x=self.linear_f(torch.mean(x[...,:,int(x.shape[-2]/2)+1,:],-2))
        #x=self.linear_f(torch.max(x[...,:,int(x.shape[-2]/2)+1,:],-2)[0])
        o=self.classifier(x).squeeze(-1)
        return o

# For Train and Test

In [20]:
def test(model,test_loader,device,line_reduce=0,col_reduce=0):
    model.eval()
    right_count,all_count=0,0
    prob_all,by_all=[],[]
    with torch.no_grad():
        for _,l_dic in enumerate(test_loader):
            by=l_dic['label'].to(device)
            by=by.to(torch.int64)
            if line_reduce==0:
                seq_feature=l_dic['seq_feature'].to(device)
                seq_mask=l_dic['seq_mask'].to(device)
                nano_feature=l_dic['nano_feature'][:,col_reduce:].to(device)
                nano_mask=l_dic['nano_mask'][:,col_reduce:].to(device)
            else:
                side_reduce=int(line_reduce/2)
                seq_feature=l_dic['seq_feature'][:,side_reduce:-side_reduce].to(device)
                seq_mask=l_dic['seq_mask'][:,side_reduce:-side_reduce].to(device)
                nano_feature=l_dic['nano_feature'][:,col_reduce:,side_reduce:-side_reduce].to(device)
                nano_mask=l_dic['nano_mask'][:,col_reduce:,side_reduce:-side_reduce].to(device)
            ry=model(seq_feature,nano_feature,seq_mask,nano_mask)
            out_y=ry>0.5
            right_count+=out_y.eq(by).sum()
            all_count+=len(by)
            for each in ry:
                prob_all.append(np.array(each.cpu()))
            for each in by:
                by_all.append(np.array(each.cpu()))
    roauc=roc_auc_score(by_all,prob_all)

    accuracy=100*(right_count/all_count).item()
    print('AUC:{:.4f}   accuracy:{:.4f}%'.format(roauc,accuracy))
    torch.cuda.empty_cache()

def train(model,train_loader,test_loader,device,optimizer,loss_func,epochs,line_reduce=0,col_reduce=0):
    torch.cuda.empty_cache()
    for epoch in range(epochs):
        total_loss=0
        model.train()
        for _,l_dic in enumerate(train_loader):
            by=l_dic['label'].to(device)
            if line_reduce==0:
                seq_feature=l_dic['seq_feature'].to(device)
                seq_mask=l_dic['seq_mask'].to(device)
                nano_feature=l_dic['nano_feature'][:,col_reduce:].to(device)
                nano_mask=l_dic['nano_mask'][:,col_reduce:].to(device)
            else:
                side_reduce=int(line_reduce/2)
                seq_feature=l_dic['seq_feature'][:,side_reduce:-side_reduce].to(device)
                seq_mask=l_dic['seq_mask'][:,side_reduce:-side_reduce].to(device)
                nano_feature=l_dic['nano_feature'][:,col_reduce:,side_reduce:-side_reduce].to(device)
                nano_mask=l_dic['nano_mask'][:,col_reduce:,side_reduce:-side_reduce].to(device)
            ry=model(seq_feature,nano_feature,seq_mask,nano_mask)
            loss=loss_func(ry,by.float())
            total_loss+=loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print('epoch '+str(epoch+1)+' loss:  ',total_loss/len(test_loader))        
        if epoch%10==9:
            print('At epoch '+str(epoch+1),':')
            test(model,test_loader,device,line_reduce,col_reduce)
            torch.save(model.state_dict(),'./model/model_'+str(epoch+1)+'_'+str(int(time.time()))+'.pkl')

In [21]:
def detailed_test(model,test_loader,device,line_reduce=0,col_reduce=0,curve_name=None,histo_name=None):
    model.eval()
    right_count,all_count=0,0
    more_dict={0.5:[0,0],0.6:[0,0],0.8:[0,0],0.7:[0,0],0.9:[0,0],0.95:[0,0],0.98:[0,0],\
               0.99:[0,0],0.995:[0,0],0.999:[0,0],0.9995:[0,0],0.9999:[0,0],0.99995:[0,0],\
               0.99999:[0,0],0.999995:[0,0],0.999999:[0,0]}
    prob_all,by_all=[],[]
    motif_dict={}
    range_list=[]
    with torch.no_grad():
        for _,l_dic in enumerate(test_loader):
            by=l_dic['label'].to(device).to(torch.int64)
            if line_reduce==0:
                seq_feature=l_dic['seq_feature'].to(device)
                seq_mask=l_dic['seq_mask'].to(device)
                nano_feature=l_dic['nano_feature'][:,col_reduce:].to(device)
                nano_mask=l_dic['nano_mask'][:,col_reduce:].to(device)
            else:
                side_reduce=int(line_reduce/2)
                seq_feature=l_dic['seq_feature'][:,side_reduce:-side_reduce].to(device)
                seq_mask=l_dic['seq_mask'][:,side_reduce:-side_reduce].to(device)
                nano_feature=l_dic['nano_feature'][:,col_reduce:,side_reduce:-side_reduce].to(device)
                nano_mask=l_dic['nano_mask'][:,col_reduce:,side_reduce:-side_reduce].to(device)
            ry=model(seq_feature,nano_feature,seq_mask,nano_mask)
            out_y=ry>0.5
            right_count+=out_y.eq(by).sum()
            all_count+=len(by)
            for each in ry:
                prob_all.append(np.array(each.cpu()))
            for each in by:
                by_all.append(np.array(each.cpu()))
            for key in more_dict:
                more_dict[key][0]+=((ry>key)&by).sum()
                more_dict[key][1]+=(ry>key).sum()

            if histo_name:
                middle_pos=int((len(l_dic['seq_feature'][0])-1)/2)
                center_seqs=l_dic['seq_feature'][:,middle_pos-2:middle_pos+3]            
                for i in range(len(by)):
                    _Seq=''
                    for j in range(5):
                        if abs(center_seqs[i][j][0]-1)<0.01:
                            _Seq+='A'
                        elif abs(center_seqs[i][j][1]-1)<0.01:
                            _Seq+='T'
                        elif abs(center_seqs[i][j][2]-1)<0.01:
                            _Seq+='C'
                        elif abs(center_seqs[i][j][3]-1)<0.01:
                            _Seq+='G'
                        else:
                            _Seq+='N'
                    if 'N' not in _Seq:
                        if _Seq not in motif_dict:
                            motif_dict[_Seq]={'TP':0,'FP':0,'TN':0,'FN':0}
                        if out_y[i]==1 and by[i]==1:
                            motif_dict[_Seq]['TP']+=1
                        elif out_y[i]==1 and by[i]==0:
                            motif_dict[_Seq]['FP']+=1
                        elif out_y[i]==0 and by[i]==0:
                            motif_dict[_Seq]['TN']+=1
                        elif out_y[i]==0 and by[i]==1:
                            motif_dict[_Seq]['FN']+=1
                for i in range(len(ry)):
                    range_list.append([ry[i].cpu().item(),by[i].cpu().item()])
    if histo_name:
        save_frame=pd.DataFrame(motif_dict).T
        save_frame.to_csv('./edata/Save_for_drawing/'+histo_name+'_motif_histo.csv',index=True,sep=',')
        save_frame=pd.DataFrame(range_list)
        save_frame.columns=['Probability score','Ground Truth']
        save_frame.to_csv('./edata/Save_for_drawing/'+histo_name+'_range_histo.csv',index=False,sep=',')
    if curve_name:
        save_frame=pd.DataFrame({'label':by_all,'pred':prob_all})
        save_frame.to_csv('./edata/Save_for_drawing/'+curve_name+'_curve.csv',index=False,sep=',')

    print('Im total',all_count,'samples:')
    auc=roc_auc_score(by_all,prob_all)
    accuracy=100*(right_count/all_count).item()
    print('AUC:{:.4f}   accuracy:{:.4f}%'.format(auc,accuracy))
    for key in more_dict:
        if more_dict[key][1]>0:
            print('Precision when positive threshold at {:g} is :{:.4f}% (total:{:d})'.format(key,more_dict[key][0]/more_dict[key][1],more_dict[key][1]))
    torch.cuda.empty_cache()

# Train

In [32]:
device=torch.device('cuda:0')
model=Nano(c_s=4,c_x=3,c_emb=96,c_f=16,c_hidden_att=64,c_o=1,no_heads=8,blocks_lis=[2,2,2,0,0,0],
            row_dropout=0.1,col_dropout=0.1,transition_n=2,inf=1e9,eps=1e-8,
            clear_cache_between_blocks=False).to(device)
optimizer=optim.Adam(model.parameters(),lr=0.0001)
loss_func=nn.BCELoss().to(device)
epochs=300
train(model,m5C_Nano_train_loader,m5C_Nano_test_loader,device,optimizer,loss_func,epochs,0,0)

epoch 1 loss:   tensor(2.5643, device='cuda:0', grad_fn=<DivBackward0>)
epoch 2 loss:   tensor(2.2738, device='cuda:0', grad_fn=<DivBackward0>)
epoch 3 loss:   tensor(2.1611, device='cuda:0', grad_fn=<DivBackward0>)
epoch 4 loss:   tensor(2.1215, device='cuda:0', grad_fn=<DivBackward0>)
epoch 5 loss:   tensor(2.0797, device='cuda:0', grad_fn=<DivBackward0>)
epoch 6 loss:   tensor(2.0458, device='cuda:0', grad_fn=<DivBackward0>)
epoch 7 loss:   tensor(1.9788, device='cuda:0', grad_fn=<DivBackward0>)
epoch 8 loss:   tensor(1.9488, device='cuda:0', grad_fn=<DivBackward0>)
epoch 9 loss:   tensor(1.9025, device='cuda:0', grad_fn=<DivBackward0>)
epoch 10 loss:   tensor(1.8760, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 10 :
AUC:0.8478   accuracy:76.8559%
epoch 11 loss:   tensor(1.8268, device='cuda:0', grad_fn=<DivBackward0>)
epoch 12 loss:   tensor(1.7704, device='cuda:0', grad_fn=<DivBackward0>)
epoch 13 loss:   tensor(1.7357, device='cuda:0', grad_fn=<DivBackward0>)
epoch 14 loss: 

epoch 108 loss:   tensor(0.1240, device='cuda:0', grad_fn=<DivBackward0>)
epoch 109 loss:   tensor(0.1120, device='cuda:0', grad_fn=<DivBackward0>)
epoch 110 loss:   tensor(0.1375, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 110 :
AUC:0.9338   accuracy:87.0742%
epoch 111 loss:   tensor(0.1572, device='cuda:0', grad_fn=<DivBackward0>)
epoch 112 loss:   tensor(0.1177, device='cuda:0', grad_fn=<DivBackward0>)
epoch 113 loss:   tensor(0.1803, device='cuda:0', grad_fn=<DivBackward0>)
epoch 114 loss:   tensor(0.1151, device='cuda:0', grad_fn=<DivBackward0>)
epoch 115 loss:   tensor(0.1035, device='cuda:0', grad_fn=<DivBackward0>)
epoch 116 loss:   tensor(0.1366, device='cuda:0', grad_fn=<DivBackward0>)
epoch 117 loss:   tensor(0.1289, device='cuda:0', grad_fn=<DivBackward0>)
epoch 118 loss:   tensor(0.1490, device='cuda:0', grad_fn=<DivBackward0>)
epoch 119 loss:   tensor(0.0603, device='cuda:0', grad_fn=<DivBackward0>)
epoch 120 loss:   tensor(0.2141, device='cuda:0', grad_fn=<DivBack

epoch 212 loss:   tensor(0.0614, device='cuda:0', grad_fn=<DivBackward0>)
epoch 213 loss:   tensor(0.0615, device='cuda:0', grad_fn=<DivBackward0>)
epoch 214 loss:   tensor(0.0707, device='cuda:0', grad_fn=<DivBackward0>)
epoch 215 loss:   tensor(0.1037, device='cuda:0', grad_fn=<DivBackward0>)
epoch 216 loss:   tensor(0.0435, device='cuda:0', grad_fn=<DivBackward0>)
epoch 217 loss:   tensor(0.0311, device='cuda:0', grad_fn=<DivBackward0>)
epoch 218 loss:   tensor(0.0568, device='cuda:0', grad_fn=<DivBackward0>)
epoch 219 loss:   tensor(0.0707, device='cuda:0', grad_fn=<DivBackward0>)
epoch 220 loss:   tensor(0.0891, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 220 :
AUC:0.9396   accuracy:88.6463%
epoch 221 loss:   tensor(0.0390, device='cuda:0', grad_fn=<DivBackward0>)
epoch 222 loss:   tensor(0.0872, device='cuda:0', grad_fn=<DivBackward0>)
epoch 223 loss:   tensor(0.0274, device='cuda:0', grad_fn=<DivBackward0>)
epoch 224 loss:   tensor(0.0342, device='cuda:0', grad_fn=<DivBack

In [33]:
device=torch.device('cuda:0')
model=Nano(c_s=4,c_x=3,c_emb=96,c_f=16,c_hidden_att=64,c_o=1,no_heads=8,blocks_lis=[2,2,0,0,1,1],
            row_dropout=0.1,col_dropout=0.1,transition_n=2,inf=1e9,eps=1e-8,
            clear_cache_between_blocks=False).to(device)
optimizer=optim.Adam(model.parameters(),lr=0.0001)
loss_func=nn.BCELoss().to(device)
epochs=300
train(model,m5C_Nano_train_loader,m5C_Nano_test_loader,device,optimizer,loss_func,epochs,0,0)

epoch 1 loss:   tensor(2.5054, device='cuda:0', grad_fn=<DivBackward0>)
epoch 2 loss:   tensor(2.2480, device='cuda:0', grad_fn=<DivBackward0>)
epoch 3 loss:   tensor(2.1425, device='cuda:0', grad_fn=<DivBackward0>)
epoch 4 loss:   tensor(2.0561, device='cuda:0', grad_fn=<DivBackward0>)
epoch 5 loss:   tensor(2.0239, device='cuda:0', grad_fn=<DivBackward0>)
epoch 6 loss:   tensor(1.9581, device='cuda:0', grad_fn=<DivBackward0>)
epoch 7 loss:   tensor(1.9252, device='cuda:0', grad_fn=<DivBackward0>)
epoch 8 loss:   tensor(1.8546, device='cuda:0', grad_fn=<DivBackward0>)
epoch 9 loss:   tensor(1.8025, device='cuda:0', grad_fn=<DivBackward0>)
epoch 10 loss:   tensor(1.7595, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 10 :
AUC:0.8706   accuracy:78.0786%
epoch 11 loss:   tensor(1.7013, device='cuda:0', grad_fn=<DivBackward0>)
epoch 12 loss:   tensor(1.6787, device='cuda:0', grad_fn=<DivBackward0>)
epoch 13 loss:   tensor(1.6382, device='cuda:0', grad_fn=<DivBackward0>)
epoch 14 loss: 

epoch 108 loss:   tensor(0.1027, device='cuda:0', grad_fn=<DivBackward0>)
epoch 109 loss:   tensor(0.1289, device='cuda:0', grad_fn=<DivBackward0>)
epoch 110 loss:   tensor(0.1544, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 110 :
AUC:0.9430   accuracy:88.0349%
epoch 111 loss:   tensor(0.1675, device='cuda:0', grad_fn=<DivBackward0>)
epoch 112 loss:   tensor(0.1242, device='cuda:0', grad_fn=<DivBackward0>)
epoch 113 loss:   tensor(0.1587, device='cuda:0', grad_fn=<DivBackward0>)
epoch 114 loss:   tensor(0.1038, device='cuda:0', grad_fn=<DivBackward0>)
epoch 115 loss:   tensor(0.0349, device='cuda:0', grad_fn=<DivBackward0>)
epoch 116 loss:   tensor(0.2675, device='cuda:0', grad_fn=<DivBackward0>)
epoch 117 loss:   tensor(0.1041, device='cuda:0', grad_fn=<DivBackward0>)
epoch 118 loss:   tensor(0.1567, device='cuda:0', grad_fn=<DivBackward0>)
epoch 119 loss:   tensor(0.1093, device='cuda:0', grad_fn=<DivBackward0>)
epoch 120 loss:   tensor(0.1102, device='cuda:0', grad_fn=<DivBack

epoch 212 loss:   tensor(0.0708, device='cuda:0', grad_fn=<DivBackward0>)
epoch 213 loss:   tensor(0.0620, device='cuda:0', grad_fn=<DivBackward0>)
epoch 214 loss:   tensor(0.0691, device='cuda:0', grad_fn=<DivBackward0>)
epoch 215 loss:   tensor(0.0677, device='cuda:0', grad_fn=<DivBackward0>)
epoch 216 loss:   tensor(0.0420, device='cuda:0', grad_fn=<DivBackward0>)
epoch 217 loss:   tensor(0.0558, device='cuda:0', grad_fn=<DivBackward0>)
epoch 218 loss:   tensor(0.0524, device='cuda:0', grad_fn=<DivBackward0>)
epoch 219 loss:   tensor(0.0463, device='cuda:0', grad_fn=<DivBackward0>)
epoch 220 loss:   tensor(0.0777, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 220 :
AUC:0.9410   accuracy:89.4323%
epoch 221 loss:   tensor(0.0305, device='cuda:0', grad_fn=<DivBackward0>)
epoch 222 loss:   tensor(0.0814, device='cuda:0', grad_fn=<DivBackward0>)
epoch 223 loss:   tensor(0.0571, device='cuda:0', grad_fn=<DivBackward0>)
epoch 224 loss:   tensor(0.0535, device='cuda:0', grad_fn=<DivBack

In [34]:
device=torch.device('cuda:0')
model=Nano(c_s=4,c_x=3,c_emb=96,c_f=16,c_hidden_att=64,c_o=1,no_heads=8,blocks_lis=[0,0,0,0,0,0],
            row_dropout=0.1,col_dropout=0.1,transition_n=2,inf=1e9,eps=1e-8,
            clear_cache_between_blocks=False).to(device)
optimizer=optim.Adam(model.parameters(),lr=0.0001)
loss_func=nn.BCELoss().to(device)
epochs=300
train(model,m5C_Nano_train_loader,m5C_Nano_test_loader,device,optimizer,loss_func,epochs,0,0)

epoch 1 loss:   tensor(2.6367, device='cuda:0', grad_fn=<DivBackward0>)
epoch 2 loss:   tensor(2.2904, device='cuda:0', grad_fn=<DivBackward0>)
epoch 3 loss:   tensor(2.1701, device='cuda:0', grad_fn=<DivBackward0>)
epoch 4 loss:   tensor(2.0917, device='cuda:0', grad_fn=<DivBackward0>)
epoch 5 loss:   tensor(2.0328, device='cuda:0', grad_fn=<DivBackward0>)
epoch 6 loss:   tensor(1.9862, device='cuda:0', grad_fn=<DivBackward0>)
epoch 7 loss:   tensor(1.9472, device='cuda:0', grad_fn=<DivBackward0>)
epoch 8 loss:   tensor(1.8974, device='cuda:0', grad_fn=<DivBackward0>)
epoch 9 loss:   tensor(1.8830, device='cuda:0', grad_fn=<DivBackward0>)
epoch 10 loss:   tensor(1.8274, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 10 :
AUC:0.8521   accuracy:78.3406%
epoch 11 loss:   tensor(1.7943, device='cuda:0', grad_fn=<DivBackward0>)
epoch 12 loss:   tensor(1.7644, device='cuda:0', grad_fn=<DivBackward0>)
epoch 13 loss:   tensor(1.7048, device='cuda:0', grad_fn=<DivBackward0>)
epoch 14 loss: 

epoch 108 loss:   tensor(0.0647, device='cuda:0', grad_fn=<DivBackward0>)
epoch 109 loss:   tensor(0.0844, device='cuda:0', grad_fn=<DivBackward0>)
epoch 110 loss:   tensor(0.0433, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 110 :
AUC:0.9418   accuracy:89.1703%
epoch 111 loss:   tensor(0.0765, device='cuda:0', grad_fn=<DivBackward0>)
epoch 112 loss:   tensor(0.0967, device='cuda:0', grad_fn=<DivBackward0>)
epoch 113 loss:   tensor(0.0369, device='cuda:0', grad_fn=<DivBackward0>)
epoch 114 loss:   tensor(0.0778, device='cuda:0', grad_fn=<DivBackward0>)
epoch 115 loss:   tensor(0.0638, device='cuda:0', grad_fn=<DivBackward0>)
epoch 116 loss:   tensor(0.0486, device='cuda:0', grad_fn=<DivBackward0>)
epoch 117 loss:   tensor(0.0307, device='cuda:0', grad_fn=<DivBackward0>)
epoch 118 loss:   tensor(0.1087, device='cuda:0', grad_fn=<DivBackward0>)
epoch 119 loss:   tensor(0.0295, device='cuda:0', grad_fn=<DivBackward0>)
epoch 120 loss:   tensor(0.0639, device='cuda:0', grad_fn=<DivBack

epoch 212 loss:   tensor(0.0334, device='cuda:0', grad_fn=<DivBackward0>)
epoch 213 loss:   tensor(0.0322, device='cuda:0', grad_fn=<DivBackward0>)
epoch 214 loss:   tensor(0.0079, device='cuda:0', grad_fn=<DivBackward0>)
epoch 215 loss:   tensor(0.0515, device='cuda:0', grad_fn=<DivBackward0>)
epoch 216 loss:   tensor(0.0153, device='cuda:0', grad_fn=<DivBackward0>)
epoch 217 loss:   tensor(0.0589, device='cuda:0', grad_fn=<DivBackward0>)
epoch 218 loss:   tensor(0.0330, device='cuda:0', grad_fn=<DivBackward0>)
epoch 219 loss:   tensor(0.0290, device='cuda:0', grad_fn=<DivBackward0>)
epoch 220 loss:   tensor(0.0493, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 220 :
AUC:0.9508   accuracy:89.0830%
epoch 221 loss:   tensor(0.0173, device='cuda:0', grad_fn=<DivBackward0>)
epoch 222 loss:   tensor(0.0387, device='cuda:0', grad_fn=<DivBackward0>)
epoch 223 loss:   tensor(0.0449, device='cuda:0', grad_fn=<DivBackward0>)
epoch 224 loss:   tensor(0.0033, device='cuda:0', grad_fn=<DivBack

In [35]:
#20 reads 3 sites
device=torch.device('cuda:0')
model=Nano(c_s=4,c_x=3,c_emb=96,c_f=16,c_hidden_att=64,c_o=1,no_heads=8,blocks_lis=[2,2,2,0,0,0],
            row_dropout=0.1,col_dropout=0.1,transition_n=2,inf=1e9,eps=1e-8,
            clear_cache_between_blocks=False).to(device)
optimizer=optim.Adam(model.parameters(),lr=0.0001)
loss_func=nn.BCELoss().to(device)
epochs=300
seq_reduce=22
reads_reduce=30
train(model,m5C_Nano_train_loader,m5C_Nano_test_loader,device,optimizer,loss_func,epochs,seq_reduce,reads_reduce)

epoch 1 loss:   tensor(2.7702, device='cuda:0', grad_fn=<DivBackward0>)
epoch 2 loss:   tensor(2.7000, device='cuda:0', grad_fn=<DivBackward0>)
epoch 3 loss:   tensor(2.6918, device='cuda:0', grad_fn=<DivBackward0>)
epoch 4 loss:   tensor(2.6882, device='cuda:0', grad_fn=<DivBackward0>)
epoch 5 loss:   tensor(2.6867, device='cuda:0', grad_fn=<DivBackward0>)
epoch 6 loss:   tensor(2.6865, device='cuda:0', grad_fn=<DivBackward0>)
epoch 7 loss:   tensor(2.6850, device='cuda:0', grad_fn=<DivBackward0>)
epoch 8 loss:   tensor(2.6841, device='cuda:0', grad_fn=<DivBackward0>)
epoch 9 loss:   tensor(2.6835, device='cuda:0', grad_fn=<DivBackward0>)
epoch 10 loss:   tensor(2.6852, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 10 :
AUC:0.5963   accuracy:57.0306%
epoch 11 loss:   tensor(2.6846, device='cuda:0', grad_fn=<DivBackward0>)
epoch 12 loss:   tensor(2.6950, device='cuda:0', grad_fn=<DivBackward0>)
epoch 13 loss:   tensor(2.6963, device='cuda:0', grad_fn=<DivBackward0>)
epoch 14 loss: 

epoch 108 loss:   tensor(2.3679, device='cuda:0', grad_fn=<DivBackward0>)
epoch 109 loss:   tensor(2.3599, device='cuda:0', grad_fn=<DivBackward0>)
epoch 110 loss:   tensor(2.3458, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 110 :
AUC:0.6923   accuracy:64.6288%
epoch 111 loss:   tensor(2.3371, device='cuda:0', grad_fn=<DivBackward0>)
epoch 112 loss:   tensor(2.3234, device='cuda:0', grad_fn=<DivBackward0>)
epoch 113 loss:   tensor(2.3154, device='cuda:0', grad_fn=<DivBackward0>)
epoch 114 loss:   tensor(2.3132, device='cuda:0', grad_fn=<DivBackward0>)
epoch 115 loss:   tensor(2.3179, device='cuda:0', grad_fn=<DivBackward0>)
epoch 116 loss:   tensor(2.3221, device='cuda:0', grad_fn=<DivBackward0>)
epoch 117 loss:   tensor(2.3052, device='cuda:0', grad_fn=<DivBackward0>)
epoch 118 loss:   tensor(2.3044, device='cuda:0', grad_fn=<DivBackward0>)
epoch 119 loss:   tensor(2.2921, device='cuda:0', grad_fn=<DivBackward0>)
epoch 120 loss:   tensor(2.2785, device='cuda:0', grad_fn=<DivBack

epoch 212 loss:   tensor(1.6360, device='cuda:0', grad_fn=<DivBackward0>)
epoch 213 loss:   tensor(1.6542, device='cuda:0', grad_fn=<DivBackward0>)
epoch 214 loss:   tensor(1.6325, device='cuda:0', grad_fn=<DivBackward0>)
epoch 215 loss:   tensor(1.5930, device='cuda:0', grad_fn=<DivBackward0>)
epoch 216 loss:   tensor(1.5433, device='cuda:0', grad_fn=<DivBackward0>)
epoch 217 loss:   tensor(1.5234, device='cuda:0', grad_fn=<DivBackward0>)
epoch 218 loss:   tensor(1.6789, device='cuda:0', grad_fn=<DivBackward0>)
epoch 219 loss:   tensor(1.5411, device='cuda:0', grad_fn=<DivBackward0>)
epoch 220 loss:   tensor(1.5354, device='cuda:0', grad_fn=<DivBackward0>)
At epoch 220 :
AUC:0.6928   accuracy:64.2795%
epoch 221 loss:   tensor(1.4644, device='cuda:0', grad_fn=<DivBackward0>)
epoch 222 loss:   tensor(1.4874, device='cuda:0', grad_fn=<DivBackward0>)
epoch 223 loss:   tensor(1.4355, device='cuda:0', grad_fn=<DivBackward0>)
epoch 224 loss:   tensor(1.4492, device='cuda:0', grad_fn=<DivBack

# Test

In [23]:
device=torch.device('cuda:0')
model=Nano(c_s=4,c_x=3,c_emb=96,c_f=16,c_hidden_att=64,c_o=1,no_heads=8,blocks_lis=[2,2,2,0,0,0],
            row_dropout=0.1,col_dropout=0.1,transition_n=2,inf=1e9,eps=1e-8,
            clear_cache_between_blocks=False).to(device)
#model.load_state_dict(torch.load('./model/m5C_NSWord_keep_model_290_222000.pkl'))
model.load_state_dict(torch.load('./model/m5C_NSWord_keep_model_260_222000.pkl'))
detailed_test(model,m5C_Nano_test_loader,device,0,0,'m5C_Blocks=[222000],50reads_25sites')

Im total 1145 samples:
AUC:0.9516   accuracy:88.6463%
Precision when positive threshold at 0.5 is :0.8506% (total:629)
Precision when positive threshold at 0.6 is :0.8583% (total:621)
Precision when positive threshold at 0.8 is :0.8699% (total:607)
Precision when positive threshold at 0.7 is :0.8648% (total:614)
Precision when positive threshold at 0.9 is :0.8832% (total:591)
Precision when positive threshold at 0.95 is :0.8948% (total:580)
Precision when positive threshold at 0.98 is :0.8989% (total:564)
Precision when positive threshold at 0.99 is :0.9091% (total:550)
Precision when positive threshold at 0.995 is :0.9148% (total:540)
Precision when positive threshold at 0.999 is :0.9327% (total:505)
Precision when positive threshold at 0.9995 is :0.9389% (total:491)
Precision when positive threshold at 0.9999 is :0.9427% (total:454)
Precision when positive threshold at 0.99995 is :0.9497% (total:437)
Precision when positive threshold at 0.99999 is :0.9509% (total:407)
Precision when 

In [24]:
device=torch.device('cuda:0')
model=Nano(c_s=4,c_x=3,c_emb=96,c_f=16,c_hidden_att=64,c_o=1,no_heads=8,blocks_lis=[2,2,0,0,1,1],
            row_dropout=0.1,col_dropout=0.1,transition_n=2,inf=1e9,eps=1e-8,
            clear_cache_between_blocks=False).to(device)
#model.load_state_dict(torch.load('./model/m5C_NSWord_keep_model_260_2200011.pkl'))
model.load_state_dict(torch.load('./model/m5C_NSWord_keep_model_240_220011.pkl'))
detailed_test(model,m5C_Nano_test_loader,device,0,0,'m5C_Blocks=[220011],50reads_25sites')

Im total 1145 samples:
AUC:0.9539   accuracy:88.9083%
Precision when positive threshold at 0.5 is :0.8663% (total:606)
Precision when positive threshold at 0.6 is :0.8659% (total:604)
Precision when positive threshold at 0.8 is :0.8857% (total:586)
Precision when positive threshold at 0.7 is :0.8717% (total:600)
Precision when positive threshold at 0.9 is :0.8995% (total:567)
Precision when positive threshold at 0.95 is :0.9035% (total:549)
Precision when positive threshold at 0.98 is :0.9173% (total:520)
Precision when positive threshold at 0.99 is :0.9246% (total:491)
Precision when positive threshold at 0.995 is :0.9374% (total:463)
Precision when positive threshold at 0.999 is :0.9586% (total:411)
Precision when positive threshold at 0.9995 is :0.9635% (total:384)
Precision when positive threshold at 0.9999 is :0.9643% (total:308)
Precision when positive threshold at 0.99995 is :0.9725% (total:291)
Precision when positive threshold at 0.99999 is :0.9716% (total:211)
Precision when 

In [26]:
device=torch.device('cuda:0')
model=Nano(c_s=4,c_x=3,c_emb=96,c_f=16,c_hidden_att=64,c_o=1,no_heads=8,blocks_lis=[0,0,0,0,0,0],
            row_dropout=0.1,col_dropout=0.1,transition_n=2,inf=1e9,eps=1e-8,
            clear_cache_between_blocks=False).to(device)
#model.load_state_dict(torch.load('./model/m5C_NSWord_keep_model_190_000000.pkl'))
model.load_state_dict(torch.load('./model/m5C_NSWord_keep_model_150_000000.pkl'))
detailed_test(model,m5C_Nano_test_loader,device,0,0,'m5C_Blocks=[000000],50reads_25sites')

Im total 1145 samples:
AUC:0.9532   accuracy:88.8210%
Precision when positive threshold at 0.5 is :0.8649% (total:607)
Precision when positive threshold at 0.6 is :0.8690% (total:603)
Precision when positive threshold at 0.8 is :0.8788% (total:594)
Precision when positive threshold at 0.7 is :0.8719% (total:601)
Precision when positive threshold at 0.9 is :0.8855% (total:585)
Precision when positive threshold at 0.95 is :0.8866% (total:582)
Precision when positive threshold at 0.98 is :0.8996% (total:568)
Precision when positive threshold at 0.99 is :0.9055% (total:561)
Precision when positive threshold at 0.995 is :0.9116% (total:554)
Precision when positive threshold at 0.999 is :0.9244% (total:529)
Precision when positive threshold at 0.9995 is :0.9290% (total:521)
Precision when positive threshold at 0.9999 is :0.9413% (total:494)
Precision when positive threshold at 0.99995 is :0.9424% (total:486)
Precision when positive threshold at 0.99999 is :0.9539% (total:434)
Precision when 

In [27]:
#20 reads 3 sites
device=torch.device('cuda:0')
model=Nano(c_s=4,c_x=3,c_emb=96,c_f=16,c_hidden_att=64,c_o=1,no_heads=8,blocks_lis=[2,2,2,0,0,0],
            row_dropout=0.1,col_dropout=0.1,transition_n=2,inf=1e9,eps=1e-8,
            clear_cache_between_blocks=False).to(device)
seq_reduce=22
reads_reduce=30
#model.load_state_dict(torch.load('./model/m5C_NSWord_20reads_3sites_keep_model_230_222000.pkl'))
model.load_state_dict(torch.load('./model/m5C_NSWord_20reads_3sites_keep_model_300_222000.pkl'))
detailed_test(model,m5C_Nano_test_loader,device,seq_reduce,reads_reduce,'m5C_Blocks=[222000],20reads_3sites')

Im total 1145 samples:
AUC:0.7156   accuracy:66.2009%
Precision when positive threshold at 0.5 is :0.6575% (total:584)
Precision when positive threshold at 0.6 is :0.6717% (total:530)
Precision when positive threshold at 0.8 is :0.7026% (total:417)
Precision when positive threshold at 0.7 is :0.6897% (total:477)
Precision when positive threshold at 0.9 is :0.7207% (total:358)
Precision when positive threshold at 0.95 is :0.7451% (total:306)
Precision when positive threshold at 0.98 is :0.7677% (total:254)
Precision when positive threshold at 0.99 is :0.7692% (total:221)
Precision when positive threshold at 0.995 is :0.7865% (total:192)
Precision when positive threshold at 0.999 is :0.8267% (total:150)
Precision when positive threshold at 0.9995 is :0.8473% (total:131)
Precision when positive threshold at 0.9999 is :0.8646% (total:96)
Precision when positive threshold at 0.99995 is :0.8778% (total:90)
Precision when positive threshold at 0.99999 is :0.8824% (total:68)
Precision when pos