In [1]:
from typing import Iterator, List, Dict
import torch
import torch.optim as optim
import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor

torch.manual_seed(1)

<torch._C.Generator at 0x7f81b00250f0>

In [2]:
class PosDatasetReader(DatasetReader):
    """
    DatasetReader for PoS tagging data, one sentence per line, like

        The###DET dog###NN ate###V the###DET apple###NN
    """
    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        
    def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"sentence": sentence_field}

        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=sentence_field)
            fields["labels"] = label_field

        return Instance(fields)
    
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            for line in f:
                pairs = line.strip().split()
                sentence, tags = zip(*(pair.split("###") for pair in pairs))
                yield self.text_to_instance([Token(word) for word in sentence], tags)

In [3]:
class LstmTagger(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        self.accuracy = CategoricalAccuracy()
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        mask = get_text_field_mask(sentence)
        embeddings = self.word_embeddings(sentence)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [4]:
reader = PosDatasetReader()
train_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp'
    '/master/tutorials/tagger/training.txt'))
validation_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp'
    '/master/tutorials/tagger/validation.txt'))
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)


93B [00:00, 81553.48B/s]             
2it [00:00, 8839.42it/s]
93B [00:00, 83724.03B/s]             
2it [00:00, 7281.78it/s]
100%|██████████| 4/4 [00:00<00:00, 30066.70it/s]


In [5]:
EMBEDDING_DIM = 6
HIDDEN_DIM = 6

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
model = LstmTagger(word_embeddings, lstm, vocab)

if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
else:
    cuda_device = -1

optimizer = optim.SGD(model.parameters(), lr=0.1)
iterator = BucketIterator(batch_size=2, sorting_keys=[("sentence", "num_tokens")])
iterator.index_with(vocab)

trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=1000,
                  cuda_device=cuda_device)

trainer.train()
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)
tag_logits = predictor.predict("The dog ate the apple")['tag_logits']
tag_ids = np.argmax(tag_logits, axis=-1)

accuracy: 0.3333, loss: 1.1685 ||: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
accuracy: 0.3333, loss: 1.1592 ||: 100%|██████████| 1/1 [00:00<00:00, 36.34it/s]
accuracy: 0.3333, loss: 1.1604 ||: 100%|██████████| 1/1 [00:00<00:00, 62.36it/s]
accuracy: 0.3333, loss: 1.1516 ||: 100%|██████████| 1/1 [00:00<00:00, 289.22it/s]
accuracy: 0.3333, loss: 1.1529 ||: 100%|██████████| 1/1 [00:00<00:00, 45.65it/s]
accuracy: 0.3333, loss: 1.1445 ||: 100%|██████████| 1/1 [00:00<00:00, 299.76it/s]
accuracy: 0.3333, loss: 1.1458 ||: 100%|██████████| 1/1 [00:00<00:00, 52.91it/s]
accuracy: 0.3333, loss: 1.1379 ||: 100%|██████████| 1/1 [00:00<00:00, 233.57it/s]
accuracy: 0.3333, loss: 1.1391 ||: 100%|██████████| 1/1 [00:00<00:00, 48.10it/s]
accuracy: 0.3333, loss: 1.1316 ||: 100%|██████████| 1/1 [00:00<00:00, 74.05it/s]
accuracy: 0.3333, loss: 1.1329 ||: 100%|██████████| 1/1 [00:00<00:00, 48.44it/s]
accuracy: 0.3333, loss: 1.1259 ||: 100%|██████████| 1/1 [00:00<00:00, 314.96it/s]
accuracy: 0.3333, loss: 

accuracy: 0.4444, loss: 1.0534 ||: 100%|██████████| 1/1 [00:00<00:00, 52.95it/s]
accuracy: 0.4444, loss: 1.0517 ||: 100%|██████████| 1/1 [00:00<00:00, 329.22it/s]
accuracy: 0.4444, loss: 1.0531 ||: 100%|██████████| 1/1 [00:00<00:00, 67.19it/s]
accuracy: 0.4444, loss: 1.0513 ||: 100%|██████████| 1/1 [00:00<00:00, 234.32it/s]
accuracy: 0.4444, loss: 1.0528 ||: 100%|██████████| 1/1 [00:00<00:00, 56.69it/s]
accuracy: 0.4444, loss: 1.0510 ||: 100%|██████████| 1/1 [00:00<00:00, 349.64it/s]
accuracy: 0.4444, loss: 1.0524 ||: 100%|██████████| 1/1 [00:00<00:00, 61.46it/s]
accuracy: 0.4444, loss: 1.0507 ||: 100%|██████████| 1/1 [00:00<00:00, 126.48it/s]
accuracy: 0.4444, loss: 1.0521 ||: 100%|██████████| 1/1 [00:00<00:00, 66.02it/s]
accuracy: 0.4444, loss: 1.0504 ||: 100%|██████████| 1/1 [00:00<00:00, 315.12it/s]
accuracy: 0.4444, loss: 1.0518 ||: 100%|██████████| 1/1 [00:00<00:00, 48.39it/s]
accuracy: 0.4444, loss: 1.0501 ||: 100%|██████████| 1/1 [00:00<00:00, 327.04it/s]
accuracy: 0.4444, loss

accuracy: 0.4444, loss: 1.0370 ||: 100%|██████████| 1/1 [00:00<00:00, 74.03it/s]
accuracy: 0.4444, loss: 1.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 120.06it/s]
accuracy: 0.4444, loss: 1.0366 ||: 100%|██████████| 1/1 [00:00<00:00, 74.65it/s]
accuracy: 0.4444, loss: 1.0349 ||: 100%|██████████| 1/1 [00:00<00:00, 114.04it/s]
accuracy: 0.4444, loss: 1.0362 ||: 100%|██████████| 1/1 [00:00<00:00, 53.86it/s]
accuracy: 0.4444, loss: 1.0345 ||: 100%|██████████| 1/1 [00:00<00:00, 423.50it/s]
accuracy: 0.4444, loss: 1.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 54.63it/s]
accuracy: 0.4444, loss: 1.0341 ||: 100%|██████████| 1/1 [00:00<00:00, 450.27it/s]
accuracy: 0.4444, loss: 1.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 55.44it/s]
accuracy: 0.4444, loss: 1.0336 ||: 100%|██████████| 1/1 [00:00<00:00, 445.59it/s]
accuracy: 0.4444, loss: 1.0349 ||: 100%|██████████| 1/1 [00:00<00:00, 60.64it/s]
accuracy: 0.4444, loss: 1.0332 ||: 100%|██████████| 1/1 [00:00<00:00, 315.65it/s]
accuracy: 0.4444, loss

accuracy: 0.4444, loss: 1.0061 ||: 100%|██████████| 1/1 [00:00<00:00, 87.54it/s]
accuracy: 0.4444, loss: 1.0042 ||: 100%|██████████| 1/1 [00:00<00:00, 337.19it/s]
accuracy: 0.4444, loss: 1.0051 ||: 100%|██████████| 1/1 [00:00<00:00, 142.12it/s]
accuracy: 0.4444, loss: 1.0033 ||: 100%|██████████| 1/1 [00:00<00:00, 329.97it/s]
accuracy: 0.4444, loss: 1.0042 ||: 100%|██████████| 1/1 [00:00<00:00, 87.62it/s]
accuracy: 0.4444, loss: 1.0024 ||: 100%|██████████| 1/1 [00:00<00:00, 331.43it/s]
accuracy: 0.4444, loss: 1.0033 ||: 100%|██████████| 1/1 [00:00<00:00, 144.66it/s]
accuracy: 0.4444, loss: 1.0014 ||: 100%|██████████| 1/1 [00:00<00:00, 332.20it/s]
accuracy: 0.4444, loss: 1.0023 ||: 100%|██████████| 1/1 [00:00<00:00, 128.11it/s]
accuracy: 0.4444, loss: 1.0005 ||: 100%|██████████| 1/1 [00:00<00:00, 333.78it/s]
accuracy: 0.4444, loss: 1.0014 ||: 100%|██████████| 1/1 [00:00<00:00, 118.44it/s]
accuracy: 0.4444, loss: 0.9995 ||: 100%|██████████| 1/1 [00:00<00:00, 331.72it/s]
accuracy: 0.4444, 

accuracy: 0.4444, loss: 0.9355 ||: 100%|██████████| 1/1 [00:00<00:00, 155.54it/s]
accuracy: 0.4444, loss: 0.9329 ||: 100%|██████████| 1/1 [00:00<00:00, 312.08it/s]
accuracy: 0.4444, loss: 0.9334 ||: 100%|██████████| 1/1 [00:00<00:00, 174.80it/s]
accuracy: 0.4444, loss: 0.9307 ||: 100%|██████████| 1/1 [00:00<00:00, 443.65it/s]
accuracy: 0.4444, loss: 0.9313 ||: 100%|██████████| 1/1 [00:00<00:00, 171.71it/s]
accuracy: 0.4444, loss: 0.9286 ||: 100%|██████████| 1/1 [00:00<00:00, 442.20it/s]
accuracy: 0.4444, loss: 0.9292 ||: 100%|██████████| 1/1 [00:00<00:00, 167.77it/s]
accuracy: 0.4444, loss: 0.9264 ||: 100%|██████████| 1/1 [00:00<00:00, 309.02it/s]
accuracy: 0.4444, loss: 0.9270 ||: 100%|██████████| 1/1 [00:00<00:00, 128.98it/s]
accuracy: 0.4444, loss: 0.9242 ||: 100%|██████████| 1/1 [00:00<00:00, 310.05it/s]
accuracy: 0.5556, loss: 0.9248 ||: 100%|██████████| 1/1 [00:00<00:00, 104.71it/s]
accuracy: 0.4444, loss: 0.9220 ||: 100%|██████████| 1/1 [00:00<00:00, 236.63it/s]
accuracy: 0.5556

accuracy: 0.6667, loss: 0.7936 ||: 100%|██████████| 1/1 [00:00<00:00, 54.50it/s]
accuracy: 0.6667, loss: 0.7903 ||: 100%|██████████| 1/1 [00:00<00:00, 113.00it/s]
accuracy: 0.6667, loss: 0.7901 ||: 100%|██████████| 1/1 [00:00<00:00, 53.66it/s]
accuracy: 0.6667, loss: 0.7868 ||: 100%|██████████| 1/1 [00:00<00:00, 115.29it/s]
accuracy: 0.6667, loss: 0.7867 ||: 100%|██████████| 1/1 [00:00<00:00, 47.40it/s]
accuracy: 0.6667, loss: 0.7833 ||: 100%|██████████| 1/1 [00:00<00:00, 114.94it/s]
accuracy: 0.6667, loss: 0.7831 ||: 100%|██████████| 1/1 [00:00<00:00, 54.24it/s]
accuracy: 0.6667, loss: 0.7798 ||: 100%|██████████| 1/1 [00:00<00:00, 113.11it/s]
accuracy: 0.6667, loss: 0.7796 ||: 100%|██████████| 1/1 [00:00<00:00, 55.11it/s]
accuracy: 0.6667, loss: 0.7763 ||: 100%|██████████| 1/1 [00:00<00:00, 115.55it/s]
accuracy: 0.6667, loss: 0.7761 ||: 100%|██████████| 1/1 [00:00<00:00, 73.68it/s]
accuracy: 0.6667, loss: 0.7728 ||: 100%|██████████| 1/1 [00:00<00:00, 184.84it/s]
accuracy: 0.6667, loss

accuracy: 0.7778, loss: 0.6130 ||: 100%|██████████| 1/1 [00:00<00:00, 66.95it/s]
accuracy: 0.7778, loss: 0.6110 ||: 100%|██████████| 1/1 [00:00<00:00, 114.32it/s]
accuracy: 0.7778, loss: 0.6093 ||: 100%|██████████| 1/1 [00:00<00:00, 67.21it/s]
accuracy: 0.7778, loss: 0.6074 ||: 100%|██████████| 1/1 [00:00<00:00, 115.33it/s]
accuracy: 0.7778, loss: 0.6057 ||: 100%|██████████| 1/1 [00:00<00:00, 52.73it/s]
accuracy: 0.7778, loss: 0.6038 ||: 100%|██████████| 1/1 [00:00<00:00, 115.34it/s]
accuracy: 0.7778, loss: 0.6021 ||: 100%|██████████| 1/1 [00:00<00:00, 52.99it/s]
accuracy: 0.7778, loss: 0.6002 ||: 100%|██████████| 1/1 [00:00<00:00, 115.58it/s]
accuracy: 0.7778, loss: 0.5985 ||: 100%|██████████| 1/1 [00:00<00:00, 48.90it/s]
accuracy: 0.7778, loss: 0.5966 ||: 100%|██████████| 1/1 [00:00<00:00, 110.62it/s]
accuracy: 0.7778, loss: 0.5949 ||: 100%|██████████| 1/1 [00:00<00:00, 62.71it/s]
accuracy: 0.7778, loss: 0.5930 ||: 100%|██████████| 1/1 [00:00<00:00, 115.43it/s]
accuracy: 0.7778, loss

accuracy: 1.0000, loss: 0.4350 ||: 100%|██████████| 1/1 [00:00<00:00, 60.61it/s]
accuracy: 1.0000, loss: 0.4337 ||: 100%|██████████| 1/1 [00:00<00:00, 227.32it/s]
accuracy: 1.0000, loss: 0.4316 ||: 100%|██████████| 1/1 [00:00<00:00, 53.72it/s]
accuracy: 1.0000, loss: 0.4303 ||: 100%|██████████| 1/1 [00:00<00:00, 232.04it/s]
accuracy: 1.0000, loss: 0.4283 ||: 100%|██████████| 1/1 [00:00<00:00, 73.42it/s]
accuracy: 1.0000, loss: 0.4269 ||: 100%|██████████| 1/1 [00:00<00:00, 398.51it/s]
accuracy: 1.0000, loss: 0.4249 ||: 100%|██████████| 1/1 [00:00<00:00, 74.69it/s]
accuracy: 1.0000, loss: 0.4235 ||: 100%|██████████| 1/1 [00:00<00:00, 423.11it/s]
accuracy: 1.0000, loss: 0.4216 ||: 100%|██████████| 1/1 [00:00<00:00, 42.03it/s]
accuracy: 1.0000, loss: 0.4202 ||: 100%|██████████| 1/1 [00:00<00:00, 420.86it/s]
accuracy: 1.0000, loss: 0.4182 ||: 100%|██████████| 1/1 [00:00<00:00, 76.13it/s]
accuracy: 1.0000, loss: 0.4168 ||: 100%|██████████| 1/1 [00:00<00:00, 67.01it/s]
accuracy: 1.0000, loss:

accuracy: 1.0000, loss: 0.2855 ||: 100%|██████████| 1/1 [00:00<00:00, 41.11it/s]
accuracy: 1.0000, loss: 0.2838 ||: 100%|██████████| 1/1 [00:00<00:00, 171.51it/s]
accuracy: 1.0000, loss: 0.2830 ||: 100%|██████████| 1/1 [00:00<00:00, 42.29it/s]
accuracy: 1.0000, loss: 0.2813 ||: 100%|██████████| 1/1 [00:00<00:00, 146.43it/s]
accuracy: 1.0000, loss: 0.2806 ||: 100%|██████████| 1/1 [00:00<00:00, 45.45it/s]
accuracy: 1.0000, loss: 0.2788 ||: 100%|██████████| 1/1 [00:00<00:00, 144.70it/s]
accuracy: 1.0000, loss: 0.2781 ||: 100%|██████████| 1/1 [00:00<00:00, 65.73it/s]
accuracy: 1.0000, loss: 0.2764 ||: 100%|██████████| 1/1 [00:00<00:00, 279.43it/s]
accuracy: 1.0000, loss: 0.2757 ||: 100%|██████████| 1/1 [00:00<00:00, 42.28it/s]
accuracy: 1.0000, loss: 0.2739 ||: 100%|██████████| 1/1 [00:00<00:00, 144.51it/s]
accuracy: 1.0000, loss: 0.2732 ||: 100%|██████████| 1/1 [00:00<00:00, 40.78it/s]
accuracy: 1.0000, loss: 0.2715 ||: 100%|██████████| 1/1 [00:00<00:00, 162.96it/s]
accuracy: 1.0000, loss

accuracy: 1.0000, loss: 0.1846 ||: 100%|██████████| 1/1 [00:00<00:00, 108.78it/s]
accuracy: 1.0000, loss: 0.1831 ||: 100%|██████████| 1/1 [00:00<00:00, 419.26it/s]
accuracy: 1.0000, loss: 0.1830 ||: 100%|██████████| 1/1 [00:00<00:00, 86.65it/s]
accuracy: 1.0000, loss: 0.1816 ||: 100%|██████████| 1/1 [00:00<00:00, 427.16it/s]
accuracy: 1.0000, loss: 0.1815 ||: 100%|██████████| 1/1 [00:00<00:00, 93.65it/s]
accuracy: 1.0000, loss: 0.1800 ||: 100%|██████████| 1/1 [00:00<00:00, 356.75it/s]
accuracy: 1.0000, loss: 0.1800 ||: 100%|██████████| 1/1 [00:00<00:00, 93.66it/s]
accuracy: 1.0000, loss: 0.1785 ||: 100%|██████████| 1/1 [00:00<00:00, 374.52it/s]
accuracy: 1.0000, loss: 0.1784 ||: 100%|██████████| 1/1 [00:00<00:00, 140.17it/s]
accuracy: 1.0000, loss: 0.1770 ||: 100%|██████████| 1/1 [00:00<00:00, 380.44it/s]
accuracy: 1.0000, loss: 0.1770 ||: 100%|██████████| 1/1 [00:00<00:00, 101.69it/s]
accuracy: 1.0000, loss: 0.1755 ||: 100%|██████████| 1/1 [00:00<00:00, 239.91it/s]
accuracy: 1.0000, l

accuracy: 1.0000, loss: 0.1239 ||: 100%|██████████| 1/1 [00:00<00:00, 66.55it/s]
accuracy: 1.0000, loss: 0.1228 ||: 100%|██████████| 1/1 [00:00<00:00, 254.28it/s]
accuracy: 1.0000, loss: 0.1229 ||: 100%|██████████| 1/1 [00:00<00:00, 64.32it/s]
accuracy: 1.0000, loss: 0.1219 ||: 100%|██████████| 1/1 [00:00<00:00, 234.65it/s]
accuracy: 1.0000, loss: 0.1220 ||: 100%|██████████| 1/1 [00:00<00:00, 54.90it/s]
accuracy: 1.0000, loss: 0.1210 ||: 100%|██████████| 1/1 [00:00<00:00, 189.85it/s]
accuracy: 1.0000, loss: 0.1211 ||: 100%|██████████| 1/1 [00:00<00:00, 67.03it/s]
accuracy: 1.0000, loss: 0.1201 ||: 100%|██████████| 1/1 [00:00<00:00, 339.34it/s]
accuracy: 1.0000, loss: 0.1203 ||: 100%|██████████| 1/1 [00:00<00:00, 74.89it/s]
accuracy: 1.0000, loss: 0.1193 ||: 100%|██████████| 1/1 [00:00<00:00, 360.43it/s]
accuracy: 1.0000, loss: 0.1194 ||: 100%|██████████| 1/1 [00:00<00:00, 74.17it/s]
accuracy: 1.0000, loss: 0.1184 ||: 100%|██████████| 1/1 [00:00<00:00, 348.51it/s]
accuracy: 1.0000, loss

accuracy: 1.0000, loss: 0.0880 ||: 100%|██████████| 1/1 [00:00<00:00, 65.43it/s]
accuracy: 1.0000, loss: 0.0873 ||: 100%|██████████| 1/1 [00:00<00:00, 296.50it/s]
accuracy: 1.0000, loss: 0.0874 ||: 100%|██████████| 1/1 [00:00<00:00, 65.57it/s]
accuracy: 1.0000, loss: 0.0868 ||: 100%|██████████| 1/1 [00:00<00:00, 301.53it/s]
accuracy: 1.0000, loss: 0.0869 ||: 100%|██████████| 1/1 [00:00<00:00, 65.98it/s]
accuracy: 1.0000, loss: 0.0862 ||: 100%|██████████| 1/1 [00:00<00:00, 289.04it/s]
accuracy: 1.0000, loss: 0.0864 ||: 100%|██████████| 1/1 [00:00<00:00, 65.71it/s]
accuracy: 1.0000, loss: 0.0857 ||: 100%|██████████| 1/1 [00:00<00:00, 335.68it/s]
accuracy: 1.0000, loss: 0.0858 ||: 100%|██████████| 1/1 [00:00<00:00, 65.64it/s]
accuracy: 1.0000, loss: 0.0852 ||: 100%|██████████| 1/1 [00:00<00:00, 283.15it/s]
accuracy: 1.0000, loss: 0.0853 ||: 100%|██████████| 1/1 [00:00<00:00, 65.73it/s]
accuracy: 1.0000, loss: 0.0846 ||: 100%|██████████| 1/1 [00:00<00:00, 292.76it/s]
accuracy: 1.0000, loss

accuracy: 1.0000, loss: 0.0660 ||: 100%|██████████| 1/1 [00:00<00:00, 130.22it/s]
accuracy: 1.0000, loss: 0.0655 ||: 100%|██████████| 1/1 [00:00<00:00, 328.63it/s]
accuracy: 1.0000, loss: 0.0656 ||: 100%|██████████| 1/1 [00:00<00:00, 72.42it/s]
accuracy: 1.0000, loss: 0.0652 ||: 100%|██████████| 1/1 [00:00<00:00, 296.06it/s]
accuracy: 1.0000, loss: 0.0653 ||: 100%|██████████| 1/1 [00:00<00:00, 102.09it/s]
accuracy: 1.0000, loss: 0.0648 ||: 100%|██████████| 1/1 [00:00<00:00, 274.48it/s]
accuracy: 1.0000, loss: 0.0650 ||: 100%|██████████| 1/1 [00:00<00:00, 92.72it/s]
accuracy: 1.0000, loss: 0.0645 ||: 100%|██████████| 1/1 [00:00<00:00, 277.58it/s]
accuracy: 1.0000, loss: 0.0646 ||: 100%|██████████| 1/1 [00:00<00:00, 83.59it/s]
accuracy: 1.0000, loss: 0.0642 ||: 100%|██████████| 1/1 [00:00<00:00, 275.87it/s]
accuracy: 1.0000, loss: 0.0643 ||: 100%|██████████| 1/1 [00:00<00:00, 115.28it/s]
accuracy: 1.0000, loss: 0.0638 ||: 100%|██████████| 1/1 [00:00<00:00, 256.75it/s]
accuracy: 1.0000, l

accuracy: 1.0000, loss: 0.0517 ||: 100%|██████████| 1/1 [00:00<00:00, 92.94it/s]
accuracy: 1.0000, loss: 0.0514 ||: 100%|██████████| 1/1 [00:00<00:00, 235.15it/s]
accuracy: 1.0000, loss: 0.0515 ||: 100%|██████████| 1/1 [00:00<00:00, 85.01it/s]
accuracy: 1.0000, loss: 0.0511 ||: 100%|██████████| 1/1 [00:00<00:00, 277.51it/s]
accuracy: 1.0000, loss: 0.0512 ||: 100%|██████████| 1/1 [00:00<00:00, 80.47it/s]
accuracy: 1.0000, loss: 0.0509 ||: 100%|██████████| 1/1 [00:00<00:00, 214.89it/s]
accuracy: 1.0000, loss: 0.0510 ||: 100%|██████████| 1/1 [00:00<00:00, 87.35it/s]
accuracy: 1.0000, loss: 0.0507 ||: 100%|██████████| 1/1 [00:00<00:00, 284.84it/s]
accuracy: 1.0000, loss: 0.0508 ||: 100%|██████████| 1/1 [00:00<00:00, 82.26it/s]
accuracy: 1.0000, loss: 0.0505 ||: 100%|██████████| 1/1 [00:00<00:00, 214.16it/s]
accuracy: 1.0000, loss: 0.0506 ||: 100%|██████████| 1/1 [00:00<00:00, 130.63it/s]
accuracy: 1.0000, loss: 0.0502 ||: 100%|██████████| 1/1 [00:00<00:00, 233.28it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.0419 ||: 100%|██████████| 1/1 [00:00<00:00, 66.20it/s]
accuracy: 1.0000, loss: 0.0417 ||: 100%|██████████| 1/1 [00:00<00:00, 163.67it/s]
accuracy: 1.0000, loss: 0.0418 ||: 100%|██████████| 1/1 [00:00<00:00, 55.00it/s]
accuracy: 1.0000, loss: 0.0415 ||: 100%|██████████| 1/1 [00:00<00:00, 109.87it/s]
accuracy: 1.0000, loss: 0.0416 ||: 100%|██████████| 1/1 [00:00<00:00, 49.96it/s]
accuracy: 1.0000, loss: 0.0414 ||: 100%|██████████| 1/1 [00:00<00:00, 109.41it/s]
accuracy: 1.0000, loss: 0.0415 ||: 100%|██████████| 1/1 [00:00<00:00, 50.11it/s]
accuracy: 1.0000, loss: 0.0412 ||: 100%|██████████| 1/1 [00:00<00:00, 109.10it/s]
accuracy: 1.0000, loss: 0.0413 ||: 100%|██████████| 1/1 [00:00<00:00, 53.66it/s]
accuracy: 1.0000, loss: 0.0411 ||: 100%|██████████| 1/1 [00:00<00:00, 111.37it/s]
accuracy: 1.0000, loss: 0.0411 ||: 100%|██████████| 1/1 [00:00<00:00, 56.44it/s]
accuracy: 1.0000, loss: 0.0409 ||: 100%|██████████| 1/1 [00:00<00:00, 243.06it/s]
accuracy: 1.0000, loss

accuracy: 1.0000, loss: 0.0349 ||: 100%|██████████| 1/1 [00:00<00:00, 64.19it/s]
accuracy: 1.0000, loss: 0.0348 ||: 100%|██████████| 1/1 [00:00<00:00, 294.32it/s]
accuracy: 1.0000, loss: 0.0348 ||: 100%|██████████| 1/1 [00:00<00:00, 56.75it/s]
accuracy: 1.0000, loss: 0.0347 ||: 100%|██████████| 1/1 [00:00<00:00, 281.65it/s]
accuracy: 1.0000, loss: 0.0347 ||: 100%|██████████| 1/1 [00:00<00:00, 56.57it/s]
accuracy: 1.0000, loss: 0.0345 ||: 100%|██████████| 1/1 [00:00<00:00, 273.55it/s]
accuracy: 1.0000, loss: 0.0346 ||: 100%|██████████| 1/1 [00:00<00:00, 56.17it/s]
accuracy: 1.0000, loss: 0.0344 ||: 100%|██████████| 1/1 [00:00<00:00, 265.82it/s]
accuracy: 1.0000, loss: 0.0345 ||: 100%|██████████| 1/1 [00:00<00:00, 56.88it/s]
accuracy: 1.0000, loss: 0.0343 ||: 100%|██████████| 1/1 [00:00<00:00, 271.65it/s]
accuracy: 1.0000, loss: 0.0344 ||: 100%|██████████| 1/1 [00:00<00:00, 57.04it/s]
accuracy: 1.0000, loss: 0.0342 ||: 100%|██████████| 1/1 [00:00<00:00, 289.10it/s]
accuracy: 1.0000, loss

accuracy: 1.0000, loss: 0.0298 ||: 100%|██████████| 1/1 [00:00<00:00, 64.14it/s]
accuracy: 1.0000, loss: 0.0296 ||: 100%|██████████| 1/1 [00:00<00:00, 266.98it/s]
accuracy: 1.0000, loss: 0.0297 ||: 100%|██████████| 1/1 [00:00<00:00, 67.75it/s]
accuracy: 1.0000, loss: 0.0295 ||: 100%|██████████| 1/1 [00:00<00:00, 115.81it/s]
accuracy: 1.0000, loss: 0.0296 ||: 100%|██████████| 1/1 [00:00<00:00, 64.91it/s]
accuracy: 1.0000, loss: 0.0294 ||: 100%|██████████| 1/1 [00:00<00:00, 113.06it/s]
accuracy: 1.0000, loss: 0.0295 ||: 100%|██████████| 1/1 [00:00<00:00, 53.11it/s]
accuracy: 1.0000, loss: 0.0293 ||: 100%|██████████| 1/1 [00:00<00:00, 112.83it/s]
accuracy: 1.0000, loss: 0.0294 ||: 100%|██████████| 1/1 [00:00<00:00, 45.41it/s]
accuracy: 1.0000, loss: 0.0293 ||: 100%|██████████| 1/1 [00:00<00:00, 108.61it/s]
accuracy: 1.0000, loss: 0.0293 ||: 100%|██████████| 1/1 [00:00<00:00, 71.24it/s]
accuracy: 1.0000, loss: 0.0292 ||: 100%|██████████| 1/1 [00:00<00:00, 116.27it/s]
accuracy: 1.0000, loss

accuracy: 1.0000, loss: 0.0258 ||: 100%|██████████| 1/1 [00:00<00:00, 174.25it/s]
accuracy: 1.0000, loss: 0.0257 ||: 100%|██████████| 1/1 [00:00<00:00, 439.29it/s]
accuracy: 1.0000, loss: 0.0257 ||: 100%|██████████| 1/1 [00:00<00:00, 172.89it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|██████████| 1/1 [00:00<00:00, 380.64it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|██████████| 1/1 [00:00<00:00, 48.84it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|██████████| 1/1 [00:00<00:00, 151.50it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|██████████| 1/1 [00:00<00:00, 175.02it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|██████████| 1/1 [00:00<00:00, 416.76it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|██████████| 1/1 [00:00<00:00, 176.63it/s]
accuracy: 1.0000, loss: 0.0254 ||: 100%|██████████| 1/1 [00:00<00:00, 437.45it/s]
accuracy: 1.0000, loss: 0.0254 ||: 100%|██████████| 1/1 [00:00<00:00, 75.32it/s]
accuracy: 1.0000, loss: 0.0253 ||: 100%|██████████| 1/1 [00:00<00:00, 443.28it/s]
accuracy: 1.0000, 

accuracy: 1.0000, loss: 0.0226 ||: 100%|██████████| 1/1 [00:00<00:00, 171.25it/s]
accuracy: 1.0000, loss: 0.0226 ||: 100%|██████████| 1/1 [00:00<00:00, 220.34it/s]
accuracy: 1.0000, loss: 0.0226 ||: 100%|██████████| 1/1 [00:00<00:00, 162.00it/s]
accuracy: 1.0000, loss: 0.0225 ||: 100%|██████████| 1/1 [00:00<00:00, 371.90it/s]
accuracy: 1.0000, loss: 0.0225 ||: 100%|██████████| 1/1 [00:00<00:00, 162.10it/s]
accuracy: 1.0000, loss: 0.0224 ||: 100%|██████████| 1/1 [00:00<00:00, 357.72it/s]
accuracy: 1.0000, loss: 0.0225 ||: 100%|██████████| 1/1 [00:00<00:00, 167.79it/s]
accuracy: 1.0000, loss: 0.0224 ||: 100%|██████████| 1/1 [00:00<00:00, 360.18it/s]
accuracy: 1.0000, loss: 0.0224 ||: 100%|██████████| 1/1 [00:00<00:00, 171.53it/s]
accuracy: 1.0000, loss: 0.0223 ||: 100%|██████████| 1/1 [00:00<00:00, 351.19it/s]
accuracy: 1.0000, loss: 0.0224 ||: 100%|██████████| 1/1 [00:00<00:00, 162.09it/s]
accuracy: 1.0000, loss: 0.0223 ||: 100%|██████████| 1/1 [00:00<00:00, 360.34it/s]
accuracy: 1.0000

accuracy: 1.0000, loss: 0.0201 ||: 100%|██████████| 1/1 [00:00<00:00, 147.12it/s]
accuracy: 1.0000, loss: 0.0201 ||: 100%|██████████| 1/1 [00:00<00:00, 284.51it/s]
accuracy: 1.0000, loss: 0.0201 ||: 100%|██████████| 1/1 [00:00<00:00, 121.91it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|██████████| 1/1 [00:00<00:00, 279.19it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|██████████| 1/1 [00:00<00:00, 85.21it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|██████████| 1/1 [00:00<00:00, 284.98it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|██████████| 1/1 [00:00<00:00, 148.77it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|██████████| 1/1 [00:00<00:00, 269.45it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|██████████| 1/1 [00:00<00:00, 84.83it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|██████████| 1/1 [00:00<00:00, 282.98it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|██████████| 1/1 [00:00<00:00, 96.05it/s]
accuracy: 1.0000, loss: 0.0198 ||: 100%|██████████| 1/1 [00:00<00:00, 196.84it/s]
accuracy: 1.0000, l

[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')
[38;5;2m✔ Linking successful[0m
/home/texuanw/softwares/anaconda3/lib/python3.6/site-packages/en_core_web_sm -->
/home/texuanw/softwares/anaconda3/lib/python3.6/site-packages/spacy/data/en_core_web_sm
You can now load the model via spacy.load('en_core_web_sm')
['DET', 'NN', 'V', 'DET', 'NN']


In [None]:
print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])

In [6]:
# Here's how to save the model.
with open("/tmp/model.th", 'wb') as f:
    torch.save(model.state_dict(), f)
vocab.save_to_files("/tmp/vocabulary")

# And here's how to reload the model.
vocab2 = Vocabulary.from_files("/tmp/vocabulary")
model2 = LstmTagger(word_embeddings, lstm, vocab2)
with open("/tmp/model.th", 'rb') as f:
    model2.load_state_dict(torch.load(f))
if cuda_device > -1:
    model2.cuda(cuda_device)

predictor2 = SentenceTaggerPredictor(model2, dataset_reader=reader)
tag_logits2 = predictor2.predict("The dog ate the apple")['tag_logits']
np.testing.assert_array_almost_equal(tag_logits2, tag_logits)