In [1]:
import sys, os
import pandas as pd
import numpy as np

In [2]:
sys.path.insert(0, '..')

In [3]:
from textgenrnn.textgenrnn import textgenrnn

Using TensorFlow backend.


In [4]:
data_folder = '/media/student/8d1913cf-1155-47a5-a7db-b9a51f445d8f/student/data'

In [5]:
os.path.exists(data_folder)

True

In [6]:
os.listdir(data_folder)

['yelp_dataset',
 'business.csv',
 'jigsaw-toxic-comment-classification-challenge.zip',
 'restaurants.csv',
 'restaurant_reviews.csv',
 'review.csv']

In [7]:
reviews = pd.read_csv(
    data_folder + '/restaurant_reviews.csv',
    nrows=20000,
    sep=',',
    index_col=0,
    usecols=['stars', 'text']
)

In [8]:
reviews.reset_index(inplace=True)

In [9]:
reviews.head()

Unnamed: 0,stars,text
0,5.0,Went in for a lunch. Steak sandwich was delici...
1,4.0,I'll be the first to admit that I was not exci...
2,3.0,Tracy dessert had a big name in Hong Kong and ...
3,1.0,This place has gone down hill. Clearly they h...
4,4.0,"Like walking back in time, every Saturday morn..."


In [10]:
texts = reviews['text'].values
labels = reviews['stars'].values

In [11]:
star2texts = {}
star2labels = {}

In [12]:
for star in [1.0, 2.0, 3.0, 4.0, 5.0]:
    reviews_current = reviews[reviews.stars == star]
    
    star2texts[star] = reviews_current['text'].values
    star2labels[star] = reviews_current['stars'].values

In [13]:
for star in [1.0, 2.0, 3.0, 4.0, 5.0]:
    print(star)
    print(len(star2texts[star]))
    print()

1.0
2259

2.0
1825

3.0
2680

4.0
5217

5.0
8019



In [14]:
num_texts = 800

In [15]:
texts_labels = []

for star in [1.0, 2.0, 3.0, 4.0, 5.0]:
    texts_labels_current = list(
        zip(star2texts[star], star2labels[star])
    )[:num_texts]
    
    texts_labels += texts_labels_current

In [16]:
len(texts_labels)

4000

In [17]:
np.random.shuffle(texts_labels)

In [18]:
texts = [p[0] for p in texts_labels]
labels = [p[1] for p in texts_labels]

In [19]:
texts[0]

"When one thinks of sushi in Northeast Ohio, a place on the border of Parma next to a Subway and a convenience store is not going to come to mind. Thankfully though, the check cashing place is no longer there. As a result I had driven past it many a time without a second thought. A few months ago I figured that I might as well give it a shot.\r\n\r\nI'm glad I did, to put it lightly. The service has been attentive each time I have been there, and the decor is nice. What surprised me is that rarely have I seen anyone else eating there. Either I'm there at the wrong times or others have made the same assumption I did. I had miso soup a couple of times while they prepared the sushi, and it is just the right blend of ingredients. I'm a snob when it comes to miso so impressing me with that is a big deal.\r\n\r\nAs for the star, the sushi, it is made fresh with care and it shows. The $3 happy hour rolls are delicious, and the specialty rolls are even better. The Parma rolls in particular hav

In [20]:
labels[0]

5.0

In [21]:
label2sentiment = {
    1.0: -1,
    2.0: -0.5,
    3.0: 0,
    4.0: +0.5,
    5.0: +1
}

In [22]:
sentiments = [label2sentiment[label] for label in labels]

In [23]:
print(sentiments[:10])

[1, 0, 0.5, -1, -0.5, 0.5, 0.5, -0.5, -1, 0]


In [24]:
model = textgenrnn()

word_level = False
new_model = False
num_epochs = 2
gen_epochs = 1
max_length = 40

train_size=0.9

In [None]:
model.train_on_texts(
    texts,
    sentiment_values=labels,
    word_level=word_level,
    num_epochs=num_epochs,
    gen_epochs=gen_epochs,
    max_length=max_length,
    new_model=new_model,
    train_size=train_size)

Training on 2,226,814 character sequences.
Epoch 1/2

In [None]:
model.save()

In [None]:
print('Samples')

sentiment_values = [-1, -0.8, -0.5, -0.1, 0, +0.1, +0.5, +0.8, +1]

for sentiment_value in sentiment_values:
    print('Sentiment:', sentiment_value)
    print(model.generate(1, sentiment_value, return_as_list=True)[0])
    print()