Skip to content
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
from itertools import chain
from typing import Dict
import numpy as np
import torch
import torch.optim as optim
from import TextFieldTensors
from import MultiProcessDataLoader
from import BucketBatchSampler
from import Vocabulary
from allennlp.models import Model
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.util import get_text_field_mask
from import CategoricalAccuracy, F1Measure
from import GradientDescentTrainer
from allennlp_models.classification.dataset_readers.stanford_sentiment_tree_bank import \
from realworldnlp.predictors import SentenceClassifierPredictor
# Model in AllenNLP represents a model that is trained.
class LstmClassifier(Model):
def __init__(self,
embedder: TextFieldEmbedder,
encoder: Seq2VecEncoder,
vocab: Vocabulary,
positive_label: str = '4') -> None:
# We need the embeddings to convert word IDs to their vector representations
self.embedder = embedder
self.encoder = encoder
# After converting a sequence of vectors to a single vector, we feed it into
# a fully-connected linear layer to reduce the dimension to the total number of labels.
self.linear = torch.nn.Linear(in_features=encoder.get_output_dim(),
# Monitor the metrics - we use accuracy, as well as prec, rec, f1 for 4 (very positive)
positive_index = vocab.get_token_index(positive_label, namespace='labels')
self.accuracy = CategoricalAccuracy()
self.f1_measure = F1Measure(positive_index)
# We use the cross entropy loss because this is a classification task.
# Note that PyTorch's CrossEntropyLoss combines softmax and log likelihood loss,
# which makes it unnecessary to add a separate softmax layer.
self.loss_function = torch.nn.CrossEntropyLoss()
# Instances are fed to forward after batching.
# Fields are passed through arguments with the same name.
def forward(self,
tokens: TextFieldTensors,
label: torch.Tensor = None) -> torch.Tensor:
# In deep NLP, when sequences of tensors in different lengths are batched together,
# shorter sequences get padded with zeros to make them equal length.
# Masking is the process to ignore extra zeros added by padding
mask = get_text_field_mask(tokens)
# Forward pass
embeddings = self.embedder(tokens)
encoder_out = self.encoder(embeddings, mask)
logits = self.linear(encoder_out)
probs = torch.softmax(logits, dim=-1)
# In AllenNLP, the output of forward() is a dictionary.
# Your output dictionary must contain a "loss" key for your model to be trained.
output = {"logits": logits, "cls_emb": encoder_out, "probs": probs}
if label is not None:
self.accuracy(logits, label)
self.f1_measure(logits, label)
output["loss"] = self.loss_function(logits, label)
return output
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {'accuracy': self.accuracy.get_metric(reset),
def main():
reader = StanfordSentimentTreeBankDatasetReader()
train_path = ''
dev_path = ''
sampler = BucketBatchSampler(batch_size=32, sorting_keys=["tokens"])
train_data_loader = MultiProcessDataLoader(reader, train_path, batch_sampler=sampler)
dev_data_loader = MultiProcessDataLoader(reader, dev_path, batch_sampler=sampler)
# You can optionally specify the minimum count of tokens/labels.
# `min_count={'tokens':3}` here means that any tokens that appear less than three times
# will be ignored and not included in the vocabulary.
vocab = Vocabulary.from_instances(chain(train_data_loader.iter_instances(), dev_data_loader.iter_instances()),
min_count={'tokens': 3})
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
# BasicTextFieldEmbedder takes a dict - we need an embedding just for tokens,
# not for labels, which are used as-is as the "answer" of the sentence classification
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
# Seq2VecEncoder is a neural network abstraction that takes a sequence of something
# (usually a sequence of embedded word vectors), processes it, and returns a single
# vector. Oftentimes this is an RNN-based architecture (e.g., LSTM or GRU), but
# AllenNLP also supports CNNs and other simple architectures (for example,
# just averaging over the input vectors).
encoder = PytorchSeq2VecWrapper(
torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
model = LstmClassifier(word_embeddings, encoder, vocab)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
trainer = GradientDescentTrainer(
predictor = SentenceClassifierPredictor(model, dataset_reader=reader)
logits = predictor.predict('This is the best movie ever!')['logits']
label_id = np.argmax(logits)
print(model.vocab.get_token_from_index(label_id, 'labels'))
if __name__ == '__main__':