# Prepare nltk training data for seq-gan

In [1]:
import json
import nltk
import numpy as np
import pandas as pd

from collections import Counter
from nltk.corpus import movie_reviews

In [2]:
# Params
MAX_SENTENCE_LENGTH = 16
VOCAB_SIZE = 5000
TRAINING_EXAMPLES = 2000

TRAIN_DATA_PATH = 'gan_train.txt'
VOCAB_PATH = 'vocabulary.json'

In [3]:
# Download nltk movie review dataset 
nltk.download('movie_reviews')
review_data = [(movie_reviews.words(file_id),category) for file_id in movie_reviews.fileids() for category in movie_reviews.categories(file_id)]

# Load dataset
df = pd.DataFrame(review_data, columns=['review', 'sentimnet'])
df = df[:TRAINING_EXAMPLES]
df.tail(3)

[nltk_data] Downloading package movie_reviews to /root/nltk_data...
[nltk_data]   Package movie_reviews is already up-to-date!


Unnamed: 0,review,sentimnet
1997,"(glory, --, starring, matthew, broderick, ,, d...",pos
1998,"(steven, spielberg, ', s, second, epic, film, ...",pos
1999,"(truman, (, "", true, -, man, "", ), burbank, is...",pos


In [4]:
def preprocess(sentence, max_length): 
    tokens = sentence
    # Pad to max_length
    if len(tokens) < max_length:
        tokens.extend([0] * (MAX_SENTENCE_LENGTH - len(tokens)))
    # Crop to max_length
    elif len(tokens) > max_length:
        tokens = tokens[:max_length]
    return tokens

# Apply preprocessing
df['processed'] = df['review'].apply(lambda s: preprocess(s, MAX_SENTENCE_LENGTH))

In [5]:
df.tail()

Unnamed: 0,review,sentimnet,processed
1995,"(wow, !, what, a, movie, ., it, ', s, everythi...",pos,"[wow, !, what, a, movie, ., it, ', s, everythi..."
1996,"(richard, gere, can, be, a, commanding, actor,...",pos,"[richard, gere, can, be, a, commanding, actor,..."
1997,"(glory, --, starring, matthew, broderick, ,, d...",pos,"[glory, --, starring, matthew, broderick, ,, d..."
1998,"(steven, spielberg, ', s, second, epic, film, ...",pos,"[steven, spielberg, ', s, second, epic, film, ..."
1999,"(truman, (, "", true, -, man, "", ), burbank, is...",pos,"[truman, (, "", true, -, man, "", ), burbank, is..."


In [6]:
# Create vocabulary, id of 0 = OOV. 10k most common
all_tokens = [t for sent in df['processed'].tolist() for t in sent]
counted = Counter(all_tokens).most_common(4999)
vocabulary = {w[0]:i+1 for i, w in enumerate(counted)}
vocabulary['<unk>'] = 0
inverse_vocabulary = {v:k for k,v in vocabulary.items()}

In [7]:
# Convert to numpy array
df['vector'] = df['processed'].apply(lambda x: [vocabulary.get(w, 0) for w in x])
train_set = np.array(df['vector'].tolist())
train_set.shape

(2000, 16)

In [8]:
# Save training set
np.savetxt(TRAIN_DATA_PATH, train_set, fmt='%d')

# Save vocabulary
with open(VOCAB_PATH, 'w') as f:
    json.dump(vocabulary, f)

In [9]:
!head -5 gan_train.txt

67 16 71 653 2342 306 10 3 1454 552 2 2343 13 468 825 7
1 169 240 5 12 469 20 44 654 15 2344 1455 7 17 5 12
17 8 43 53 228 15 170 3 2345 20 553 2346 23 1 2347 4
6 1033 23 2348 6 8 331 554 7 5 48 171 14 655 2 1034
81 16 3 1456 2349 113 2350 2351 1457 3 202 39 3 2352 2353 1458
