In [2]:
import re
from Bio.SeqRecord import SeqRecord
# from django import forms
from Bio.Seq import Seq
import numpy as np
import pandas as pd
from tensorflow import keras
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Embedding,Dense,Flatten,Dropout,Add,Bidirectional,LSTM,Conv1D,GlobalMaxPool1D,MaxPooling1D,BatchNormalization,Activation
import tensorflow as tf
from sklearn import preprocessing
from itertools import product
import joblib

In [3]:
def encode_matrix(seq_matrix):
    """将字符编码为整数
    """
    ind_to_char = ['A','T','C','G','N']
    char_to_ind = {char: i for i, char in enumerate(ind_to_char)}
    return [char_to_ind[i] for i in seq_matrix]

In [4]:
from tensorflow.keras import backend as K
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.keras.layers import Layer
#https://zhuanlan.zhihu.com/p/97525394
class Attention3d(Layer):
    def __init__(self, step_dim,
                 W_regularizer=None, b_regularizer=None,
                 W_constraint=None, b_constraint=None,
                 bias=True, **kwargs):
        """
        Keras Layer that implements an Attention mechanism for temporal data.
        Supports Masking.
        Follows the work of Raffel et al. [https://arxiv.org/abs/1512.08756]
        # Input shape
            3D tensor with shape: `(samples, steps, features)`.
        # Output shape
            2D tensor with shape: `(samples, features)`.
        :param kwargs:
        Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
        The dimensions are inferred based on the output shape of the RNN.
        Example:
            # 1
            model.add(LSTM(64, return_sequences=True))
            model.add(Attention())
            # next add a Dense layer (for classification/regression) or whatever...
            # 2
            hidden = LSTM(64, return_sequences=True)(words)
            sentence = Attention()(hidden)
            # next add a Dense layer (for classification/regression) or whatever...
        """
        #self.supports_masking = True
        

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias
        self.step_dim = step_dim
        self.features_dim = 0

        super(Attention3d, self).__init__(**kwargs)
    def get_config(self):
         config = {"W_regularizer":self.W_regularizer,
                   "b_regularizer":self.b_regularizer,"W_constraint":self.W_constraint,"b_constraint":self.b_constraint,
                    "bias":self.bias,"step_dim":self.step_dim,"features_dim":self.features_dim}
         base_config = super(Attention3d, self).get_config()
         return dict(list(base_config.items()) + list(config.items()))

    def build(self, input_shape):
        assert len(input_shape) == 3

        self.W = self.add_weight(shape=(input_shape[-1],),
                                 initializer=initializers.get('glorot_uniform'),
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        self.features_dim = input_shape[-1]

        if self.bias:
            self.b = self.add_weight(shape=(input_shape[1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
        else:
            self.b = None

        self.built = True

    def compute_mask(self, input, input_mask=None):
        # do not pass the mask to the next layers
        return None

    def call(self, x, mask=None):
        features_dim = self.features_dim
        step_dim = self.step_dim

        e = K.reshape(K.dot(K.reshape(x, (-1, features_dim)), K.reshape(self.W, (features_dim, 1))), (-1, step_dim))  # e = K.dot(x, self.W)
        if self.bias:
            e += self.b
        e = K.tanh(e)

        a = K.exp(e)
        # apply mask after the exp. will be re-normalized next
        if mask is not None:
            # cast the mask to floatX to avoid float64 upcasting in theano
            a *= K.cast(mask, K.floatx())
        # in some cases especially in the early stages of training the sum may be almost zero
        # and this results in NaN's. A workaround is to add a very small positive number ε to the sum.
        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
        a = K.expand_dims(a)

        c = K.sum(a * x, axis=1)
        return c

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.features_dim

In [5]:
def resnet_identity_block(input_data, filters, kernel_size):
    # CNN层
    x = Conv1D(filters, kernel_size, strides=1, padding='same')(input_data)
    x = BatchNormalization()(x)  #批次标准化
    x = Activation('relu')(x)
    # 第二层没有激活函数
    x = Conv1D(filters, kernel_size, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    # 两个张量相加
    x = Add()([x, input_data])
    # 对相加的结果使用ReLU激活
    x = Activation('relu')(x)
    # 返回结果
    return x

def resnet_convolutional_block(input_data, filters, kernel_size):
    # CNN层
    x = Conv1D(filters, kernel_size, strides=2, padding='valid')(input_data)
    x = BatchNormalization()(x)  #批次标准化
    x = Activation('relu')(x)
    # 第二层没有激活函数
    x = Conv1D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    X = Conv1D(filters, kernel_size, strides=2, padding='valid')(input_data)
    # 两个张量相加
    x = Add()([x, X])
    # 对相加的结果使用ReLU激活
    x = Activation('relu')(x)
    # 返回结果
    return x

In [6]:
def define_model():
    maxlen = 200
    max_features = 5
    embedding_dims = 32
    class_num = 1
    last_activation = 'sigmoid'
    input = Input((maxlen,))
    embedding = Embedding(max_features, embedding_dims, input_length=maxlen)(input)
    y = Conv1D(32, 8, strides=1, padding='same')(embedding)
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    y = MaxPooling1D(pool_size=2, strides=1)(y)
    y = resnet_convolutional_block(y, 64, 8)   #卷积残差快  https://blog.csdn.net/qq_31050167/article/details/79161077
    y = resnet_identity_block(y, 64, 8)   #恒等残差块
    y = resnet_identity_block(y, 64, 8) 
    y = GlobalMaxPool1D()(y)

    x = Bidirectional(LSTM(32, return_sequences=True))(embedding)  # LSTM
    x = Bidirectional(LSTM(32, return_sequences=True))(x)
    x = Attention3d(maxlen)(x)
    x = Dropout(0.5)(x)

    t = tf.keras.layers.Concatenate()([x,y])
    t = Dense(16,activation='relu')(t)
    output = Dense(class_num, activation=last_activation)(t)
    model = Model(inputs=input, outputs=output)
    model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
                  metrics=['accuracy'])
    return model

In [7]:
def model_layer1():
    model_1 = define_model()
    model_1.load_weights('ResNet+LSTM+Attention(layer1).h5')
    return model_1

In [8]:
def model_layer2():
    model_2 = define_model()
    model_2.load_weights('ResNet+LSTM+Attention(layer2).h5')
    return model_2

In [9]:
def predict_DNA_enhancer_layer1(seq):
    predicted_id = []
    predicted_range = []
#     predicted_end = []
    predicted_probability = []
    predicted_probability_non = []
    predicted_probability_strong = []
    predicted_probability_weak = []
    predicted_result = []
    for record in seq:
        record_id = record.id
        record_seq = record.seq
        record_seq = record_seq.upper()
        record_seq = str(record_seq)
        seq_set, index_set = seq_split(record_seq)
        for s, ind in zip(seq_set,index_set):
            predicted_id.append(record_id)
            predicted_range.append(ind)
#             predicted_end.append(ind[1])
            feature = encode_matrix(s)
            feature = np.array(feature)
            feature = np.expand_dims(feature,axis=0)
            res_1 = model_layer1().predict(feature)
            Res_1 = np.squeeze(res_1)
            predicted_probability.append(np.round(Res_1,4))
            predicted_probability_non.append(np.round(1-Res_1,4))
            if Res_1>0.5:
                res_2 = predict_DNA_enhancer_layer2(feature)
                Res_2 = np.squeeze(res_2)
                if Res_2>0.5:
                    predicted_result.append("Strong Enhancer Site")
                else:
                    predicted_result.append("Weak Enhancer Site")
                predicted_probability_strong.append(np.round(Res_2,4))
                predicted_probability_weak.append(np.round(1-Res_2,4))
            else:
                predicted_probability_strong.append("-")
                predicted_probability_weak.append("-")
                predicted_result.append("Not an Enhancer Site")
    return predicted_id,predicted_range,predicted_probability,predicted_probability_non,predicted_probability_strong,predicted_probability_weak,predicted_result

In [10]:
def seq_split(seq):
    pattern = r'[ACGT]+'
    seq_len = len(seq)
    check =re.match(pattern=pattern ,string=seq, flags=re.I)
    seq_check_len = check.span()[1]
    if seq_len!=seq_check_len:
        return print('error')
    else:
        char_list = []
        index_set = []
        if seq_len<200:
            ss = seq[seq_len-200:].rjust(200, 'N')
#             current_index = [1,seq_len]
            char_list.append(ss)
            index_set.append(str(1)+"-"+str(seq_len))
        elif seq_len==200:
            ss = seq
            current_index = [1,seq_len]
            char_list.append(ss)
            index_set.append(str(1)+"-"+str(seq_len))
        else:
            for i in range(seq_len-200):
                ss = seq[i:i+200]
#                 current_index = [i+1,200+i]
                char_list.append(ss)
                index_set.append(str(i+1)+"-"+str(200+i))
    return char_list, index_set

In [12]:
def predict_DNA_enhancer_layer2(feature):
    res_2 = model_layer2().predict(feature)
    return res_2

In [13]:
if __name__ == '__main__':
    print('testing')
    ss = [SeqRecord(
        Seq("CTGCTCTCCTCGCTCTATAAAAGTCAGAGTGCCTAAGCTGTTAATTTGCAAACCCCTTCTTAATCTACCCTCTATTCATAGTTTATATCCAGAACTATGGTTTAATATAATCGTAAGGCCATTGACTTTTGAATACGTAGCTCCAGTCTTAGTCTCACTGGACTAGGTCCTATATCTAACCACTACCACAAAGTCTCCTA"),
        id="uc001aci.2",
        name=">Chr11_6627824_6628024",
        description="toxic membrane protein, small"), ]
    print('hello')
    for s in ss:                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
        print(s.id)
    print(predict_DNA_enhancer_layer1(ss))

testing
hello
uc001aci.2
(['uc001aci.2'], ['1-200'], [0.1633], [0.8367], ['-'], ['-'], ['Not an Enhancer Site'])


In [14]:
if __name__ == '__main__':
    print('testing')
    ss = [SeqRecord(
        Seq("TGAGGAAGCACCAGTACAGGGATAAGAGATGAAGAGACAGGCCAGGTCAGGCTCACCAAGCAGGTAACCGGAACCTTTAATTTTATTATGTGGAATGCTTAATGCAGAGTTAATAGGGGCTAGAGTGGCTAGGAGAGGGGACTACTGAGATAAATAACAGGAGACAGTAATGAGTTACATGTGGATTTGGGGGGCTGCAG"),
        id="uc001aci.2",
        name=">Chr11_6627824_6628024",
        description="toxic membrane protein, small"), ]
    print('hello')
    for s in ss:                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
        print(s.id)
    print(predict_DNA_enhancer_layer1(ss))

testing
hello
uc001aci.2
(['uc001aci.2'], ['1-200'], [0.8981], [0.1019], [0.1472], [0.8528], ['Weak Enhancer Site'])


In [15]:
if __name__ == '__main__':
    print('testing')
    ss = [SeqRecord(
        Seq("ATGCTGCCAGAAGGAAAAGGGGTGGAATTAATGAAACTGGAAGGTTGTGGTGCTGGTTTGAGGAGTAAAGTATGGGGGCCAAAGTTGGCTATATGCTGGATATGAAGAGGGGGTTAATTCCTTGCAGGTCTTCTTGAGATAGAAGTCCAGGCCCTGAGGTGGCAGGCAGCCTGATAGTGAACAGAACCCTTGTGCCCATA"),
        id="uc001aci.2",
        name=">Chr11_6627824_6628024",
        description="toxic membrane protein, small"), ]
    print('hello')
    for s in ss:                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
        print(s.id)
    print(predict_DNA_enhancer_layer1(ss))

testing
hello
uc001aci.2
(['uc001aci.2'], ['1-200'], [0.9903], [0.0097], [0.8364], [0.1636], ['Strong Enhancer Site'])


In [16]:
if __name__ == '__main__':
    print('testing')
    ss = [SeqRecord(
        Seq("ATGCTGCTACTCAGAAGGAAAAGGGGTGGAATTAATGAAACTGGAAGGTTGTGGTGCTGGTTTGAGGAGTAAAGTATGGGGGCCAAAGTTGGCTATATGCTGGATATGAAGAGGGGGTTAATTCCTTGCAGGTCTTCTTGAGATAGAAGTCCAGGCCCTGAGGTGGCAGGCAGCCTGATAGTGAACAGAACCCTTGTGCCCATA"),
        id="uc001aci.2",
        name=">Chr11_6627824_6628024",
        description="toxic membrane protein, small"), ]
    print('hello')
    for s in ss:                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
        print(s.id)
    print(predict_DNA_enhancer_layer1(ss))

testing
hello
uc001aci.2
(['uc001aci.2', 'uc001aci.2', 'uc001aci.2', 'uc001aci.2'], ['1-200', '2-201', '3-202', '4-203'], [0.9686, 0.9344, 0.9839, 0.9322], [0.0314, 0.0656, 0.0161, 0.0678], [0.7755, 0.6537, 0.7788, 0.6426], [0.2245, 0.3463, 0.2212, 0.3574], ['Strong Enhancer Site', 'Strong Enhancer Site', 'Strong Enhancer Site', 'Strong Enhancer Site'])


In [17]:
if __name__ == '__main__':
    print('testing')
    ss = [SeqRecord(
        Seq("ATGCTGCCAGAAGGAAAAGGGGTGGAATTAATGAAACTGGAGGTTGTGGTGCTGGTTTGAGGAGTAAAGTGGGCCAAAGTTGGCTATATGCTGGATATGAAGAGGGGGTTAATTCCTTGCAGGTCTTCTTGAGATAGAAGTCCAGGCCCTGAGGTGGCAGGCAGCCTGATAGTGAACAGAACCCTTGTGCCCATA"),
        id="uc001aci.2",
        name=">Chr11_6627824_6628024",
        description="toxic membrane protein, small"), ]
    print('hello')
    for s in ss:                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
        print(s.id)
    print(predict_DNA_enhancer_layer1(ss))

testing
hello
uc001aci.2
(['uc001aci.2'], ['1-195'], [0.2042], [0.7958], ['-'], ['-'], ['Not an Enhancer Site'])
