In [1]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 3114889421047074357
 xla_global_id: -1]

In [2]:
import numpy as np

import tensorflow as tf

!pip install -q tensorflow-hub
!pip install -q tensorflow-datasets
import tensorflow_hub as hub
import tensorflow_datasets as tfds

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.experimental.list_physical_devices("GPU") else "NOT AVAILABLE")



Version:  2.12.0
Eager mode:  True
Hub version:  0.14.0
GPU is NOT AVAILABLE




In [3]:
# Split the training set into 60% and 40%, so we'll end up with 15,000 examples
# for training, 10,000 examples for validation and 25,000 examples for testing.
train_data, validation_data, test_data = tfds.load(
    name="imdb_reviews", 
    split=('train[:60%]', 'train[60%:]', 'test'),
    as_supervised=True)

In [4]:
def preprocess(X_batch, n_words=300):
    shape = tf.shape(X_batch) * tf.constant([1, 0]) + tf.constant([0, n_words]) # [num docs, n_words]
    #Z = tf.strings.substr(X_batch, 0, 300)
    Z = tf.strings.lower(X_batch)
    Z = tf.strings.regex_replace(Z, b"<br\\s*/?>", b" ")
    Z = tf.strings.regex_replace(Z, b"[^a-z]", b" ")
    Z = tf.strings.split(Z)
    return Z.to_tensor(shape=shape, default_value=b"<pad>")

X_example = tf.constant(["It's a great, great movie! I loved it.", "It was terrible, run away!!!"])
preprocess(X_example)

<tf.Tensor: shape=(2, 300), dtype=string, numpy=
array([[b'it', b's', b'a', b'great', b'great', b'movie', b'i', b'loved',
        b'it', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>',
        b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>', b'<pad>

In [5]:
from collections import Counter

def get_vocabulary(data_sample, max_size=1000):
    preprocessed_reviews = preprocess(data_sample).numpy()
    counter = Counter()
    for words in preprocessed_reviews:
        for word in words:
            if word != b"<pad>":
                counter[word] += 1
    return [b"<pad>"] + [word for word, count in counter.most_common(max_size)]

get_vocabulary(X_example)

[b'<pad>',
 b'it',
 b'great',
 b's',
 b'a',
 b'movie',
 b'i',
 b'loved',
 b'was',
 b'terrible',
 b'run',
 b'away']

In [6]:
class TextVectorization(tf.keras.layers.Layer):
    def __init__(self, max_vocabulary_size=1000, n_oov_buckets=100, dtype=tf.string, **kwargs):
        super().__init__(dtype=dtype, **kwargs)
        self.max_vocabulary_size = max_vocabulary_size
        self.n_oov_buckets = n_oov_buckets

    def adapt(self, data_sample):
        self.vocab = get_vocabulary(data_sample, self.max_vocabulary_size)
        words = tf.constant(self.vocab)
        word_ids = tf.range(len(self.vocab), dtype=tf.int64)
        vocab_init = tf.lookup.KeyValueTensorInitializer(words, word_ids)
        self.table = tf.lookup.StaticVocabularyTable(vocab_init, self.n_oov_buckets)
        
    def call(self, inputs):
        preprocessed_inputs = preprocess(inputs)
        return self.table.lookup(preprocessed_inputs)
    
    def get_config(self):
        config = super(TextVectorization, self).get_config()
        config.update({
            'max_vocabulary_size': self.max_vocabulary_size,
            'n_oov_buckets': self.n_oov_buckets
        })
        return config

In [7]:
max_vocabulary_size = 100000
n_oov_buckets = 100

text_vectorization = TextVectorization(max_vocabulary_size, n_oov_buckets,
                                       input_shape=[])

train_examples_batch, train_labels_batch = next(iter(train_data.batch(15000)))
text_vectorization.adapt(train_examples_batch)

In [8]:
class BagOfWords(tf.keras.layers.Layer):
    def __init__(self, n_tokens, dtype=tf.int32, **kwargs):
        super().__init__(dtype=tf.int32, **kwargs)
        self.n_tokens = n_tokens
        
    def call(self, inputs):
        one_hot = tf.one_hot(inputs, self.n_tokens)
        return tf.reduce_sum(one_hot, axis=1)[:, 1:]
    
    def get_config(self):
        config = super(BagOfWords, self).get_config()
        config.update({
            'n_tokens': self.n_tokens
        })
        return config

In [9]:
n_tokens = len(text_vectorization.vocab) + n_oov_buckets
bag_of_words = BagOfWords(n_tokens)

In [None]:
input = tf.keras.layers.Input(shape=(), dtype=tf.string)
vectorized = text_vectorization(input)
bow = bag_of_words(vectorized)
outputs = tf.keras.layers.Dense(1)(bow)
model = tf.keras.models.Model(inputs=[input],outputs=[outputs])
model.summary()

In [11]:
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=opt,
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("text_bagofwords", save_weights_only=True, save_best_only=True)

In [None]:
history = model.fit(train_data.shuffle(10000).batch(128),
                    epochs=30,
                    validation_data=validation_data.batch(128),
                    callbacks=[checkpoint_cb],
                    verbose=1)

In [None]:
model.load_weights("text_bagofwords")

In [None]:
results = model.evaluate(test_data.batch(32), verbose=2)

for name, value in zip(model.metrics_names, results):
  print("%s: %.3f" % (name, value))