In [1]:
import io
import os
import re
import shutil
import string
import tensorflow as tf

In [2]:
url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
dataset = tf.keras.utils.get_file("aclImdb_v1.tar.gz", url,untar=True) # untar = true so it untar to florder ~/.keras/datasets/aclImdb

In [3]:
dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')
train_dir = os.path.join(os.path.join(dataset_dir),'train')
remove_dir = os.path.join(train_dir, 'unsup')
shutil.rmtree(remove_dir)

In [4]:
batch_size = 1024
seed = 123
train_ds = tf.keras.utils.text_dataset_from_directory(train_dir,batch_size=batch_size,validation_split=0.2,subset='training',seed=seed)
val_ds = tf.keras.utils.text_dataset_from_directory(train_dir,batch_size=batch_size,validation_split=0.2,subset='validation',seed=seed)
train_ds = train_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

Found 25000 files belonging to 2 classes.
Using 20000 files for training.
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.


In [5]:
def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
  return tf.strings.regex_replace(stripped_html,
                                  '[%s]' % re.escape(string.punctuation), '')

In [6]:
vocab_size = 10000
seque_len = 100
embedding_dim = 16
text_layer = tf.keras.layers.TextVectorization(max_tokens=vocab_size,standardize='lower_and_strip_punctuation',output_mode='int',output_sequence_length=seque_len)
text_ds = train_ds.map(lambda x,y: x)
text_layer.adapt(text_ds)
model = tf.keras.Sequential([
    text_layer,
    tf.keras.layers.Embedding(vocab_size, embedding_dim, name="embedding"),
    tf.keras.layers.GlobalAveragePooling1D(),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(2)
])

In [7]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(train_ds,validation_data=val_ds,epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7f335d226d68>

In [17]:
weights = model.get_layer('embedding').get_weights()[0]
vocab = text_layer.get_vocabulary()
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')
for index, word in enumerate(vocab):
  if index == 0:
    continue  # skip 0, it's padding.
  vec = weights[index]
  out_v.write('\t'.join([str(x) for x in vec]) + "\n")
  out_m.write(word + "\n")
out_v.close()
out_m.close()