Skip to content

Commit

Permalink
[MRG] Sentiment analysis (#490)
Browse files Browse the repository at this point in the history
* sentiment_analysis

* create mkdocs for sentiment analysis example

* quick changes
  • Loading branch information
satyakesav authored and haifeng-jin committed Jan 27, 2019
1 parent fcb45df commit 97c02dc
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 19 deletions.
Expand Up @@ -7,7 +7,7 @@
from autokeras.pretrained.base import Pretrained
from autokeras.text.pretrained_bert.tokenization import BertTokenizer
from autokeras.text.pretrained_bert.modeling import BertForSequenceClassification
from autokeras.utils import download_file_from_google_drive
from autokeras.utils import download_file_from_google_drive, get_device
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler


Expand Down Expand Up @@ -53,19 +53,19 @@ def convert_examples_to_features(examples, max_seq_length, tokenizer):
return features


class TextSentiment(Pretrained):
class SentimentAnalysis(Pretrained):

def __init__(self):

super(TextSentiment, self).__init__()
super(SentimentAnalysis, self).__init__()
self.device = None
self.tokenizer = None
self.model = None
self.load()

def load(self):

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = get_device()
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

output_model_file = os.path.join(tempfile.gettempdir(), 'text_sentiment_pytorch_model.bin')
Expand All @@ -90,6 +90,7 @@ def predict(self, x_predict):
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)

self.model.eval()
sentence_polarity = None
for input_ids, input_mask, segment_ids in eval_dataloader:
input_ids = input_ids.to(self.device)
input_mask = input_mask.to(self.device)
Expand Down
Empty file.
@@ -1,8 +1,8 @@
from autokeras.pretrained.text_sentiment import TextSentiment
from autokeras.pretrained.sentiment_analysis import SentimentAnalysis

text_cls = TextSentiment()
sentiment_cls = SentimentAnalysis()

polarity = text_cls.predict("The model is working well..")
polarity = sentiment_cls.predict("The model is working well..")

print("Polarity of the input sentence is (sentiment is +ve if polarity > 0.5) : ", polarity)

40 changes: 30 additions & 10 deletions mkdocs/docs/start.md
Expand Up @@ -362,6 +362,7 @@ for a given tabular dataset. (Currently, theis module only supports lightgbm cla


#### by Wuyang Chen from [Dr. Atlas Wang's group](http://www.atlaswang.com/) at CSE Department, Texas A&M.
class_id_mapping = {0 : "Business", 1 : "Sci/Tech", 2 : "Sports", 3 : "World"}

`ObjectDetector` in `object_detector.py` is a child class of `Pretrained`. Currently it can load a pretrained SSD model ([Liu, Wei, et al. "Ssd: Single shot multibox detector." European conference on computer vision. Springer, Cham, 2016.](https://arxiv.org/abs/1512.02325)) and find object(s) in a given image.

Expand All @@ -384,16 +385,35 @@ Finally you can make predictions against an image:
Function ```detector.predict()``` requires the path to the image. If the ```output_file_path``` is not given, the ```detector``` will just return the numerical results as a list of dictionaries. Each dictionary is like {"left": int, "top": int, "width": int, "height": int: "category": str, "confidence": float}, where ```left``` and ```top``` is the (left, top) coordinates of the bounding box of the object and ```width``` and ```height``` are width and height of the box. ```category``` is a string representing the class the object belongs to, and the confidence can be regarded as the probability that the model believes its prediction is correct. If the ```output_file_path``` is given, then the results mentioned above will be plotted and saved in a new image file with suffix "_prediction" into the given ```output_file_path```. If you run the example/object_detection/object_detection_example.py, you will get result
```[{'category': 'person', 'width': 331, 'height': 500, 'left': 17, 'confidence': 0.9741123914718628, 'top': 0}]```














### Sentiment Analysis tutorial. [[source]]( https://github.com/jhfjhfj1/autokeras/blob/master/autokeras/pretrained/text_sentiment.py)


The sentiment analysis module provides an interface to find the sentiment of any text input. The pretrained model is obtained by training [Google AI’s BERT model]( https://arxiv.org/abs/1810.04805) on [IMDb dataset]( http://ai.stanford.edu/~amaas/data/sentiment/).

Let’s import the `SentimentAnalysis` module from *sentiment_analysis.py*. It is derived from the super class `Pretrained`.
```python
from autokeras.pretrained.text_sentiment import SentimentAnalysis
sentiment_cls = SentimentAnalysis()
```
During initialization of `SentimentAnalysis`, the pretrained model is loaded into memory i.e. CPU’s or GPU’s, if available.

Now, you may directly call the `predict` function in `SentimentAnalysis` class on any input sentence provided as a string as shown below. The function returns a value between 0 and 1.
```python
polarity = sentiment_cls.predict("The model is working well..")
```
**Note:** If the output value of the `predict` function is close to 0, it implies the statement has negative sentiment, whereas value close to 1 implies positive sentiment.

If you run *sentiment_analysis_example.py*, you should get an output value of 0.9 which implies that the input statement *The model is working well..* has strong positive sentiment.








<!-- [Data with numpy array (.npy) format.]: https://github.com/jhfjhfj1/autokeras/blob/master/examples/a_simple_example/mnist.py
[What if your data are raw image files (*e.g.* .jpg, .png, .bmp)?]: https://github.com/jhfjhfj1/autokeras/blob/master/examples/a_simple_example/load_raw_image.py
Expand Down
4 changes: 2 additions & 2 deletions tests/pretrained/test_sentiment_analysis.py
@@ -1,7 +1,7 @@
from autokeras.pretrained.text_sentiment import TextSentiment
from autokeras.pretrained.sentiment_analysis import SentimentAnalysis

def test_sentiment_analysis():
sentiment_analyzer = TextSentiment()
sentiment_analyzer = SentimentAnalysis()

positive_polarity = sentiment_analyzer.predict("The model is working really well.")
if positive_polarity <= 0.5:
Expand Down

0 comments on commit 97c02dc

Please sign in to comment.