Classification of AG news data using the Universal Sentence Encoder

https://www.tensorflow.org/datasets/catalog/ag_news_subset

https://arxiv.org/abs/1803.11175 (Universal Sentence Encoder paper)

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
import random
seed=54
random.seed(seed)
tf.random.set_seed(seed)
print(tf.__version__)
print(tfds.__version__)
print(hub.__version__)

2.6.4
4.3.0
0.12.0


In [2]:
tfds.disable_progress_bar()
train_ds, test_ds = tfds.load('ag_news_subset', as_supervised=True, split=['train','test'])
train_size = int(0.8 * train_ds.cardinality().numpy())
val_ds = train_ds.skip(train_size)
train_ds = train_ds.take(train_size) 

2022-12-15 16:42:03.542015: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "Not found: Could not locate the credentials file.". Retrieving token from GCE failed with "Failed precondition: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata".


[1mDownloading and preparing dataset 11.24 MiB (download: 11.24 MiB, generated: 35.79 MiB, total: 47.03 MiB) to /root/tensorflow_datasets/ag_news_subset/1.0.0...[0m
[1mDataset ag_news_subset downloaded and prepared to /root/tensorflow_datasets/ag_news_subset/1.0.0. Subsequent calls will reuse this data.[0m


2022-12-15 16:42:53.770999: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [3]:
train_ds.cardinality().numpy(), val_ds.cardinality().numpy(), test_ds.cardinality().numpy()

(96000, 24000, 7600)

In [4]:
train_ds = train_ds.shuffle(1024)
train_ds = train_ds.batch(64).prefetch(tf.data.experimental.AUTOTUNE)
val_ds = val_ds.batch(64).prefetch(tf.data.experimental.AUTOTUNE)
val_ds = val_ds.cache()
test_ds = test_ds.batch(64).prefetch(tf.data.experimental.AUTOTUNE)

In [5]:
hub_layer = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4",
                           input_shape=[], dtype=tf.string)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(64, activation='swish'))
model.add(tf.keras.layers.Dense(64, activation='swish'))
model.add(tf.keras.layers.Dense(4, activation='softmax'))
model.summary()


2022-12-15 16:43:31.023266: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer (KerasLayer)     (None, 512)               256797824 
_________________________________________________________________
dense (Dense)                (None, 64)                32832     
_________________________________________________________________
dense_1 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_2 (Dense)              (None, 4)                 260       
Total params: 256,835,076
Trainable params: 37,252
Non-trainable params: 256,797,824
_________________________________________________________________


In [6]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])
es = tf.keras.callbacks.EarlyStopping(patience=9, verbose=1, restore_best_weights=True)
history = model.fit(train_ds, epochs=500, callbacks=[es], verbose=2, validation_data=val_ds)

Epoch 1/500


2022-12-15 16:44:05.322185: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


1500/1500 - 44s - loss: 0.3747 - acc: 0.8743 - val_loss: 0.3255 - val_acc: 0.8848
Epoch 2/500
1500/1500 - 32s - loss: 0.3180 - acc: 0.8877 - val_loss: 0.3168 - val_acc: 0.8869
Epoch 3/500
1500/1500 - 32s - loss: 0.3105 - acc: 0.8888 - val_loss: 0.3167 - val_acc: 0.8855
Epoch 4/500
1500/1500 - 32s - loss: 0.3042 - acc: 0.8904 - val_loss: 0.3118 - val_acc: 0.8868
Epoch 5/500
1500/1500 - 32s - loss: 0.2987 - acc: 0.8926 - val_loss: 0.3091 - val_acc: 0.8883
Epoch 6/500
1500/1500 - 32s - loss: 0.2940 - acc: 0.8943 - val_loss: 0.3116 - val_acc: 0.8883
Epoch 7/500
1500/1500 - 32s - loss: 0.2898 - acc: 0.8948 - val_loss: 0.3055 - val_acc: 0.8908
Epoch 8/500
1500/1500 - 32s - loss: 0.2861 - acc: 0.8964 - val_loss: 0.3034 - val_acc: 0.8914
Epoch 9/500
1500/1500 - 32s - loss: 0.2818 - acc: 0.8979 - val_loss: 0.2966 - val_acc: 0.8936
Epoch 10/500
1500/1500 - 32s - loss: 0.2777 - acc: 0.8990 - val_loss: 0.2963 - val_acc: 0.8938
Epoch 11/500
1500/1500 - 32s - loss: 0.2734 - acc: 0.9005 - val_loss: 0

In [7]:
model.evaluate(test_ds)



[0.30084723234176636, 0.898552656173706]