Skip to content
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

136 lines (107 sloc) 5.77 KB
from typing import Dict
import numpy as np
import torch
import torch.optim as optim
from import \
from import BucketIterator
from import Vocabulary
from allennlp.models import Model
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.util import get_text_field_mask
from import CategoricalAccuracy, F1Measure
from import Trainer
from realworldnlp.predictors import SentenceClassifierPredictor
# Model in AllenNLP represents a model that is trained.
class LstmClassifier(Model):
def __init__(self,
word_embeddings: TextFieldEmbedder,
encoder: Seq2VecEncoder,
vocab: Vocabulary,
positive_label: int = 4) -> None:
# We need the embeddings to convert word IDs to their vector representations
self.word_embeddings = word_embeddings
self.encoder = encoder
# After converting a sequence of vectors to a single vector, we feed it into
# a fully-connected linear layer to reduce the dimension to the total number of labels.
self.linear = torch.nn.Linear(in_features=encoder.get_output_dim(),
# Monitor the metrics - we use accuracy, as well as prec, rec, f1 for 4 (very positive)
self.accuracy = CategoricalAccuracy()
self.f1_measure = F1Measure(positive_label)
# We use the cross entropy loss because this is a classification task.
# Note that PyTorch's CrossEntropyLoss combines softmax and log likelihood loss,
# which makes it unnecessary to add a separate softmax layer.
self.loss_function = torch.nn.CrossEntropyLoss()
# Instances are fed to forward after batching.
# Fields are passed through arguments with the same name.
def forward(self,
tokens: Dict[str, torch.Tensor],
label: torch.Tensor = None) -> torch.Tensor:
# In deep NLP, when sequences of tensors in different lengths are batched together,
# shorter sequences get padded with zeros to make them equal length.
# Masking is the process to ignore extra zeros added by padding
mask = get_text_field_mask(tokens)
# Forward pass
embeddings = self.word_embeddings(tokens)
encoder_out = self.encoder(embeddings, mask)
logits = self.linear(encoder_out)
# In AllenNLP, the output of forward() is a dictionary.
# Your output dictionary must contain a "loss" key for your model to be trained.
output = {"logits": logits}
if label is not None:
self.accuracy(logits, label)
self.f1_measure(logits, label)
output["loss"] = self.loss_function(logits, label)
return output
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
precision, recall, f1_measure = self.f1_measure.get_metric(reset)
return {'accuracy': self.accuracy.get_metric(reset),
'precision': precision,
'recall': recall,
'f1_measure': f1_measure}
def main():
reader = StanfordSentimentTreeBankDatasetReader()
train_dataset ='data/stanfordSentimentTreebank/trees/train.txt')
dev_dataset ='data/stanfordSentimentTreebank/trees/dev.txt')
# You can optionally specify the minimum count of tokens/labels.
# `min_count={'tokens':3}` here means that any tokens that appear less than three times
# will be ignored and not included in the vocabulary.
vocab = Vocabulary.from_instances(train_dataset + dev_dataset,
min_count={'tokens': 3})
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
# BasicTextFieldEmbedder takes a dict - we need an embedding just for tokens,
# not for labels, which are used as-is as the "answer" of the sentence classification
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
# Seq2VecEncoder is a neural network abstraction that takes a sequence of something
# (usually a sequence of embedded word vectors), processes it, and returns a single
# vector. Oftentimes this is an RNN-based architecture (e.g., LSTM or GRU), but
# AllenNLP also supports CNNs and other simple architectures (for example,
# just averaging over the input vectors).
encoder = PytorchSeq2VecWrapper(
torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
model = LstmClassifier(word_embeddings, encoder, vocab)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
iterator = BucketIterator(batch_size=32, sorting_keys=[("tokens", "num_tokens")])
trainer = Trainer(model=model,
predictor = SentenceClassifierPredictor(model, dataset_reader=reader)
logits = predictor.predict('This is the best movie ever!')['logits']
label_id = np.argmax(logits)
print(model.vocab.get_token_from_index(label_id, 'labels'))
if __name__ == '__main__':
You can’t perform that action at this time.