Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
91 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import numpy as np | ||
|
||
from allennlp.models.archival import load_archive | ||
from allennlp.predictors.predictor import Predictor | ||
from lit_nlp import dev_server | ||
from lit_nlp import server_flags | ||
from lit_nlp.api import dataset as lit_dataset | ||
from lit_nlp.api import model as lit_model | ||
from lit_nlp.api import types as lit_types | ||
|
||
from examples.sentiment.sst_classifier import LstmClassifier | ||
from examples.sentiment.sst_reader import StanfordSentimentTreeBankDatasetReaderWithTokenizer | ||
|
||
|
||
class SSTData(lit_dataset.Dataset): | ||
"""Stanford Sentiment Treebank, binary version (SST-2). | ||
See https://www.tensorflow.org/datasets/catalog/glue#gluesst2. | ||
""" | ||
|
||
def __init__(self, labels): | ||
self._labels = labels | ||
self._examples = [ | ||
{'sentence': 'This is the best movie ever!!!', 'label': '4'}, | ||
{'sentence': 'A good movie.', 'label': '3'}, | ||
{'sentence': 'A mediocre movie.', 'label': '1'}, | ||
{'sentence': 'It was such an awful movie...', 'label': '0'} | ||
] | ||
|
||
def spec(self): | ||
return { | ||
'sentence': lit_types.TextSegment(), | ||
'label': lit_types.CategoryLabel(vocab=self._labels) | ||
} | ||
|
||
|
||
class SentimentClassifierModel(lit_model.Model): | ||
def __init__(self): | ||
cuda_device = 0 | ||
archive_file = 'model/model.tar.gz' | ||
predictor_name = 'sentence_classifier_predictor' | ||
|
||
archive = load_archive( | ||
archive_file=archive_file, | ||
cuda_device=cuda_device | ||
) | ||
|
||
predictor = Predictor.from_archive(archive, predictor_name=predictor_name) | ||
|
||
self.predictor = predictor | ||
label_map = archive.model.vocab.get_index_to_token_vocabulary('labels') | ||
self.labels = [label for _, label in sorted(label_map.items())] | ||
|
||
def predict_minibatch(self, inputs): | ||
for inst in inputs: | ||
pred = self.predictor.predict(inst['sentence']) | ||
tokens = self.predictor._tokenizer.tokenize(inst['sentence']) | ||
yield { | ||
'tokens': tokens, | ||
'probas': np.array(pred['probs']), | ||
'cls_emb': np.array(pred['cls_emb']) | ||
} | ||
|
||
|
||
def input_spec(self): | ||
return { | ||
"sentence": lit_types.TextSegment(), | ||
"label": lit_types.CategoryLabel(vocab=self.labels, required=False) | ||
} | ||
|
||
def output_spec(self): | ||
return { | ||
"tokens": lit_types.Tokens(), | ||
"probas": lit_types.MulticlassPreds(parent="label", vocab=self.labels), | ||
"cls_emb": lit_types.Embeddings() | ||
} | ||
|
||
|
||
def main(): | ||
model = SentimentClassifierModel() | ||
models = {"sst": model} | ||
datasets = {"sst": SSTData(labels=model.labels)} | ||
|
||
lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) | ||
lit_demo.serve() | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters