In [2]:
# In AllenNLP they use type annotations for just about everything.
from typing import Iterator, List, Dict

# AllenNLP is built on top of PyTorch, so they use its code freely.
import torch
import torch.optim as optim
import numpy as np

# In AllenNLP we represent each training example as an Instance containing Fields of various types. 
# Here each example will have a TextField containing the sentence
# and a SequenceLabelField containing the corresponding part-of-speech tags.
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField

# Typically to solve a problem like this using AllenNLP, you'll have to implement two classes. 
# The first is a DatasetReader, which contains the logic for reading a file of data and producing a stream of Instances.
from allennlp.data.dataset_readers import DatasetReader

# Frequently we'll want to load datasets or models from URLs. 
# The cached_path helper downloads such files, caches them locally, and returns the local path. 
# It also accepts local file paths (which it just returns as-is).
from allennlp.common.file_utils import cached_path

# There are various ways to represent a word as one or more indices. 
# For example, you might maintain a vocabulary of unique words and give each word a corresponding id.
# Or you might have one id per character in the word and represent each word as a sequence of ids. 
# AllenNLP uses a has a TokenIndexer abstraction for this representation.
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

# Whereas a TokenIndexer represents a rule for how to turn a token into indices, a Vocabulary contains the corresponding mappings from strings to integers. 
# For example, your token indexer might specify to represent a token as a sequence of character ids, in which case the Vocabulary would contain the mapping {character -> id}. 
# In this particular example we use a SingleIdTokenIndexer that assigns each token a unique id, and so the Vocabulary will just contain a mapping {token -> id} (as well as the reverse mapping).
from allennlp.data.vocabulary import Vocabulary

# Besides DatasetReader, the other class you'll typically need to implement is Model. 
# Which is a PyTorch Module that takes tensor inputs and produces a dict of tensor outputs (including the training loss you want to optimize).
from allennlp.models import Model

# As mentioned above, our model will consist of an embedding layer, followed by a LSTM, then by a feedforward layer.
# AllenNLP includes abstractions for all of these that smartly handle padding and batching, as well as various utility functions.
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits

# We'll want to track accuracy on the training and validation datasets.
from allennlp.training.metrics import CategoricalAccuracy

# In our training we'll need a DataIterators that can intelligently batch our data.
from allennlp.data.iterators import BucketIterator

# And we'll use AllenNLP's full-featured Trainer.
from allennlp.training.trainer import Trainer

# Finally, we'll want to make predictions on new inputs, more about this below.
from allennlp.predictors import SentenceTaggerPredictor
torch.manual_seed(1)

<torch._C.Generator at 0x10ff45a50>

In [7]:
# Our first order of business is to implement our DatasetReader subclass.
class PosDatasetReader(DatasetReader):
    """
    DatasetReader for PoS tagging data, one sentence per line, like

        The###DET dog###NN ate###V the###DET apple###NN
    """
# The only parameter our DatasetReader needs is a dict of TokenIndexers that specify how to convert tokens into indices. 
# By default we'll just generate a single index for each token (which we'll call "tokens") that's just a unique id for each distinct token. 
# (This is just the standard "word to index" mapping you'd use NLP tasks.)
    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        
# DatasetReader.text_to_instance takes the inputs corresponding to a training example 
    # (in this case the tokens of the sentence and the corresponding part-of-speech tags), 
    # instantiates the corresponding Fields (in this case a TextField for the sentence and a SequenceLabelField for its tags), 
    # and returns the Instance containing those fields. 
# Notice that the tags are optional, since we'd like to be able to create instances from unlabeled data to make predictions on them.        
    def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"sentence": sentence_field}

        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=sentence_field)
            fields["labels"] = label_field

        return Instance(fields)

# The other piece we have to implement is _read, which takes a filename and produces a stream of Instances. 
# Most of the work has already been done in text_to_instance.
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            for line in f:
                pairs = line.strip().split()
                sentence, tags = zip(*(pair.split("###") for pair in pairs))
                yield self.text_to_instance([Token(word) for word in sentence], tags)

In [10]:
# The other class you'll basically always have to implement is Model, which is a subclass of torch.nn.Module. 
# How it works is largely up to you, it mostly just needs a forward method that takes tensor inputs and produces a dict of tensor outputs that includes the loss you'll use to train the model. 
# As mentioned above, our model will consist of an embedding layer, a sequence encoder, and a feedforward network.
class LstmTagger(Model):
    
    # One thing that might seem unusual is that we're going pass in the embedder and the sequence encoder as constructor parameters. 
    # This allows us to experiment with different embedders and encoders without having to change the model code.
    def __init__(self,
                 
                 # The embedding layer is specified as an AllenNLP TextFieldEmbedder which represents a general way of turning tokens into tensors. 
                 # (Here we know that we want to represent each unique word with a learned tensor, but using the general class allows us to easily experiment with different types of embeddings, for example ELMo.) 
                 word_embeddings: TextFieldEmbedder,
                 
                 # Similarly, the encoder is specified as a general Seq2SeqEncoder even though we know we want to use an LSTM. 
                 # Again, this makes it easy to experiment with other sequence encoders, for example a Transformer.
                 encoder: Seq2SeqEncoder,

                 # Every AllenNLP model also expects a Vocabulary, which contains the namespaced mappings of tokens to indices and labels to indices.
                 vocab: Vocabulary) -> None:

        # Notice that we have to pass the vocab to the base class constructor.
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        
        # The feed forward layer is not passed in as a parameter, but is constructed by us. 
        # Notice that it looks at the encoder to find the correct input dimension and looks at the vocabulary (and, in particular, at the label -> index mapping) to find the correct output dimension.
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        
        # The last thing to notice is that we also instantiate a CategoricalAccuracy metric, which we'll use to track accuracy during each training and validation epoch.
        self.accuracy = CategoricalAccuracy()
        

    # Next we need to implement forward, which is where the actual computation happens. 
    # Each Instance in your dataset will get (batched with other instances and) fed into forward. 
    # The forward method expects dicts of tensors as input, and it expects their names to be the names of the fields in your Instance. 
    # In this case we have a sentence field and (possibly) a labels field, so we'll construct our forward accordingly:
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        
        # AllenNLP is designed to operate on batched inputs, but different input sequences have different lengths. 
        # Behind the scenes AllenNLP is padding the shorter inputs so that the batch has uniform shape, which means our computations need to use a mask to exclude the padding. 
        # Here we just use the utility function get_text_field_mask, which returns a tensor of 0s and 1s corresponding to the padded and unpadded locations.
        mask = get_text_field_mask(sentence)
        
        # We start by passing the sentence tensor (each sentence a sequence of token ids) to the word_embeddings module, which converts each sentence into a sequence of embedded tensors.
        embeddings = self.word_embeddings(sentence)
        
        # We next pass the embedded tensors (and the mask) to the LSTM, which produces a sequence of encoded outputs.
        encoder_out = self.encoder(embeddings, mask)
        
        # Finally, we pass each encoded output tensor to the feedforward layer to produce logits corresponding to the various tags.
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        
        # As before, the labels were optional, as we might want to run this model to make predictions on unlabeled data. 
        # If we do have labels, then we use them to update our accuracy metric and compute the "loss" that goes in our output.
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
    
    # We included an accuracy metric that gets updated each forward pass. 
    # That means we need to override a get_metrics method that pulls the data out of it. 
    # Behind the scenes, the CategoricalAccuracy metric is storing the number of predictions and the number of correct predictions, updating those counts during each call to forward. 
    # Each call to get_metric returns the calculated accuracy and (optionally) resets the counts, which is what allows us to track accuracy anew for each epoch.
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [11]:
# Now that we've implemented a DatasetReader and Model, we're ready to train. 
# We first need an instance of our dataset reader.
reader = PosDatasetReader()

In [12]:
# Which we can use to read in the training data and validation data. 
# Here we read them in from a URL, but you could read them in from local files if your data was local. 
# We use cached_path to cache the files locally (and to hand reader.read the path to the local cached version.)
train_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp'
    '/master/tutorials/tagger/training.txt'))
validation_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp'
    '/master/tutorials/tagger/validation.txt'))

93B [00:00, 24537.35B/s]             
2it [00:00, 1893.59it/s]
93B [00:00, 44671.36B/s]             
2it [00:00, 1826.79it/s]


In [13]:
# Once we've read in the datasets, we use them to create our Vocabulary (that is, the mapping[s] from tokens / labels to ids).
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

100%|██████████| 4/4 [00:00<00:00, 13336.42it/s]


In [14]:
# Now we need to construct the model. We'll choose a size for our embedding layer and for the hidden layer of our LSTM.
EMBEDDING_DIM = 6
HIDDEN_DIM = 6

In [15]:
# For embedding the tokens we'll just use the BasicTextFieldEmbedder which takes a mapping from index names to embeddings.
# If you go back to where we defined our DatasetReader, the default parameters included a single index called "tokens", so our mapping just needs an embedding corresponding to that index. We use the Vocabulary to find how many embeddings we need and our EMBEDDING_DIM parameter to specify the output dimension. 
# It's also possible to start with pre-trained embeddings (for example, GloVe vectors), but there's no need to do that on this tiny toy dataset.
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

In [16]:
# We next need to specify the sequence encoder. 
# The need for PytorchSeq2SeqWrapper here is slightly unfortunate (and if you use configuration files you won't need to worry about it) but here it's required to add some extra functionality (and a cleaner interface) to the built-in PyTorch module. 
# In AllenNLP we do everything batch first, so we specify that as well.
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

In [17]:
# Now we can instantiate the model
    # I think what this means is that it creates an object, which is an instance of the model
    # This means that it's allocating memory for the new object and calling the constructor
model = LstmTagger(word_embeddings, lstm, vocab)

In [18]:
# Now we're ready to train the model. 
# The first thing we'll need is an optimizer. 
# We can just use PyTorch's stochastic gradient descent.
optimizer = optim.SGD(model.parameters(), lr=0.1)

In [19]:
# And we need a DataIterator that handles batching for our datasets.
# The BucketIterator sorts instances by the specified fields in order to create batches with similar sequence lengths. 
# Here we indicate that we want to sort the instances by the number of tokens in the sentence field.
iterator = BucketIterator(batch_size=2, sorting_keys=[("sentence", "num_tokens")])

In [20]:
# We also specify that the iterator should make sure its instances are indexed using our vocabulary.
# that is, that their strings have been converted to integers using the mapping we previously created.
iterator.index_with(vocab)

In [21]:
# Now we instantiate our Trainer and run it. 
# Here we tell it to run for 1000 epochs and to stop training early if it ever spends 10 epochs without the validation metric improving. 
# The default validation metric is loss (which improves by getting smaller), but it's also possible to specify a different metric and direction (e.g. accuracy should get bigger).
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=1000)

In [22]:
# When we launch it it will print a progress bar for each epoch that includes both the "loss" and the "accuracy" metric. 
# If our model is good, the loss should go down and the accuracy up as we train.
trainer.train()

accuracy: 0.3333, loss: 1.1685 ||: 100%|██████████| 1/1 [00:00<00:00, 15.06it/s]
accuracy: 0.3333, loss: 1.1592 ||: 100%|██████████| 1/1 [00:00<00:00, 119.59it/s]
accuracy: 0.3333, loss: 1.1604 ||: 100%|██████████| 1/1 [00:00<00:00, 99.38it/s]
accuracy: 0.3333, loss: 1.1516 ||: 100%|██████████| 1/1 [00:00<00:00, 130.42it/s]
accuracy: 0.3333, loss: 1.1529 ||: 100%|██████████| 1/1 [00:00<00:00, 103.91it/s]
accuracy: 0.3333, loss: 1.1445 ||: 100%|██████████| 1/1 [00:00<00:00, 126.24it/s]
accuracy: 0.3333, loss: 1.1458 ||: 100%|██████████| 1/1 [00:00<00:00, 83.63it/s]
accuracy: 0.3333, loss: 1.1379 ||: 100%|██████████| 1/1 [00:00<00:00, 146.69it/s]
accuracy: 0.3333, loss: 1.1391 ||: 100%|██████████| 1/1 [00:00<00:00, 113.38it/s]
accuracy: 0.3333, loss: 1.1316 ||: 100%|██████████| 1/1 [00:00<00:00, 124.19it/s]
accuracy: 0.3333, loss: 1.1329 ||: 100%|██████████| 1/1 [00:00<00:00, 86.62it/s]
accuracy: 0.3333, loss: 1.1259 ||: 100%|██████████| 1/1 [00:00<00:00, 157.83it/s]
accuracy: 0.3333, lo

accuracy: 0.4444, loss: 1.0534 ||: 100%|██████████| 1/1 [00:00<00:00, 143.18it/s]
accuracy: 0.4444, loss: 1.0517 ||: 100%|██████████| 1/1 [00:00<00:00, 241.65it/s]
accuracy: 0.4444, loss: 1.0531 ||: 100%|██████████| 1/1 [00:00<00:00, 89.87it/s]
accuracy: 0.4444, loss: 1.0513 ||: 100%|██████████| 1/1 [00:00<00:00, 210.12it/s]
accuracy: 0.4444, loss: 1.0528 ||: 100%|██████████| 1/1 [00:00<00:00, 201.45it/s]
accuracy: 0.4444, loss: 1.0510 ||: 100%|██████████| 1/1 [00:00<00:00, 134.59it/s]
accuracy: 0.4444, loss: 1.0524 ||: 100%|██████████| 1/1 [00:00<00:00, 137.69it/s]
accuracy: 0.4444, loss: 1.0507 ||: 100%|██████████| 1/1 [00:00<00:00, 196.39it/s]
accuracy: 0.4444, loss: 1.0521 ||: 100%|██████████| 1/1 [00:00<00:00, 158.53it/s]
accuracy: 0.4444, loss: 1.0504 ||: 100%|██████████| 1/1 [00:00<00:00, 407.17it/s]
accuracy: 0.4444, loss: 1.0518 ||: 100%|██████████| 1/1 [00:00<00:00, 126.64it/s]
accuracy: 0.4444, loss: 1.0501 ||: 100%|██████████| 1/1 [00:00<00:00, 127.81it/s]
accuracy: 0.4444,

accuracy: 0.4444, loss: 1.0370 ||: 100%|██████████| 1/1 [00:00<00:00, 126.32it/s]
accuracy: 0.4444, loss: 1.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 502.43it/s]
accuracy: 0.4444, loss: 1.0366 ||: 100%|██████████| 1/1 [00:00<00:00, 128.29it/s]
accuracy: 0.4444, loss: 1.0349 ||: 100%|██████████| 1/1 [00:00<00:00, 236.34it/s]
accuracy: 0.4444, loss: 1.0362 ||: 100%|██████████| 1/1 [00:00<00:00, 190.74it/s]
accuracy: 0.4444, loss: 1.0345 ||: 100%|██████████| 1/1 [00:00<00:00, 130.19it/s]
accuracy: 0.4444, loss: 1.0357 ||: 100%|██████████| 1/1 [00:00<00:00, 73.74it/s]
accuracy: 0.4444, loss: 1.0341 ||: 100%|██████████| 1/1 [00:00<00:00, 108.34it/s]
accuracy: 0.4444, loss: 1.0353 ||: 100%|██████████| 1/1 [00:00<00:00, 53.89it/s]
accuracy: 0.4444, loss: 1.0336 ||: 100%|██████████| 1/1 [00:00<00:00, 115.95it/s]
accuracy: 0.4444, loss: 1.0349 ||: 100%|██████████| 1/1 [00:00<00:00, 62.25it/s]
accuracy: 0.4444, loss: 1.0332 ||: 100%|██████████| 1/1 [00:00<00:00, 170.15it/s]
accuracy: 0.4444, l

accuracy: 0.4444, loss: 1.0061 ||: 100%|██████████| 1/1 [00:00<00:00, 173.56it/s]
accuracy: 0.4444, loss: 1.0042 ||: 100%|██████████| 1/1 [00:00<00:00, 399.50it/s]
accuracy: 0.4444, loss: 1.0051 ||: 100%|██████████| 1/1 [00:00<00:00, 116.78it/s]
accuracy: 0.4444, loss: 1.0033 ||: 100%|██████████| 1/1 [00:00<00:00, 188.79it/s]
accuracy: 0.4444, loss: 1.0042 ||: 100%|██████████| 1/1 [00:00<00:00, 102.38it/s]
accuracy: 0.4444, loss: 1.0024 ||: 100%|██████████| 1/1 [00:00<00:00, 294.90it/s]
accuracy: 0.4444, loss: 1.0033 ||: 100%|██████████| 1/1 [00:00<00:00, 93.38it/s]
accuracy: 0.4444, loss: 1.0014 ||: 100%|██████████| 1/1 [00:00<00:00, 367.66it/s]
accuracy: 0.4444, loss: 1.0023 ||: 100%|██████████| 1/1 [00:00<00:00, 130.29it/s]
accuracy: 0.4444, loss: 1.0005 ||: 100%|██████████| 1/1 [00:00<00:00, 354.46it/s]
accuracy: 0.4444, loss: 1.0014 ||: 100%|██████████| 1/1 [00:00<00:00, 216.78it/s]
accuracy: 0.4444, loss: 0.9995 ||: 100%|██████████| 1/1 [00:00<00:00, 560.51it/s]
accuracy: 0.4444,

accuracy: 0.4444, loss: 0.9355 ||: 100%|██████████| 1/1 [00:00<00:00, 182.35it/s]
accuracy: 0.4444, loss: 0.9329 ||: 100%|██████████| 1/1 [00:00<00:00, 138.13it/s]
accuracy: 0.4444, loss: 0.9334 ||: 100%|██████████| 1/1 [00:00<00:00, 86.34it/s]
accuracy: 0.4444, loss: 0.9307 ||: 100%|██████████| 1/1 [00:00<00:00, 259.53it/s]
accuracy: 0.4444, loss: 0.9313 ||: 100%|██████████| 1/1 [00:00<00:00, 95.28it/s]
accuracy: 0.4444, loss: 0.9286 ||: 100%|██████████| 1/1 [00:00<00:00, 200.57it/s]
accuracy: 0.4444, loss: 0.9292 ||: 100%|██████████| 1/1 [00:00<00:00, 98.80it/s]
accuracy: 0.4444, loss: 0.9264 ||: 100%|██████████| 1/1 [00:00<00:00, 209.64it/s]
accuracy: 0.4444, loss: 0.9270 ||: 100%|██████████| 1/1 [00:00<00:00, 107.97it/s]
accuracy: 0.4444, loss: 0.9242 ||: 100%|██████████| 1/1 [00:00<00:00, 280.27it/s]
accuracy: 0.5556, loss: 0.9248 ||: 100%|██████████| 1/1 [00:00<00:00, 96.90it/s]
accuracy: 0.4444, loss: 0.9220 ||: 100%|██████████| 1/1 [00:00<00:00, 155.67it/s]
accuracy: 0.5556, lo

accuracy: 0.6667, loss: 0.7936 ||: 100%|██████████| 1/1 [00:00<00:00, 78.86it/s]
accuracy: 0.6667, loss: 0.7903 ||: 100%|██████████| 1/1 [00:00<00:00, 278.01it/s]
accuracy: 0.6667, loss: 0.7901 ||: 100%|██████████| 1/1 [00:00<00:00, 181.75it/s]
accuracy: 0.6667, loss: 0.7868 ||: 100%|██████████| 1/1 [00:00<00:00, 135.83it/s]
accuracy: 0.6667, loss: 0.7867 ||: 100%|██████████| 1/1 [00:00<00:00, 108.37it/s]
accuracy: 0.6667, loss: 0.7833 ||: 100%|██████████| 1/1 [00:00<00:00, 164.99it/s]
accuracy: 0.6667, loss: 0.7831 ||: 100%|██████████| 1/1 [00:00<00:00, 104.65it/s]
accuracy: 0.6667, loss: 0.7798 ||: 100%|██████████| 1/1 [00:00<00:00, 198.42it/s]
accuracy: 0.6667, loss: 0.7796 ||: 100%|██████████| 1/1 [00:00<00:00, 116.01it/s]
accuracy: 0.6667, loss: 0.7763 ||: 100%|██████████| 1/1 [00:00<00:00, 267.22it/s]
accuracy: 0.6667, loss: 0.7761 ||: 100%|██████████| 1/1 [00:00<00:00, 91.64it/s]
accuracy: 0.6667, loss: 0.7728 ||: 100%|██████████| 1/1 [00:00<00:00, 210.12it/s]
accuracy: 0.6667, 

accuracy: 0.7778, loss: 0.6130 ||: 100%|██████████| 1/1 [00:00<00:00, 85.48it/s]
accuracy: 0.7778, loss: 0.6110 ||: 100%|██████████| 1/1 [00:00<00:00, 335.01it/s]
accuracy: 0.7778, loss: 0.6093 ||: 100%|██████████| 1/1 [00:00<00:00, 93.69it/s]
accuracy: 0.7778, loss: 0.6074 ||: 100%|██████████| 1/1 [00:00<00:00, 218.97it/s]
accuracy: 0.7778, loss: 0.6057 ||: 100%|██████████| 1/1 [00:00<00:00, 146.47it/s]
accuracy: 0.7778, loss: 0.6038 ||: 100%|██████████| 1/1 [00:00<00:00, 359.07it/s]
accuracy: 0.7778, loss: 0.6021 ||: 100%|██████████| 1/1 [00:00<00:00, 123.53it/s]
accuracy: 0.7778, loss: 0.6002 ||: 100%|██████████| 1/1 [00:00<00:00, 115.26it/s]
accuracy: 0.7778, loss: 0.5985 ||: 100%|██████████| 1/1 [00:00<00:00, 88.67it/s]
accuracy: 0.7778, loss: 0.5966 ||: 100%|██████████| 1/1 [00:00<00:00, 189.08it/s]
accuracy: 0.7778, loss: 0.5949 ||: 100%|██████████| 1/1 [00:00<00:00, 96.93it/s]
accuracy: 0.7778, loss: 0.5930 ||: 100%|██████████| 1/1 [00:00<00:00, 249.54it/s]
accuracy: 0.7778, lo

accuracy: 1.0000, loss: 0.4350 ||: 100%|██████████| 1/1 [00:00<00:00, 84.86it/s]
accuracy: 1.0000, loss: 0.4337 ||: 100%|██████████| 1/1 [00:00<00:00, 228.67it/s]
accuracy: 1.0000, loss: 0.4316 ||: 100%|██████████| 1/1 [00:00<00:00, 90.29it/s]
accuracy: 1.0000, loss: 0.4303 ||: 100%|██████████| 1/1 [00:00<00:00, 198.85it/s]
accuracy: 1.0000, loss: 0.4283 ||: 100%|██████████| 1/1 [00:00<00:00, 78.56it/s]
accuracy: 1.0000, loss: 0.4269 ||: 100%|██████████| 1/1 [00:00<00:00, 224.10it/s]
accuracy: 1.0000, loss: 0.4249 ||: 100%|██████████| 1/1 [00:00<00:00, 48.70it/s]
accuracy: 1.0000, loss: 0.4235 ||: 100%|██████████| 1/1 [00:00<00:00, 250.06it/s]
accuracy: 1.0000, loss: 0.4216 ||: 100%|██████████| 1/1 [00:00<00:00, 112.38it/s]
accuracy: 1.0000, loss: 0.4202 ||: 100%|██████████| 1/1 [00:00<00:00, 134.57it/s]
accuracy: 1.0000, loss: 0.4182 ||: 100%|██████████| 1/1 [00:00<00:00, 92.21it/s]
accuracy: 1.0000, loss: 0.4168 ||: 100%|██████████| 1/1 [00:00<00:00, 184.98it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.2855 ||: 100%|██████████| 1/1 [00:00<00:00, 75.36it/s]
accuracy: 1.0000, loss: 0.2838 ||: 100%|██████████| 1/1 [00:00<00:00, 343.54it/s]
accuracy: 1.0000, loss: 0.2830 ||: 100%|██████████| 1/1 [00:00<00:00, 160.71it/s]
accuracy: 1.0000, loss: 0.2813 ||: 100%|██████████| 1/1 [00:00<00:00, 345.07it/s]
accuracy: 1.0000, loss: 0.2806 ||: 100%|██████████| 1/1 [00:00<00:00, 84.55it/s]
accuracy: 1.0000, loss: 0.2788 ||: 100%|██████████| 1/1 [00:00<00:00, 209.03it/s]
accuracy: 1.0000, loss: 0.2781 ||: 100%|██████████| 1/1 [00:00<00:00, 79.94it/s]
accuracy: 1.0000, loss: 0.2764 ||: 100%|██████████| 1/1 [00:00<00:00, 293.53it/s]
accuracy: 1.0000, loss: 0.2757 ||: 100%|██████████| 1/1 [00:00<00:00, 92.44it/s]
accuracy: 1.0000, loss: 0.2739 ||: 100%|██████████| 1/1 [00:00<00:00, 105.22it/s]
accuracy: 1.0000, loss: 0.2732 ||: 100%|██████████| 1/1 [00:00<00:00, 94.69it/s]
accuracy: 1.0000, loss: 0.2715 ||: 100%|██████████| 1/1 [00:00<00:00, 114.81it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.1846 ||: 100%|██████████| 1/1 [00:00<00:00, 136.50it/s]
accuracy: 1.0000, loss: 0.1831 ||: 100%|██████████| 1/1 [00:00<00:00, 205.93it/s]
accuracy: 1.0000, loss: 0.1830 ||: 100%|██████████| 1/1 [00:00<00:00, 138.91it/s]
accuracy: 1.0000, loss: 0.1816 ||: 100%|██████████| 1/1 [00:00<00:00, 97.48it/s]
accuracy: 1.0000, loss: 0.1815 ||: 100%|██████████| 1/1 [00:00<00:00, 71.26it/s]
accuracy: 1.0000, loss: 0.1800 ||: 100%|██████████| 1/1 [00:00<00:00, 178.57it/s]
accuracy: 1.0000, loss: 0.1800 ||: 100%|██████████| 1/1 [00:00<00:00, 62.63it/s]
accuracy: 1.0000, loss: 0.1785 ||: 100%|██████████| 1/1 [00:00<00:00, 140.93it/s]
accuracy: 1.0000, loss: 0.1784 ||: 100%|██████████| 1/1 [00:00<00:00, 75.47it/s]
accuracy: 1.0000, loss: 0.1770 ||: 100%|██████████| 1/1 [00:00<00:00, 116.69it/s]
accuracy: 1.0000, loss: 0.1770 ||: 100%|██████████| 1/1 [00:00<00:00, 67.55it/s]
accuracy: 1.0000, loss: 0.1755 ||: 100%|██████████| 1/1 [00:00<00:00, 175.35it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.1239 ||: 100%|██████████| 1/1 [00:00<00:00, 87.10it/s]
accuracy: 1.0000, loss: 0.1228 ||: 100%|██████████| 1/1 [00:00<00:00, 225.59it/s]
accuracy: 1.0000, loss: 0.1229 ||: 100%|██████████| 1/1 [00:00<00:00, 98.24it/s]
accuracy: 1.0000, loss: 0.1219 ||: 100%|██████████| 1/1 [00:00<00:00, 537.32it/s]
accuracy: 1.0000, loss: 0.1220 ||: 100%|██████████| 1/1 [00:00<00:00, 76.10it/s]
accuracy: 1.0000, loss: 0.1210 ||: 100%|██████████| 1/1 [00:00<00:00, 122.73it/s]
accuracy: 1.0000, loss: 0.1211 ||: 100%|██████████| 1/1 [00:00<00:00, 56.68it/s]
accuracy: 1.0000, loss: 0.1201 ||: 100%|██████████| 1/1 [00:00<00:00, 149.30it/s]
accuracy: 1.0000, loss: 0.1203 ||: 100%|██████████| 1/1 [00:00<00:00, 63.23it/s]
accuracy: 1.0000, loss: 0.1193 ||: 100%|██████████| 1/1 [00:00<00:00, 98.70it/s]
accuracy: 1.0000, loss: 0.1194 ||: 100%|██████████| 1/1 [00:00<00:00, 107.57it/s]
accuracy: 1.0000, loss: 0.1184 ||: 100%|██████████| 1/1 [00:00<00:00, 128.01it/s]
accuracy: 1.0000, loss

accuracy: 1.0000, loss: 0.0880 ||: 100%|██████████| 1/1 [00:00<00:00, 73.36it/s]
accuracy: 1.0000, loss: 0.0873 ||: 100%|██████████| 1/1 [00:00<00:00, 141.12it/s]
accuracy: 1.0000, loss: 0.0874 ||: 100%|██████████| 1/1 [00:00<00:00, 117.19it/s]
accuracy: 1.0000, loss: 0.0868 ||: 100%|██████████| 1/1 [00:00<00:00, 215.98it/s]
accuracy: 1.0000, loss: 0.0869 ||: 100%|██████████| 1/1 [00:00<00:00, 123.95it/s]
accuracy: 1.0000, loss: 0.0862 ||: 100%|██████████| 1/1 [00:00<00:00, 172.95it/s]
accuracy: 1.0000, loss: 0.0864 ||: 100%|██████████| 1/1 [00:00<00:00, 63.16it/s]
accuracy: 1.0000, loss: 0.0857 ||: 100%|██████████| 1/1 [00:00<00:00, 72.02it/s]
accuracy: 1.0000, loss: 0.0858 ||: 100%|██████████| 1/1 [00:00<00:00, 56.41it/s]
accuracy: 1.0000, loss: 0.0852 ||: 100%|██████████| 1/1 [00:00<00:00, 127.50it/s]
accuracy: 1.0000, loss: 0.0853 ||: 100%|██████████| 1/1 [00:00<00:00, 61.92it/s]
accuracy: 1.0000, loss: 0.0846 ||: 100%|██████████| 1/1 [00:00<00:00, 169.15it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.0660 ||: 100%|██████████| 1/1 [00:00<00:00, 126.90it/s]
accuracy: 1.0000, loss: 0.0655 ||: 100%|██████████| 1/1 [00:00<00:00, 90.52it/s]
accuracy: 1.0000, loss: 0.0656 ||: 100%|██████████| 1/1 [00:00<00:00, 163.08it/s]
accuracy: 1.0000, loss: 0.0652 ||: 100%|██████████| 1/1 [00:00<00:00, 213.40it/s]
accuracy: 1.0000, loss: 0.0653 ||: 100%|██████████| 1/1 [00:00<00:00, 137.19it/s]
accuracy: 1.0000, loss: 0.0648 ||: 100%|██████████| 1/1 [00:00<00:00, 96.39it/s]
accuracy: 1.0000, loss: 0.0650 ||: 100%|██████████| 1/1 [00:00<00:00, 77.51it/s]
accuracy: 1.0000, loss: 0.0645 ||: 100%|██████████| 1/1 [00:00<00:00, 292.59it/s]
accuracy: 1.0000, loss: 0.0646 ||: 100%|██████████| 1/1 [00:00<00:00, 86.58it/s]
accuracy: 1.0000, loss: 0.0642 ||: 100%|██████████| 1/1 [00:00<00:00, 217.11it/s]
accuracy: 1.0000, loss: 0.0643 ||: 100%|██████████| 1/1 [00:00<00:00, 104.20it/s]
accuracy: 1.0000, loss: 0.0638 ||: 100%|██████████| 1/1 [00:00<00:00, 175.25it/s]
accuracy: 1.0000, lo

accuracy: 1.0000, loss: 0.0517 ||: 100%|██████████| 1/1 [00:00<00:00, 76.57it/s]
accuracy: 1.0000, loss: 0.0514 ||: 100%|██████████| 1/1 [00:00<00:00, 237.37it/s]
accuracy: 1.0000, loss: 0.0515 ||: 100%|██████████| 1/1 [00:00<00:00, 66.76it/s]
accuracy: 1.0000, loss: 0.0511 ||: 100%|██████████| 1/1 [00:00<00:00, 270.50it/s]
accuracy: 1.0000, loss: 0.0512 ||: 100%|██████████| 1/1 [00:00<00:00, 87.76it/s]
accuracy: 1.0000, loss: 0.0509 ||: 100%|██████████| 1/1 [00:00<00:00, 250.69it/s]
accuracy: 1.0000, loss: 0.0510 ||: 100%|██████████| 1/1 [00:00<00:00, 120.32it/s]
accuracy: 1.0000, loss: 0.0507 ||: 100%|██████████| 1/1 [00:00<00:00, 192.61it/s]
accuracy: 1.0000, loss: 0.0508 ||: 100%|██████████| 1/1 [00:00<00:00, 73.82it/s]
accuracy: 1.0000, loss: 0.0505 ||: 100%|██████████| 1/1 [00:00<00:00, 185.56it/s]
accuracy: 1.0000, loss: 0.0506 ||: 100%|██████████| 1/1 [00:00<00:00, 69.07it/s]
accuracy: 1.0000, loss: 0.0502 ||: 100%|██████████| 1/1 [00:00<00:00, 174.16it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.0419 ||: 100%|██████████| 1/1 [00:00<00:00, 93.83it/s]
accuracy: 1.0000, loss: 0.0417 ||: 100%|██████████| 1/1 [00:00<00:00, 221.49it/s]
accuracy: 1.0000, loss: 0.0418 ||: 100%|██████████| 1/1 [00:00<00:00, 128.63it/s]
accuracy: 1.0000, loss: 0.0415 ||: 100%|██████████| 1/1 [00:00<00:00, 231.10it/s]
accuracy: 1.0000, loss: 0.0416 ||: 100%|██████████| 1/1 [00:00<00:00, 104.52it/s]
accuracy: 1.0000, loss: 0.0414 ||: 100%|██████████| 1/1 [00:00<00:00, 104.18it/s]
accuracy: 1.0000, loss: 0.0415 ||: 100%|██████████| 1/1 [00:00<00:00, 41.06it/s]
accuracy: 1.0000, loss: 0.0412 ||: 100%|██████████| 1/1 [00:00<00:00, 186.81it/s]
accuracy: 1.0000, loss: 0.0413 ||: 100%|██████████| 1/1 [00:00<00:00, 74.68it/s]
accuracy: 1.0000, loss: 0.0411 ||: 100%|██████████| 1/1 [00:00<00:00, 118.71it/s]
accuracy: 1.0000, loss: 0.0411 ||: 100%|██████████| 1/1 [00:00<00:00, 49.12it/s]
accuracy: 1.0000, loss: 0.0409 ||: 100%|██████████| 1/1 [00:00<00:00, 55.37it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.0349 ||: 100%|██████████| 1/1 [00:00<00:00, 112.70it/s]
accuracy: 1.0000, loss: 0.0348 ||: 100%|██████████| 1/1 [00:00<00:00, 172.71it/s]
accuracy: 1.0000, loss: 0.0348 ||: 100%|██████████| 1/1 [00:00<00:00, 98.16it/s]
accuracy: 1.0000, loss: 0.0347 ||: 100%|██████████| 1/1 [00:00<00:00, 229.94it/s]
accuracy: 1.0000, loss: 0.0347 ||: 100%|██████████| 1/1 [00:00<00:00, 67.65it/s]
accuracy: 1.0000, loss: 0.0345 ||: 100%|██████████| 1/1 [00:00<00:00, 144.29it/s]
accuracy: 1.0000, loss: 0.0346 ||: 100%|██████████| 1/1 [00:00<00:00, 76.87it/s]
accuracy: 1.0000, loss: 0.0344 ||: 100%|██████████| 1/1 [00:00<00:00, 86.50it/s]
accuracy: 1.0000, loss: 0.0345 ||: 100%|██████████| 1/1 [00:00<00:00, 69.26it/s]
accuracy: 1.0000, loss: 0.0343 ||: 100%|██████████| 1/1 [00:00<00:00, 126.98it/s]
accuracy: 1.0000, loss: 0.0344 ||: 100%|██████████| 1/1 [00:00<00:00, 49.33it/s]
accuracy: 1.0000, loss: 0.0342 ||: 100%|██████████| 1/1 [00:00<00:00, 101.60it/s]
accuracy: 1.0000, loss

accuracy: 1.0000, loss: 0.0298 ||: 100%|██████████| 1/1 [00:00<00:00, 71.12it/s]
accuracy: 1.0000, loss: 0.0296 ||: 100%|██████████| 1/1 [00:00<00:00, 268.90it/s]
accuracy: 1.0000, loss: 0.0297 ||: 100%|██████████| 1/1 [00:00<00:00, 164.02it/s]
accuracy: 1.0000, loss: 0.0295 ||: 100%|██████████| 1/1 [00:00<00:00, 191.71it/s]
accuracy: 1.0000, loss: 0.0296 ||: 100%|██████████| 1/1 [00:00<00:00, 97.66it/s]
accuracy: 1.0000, loss: 0.0294 ||: 100%|██████████| 1/1 [00:00<00:00, 195.23it/s]
accuracy: 1.0000, loss: 0.0295 ||: 100%|██████████| 1/1 [00:00<00:00, 56.06it/s]
accuracy: 1.0000, loss: 0.0293 ||: 100%|██████████| 1/1 [00:00<00:00, 235.25it/s]
accuracy: 1.0000, loss: 0.0294 ||: 100%|██████████| 1/1 [00:00<00:00, 63.72it/s]
accuracy: 1.0000, loss: 0.0293 ||: 100%|██████████| 1/1 [00:00<00:00, 271.65it/s]
accuracy: 1.0000, loss: 0.0293 ||: 100%|██████████| 1/1 [00:00<00:00, 80.78it/s]
accuracy: 1.0000, loss: 0.0292 ||: 100%|██████████| 1/1 [00:00<00:00, 283.28it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.0258 ||: 100%|██████████| 1/1 [00:00<00:00, 121.86it/s]
accuracy: 1.0000, loss: 0.0257 ||: 100%|██████████| 1/1 [00:00<00:00, 191.02it/s]
accuracy: 1.0000, loss: 0.0257 ||: 100%|██████████| 1/1 [00:00<00:00, 93.40it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|██████████| 1/1 [00:00<00:00, 281.14it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|██████████| 1/1 [00:00<00:00, 63.84it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|██████████| 1/1 [00:00<00:00, 204.96it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|██████████| 1/1 [00:00<00:00, 68.12it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|██████████| 1/1 [00:00<00:00, 156.79it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|██████████| 1/1 [00:00<00:00, 61.02it/s]
accuracy: 1.0000, loss: 0.0254 ||: 100%|██████████| 1/1 [00:00<00:00, 167.26it/s]
accuracy: 1.0000, loss: 0.0254 ||: 100%|██████████| 1/1 [00:00<00:00, 37.48it/s]
accuracy: 1.0000, loss: 0.0253 ||: 100%|██████████| 1/1 [00:00<00:00, 171.88it/s]
accuracy: 1.0000, los

accuracy: 1.0000, loss: 0.0226 ||: 100%|██████████| 1/1 [00:00<00:00, 102.42it/s]
accuracy: 1.0000, loss: 0.0226 ||: 100%|██████████| 1/1 [00:00<00:00, 272.18it/s]
accuracy: 1.0000, loss: 0.0226 ||: 100%|██████████| 1/1 [00:00<00:00, 111.26it/s]
accuracy: 1.0000, loss: 0.0225 ||: 100%|██████████| 1/1 [00:00<00:00, 565.88it/s]
accuracy: 1.0000, loss: 0.0225 ||: 100%|██████████| 1/1 [00:00<00:00, 76.56it/s]
accuracy: 1.0000, loss: 0.0224 ||: 100%|██████████| 1/1 [00:00<00:00, 262.97it/s]
accuracy: 1.0000, loss: 0.0225 ||: 100%|██████████| 1/1 [00:00<00:00, 87.84it/s]
accuracy: 1.0000, loss: 0.0224 ||: 100%|██████████| 1/1 [00:00<00:00, 318.50it/s]
accuracy: 1.0000, loss: 0.0224 ||: 100%|██████████| 1/1 [00:00<00:00, 153.27it/s]
accuracy: 1.0000, loss: 0.0223 ||: 100%|██████████| 1/1 [00:00<00:00, 295.69it/s]
accuracy: 1.0000, loss: 0.0224 ||: 100%|██████████| 1/1 [00:00<00:00, 84.52it/s]
accuracy: 1.0000, loss: 0.0223 ||: 100%|██████████| 1/1 [00:00<00:00, 248.57it/s]
accuracy: 1.0000, l

accuracy: 1.0000, loss: 0.0201 ||: 100%|██████████| 1/1 [00:00<00:00, 123.96it/s]
accuracy: 1.0000, loss: 0.0201 ||: 100%|██████████| 1/1 [00:00<00:00, 238.65it/s]
accuracy: 1.0000, loss: 0.0201 ||: 100%|██████████| 1/1 [00:00<00:00, 119.30it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|██████████| 1/1 [00:00<00:00, 238.22it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|██████████| 1/1 [00:00<00:00, 145.37it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|██████████| 1/1 [00:00<00:00, 269.18it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|██████████| 1/1 [00:00<00:00, 122.52it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|██████████| 1/1 [00:00<00:00, 234.63it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|██████████| 1/1 [00:00<00:00, 147.41it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|██████████| 1/1 [00:00<00:00, 225.08it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|██████████| 1/1 [00:00<00:00, 130.86it/s]
accuracy: 1.0000, loss: 0.0198 ||: 100%|██████████| 1/1 [00:00<00:00, 303.56it/s]
accuracy: 1.0000

{'training_duration': '00:00:37',
 'training_start_epoch': 0,
 'training_epochs': 999,
 'epoch': 999,
 'training_accuracy': 1.0,
 'training_loss': 0.01809469424188137,
 'validation_accuracy': 1.0,
 'validation_loss': 0.01804003119468689,
 'best_epoch': 999,
 'best_validation_accuracy': 1.0,
 'best_validation_loss': 0.01804003119468689}

In [23]:
# As in the original PyTorch tutorial, we'd like to look at the predictions our model generates. 
# AllenNLP contains a Predictor abstraction that takes inputs, converts them to instances, feeds them through your model, and returns JSON-serializable results. 
# Often you'd need to implement your own Predictor, but AllenNLP already has a SentenceTaggerPredictor that works perfectly here, so we can use it. 
# It requires our model (for making predictions) and a dataset reader (for creating instances). 
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)

Spacy models 'en_core_web_sm' not found.  Downloading and installing.



[93m    Linking successful[0m
    /anaconda2/envs/ipykernel_py3/lib/python3.7/site-packages/en_core_web_sm
    -->
    /anaconda2/envs/ipykernel_py3/lib/python3.7/site-packages/spacy/data/en_core_web_sm

    You can now load the model via spacy.load('en_core_web_sm')



In [55]:
# It has a predict method that just needs a sentence and returns (a JSON-serializable version of) the output dict from forward. 
# Here tag_logits will be a (5, 3) array of logits, corresponding to the 3 possible tags for each of the 5 words.
tag_logits = predictor.predict("The dog ate the apple.") ['tag_logits']

In [56]:
# To get the actual "predictions" we can just take the argmax.
tag_ids = np.argmax(tag_logits, axis=-1)

In [57]:
# And then use our vocabulary to find the predicted tags.
print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])

['DET', 'NN', 'V', 'DET', 'NN', 'NN']
