From 0ed8610672424e1f9ab562bba276ae35d3ddcca1 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 20 May 2018 14:54:01 +0000 Subject: [PATCH] Make lstm tagger not depend on pytorch --- examples/lstm_tagger.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/examples/lstm_tagger.py b/examples/lstm_tagger.py index baab55317..5d411a97d 100644 --- a/examples/lstm_tagger.py +++ b/examples/lstm_tagger.py @@ -20,18 +20,9 @@ from thinc.neural.optimizers import SGD from thinc.neural.util import get_array_module -import torch -import torch.nn -import torch.autograd - from thinc.extra.datasets import ancora_pos_tags -def PyTorchBiLSTM(nO, nI, depth): - model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True) - return with_square_sequences(PyTorchWrapperRNN(model)) - - def FeatureExtracter(lang, attrs=[LOWER, SHAPE, PREFIX, SUFFIX], tokenized=True): nlp = spacy.blank(lang) nlp.vocab.lex_attr_getters[PREFIX] = lambda string: string[:3] @@ -102,14 +93,12 @@ def debug(X, drop=0.): L2=("L2 regularization penalty", "option", "L", float), ) def main(width=32, depth=1, vector_length=32, - min_batch_size=1, max_batch_size=4, learn_rate=0.001, + min_batch_size=16, max_batch_size=16, learn_rate=0.001, momentum=0.9, dropout=0.5, dropout_decay=1e-4, nb_epoch=20, L2=1e-6): cfg = dict(locals()) print(cfg) train_data, check_data, nr_tag = ancora_pos_tags() - train_data = train_data[:100] - check_data = check_data[:100] extracter = FeatureExtracter('es', attrs=[LOWER, SHAPE, PREFIX, SUFFIX]) with Model.define_operators({'**': clone, '>>': chain, '+': add,