In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import bp_pas.params as params
import bp_pas.corpus as corpus
import bp_pas.stats as stats
from bp_pas.eval import evaluate

import sys
import gflags
import tensorflow as tf
import numpy as np

In [3]:
params.load_defaults()
flags = gflags.FLAGS
arg_str = '' #'--train_data single.utf8 --dev_data single.utf8 --test_data single.utf8'
args = [''] + arg_str.split(' ')
argv = flags(args)

In [58]:
# tmp params to add to flags
unk_threshold = 10
unk_token = '<UNK>'
arg_embedding_size = 32
pred_embedding_size = 8


In [5]:
# Load data
print(flags.train_data)
ntc = corpus.NTCLoader()
train_data = ntc.load_corpus(flags.train_data, flags.max_train_instances)
test_data  = ntc.load_corpus(flags.test_data, flags.max_train_instances)
dev_data   = ntc.load_corpus(flags.dev_data, flags.max_train_instances)
print('{} train sentences.'.format(len(train_data)))
print('{} test sentences.'.format(len(test_data)))
print('{} dev sentences.'.format(len(dev_data)))
stats.corpus_statistics(train_data)
stats.show_case_dist(train_data)

data/NTC_1.5/processed/train.utf8
1 train sentences.
1 test sentences.
1 dev sentences.

CORPUS STATISTICS
	Docs: 1  Sents: 41005  Words: 1048292
	Predicates: 117565  Arguments 181613



CASE DISTRIBUTION
	Ga	BST: 357  DEP: 63280  INTRA-ZERO: 21328  INTER-ZERO: 0  EXOPHORA: 0
	O	BST: 283  DEP: 41807  INTRA-ZERO: 3398  INTER-ZERO: 0  EXOPHORA: 0
	Ni	BST: 786  DEP: 13807  INTRA-ZERO: 1034  INTER-ZERO: 0  EXOPHORA: 0

	Predicates: 106708


In [6]:
from collections import Counter
class Vocab:
    
    def __init__(self, init_words=[], unk_token = '<UNK>', unk_threshold=10):
        self.word2int = {}
        self.int2word = []
        self.frozen = False
        self.add(unk_token)
        if len(init_words) > 0:
            self.build(init_words, unk_threshold)

    def add(self, elem):
        assert not self.frozen
        if elem not in self.word2int:
            self.int2word.append(elem)
            self.word2int[elem] = len(self) - 1
    
    def build(self, words, unk_threshold=10):
        counter = Counter()
        for w in words:
            counter[w] += 1
        for w,c in counter.items():
            if c > unk_threshold:
                self.add(w)
    
    def freeze(self):
        self.frozen = True

    def index(self, elem):
        if self.frozen and elem not in self.word2int:
            return 0
        else:
            assert elem in self.word2int
            return self.word2int[elem]
    
    def word(self, index):
        assert index < self.int2word
        return self.int2word[index]
    
    def __len__(self):
        return len(self.int2word)


In [172]:
# Setup vocabulary
all_sents = train_data[0] + test_data[0] + dev_data[0]


print('Collecting all tokens...')
word_tokens = list(word.form
                   for sent in all_sents
                   for word in sent)
print('Total tokens: {}\n'.format(len(word_tokens)))

print('Building (argument) vocabulary...')
arg_vocab = Vocab(init_words=word_tokens, 
                  unk_token=unk_token, 
                  unk_threshold=unk_threshold)  #[unk_token] + get_vocab(word_tokens)
arg_vocab.freeze()
arg_vocabulary_size = len(arg_vocab)
print('Vocabulary size: {}\n'.format(arg_vocabulary_size))

print('Collecting predicates...')
pred_vocab = Vocab(init_words=[word.form
                               for sent in all_sents
                               for word in sent if word.is_prd])
pred_vocab.freeze()
pred_vocabulary_size = len(pred_vocab)
print('Number of predicates: {}\n'.format(pred_vocabulary_size))

print('Collecting argument types...')
arg_types = ['NIL'] + list(set([arg.arg_type 
                      for sent in all_sents 
                      for pas in sent.pas
                      for arg in pas.args]))
#for sent in train_data[0]:
#    for pas in sent.pas:
#        print(pas.pred.word_index)
#        for arg in pas.args:
#            print(arg)
print('Arg types: ', arg_types)
num_types = len(arg_types)

Collecting all tokens...
Total tokens: 1783582

Building (argument) vocabulary...
Vocabulary size: 10676

Collecting predicates...
Number of predicates: 2530

Collecting argument types...
Arg types:  ['NIL', 'O', 'NI', 'GA']


In [271]:
# Convert data structures into numpy ndarrays for placeholders
def sent2nump(sent):
    sent_ids = [arg_vocab.index(w.form) for w in sent]
    pases = [w.form for w in sent if w.is_prd]
    pred_ids = [pred_vocab.index(w.form) for w in sent if w.is_prd]
#    print(pas_ids)
#    print(pases)
    label_mats = []
    for pas in sent.pas:
        lm = label_matrix(pas, arg_types, len(sent))
        label_mats.append(lm)
    lm = np.concatenate(label_mats, axis=1) #.shape
    return [sent_ids], [pred_ids], lm
#        break

def label_matrix(pas, arg_types, sent_len):
    zeros = np.zeros((len(arg_types), sent_len))
    for arg in pas.args:
        zeros[arg_types.index(arg.arg_type), arg.word_index] = 1.0
    for j in range(sent_len):
        found = False
        for i in range(1, len(arg_types)):
            if zeros[i, j] == 1.0:
                found = True
        if not found:
            zeros[0,j] = 1.0
    return zeros







#def pas2matrix(pas):
    

def sent2ints(sent):
    sent_ids = [vocab.index(w.form) for w in sent]
    pas_ids = []
    return sent_ids, pas_ids

#def sents2batch(sents, vocab, batch_size):
#    for sent in sents:
#
#    int_sents = [ for sent in train_data[0]]
#    return np.array([int_sents[0]]), np.array([int_sents[0]])


#sents, labels = sents2batch(train_data, vocab, batch_size=1)
#print(batched_sent_dicts.shape)

d = train_data[0][0]
#print(len(d))
#print(' '.join([w.form for w in d.words]))
#for w in d:
#    if w.is_prd:
#        print(w, '\t', w.arg_types, '\t  ', w.arg_indices)
#sent2nump(d)

#GA_INDEX = 0
#O_INDEX = 1
#NI_INDEX = 2

# pred id 10 has NI and GA
# pred id 25 has GA and O
# pred id 31 has GA

train_dicts = [sent2nump(td) for td in train_data[0][1:5]]
#sent_ids, pred_ids = sent2nump(train_data[0][0])
#print(sent_ids)
#print(pred_ids)

In [342]:
# Setup the context embedding
tf.reset_default_graph()

# Construct embedding matrix for contex/arguments
arg_embeddings = tf.Variable(tf.random_uniform([arg_vocabulary_size, arg_embedding_size], -1.0, 1.0, dtype=tf.float64))

# Construct embedding matrix for predicates
pred_embeddings = tf.Variable(tf.random_uniform([pred_vocabulary_size, num_types * pred_embedding_size], -1.0, 1.0, dtype=tf.float64))

# Setup placeholders
batch_size = 1
sent_placeholder = tf.placeholder(tf.int32, shape=(batch_size, None))
pred_placeholder = tf.placeholder(tf.int32, shape=(batch_size, None))
gold_placeholder = tf.placeholder(tf.float64, shape=(num_types, None))

def context_embeddings(sent, arg_embeddings, output_dim):
    # Shape [batch_size, max_sent_len, emb_dim]
    embed_rep = tf.gather(arg_embeddings, sent)
    # List of length max_sent_len, comprising [batch_size, emb_dim] tensors
#    X = tf.unstack(embed_rep, axis=1, num=72)
#    return X
    X = embed_rep
#    return X
    fw_cell = tf.contrib.rnn.LSTMCell(num_units=output_dim/2, state_is_tuple=True)
    bw_cell = tf.contrib.rnn.LSTMCell(num_units=output_dim/2, state_is_tuple=True)
    outputs, states  = tf.nn.bidirectional_dynamic_rnn(
        cell_fw=fw_cell,
        cell_bw=bw_cell,
        dtype=tf.float64,
#        sequence_length=X_lengths,
        inputs=X)
    output_fw, output_bw = outputs
    states_fw, states_bw = states
    final_rep = tf.concat(outputs, 2)
    return final_rep


def pred_full_scoring(preds, pred_embeddings, context_mat):
    # Reshape the context mat to remove the batch dim
    context_tensor = tf.reshape(context_mat, shape=[-1, tf.shape(context_mat)[-1]])

    # Shape [batch_size, num_tokens, pred_embed_dim]
    # actually ,remove batch for now
    pred_mat = tf.gather(pred_embeddings, preds)
    pred_mat = tf.squeeze(pred_mat)

    # Split the pred embeddings into individual pred tensors
    score_tensor = tf.map_fn(lambda x: scoring(x, context_tensor), pred_mat)
    
    # Reshape score tensor for loss tensor
    score_tensor = tf.transpose(tf.reshape(score_tensor, shape=[-1, tf.shape(score_tensor)[-1]]))
    return score_tensor


def scoring(pred_tensor, context_tensor):
#    pred_tensor = tf.reshape(pred_tensor, shape=[1, -1, 4])
    pred_tensor = tf.reshape(pred_tensor, shape=[tf.shape(context_tensor)[-1], -1])
    return tf.matmul(context_tensor, pred_tensor)

#def pred_single_scoring(single_pred_set, single_context, num_slots, num_preds):
#    return single_pred_set
#    single_pred = tf.unstack(single_pred_set, num=num_preds)[0]
#    return tf.transpose(single_pred)
#    return tf.matmul(single_context, tf.transpose(single_pred))

def srl_loss(gold, preds, pred_embeddings, context_mat):
    prediction = pred_full_scoring(pred_placeholder, pred_embeddings, context_rep)
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=tf.traspose(prediction), 
                                                   labels=tf.transpose(gold))
    return loss

context_rep = context_embeddings(sent_placeholder, arg_embeddings, output_dim=8)
pred_scores = pred_full_scoring(pred_placeholder, pred_embeddings, context_rep)


TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(1, ?), dtype=int32) is not an element of this graph.

In [339]:
from bp_pas.eval import evaluate

sess = tf.Session()
sess.run(tf.global_variables_initializer())

opt = tf.train.GradientDescentOptimizer(learning_rate=0.001)
loss = srl_loss(gold_placeholder, pred_placeholder, pred_embeddings, context_rep)
opt_op = opt.minimize(loss)

num_epochs = 10000
for epoch in range(num_epochs):
    epoch_loss = 0
    for train_dict in train_dicts:
        sent_ids, pred_ids, gold_labels = train_dict
#        print(gold_labels.shape)
        _, closs = sess.run([opt_op, loss], feed_dict = {
            sent_placeholder: sent_ids,
            pred_placeholder: pred_ids,
            gold_placeholder: gold_labels
        })
        epoch_loss += closs.mean()
    if epoch % 100 == 0:
        decoded = [pas for pas in decode(sess, sent).pas for sent in train_data[0][:5]]
        gold = [pas for pas in sent.pas for sent in train_data[0][:5]]
        evaluate(decoded, gold)
        print('{0:>4}: {1:.2f}'.format(epoch, epoch_loss))
    

#feed_dict = {sent_placeholder: sent_ids, pred_placeholder: pred_ids}
#context_out = sess.run(context_rep, feed_dict=feed_dict)
#print(context_out.shape)
#out = sess.run(pred_scores, feed_dict=feed_dict)
#print(out)
#print(out.shape)
#print(out[0].shape)

# 1, 9, 8

#want 1, 9, 1



# context mat is 71, 8

# pred mat is 8, 36

# after mult is 71 x 36


Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
   0: 2.84
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 100: 0.72
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 200: 0.49
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 300: 0.42
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 400: 0.39
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 500: 0.36
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 600: 0.33
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 700: 0.31
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 800: 0.29
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
 900: 0.27
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1000: 0.25
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1100: 0.23
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1200: 0.22
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1300: 0.21
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1400: 0.19
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1500: 0.18
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1600: 0.17
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1700: 0.17
Prec = (0/60) = 0.0
Rec  = (0/15) = 0.0
1800: 0.16
Prec = (0/60) = 0.0
Rec  = (0/1

KeyboardInterrupt: 

17
	 0
	 3
	 34
36
	 0
	 0
	 25
40
	 38
	 25
	 68
46
	 28
	 36
	 44
52
	 0
	 0
	 4
54
	 0
	 68
	 0
57
	 69
	 0
	 0
63
	 69
	 0
	 69
69
	 23
	 68
	 36


<bp_pas.ling.sent.Sentence at 0x1906d7908>