In [1]:
import io
import re
import string

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
TRAIN_DIR = 'aclImdb/train'
VAL_SPLIT = 0.2
BATCH_SIZE = 1024
SEED = 42

VOCAB_SIZE = 10000
MAX_SEQUENCE_LENGTH = 100
AUTOTUNE = tf.data.experimental.AUTOTUNE

EMBEDDING_DIM = 16

In [3]:
train_ds = keras.preprocessing.text_dataset_from_directory(
    TRAIN_DIR, 
    batch_size=BATCH_SIZE, 
    validation_split=VAL_SPLIT, 
    subset='training',
    seed=SEED
).cache().prefetch(buffer_size=AUTOTUNE)

val_ds = keras.preprocessing.text_dataset_from_directory(
    TRAIN_DIR, 
    batch_size=BATCH_SIZE, 
    validation_split=VAL_SPLIT, 
    subset='validation', 
    seed=SEED
).cache().prefetch(buffer_size=AUTOTUNE)

Found 25000 files belonging to 2 classes.
Using 20000 files for training.
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.


In [4]:
def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
    return tf.strings.regex_replace(stripped_html, f'[{re.escape(string.punctuation)}]', '')

vectorize_layer = keras.layers.experimental.preprocessing.TextVectorization(
    standardize=custom_standardization,
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH
)

text_ds = train_ds.map(lambda x, y: x)
vectorize_layer.adapt(text_ds)

In [5]:
model = keras.models.Sequential([
    vectorize_layer,
    layers.Embedding(VOCAB_SIZE, EMBEDDING_DIM, name='embedding'),
    layers.GlobalAveragePooling1D(),
    layers.Dense(16, activation='relu'),
    layers.Dense(1)
])

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [6]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='logs')

In [7]:
history = model.fit(
    train_ds,
    validation_data=val_ds, 
    epochs=15,
    callbacks=[tensorboard_callback]
)

Epoch 1/15
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


In [8]:
# %load_ext tensorboard

In [9]:
# %tensorboard --logdir logs

In [10]:
weights = model.get_layer('embedding').get_weights()[0]
vocab = vectorize_layer.get_vocabulary()

In [11]:
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

for index, word in enumerate(vocab):
    if index == 0: continue # skip 0, it's padding.
    vec = weights[index] 
    out_v.write('\t'.join([str(x) for x in vec]) + "\n")
    out_m.write(word + "\n")
out_v.close()
out_m.close()