-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Keras Example for (Parikh et al, 2016) implementation (#2803)
* bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from tensorflow/tensorflow#3388 for tf bug workaround
- Loading branch information
1 parent
405a826
commit 9faea3f
Showing
5 changed files
with
1,250 additions
and
361 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,139 +1,197 @@ | ||
from __future__ import division, unicode_literals, print_function | ||
import spacy | ||
|
||
import plac | ||
from pathlib import Path | ||
import numpy as np | ||
import ujson as json | ||
import numpy | ||
from keras.utils.np_utils import to_categorical | ||
|
||
from spacy_hook import get_embeddings, get_word_ids | ||
from spacy_hook import create_similarity_pipeline | ||
from keras.utils import to_categorical | ||
import plac | ||
import sys | ||
|
||
from keras_decomposable_attention import build_model | ||
from spacy_hook import get_embeddings, KerasSimilarityShim | ||
|
||
try: | ||
import cPickle as pickle | ||
except ImportError: | ||
import pickle | ||
|
||
import spacy | ||
|
||
# workaround for keras/tensorflow bug | ||
# see https://github.com/tensorflow/tensorflow/issues/3388 | ||
import os | ||
import importlib | ||
from keras import backend as K | ||
|
||
def set_keras_backend(backend): | ||
if K.backend() != backend: | ||
os.environ['KERAS_BACKEND'] = backend | ||
importlib.reload(K) | ||
assert K.backend() == backend | ||
if backend == "tensorflow": | ||
K.get_session().close() | ||
cfg = K.tf.ConfigProto() | ||
cfg.gpu_options.allow_growth = True | ||
K.set_session(K.tf.Session(config=cfg)) | ||
K.clear_session() | ||
|
||
set_keras_backend("tensorflow") | ||
|
||
|
||
def train(train_loc, dev_loc, shape, settings): | ||
train_texts1, train_texts2, train_labels = read_snli(train_loc) | ||
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc) | ||
|
||
print("Loading spaCy") | ||
nlp = spacy.load('en') | ||
nlp = spacy.load('en_vectors_web_lg') | ||
assert nlp.path is not None | ||
|
||
print("Processing texts...") | ||
train_X = create_dataset(nlp, train_texts1, train_texts2, 100, shape[0]) | ||
dev_X = create_dataset(nlp, dev_texts1, dev_texts2, 100, shape[0]) | ||
|
||
print("Compiling network") | ||
model = build_model(get_embeddings(nlp.vocab), shape, settings) | ||
print("Processing texts...") | ||
Xs = [] | ||
for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2): | ||
Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)), | ||
max_length=shape[0], | ||
rnn_encode=settings['gru_encode'], | ||
tree_truncate=settings['tree_truncate'])) | ||
train_X1, train_X2, dev_X1, dev_X2 = Xs | ||
|
||
print(settings) | ||
model.fit( | ||
[train_X1, train_X2], | ||
train_X, | ||
train_labels, | ||
validation_data=([dev_X1, dev_X2], dev_labels), | ||
nb_epoch=settings['nr_epoch'], | ||
batch_size=settings['batch_size']) | ||
validation_data = (dev_X, dev_labels), | ||
epochs = settings['nr_epoch'], | ||
batch_size = settings['batch_size']) | ||
|
||
if not (nlp.path / 'similarity').exists(): | ||
(nlp.path / 'similarity').mkdir() | ||
print("Saving to", nlp.path / 'similarity') | ||
weights = model.get_weights() | ||
# remove the embedding matrix. We can reconstruct it. | ||
del weights[1] | ||
with (nlp.path / 'similarity' / 'model').open('wb') as file_: | ||
pickle.dump(weights[1:], file_) | ||
with (nlp.path / 'similarity' / 'config.json').open('wb') as file_: | ||
pickle.dump(weights, file_) | ||
with (nlp.path / 'similarity' / 'config.json').open('w') as file_: | ||
file_.write(model.to_json()) | ||
|
||
|
||
def evaluate(dev_loc): | ||
def evaluate(dev_loc, shape): | ||
dev_texts1, dev_texts2, dev_labels = read_snli(dev_loc) | ||
nlp = spacy.load('en', | ||
create_pipeline=create_similarity_pipeline) | ||
nlp = spacy.load('en_vectors_web_lg') | ||
nlp.add_pipe(KerasSimilarityShim.load(nlp.path / 'similarity', nlp, shape[0])) | ||
|
||
total = 0. | ||
correct = 0. | ||
for text1, text2, label in zip(dev_texts1, dev_texts2, dev_labels): | ||
doc1 = nlp(text1) | ||
doc2 = nlp(text2) | ||
sim = doc1.similarity(doc2) | ||
if sim.argmax() == label.argmax(): | ||
sim, _ = doc1.similarity(doc2) | ||
if sim == KerasSimilarityShim.entailment_types[label.argmax()]: | ||
correct += 1 | ||
total += 1 | ||
return correct, total | ||
|
||
|
||
def demo(): | ||
nlp = spacy.load('en', | ||
create_pipeline=create_similarity_pipeline) | ||
doc1 = nlp(u'What were the best crime fiction books in 2016?') | ||
doc2 = nlp( | ||
u'What should I read that was published last year? I like crime stories.') | ||
print(doc1) | ||
print(doc2) | ||
print("Similarity", doc1.similarity(doc2)) | ||
def demo(shape): | ||
nlp = spacy.load('en_vectors_web_lg') | ||
nlp.add_pipe(KerasSimilarityShim.load(nlp.path / 'similarity', nlp, shape[0])) | ||
|
||
doc1 = nlp(u'The king of France is bald.') | ||
doc2 = nlp(u'France has no king.') | ||
|
||
print("Sentence 1:", doc1) | ||
print("Sentence 2:", doc2) | ||
|
||
entailment_type, confidence = doc1.similarity(doc2) | ||
print("Entailment type:", entailment_type, "(Confidence:", confidence, ")") | ||
|
||
|
||
LABELS = {'entailment': 0, 'contradiction': 1, 'neutral': 2} | ||
def read_snli(path): | ||
texts1 = [] | ||
texts2 = [] | ||
labels = [] | ||
with path.open() as file_: | ||
with open(path, 'r') as file_: | ||
for line in file_: | ||
eg = json.loads(line) | ||
label = eg['gold_label'] | ||
if label == '-': | ||
if label == '-': # per Parikh, ignore - SNLI entries | ||
continue | ||
texts1.append(eg['sentence1']) | ||
texts2.append(eg['sentence2']) | ||
labels.append(LABELS[label]) | ||
return texts1, texts2, to_categorical(numpy.asarray(labels, dtype='int32')) | ||
return texts1, texts2, to_categorical(np.asarray(labels, dtype='int32')) | ||
|
||
def create_dataset(nlp, texts, hypotheses, num_unk, max_length): | ||
sents = texts + hypotheses | ||
|
||
sents_as_ids = [] | ||
for sent in sents: | ||
doc = nlp(sent) | ||
word_ids = [] | ||
|
||
for i, token in enumerate(doc): | ||
# skip odd spaces from tokenizer | ||
if token.has_vector and token.vector_norm == 0: | ||
continue | ||
|
||
if i > max_length: | ||
break | ||
|
||
if token.has_vector: | ||
word_ids.append(token.rank + num_unk + 1) | ||
else: | ||
# if we don't have a vector, pick an OOV entry | ||
word_ids.append(token.rank % num_unk + 1) | ||
|
||
# there must be a simpler way of generating padded arrays from lists... | ||
word_id_vec = np.zeros((max_length), dtype='int') | ||
clipped_len = min(max_length, len(word_ids)) | ||
word_id_vec[:clipped_len] = word_ids[:clipped_len] | ||
sents_as_ids.append(word_id_vec) | ||
|
||
|
||
return [np.array(sents_as_ids[:len(texts)]), np.array(sents_as_ids[len(texts):])] | ||
|
||
|
||
@plac.annotations( | ||
mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]), | ||
train_loc=("Path to training data", "positional", None, Path), | ||
dev_loc=("Path to development data", "positional", None, Path), | ||
train_loc=("Path to training data", "option", "t", str), | ||
dev_loc=("Path to development or test data", "option", "s", str), | ||
max_length=("Length to truncate sentences", "option", "L", int), | ||
nr_hidden=("Number of hidden units", "option", "H", int), | ||
dropout=("Dropout level", "option", "d", float), | ||
learn_rate=("Learning rate", "option", "e", float), | ||
learn_rate=("Learning rate", "option", "r", float), | ||
batch_size=("Batch size for neural network training", "option", "b", int), | ||
nr_epoch=("Number of training epochs", "option", "i", int), | ||
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool), | ||
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool), | ||
nr_epoch=("Number of training epochs", "option", "e", int), | ||
entail_dir=("Direction of entailment", "option", "D", str, ["both", "left", "right"]) | ||
) | ||
def main(mode, train_loc, dev_loc, | ||
tree_truncate=False, | ||
gru_encode=False, | ||
max_length=100, | ||
nr_hidden=100, | ||
dropout=0.2, | ||
learn_rate=0.001, | ||
batch_size=100, | ||
nr_epoch=5): | ||
max_length = 50, | ||
nr_hidden = 200, | ||
dropout = 0.2, | ||
learn_rate = 0.001, | ||
batch_size = 1024, | ||
nr_epoch = 10, | ||
entail_dir="both"): | ||
|
||
shape = (max_length, nr_hidden, 3) | ||
settings = { | ||
'lr': learn_rate, | ||
'dropout': dropout, | ||
'batch_size': batch_size, | ||
'nr_epoch': nr_epoch, | ||
'tree_truncate': tree_truncate, | ||
'gru_encode': gru_encode | ||
'entail_dir': entail_dir | ||
} | ||
|
||
if mode == 'train': | ||
if train_loc == None or dev_loc == None: | ||
print("Train mode requires paths to training and development data sets.") | ||
sys.exit(1) | ||
train(train_loc, dev_loc, shape, settings) | ||
elif mode == 'evaluate': | ||
correct, total = evaluate(dev_loc) | ||
if dev_loc == None: | ||
print("Evaluate mode requires paths to test data set.") | ||
sys.exit(1) | ||
correct, total = evaluate(dev_loc, shape) | ||
print(correct, '/', total, correct / total) | ||
else: | ||
demo() | ||
demo(shape) | ||
|
||
if __name__ == '__main__': | ||
plac.call(main) |
Oops, something went wrong.