In [10]:
# coding: utf-8

from __future__ import print_function
from load_glove import *
import json
from collections import defaultdict
import numpy as np
import tensorflow as tf
from sys import argv
import os
import argparse
import gzip

In [37]:
parser = argparse.ArgumentParser()
parser.add_argument('--runmode', dest='runmode', choices=["train", "test"], default="train")
parser.add_argument('--dataset_name', dest='dataset_name', type=str, default="duc2004")
parser.add_argument('--trained_on', dest='trained_on', type=str, default="duc2004")
parser.add_argument('--cnn', dest='cnn', type=int, default=0)
parser.add_argument('--train_embedding', dest='train_embedding', type=int, default=0)
parser.add_argument('--output_root', dest='output_root', type=str, default="")
parser.add_argument('--evaluation_root', dest='evaluation_root', type=str, default="../../evaluation")
parser.add_argument('--glove_location', dest='glove_location', type=str, default="../../data/glove/glove.6B.100d.txt")

INPUT_MAX = 150
OUTPUT_MAX = 20
VOCAB_MAX = 30000

GLV_RANGE = 0.5
LR_DECAY_AMOUNT = 0.9
starter_learning_rate = 1e-2
hs = 256

BATCH_SIZE = 32
PRINT_EVERY = 5
CHECKPOINT_EVERY = 5000
TRAIN_KEEP_PROB = 0.5
TRAIN_EMBEDDING = args.train_embedding
USE_CNN = args.cnn
KERNEL_SIZE = 7
#args = parser.parse_args()
args = parser.parse_args([]) # for ipython notebook

In [38]:
log_components = ["train"]
if args.runmode == "train":
    log_components += [args.dataset_name]
elif args.runmode == "test":
    log_components += [args.trained_on]
if args.cnn:
    log_components += ["cnn"]
if args.train_embedding:
    log_components += ["train_embedding"]
LOGDIR = "-".join(log_components)
if args.output_root != "":
    LOGDIR = os.path.join(args.output_root, LOGDIR)
dataset_file = os.path.join("../../data", args.dataset_name, "data.json")
print("Runmode %s on dataset %s" % (args.runmode, args.dataset_name))

if args.runmode == "test":
    predictions_dir_name = "_".join(["seq2seq", LOGDIR, args.trained_on, args.dataset_name])
    predictions_dir_path = os.path.join(args.evaluation_root, predictions_dir_name)
    predictions_file_path = os.path.join(predictions_dir_path, "prediction.json.gz")

Runmode train on dataset duc2004


In [98]:
GLOVE_LOC = args.glove_location

INPUT_MAX = 100
OUTPUT_MAX = 20
VOCAB_MAX = 30000

GLV_RANGE = 0.5
LR_DECAY_AMOUNT = 0.9
starter_learning_rate = 1e-2
hs = 256

BATCH_SIZE = 32
PRINT_EVERY = 5
CHECKPOINT_EVERY = 5000
TRAIN_KEEP_PROB = 0.5
TRAIN_EMBEDDING = args.train_embedding
USE_CNN = args.cnn
KERNEL_SIZE = 7
WINDOW_SIZE = 4

In [68]:

print("Loading GLOVE vectors...")
words = glove2dict(GLOVE_LOC)
word_counter = defaultdict(int)
GLV_DIM = words['the'].shape[0]
print("...loaded %d dimensional GLOVE vectors!" % GLV_DIM)


Loading GLOVE vectors...
...loaded 100 dimensional GLOVE vectors!


In [103]:

not_letters_or_digits = u'!"#%\'()*+,-./:;<=>?@[\]^_`{|}~'
translate_table = dict((ord(char), None) for char in not_letters_or_digits)
def clean(text,clip_n=0):
    res = text.replace('<d>','').replace('<p>','').replace('<s>','').replace('</d>','').replace('</p>','').replace('</s>','').translate(translate_table)
    r2 = []
    for word in res.split():
        if word not in words:
            words[word] = np.array([random.uniform(-GLV_RANGE, GLV_RANGE) for i in range(GLV_DIM)])
    for word in res.split():
        word_counter[word] += 1
    if clip_n > 0:
        return ' '.join(res.split()[:clip_n])
    else:
        return res

print("Loading dataset...")
from collections import defaultdict
with open(dataset_file) as fp:
    data = json.load(fp)
    train_o = [x for x in data if x['set'] == 'train']
    dev_o = [x for x in data if x['set'] == 'dev']
    test_o = [x for x in data if x['set'] == 'test']

    train = sum([[(clean(x['data'],INPUT_MAX), clean(x['label'][i],OUTPUT_MAX),idx) for i in range(len(x['label']))] for idx,x in enumerate(train_o)],[])
    dev   = sum([[(clean(x['data'],INPUT_MAX), clean(x['label'][i],OUTPUT_MAX),idx) for i in range(len(x['label']))] for idx,x in enumerate(dev_o)  ],[])
    test  = sum([[(clean(x['data'],INPUT_MAX), clean(x['label'][i],OUTPUT_MAX),idx) for i in range(len(x['label']))] for idx,x in enumerate(test_o) ],[])

    valid_words = (sorted([(v,k) for k,v in word_counter.items()])[::-1])
    print(len(valid_words), "valid words found.")
    valid_words = ['<PAD>'] + [x[1] for x in valid_words[:VOCAB_MAX]] + ['<EOS>','<UNK>','<SOS>']
    unk_idx = valid_words.index('<UNK>')
    vwd = defaultdict(lambda : unk_idx)
    for idx,word in enumerate(valid_words):
        vwd[word] = idx

    initial_matrix = np.array([words[x] for x in valid_words])
    def sent_to_idxs(s):
        base =  [vwd[word] for word in s.split()]
        sen_len = len(base)
        base =  [vwd['<SOS>']] + base# + [valid_words.index('<EOS>')]
        pad_word = (OUTPUT_MAX-sen_len)
        base = base + pad_word*[vwd['<EOS>']]
        return base,(sen_len,pad_word)
    def sent_to_idxs_nopad(sentence):
        base =  [vwd[word] for word in sentence.split()]
        return base
    random.seed(111948)
    #train_idx = range(len(train))
    #random.shuffle(train_idx)
    #train = [train[i] for i in train_idx]
    random.shuffle(train)
    train_x = [sent_to_idxs_nopad(x[0]) for x in train]
    train_y = [sent_to_idxs(x[1])[0] for x in train]
    train_len = [sent_to_idxs(x[1])[1] for x in train]

    dev_x = [sent_to_idxs_nopad(x[0]) for x in dev]
    dev_y = [sent_to_idxs(x[1])[0] for x in dev]
    dev_len = [sent_to_idxs(x[1])[1] for x in dev]

    test_x = [sent_to_idxs_nopad(x[0]) for x in test]
    test_y = [sent_to_idxs(x[1])[0] for x in test]
    test_len = [sent_to_idxs(x[1])[1] for x in test]
    


Loading dataset...
17957 valid words found.


In [105]:
PAD_INDEX = valid_words.index('<PAD>')
SOS_INDEX = valid_words.index('<SOS>')

In [101]:
tf.reset_default_graph()
global_step = tf.Variable(0, trainable=False)
VOCAB_SIZE = len(valid_words)
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, len(train_x)/BATCH_SIZE, LR_DECAY_AMOUNT, staircase=True)

input_placeholder = tf.placeholder(tf.int32, (None, INPUT_MAX))
input_step_placeholder = tf.placeholder(tf.int32, (None, INPUT_MAX))
labels_placeholder = tf.placeholder(tf.int32, (None, OUTPUT_MAX+1))
window_placeholder =  tf.placeholder(tf.int32, (None, WINDOW_SIZE))
step_placeholder =  tf.placeholder(tf.int32, ())
CONTEXT_SIZE = (INPUT_MAX + WINDOW_SIZE) * GLV_DIM

embedding = tf.Variable(initial_matrix, dtype=tf.float32, trainable=TRAIN_EMBEDDING)
input_embed = tf.nn.embedding_lookup(embedding, input_placeholder)
#print("input_embed shape: ", input_embed.shape)
window_embed = tf.nn.embedding_lookup(embedding, window_placeholder)
#print("window_embed shape: ", window_embed.shape)
context_concat = tf.concat([input_embed, window_embed], axis=1)
#print("context_concat shape: ", context_concat.shape)
context = tf.reshape(context_concat, [-1, CONTEXT_SIZE])
#print("context shape: ", context.shape)

W1 = tf.get_variable("W1", shape=[CONTEXT_SIZE, HIDDEN_STATE_SIZE], initializer=tf.contrib.layers.xavier_initializer())
b1 = tf.Variable(tf.constant(0.0, shape=[HIDDEN_STATE_SIZE]))
h = tf.nn.tanh(tf.matmul(context, W1) + b1)
print("h shape: ", h.shape)

W2 = tf.get_variable("W2", shape=[HIDDEN_STATE_SIZE, VOCAB_SIZE], initializer=tf.contrib.layers.xavier_initializer())
preds = tf.matmul(h, W2)
print(preds.shape)
print(labels_placeholder.shape)

ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=preds,labels=labels_placeholder[:,step_placeholder])
loss = tf.reduce_mean(ce)
tf.summary.scalar('loss', loss)

optimizer = tf.train.AdamOptimizer(learning_rate)
gvs = optimizer.compute_gradients(loss)
grads = [g for g,v in gvs]
tvars = [v for g,v in gvs]
grads, _= tf.clip_by_global_norm(grads,5)
train_step = optimizer.apply_gradients(zip(grads,tvars))

h shape:  (?, 100)
(?, 17961)
(?, 21)


In [110]:
([PAD_INDEX] * (WINDOW_SIZE-1)) + [SOS_INDEX]

[0, 0, 0, 17960]

In [95]:
def sample(context_vector):
    sentence = []
    window = ([PAD_INDEX] * (WINDOW_SIZE-1)) + [SOS_INDEX]
    for i in range(OUTPUT_MAX):
        x = sent_to_idxs(' '.join(sentence))[0]
        context_vector = np.array(context_vector).reshape([1,-1])
        feed_dict = {
            input_placeholder: context_vector,
            labels_placeholder: np.array(x).reshape([1,-1]),
            window_placeholder: window
            step_placeholder: i+1,
        }
        probs = preds.eval(feed_dict=feed_dict)
        word_index = np.argmax(probs)
        window.append(word_index)
        window.pop(0)
        new_word = valid_words[word_index]
        #if new_word != '<EOS>':
        sentence.append(new_word)
        #else:
        #   break
    return ' '.join(sentence)

def idxs_to_sent(idxs):
    ' '.join([x for x in [valid_words[x] for x in idxs] if x not in ['<EOS>','<SOS>']])


def sample_batch(context_vectors):
    sizes = [len(x) for x in context_vectors]
    mat = np.zeros(shape=(len(context_vectors),INPUT_MAX))
    for idx,row in enumerate(context_vectors):
        mat[idx,:sizes[idx]] = np.array(row)
    num_sent = np.array(context_vectors).shape[0]
    sentences = [[] for _ in range(num_sent)]

    for i in range(OUTPUT_MAX):
        x = [sent_to_idxs(' '.join(s))[0] for s in sentences]
        feed_dict = {
            input_placeholder: mat.reshape([num_sent,-1]),
            labels_placeholder: np.array(x).reshape([num_sent,-1]),
            step_placeholder: i
        }
        probs = preds.eval(feed_dict=feed_dict) # batch,word,vocab
        for batch_i,batch_prob in enumerate(probs):
            new_word = valid_words[np.argmax(batch_prob[i,:])]
            sentences[batch_i].append(new_word)
    stops = [sentence.index('<EOS>') if '<EOS>' in sentence else OUTPUT_MAX for sentence in sentences]
    return [' '.join(sentence[:maxe]) for maxe,sentence in zip(stops,sentences)]


In [46]:
def try_restoring_checkpoint(session, saver):
    print('trying to restore checkpoints...')
    try:
      ckpt_state = tf.train.get_checkpoint_state(LOGDIR)
    except tf.errors.OutOfRangeError as e:
      print('Cannot restore checkpoint: ', e)
      exit(1)

    if not (ckpt_state and ckpt_state.model_checkpoint_path):
      print('No model at %s, starting with fresh parameters' % LOGDIR)
      return

    print('Loading checkpoint ', ckpt_state.model_checkpoint_path)
    saver.restore(session, ckpt_state.model_checkpoint_path)
    print('...loaded.')

In [96]:
with tf.Session() as sess:
    merged = tf.summary.merge_all()
    sess.run(tf.global_variables_initializer())
    summary_writer = tf.summary.FileWriter(LOGDIR, sess.graph)
    saver =  tf.train.Saver()
    try_restoring_checkpoint(sess, saver)
    data_size = len(train_x)
    for i in range(data_size * 10):
        start_idx = (i*BATCH_SIZE) % data_size
        end_idx = start_idx+BATCH_SIZE
        mask = np.array([np.array([1.0]*x[0] + [0.0]*x[1]) for x in train_len[start_idx:end_idx]])
        train_sizes = [len(x) for x in train_x[start_idx:end_idx]]
        inputs_batch = np.zeros(shape=(len(train_sizes),INPUT_MAX))
        for idx,row in enumerate(train_x[start_idx:end_idx]):
            inputs_batch[idx,:train_sizes[idx]] = np.array(row)
        labels_batch = train_y[start_idx:end_idx]
        for step in range(OUTPUT_MAX):
            feed_dict = {
                input_placeholder: inputs_batch,
                labels_placeholder: labels_batch,
                step_placeholder: step
            }
            _, bl, summary = sess.run([train_step, loss, merged], feed_dict=feed_dict)
        if args.runmode == "train":
            summary_writer.add_summary(summary, i)
        if i % PRINT_EVERY == 0:
            print(i,bl)
            print('TRAIN_SAMPLE: ', sample(train_x[start_idx]))
            print('TRAIN_LABEL: ', ' '.join([x for x in [valid_words[x] for x in train_y[start_idx]] if x not in ['<EOS>','<SOS>']]))
            index = int(random.random()*10)
            print('DEV   SAMPLE: ', sample(dev_x[index]))
            print('DEV   LABEL: ', ' '.join([x for x in [valid_words[x] for x in dev_y[index]] if x not in ['<EOS>','<SOS>']]))

            print('\n')
        if i != 0 and i %2000*len(train_x)/BATCH_SIZE == 0:
            print("Saving checkpoint...")
            saver.save(sess, os.path.join(LOGDIR, 'model-checkpoint-'), global_step=i)
            summary_writer.flush()

trying to restore checkpoints...
No model at train-duc2004, starting with fresh parameters
0 3.06758
TRAIN_SAMPLE:  <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
TRAIN_LABEL:  primakov says economic crisis will not privatization efforts
DEV   SAMPLE:  <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
DEV   LABEL:  cambodian king announces coalition government with hun sen as sole premier


5 0.806815
TRAIN_SAMPLE:  <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
TRAIN_LABEL:  chinese dissident yao zhenxian flees to us to escape arrest in china
DEV   SAMPLE:  <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
DEV   LABEL:  disputes over presidency block efforts to form a new government


10 1.02936
TRAIN_SA

KeyboardInterrupt: 

In [50]:
print(len(train_mat))

32

In [58]:
train_y[start_idx:end_idx][1][1]

2874