In [1]:
import tensorflow as tf
from common.model.utils_ori import *
import numpy as np
import json

In [2]:
#tf.enable_eager_execution()

In [3]:
ROOT = "../../data/protein/classification/full_750/"
DATA_PATH = ROOT+"3_kmers"
EMBEDDING_PATH = "../../data/protein/classification/data_sources/protVec_100d_3grams.csv"
MODEL_PATH = "../../weights/protein/classification/full_750/3_kmers"
SEQUENCE_LENGTH=748

In [4]:
with open(ROOT+"classToIndex.json") as f:
    data = json.load(f)
NUM_CLASSES = max(data.values())+1

In [5]:
def input_fn():
    return get_batches(extract_seq_and_label, DATA_PATH, 2, running_mode="train", 
                       args=[[SEQUENCE_LENGTH], False], balance=False)

In [49]:
class ResnetIdentityBlock(tf.keras.Model):

    def __init__(self, num_outputs, kernel_size, strides, dilation_rate=1, dropout=0.2, downsample = True, 
                 act=tf.nn.relu):
        super(ResnetIdentityBlock, self).__init__()
        self.act = act
        self.conv1 = tf.layers.Conv1D(num_outputs, kernel_size, strides=strides, dilation_rate=1, 
                            activation=tf.nn.relu, name="conv1")
        self.conv2 = tf.layers.Conv1D(num_outputs, kernel_size, strides=1, dilation_rate=dilation_rate, 
                            activation=tf.nn.relu, name="conv2", padding="SAME")
        self.conv3 = tf.layers.Conv1D(num_outputs, kernel_size, strides=strides, dilation_rate=1, 
                            activation=tf.nn.relu, name="conv3")
        self.dropout1 = tf.layers.Dropout(dropout)
        self.dropout2 = tf.layers.Dropout(dropout)
        self.bn1 = tf.layers.BatchNormalization(name="bn1")
        self.bn2 = tf.layers.BatchNormalization(name="bn2")
        if downsample:            
            self.downsample = tf.layers.Conv1D(num_outputs, kernel_size, strides=strides, dilation_rate=1, 
                              activation=tf.nn.relu, name="conv3")
        self.name = name

    def call(self, input, training=True):
        residual = input

        out = self.conv1(input)
        out = self.bn1(out)
        out = self.act(out)
        out = self.dropout1(out, training=training)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.dropout2(out, training=training)
        if self.downsample is not None:
            residual = self.downsample(input)

        out += residual
        out = self.act(out)
        
        return out

In [50]:
def model(x, is_training=False):
    embeddings = np.loadtxt(open(EMBEDDING_PATH, "rb"), delimiter="\t", skiprows=1, usecols=[i for i in range(1,101)])
    embedding_weights = tf.get_variable(name="Embedding_weights", shape=[embeddings.shape[0], embeddings.shape[1]],
                                    initializer=tf.constant_initializer(embeddings),trainable=False)

    embedded_seq = tf.nn.embedding_lookup(embedding_weights, x)
    print(embedded_seq.shape)
    resnet_block1 = ResnetIdentityBlock(128, 5, 3)(embedded_seq)
    resnet_block2 = ResnetIdentityBlock(256, 5, 3)(resnet_block1)    
    resnet_block3 = ResnetIdentityBlock(512, 5, 3)(resnet_block2)   
    resnet_block4 = ResnetIdentityBlock(1024, 5, 3)(resnet_block3)
    flat = tf.reduce_sum(resnet_block4, [1])
    out = tf.layers.dense(inputs=flat, units=NUM_CLASSES)
    return out

In [51]:
def model_fn(features, labels, mode):
    if mode == tf.estimator.ModeKeys.PREDICT:
        logits = model(features, is_training=False)
        predictions = {'classes': tf.argmax(logits, axis=1),
                       'probabilities': tf.nn.softmax(logits)}
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={'classify': tf.estimator.export.PredictOutput(predictions)})
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
        logits = model(features, is_training=True)
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
        accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, axis=1))
        tf.identity(accuracy[1], name='train_accuracy')
        tf.summary.scalar('train_accuracy', accuracy[1])
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=loss,
            train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
    if mode == tf.estimator.ModeKeys.EVAL:
        logits = model(features, is_training=False)
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
        return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,
                                          loss=loss,
                                          eval_metric_ops={'accuracy': tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, axis=1))})

    

In [52]:
enzyme_classifier = tf.estimator.Estimator(
      model_fn=model_fn,
      model_dir=MODEL_PATH)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '../../weigths/protein/classification/full_750/3_kmers', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x0000024FCA5BE748>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


In [32]:
tensors_to_log = {'train_accuracy': 'train_accuracy'}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=1)
enzyme_classifier.train(input_fn=input_fn, hooks=[logging_hook])

Loading files from ../../data/protein/classification/full_750/3_kmers
Found 60 file(s)
Loading process will use 4 CPUs
INFO:tensorflow:Calling model_fn.
(2, ?, 100)
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.


KeyboardInterrupt: 