In [None]:
import sys
sys.path.append("../src/")

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

from allennlp.common.file_utils import cached_path
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.data.iterators import BucketIterator
from allennlp.predictors import SentenceTaggerPredictor
from allennlp.training.trainer import Trainer

import torch
import torch.optim as optim

In [None]:
from multi_criterion_dataset_reader import MultiCriterionDatasetReader
from multi_criterion_tokenizer import MultiCriterionTokenizer
from tokenize_predictor import TokenizePredictor

In [None]:
reader = MultiCriterionDatasetReader()
train_dataset = reader.read(cached_path(
    '../data/sample/train_dataset.json'))
vocab = Vocabulary.from_instances(train_dataset)

EMBEDDING_DIM = 6
HIDDEN_DIM = 6

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
criterion_embedding = Embedding(num_embeddings=7, embedding_dim=EMBEDDING_DIM)  # FIXME: num embeddings
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
model = MultiCriterionTokenizer(word_embeddings, criterion_embedding, lstm, vocab)

optimizer = optim.SGD(model.parameters(), lr=0.1)
iterator = BucketIterator(batch_size=1, sorting_keys=[("sentence", "num_tokens")])
iterator.index_with(vocab)
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  patience=10,
                  num_epochs=100,
                  cuda_device=-1)
trainer.train()


In [None]:
predictor = TokenizePredictor(model, dataset_reader=reader)
predictor.predict("選挙管理委員会", "sudachib")