In [1]:
import numpy as np
import time
import data
import tqdm
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)
import tensorflow_fold as td
from conv_lstm_cell import *


FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size',    8, """batchsize""")
tf.app.flags.DEFINE_integer('epochs',        10, """epoch count""")
tf.app.flags.DEFINE_integer('truncate',      200, """truncate input sequences to this length""")
tf.app.flags.DEFINE_string('data_dir',       "/mnt/permanent/Home/nessie/velkey/data/", """data store basedir""")
tf.app.flags.DEFINE_string('log_dir',        "/mnt/permanent/Home/nessie/velkey/logs/", """logging directory root""")
tf.app.flags.DEFINE_string('run_name',       "ce_b8_3x100_static_trun200", """naming: loss_fn, batch size, architecture, optimizer""")
tf.app.flags.DEFINE_string('data_type',      "sentence/", """can be sentence/, word/""")
tf.app.flags.DEFINE_string('model',          "lstm", """can be lstm, convlstm""")
tf.app.flags.DEFINE_integer('stack_cells',   2, """how many lstms to stack in each dimensions""")
tf.app.flags.DEFINE_integer('cell_size',     1000, """only valid with lstm model, size of the LSTM cell""")
tf.app.flags.DEFINE_integer('conv_kernel',   0, """convolutional kernel size for convlstm, if 0, vocab size is used""")
tf.app.flags.DEFINE_integer('conv_channels', 64, """convolutional output channels for convlstm""")
tf.app.flags.DEFINE_string('loss',           "crossentropy", """can be l1, l2, crossentropy""")
tf.app.flags.DEFINE_string('optimizer',      "ADAM", """can be ADAM, RMS, SGD""")
tf.app.flags.DEFINE_float('learning_rate',   0.001, """starting learning rate""")


vocabulary = data.vocabulary(FLAGS.data_dir + 'vocabulary')
vsize=len(vocabulary)
print(vocabulary)

index = lambda char: vocabulary.index(char)
char = lambda i: vocabulary[i]


class data():
    def __init__(self, folder, truncate):
        self.data_dir = folder
        self.data = dict()
        self.size = dict()
        self.datasets = ["train", "test", "validation"]
        self.truncate = truncate
        
        for dataset in self.datasets:
            self.data[dataset] = self.sentence_reader(folder+dataset)
            self.size[dataset] = sum(1 for line in open(folder+dataset))

                        
    def sentence_reader(self, file):
        """
        read sentences from the data format setence: sentence\tlabels\n
        """
        data = [line[:-1].split('\t') for line in open(file)]
        while True:
            for item in data:
                tags = [int(num) for num in item[1]]
                if len(item[0]) == len(tags) and len(tags) != 0:
                    sent_onehot = self.onehot(item[0])
                    if len(sent_onehot) >= self.truncate:
                        sent_onehot=sent_onehot[:self.truncate]
                        tags = tags[:self.truncate]
                    yield (sent_onehot, tags)    

            
    def onehot(self, string):
        onehot = np.zeros([len(string),vsize])
        indices = np.arange(len(string)), np.array([int(index(char)) for char in string])
        onehot[indices]=1
        return [onehot[i,:] for i in range(len(onehot))]

    
def model_information():
    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        print(variable.name, shape)
        # print(len(shape))
        variable_parametes = 1
        for dim in shape:
            variable_parametes *= dim.value
        print("\tparams: ", variable_parametes)
        total_parameters += variable_parametes
    print(total_parameters)
    return total_parameters

[' ', '!', '"', '$', '%', '&', "'", '(', ')', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '§', '°', 'Á', 'É', 'Í', 'Ó', 'Ö', 'Ú', 'Ü', 'á', 'ä', 'é', 'ë', 'í', 'ó', 'ö', 'ú', 'ü', 'Ő', 'ő', 'ű']


In [2]:
store = data(FLAGS.data_dir + FLAGS.data_type, FLAGS.truncate)
def pad(record):
    pads = ((FLAGS.truncate-len(record[1]), 0), (0, 0))
    ins = np.pad(record[0], pad_width=pads, mode="constant", constant_values=0)
    outs = np.pad(record[1], pad_width=(FLAGS.truncate-len(record[1]), 0), mode="constant", constant_values=0)
    return (ins, outs)

def get_padded_batch(dataset="train"):
    data = np.zeros([FLAGS.batch_size, FLAGS.truncate, vsize])
    labels = np.zeros([FLAGS.batch_size, FLAGS.truncate, 1])
    for i in range(FLAGS.batch_size):
        sentence, label = pad(next(store.data[dataset]))
        data[i] = sentence
        labels[i, :, 0] = label
    return data, labels

In [3]:
x = tf.placeholder(tf.float32, shape=(FLAGS.batch_size, FLAGS.truncate, vsize))
y = tf.placeholder(tf.int32, shape=(FLAGS.batch_size, FLAGS.truncate, 1))
labels = y

num_units = [100, 100, 100]
with tf.variable_scope("fw"):
    fw_cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(num_units=units) for units in num_units])
with tf.variable_scope("bw"):
    bw_cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(num_units=units) for units in num_units])
    
outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw=fw_cells,cell_bw=bw_cells, inputs=x, dtype=tf.float32)

RNN_out = tf.concat(outputs, -1)
filters = tf.get_variable(shape=[1,200,1], dtype=tf.float32, name="filters")
bias = tf.get_variable(shape=[1], dtype=tf.float32, name="bias")

logits = tf.nn.conv1d(RNN_out,filters=filters,stride=1,padding='SAME') + bias

predictions = tf.nn.sigmoid(logits)

In [7]:
valid_chars_in_batch = tf.reduce_sum(x)
all_chars_in_batch = tf.size(x) / vsize
valid_ratio = valid_chars_in_batch / tf.cast(all_chars_in_batch, tf.float32)

l1_loss = tf.reduce_mean(tf.abs(tf.subtract(logits, tf.cast(labels, tf.float32))))
l2_loss = tf.reduce_mean(tf.abs(tf.subtract(logits, tf.cast(labels, tf.float32))))
cross_entropy = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.cast(labels, tf.float32)))

if FLAGS.loss == "l1":
    loss = l1_loss
elif FLAGS.loss == "l2":
    loss = l2_loss
elif FLAGS.loss == "crossentropy":
    loss = cross_entropy
else:
    raise NotImplemented

path = FLAGS.log_dir + FLAGS.run_name
writer = tf.summary.FileWriter(path, graph=tf.get_default_graph())
saver = tf.train.Saver(max_to_keep=20)
tf.summary.scalar("batch_loss", loss)

#Accuracy
acc = tf.reduce_sum(tf.cast(tf.equal(tf.less(0.5,predictions), tf.cast(labels, tf.bool)),tf.int32))*100/tf.size(labels)
tf.summary.scalar("batch_accuracy", acc)

# Recall
label_matches = tf.equal(tf.less(0.5, predictions), tf.cast(labels, tf.bool))
correct_trues = tf.reduce_sum(tf.cast(tf.logical_and(label_matches, tf.cast(labels, tf.bool)), tf.int32))
all_trues = tf.reduce_sum(labels)
recall = tf.cast(correct_trues,tf.float32) / tf.cast(all_trues, tf.float32)
tf.summary.scalar("recall", recall)
         
summary_op = tf.summary.merge_all()

if FLAGS.optimizer == "ADAM":
    opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
elif FLAGS.optimizer == "RMS":
    opt = tf.train.RMSPropOptimizer(learning_rate=FLAGS.learning_rate)
elif FLAGS.optimizer == "SGD":
    opt = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
else:
    raise NotImplemented

train_op = opt.minimize(loss)
tf.global_variables_initializer()

# validation summary:
validation_loss_placeholder = tf.placeholder(tf.float32, name="validation")
validation_loss_summary = tf.summary.scalar('validation_loss', validation_loss_placeholder)
test_loss_placeholder = tf.placeholder(tf.float32, name="test")
test_loss_summary = tf.summary.scalar('validation_loss', test_loss_placeholder)


def get_metrics_on_dataset(dataset, train_step):
    losses = []
    accs = []
    recalls = []
    step = int(store.size[dataset] / FLAGS.batch_size)
    for i in tqdm.trange(step):
        b_data, b_label = get_padded_batch("validation")
        batch_loss, accuracy, rec = sess.run([loss, acc, recall], {x: b_data, y: b_label})
        losses.append(batch_loss)
        accs.append(accuracy)
        recalls.append(rec)
    
    avg_loss = np.average(losses)
    
    if dataset == "validation":
        valid_summary = sess.run(validation_loss_summary,feed_dict={validation_loss_placeholder: avg_loss})
        writer.add_summary(valid_summary, train_step)
    elif dataset == "test":
        test_summary = sess.run(test_loss_summary,feed_dict={test_loss_placeholder: avg_loss})
        writer.add_summary(test_summary, train_step)

    return np.average(losses), np.average(accs), np.average(recalls)
    
    
class stopper():
    def __init__(self, patience=20):
        self.log = []
        self.patience = patience
        self.should_stop = False
        
    def add(self, value):
        self.log.append(value)
        return self.check()
    
    def check(self):
        minimum = min(self.log)
        errors = sum([1 if i>minimum else 0 for i in self.log[self.log.index(minimum):]])
        if errors > self.patience:
            self.should_stop = True
        return self.should_stop
    
early = stopper(20)
steps = FLAGS.epochs * int(store.size["train"] / FLAGS.batch_size)

In [None]:
sess.run(tf.global_variables_initializer())
for i in tqdm.trange(steps, unit="batches"):
    b_data, b_label = get_padded_batch("train")
    _, batch_loss, summary, preds = sess.run([train_op, loss, summary_op,predictions], {x: b_data, y: b_label})
    assert not np.isnan(batch_loss)
    
    if i % 5 == 0:
        writer.add_summary(summary, i)
        
    if i % 5000 == 0:
        l, a, r = get_metrics_on_dataset("validation", i)
        print("loss: ", l, " accuracy: ", a, "% recall: ", r)
        if early.add(l):
            break
            
    if i % 10000 == 0:
        save_path = saver.save(sess, path + "/model.ckpt", global_step=i)
        
print("Testing...")
l, a, r = get_metrics_on_dataset("test", steps)
print("loss: ", l, " accuracy: ", a, "% recall: ", r)

  0%|          | 0/90630 [00:00<?, ?batches/s]
  0%|          | 0/503 [00:00<?, ?it/s][A
  0%|          | 1/503 [00:00<00:53,  9.47it/s][A
  1%|          | 3/503 [00:00<00:48, 10.28it/s][A
  1%|          | 5/503 [00:00<00:44, 11.22it/s][A
  1%|▏         | 7/503 [00:00<00:42, 11.69it/s][A
  2%|▏         | 9/503 [00:00<00:41, 12.02it/s][A
  2%|▏         | 11/503 [00:00<00:40, 12.21it/s][A
  3%|▎         | 13/503 [00:01<00:39, 12.37it/s][A
  3%|▎         | 15/503 [00:01<00:39, 12.47it/s][A
  3%|▎         | 17/503 [00:01<00:38, 12.55it/s][A
  4%|▍         | 19/503 [00:01<00:38, 12.62it/s][A
  4%|▍         | 21/503 [00:01<00:38, 12.68it/s][A
  5%|▍         | 23/503 [00:01<00:38, 12.49it/s][A
  5%|▍         | 25/503 [00:01<00:38, 12.56it/s][A
  5%|▌         | 27/503 [00:02<00:38, 12.40it/s][A
  6%|▌         | 29/503 [00:02<00:38, 12.47it/s][A
  6%|▌         | 31/503 [00:02<00:37, 12.52it/s][A
  7%|▋         | 33/503 [00:02<00:37, 12.56it/s][A
  7%|▋         | 35/503 [00:02<

 61%|██████    | 307/503 [00:24<00:15, 12.40it/s][A
 61%|██████▏   | 309/503 [00:24<00:15, 12.40it/s][A
 62%|██████▏   | 311/503 [00:25<00:15, 12.39it/s][A
 62%|██████▏   | 313/503 [00:25<00:15, 12.40it/s][A
 63%|██████▎   | 315/503 [00:25<00:15, 12.40it/s][A
 63%|██████▎   | 317/503 [00:25<00:14, 12.41it/s][A
 63%|██████▎   | 319/503 [00:25<00:14, 12.41it/s][A
 64%|██████▍   | 321/503 [00:25<00:14, 12.42it/s][A
 64%|██████▍   | 323/503 [00:26<00:14, 12.42it/s][A
 65%|██████▍   | 325/503 [00:26<00:14, 12.43it/s][A
 65%|██████▌   | 327/503 [00:26<00:14, 12.41it/s][A
 65%|██████▌   | 329/503 [00:26<00:14, 12.42it/s][A
 66%|██████▌   | 331/503 [00:26<00:13, 12.42it/s][A
 66%|██████▌   | 333/503 [00:26<00:13, 12.43it/s][A
 67%|██████▋   | 335/503 [00:26<00:13, 12.43it/s][A
 67%|██████▋   | 337/503 [00:27<00:13, 12.42it/s][A
 67%|██████▋   | 339/503 [00:27<00:13, 12.43it/s][A
 68%|██████▊   | 341/503 [00:27<00:13, 12.43it/s][A
 68%|██████▊   | 343/503 [00:27<00:12, 12.44it

loss:  0.91398  accuracy:  4.30181411531 % recall:  1.0


  6%|▌         | 5000/90630 [23:49<6:47:53,  3.50batches/s]
  0%|          | 0/503 [00:00<?, ?it/s][A
  0%|          | 2/503 [00:00<00:37, 13.52it/s][A
  1%|          | 4/503 [00:00<00:37, 13.42it/s][A
  1%|          | 6/503 [00:00<00:37, 13.24it/s][A
  1%|▏         | 7/503 [00:00<00:40, 12.35it/s][A
  2%|▏         | 9/503 [00:00<00:39, 12.50it/s][A
  2%|▏         | 11/503 [00:00<00:38, 12.62it/s][A
  3%|▎         | 13/503 [00:01<00:38, 12.70it/s][A
  3%|▎         | 15/503 [00:01<00:38, 12.79it/s][A
  3%|▎         | 17/503 [00:01<00:37, 12.84it/s][A
  4%|▍         | 19/503 [00:01<00:37, 12.85it/s][A
  4%|▍         | 21/503 [00:01<00:37, 12.87it/s][A
  5%|▍         | 23/503 [00:01<00:37, 12.65it/s][A
  5%|▍         | 25/503 [00:01<00:37, 12.69it/s][A
  5%|▌         | 27/503 [00:02<00:38, 12.52it/s][A
  6%|▌         | 29/503 [00:02<00:38, 12.35it/s][A
  6%|▌         | 31/503 [00:02<00:38, 12.21it/s][A
  7%|▋         | 33/503 [00:02<00:38, 12.27it/s][A
  7%|▋         | 3

 61%|██████    | 307/503 [00:24<00:15, 12.47it/s][A
 61%|██████▏   | 309/503 [00:24<00:15, 12.46it/s][A
 62%|██████▏   | 311/503 [00:24<00:15, 12.46it/s][A
 62%|██████▏   | 313/503 [00:25<00:15, 12.47it/s][A
 63%|██████▎   | 315/503 [00:25<00:15, 12.47it/s][A
 63%|██████▎   | 317/503 [00:25<00:14, 12.48it/s][A
 63%|██████▎   | 319/503 [00:25<00:14, 12.48it/s][A
 64%|██████▍   | 321/503 [00:25<00:14, 12.48it/s][A
 64%|██████▍   | 323/503 [00:25<00:14, 12.47it/s][A
 65%|██████▍   | 325/503 [00:26<00:14, 12.47it/s][A
 65%|██████▌   | 327/503 [00:26<00:14, 12.47it/s][A
 65%|██████▌   | 329/503 [00:26<00:13, 12.47it/s][A
 66%|██████▌   | 331/503 [00:26<00:13, 12.48it/s][A
 66%|██████▌   | 333/503 [00:26<00:13, 12.47it/s][A
 67%|██████▋   | 335/503 [00:26<00:13, 12.47it/s][A
 67%|██████▋   | 337/503 [00:27<00:13, 12.46it/s][A
 67%|██████▋   | 339/503 [00:27<00:13, 12.45it/s][A
 68%|██████▊   | 341/503 [00:27<00:13, 12.45it/s][A
 68%|██████▊   | 343/503 [00:27<00:12, 12.45it

loss:  0.0310581  accuracy:  98.9054423459 % recall:  0.838144


  6%|▌         | 5055/90630 [24:44<6:58:54,  3.40batches/s]

11.0