<a href="https://colab.research.google.com/github/h4ck4l1/datasets/blob/main/NLP_with_RNN_and_Attention/pretrained_embeddings_using_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install -q -U tensorflow-text
!pip3 install -q -U tf-models-official
import tensorflow as tf
import numpy as np
import os
from tensorflow import keras
import tensorflow_datasets as tfds
import plotly.graph_objects as go
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization
from google.colab import auth
auth.authenticate_user()
from IPython.display import clear_output
clear_output()

In [None]:
os.environ["TFHUB_MODEL_LOAD_FORMAT"] = "UNCOMPRESSED"

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)



In [None]:
tfds_name = "imdb_reviews"

In [None]:
with tf.device("/job:localhost"):
    in_memory_ds = tfds.load(tfds_name,batch_size=-1,shuffle_files=True)
    ds_info = tfds.builder(tfds_name).info

In [None]:
def bert_preprocessing_model(seq_length=512):
    input_segments = keras.layers.Input(shape=(),dtype=tf.string)
    bert_preprocess = hub.load("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
    tokenize = hub.KerasLayer(bert_preprocess.tokenize,name="tokenizer")
    out_sequence = [tokenize(input_segments)]
    packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,arguments=dict(seq_length=seq_length),name="packer")
    model_inputs = packer(out_sequence)
    return keras.Model(input_segments,model_inputs)

In [None]:
AUTO = tf.data.AUTOTUNE
def load_dataset_from_memory(in_memory_ds,split,ds_info,batch_size,bert_preprocessing):
    ds = tf.data.Dataset.from_tensor_slices(in_memory_ds[split])
    ds_size = ds_info.splits[split].num_examples
    if "train" in split:
        ds = ds.shuffle(ds_size//4)
        ds = ds.repeat()
    ds = ds.batch(batch_size,drop_remainder=True)
    ds = ds.map(lambda tfds_dict: (bert_preprocessing(tfds_dict['text']),tfds_dict['label']))
    ds = ds.cache()
    ds = ds.prefetch(AUTO)
    return ds,ds_size

In [None]:
class SentimentModel(keras.Model):

    def __init__(self,**kwargs):
        super(SentimentModel,self).__init__(**kwargs)
        self.encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2",trainable=True)
        self.dense_1 = keras.layers.Dense(64,"relu")
        self.out = keras.layers.Dense(1,"sigmoid")

    def call(self,processed_text):
        z = self.encoder(processed_text)
        z = self.dense_1(z['pooled_output'])
        z = self.out(z)
        return z


In [None]:
BATCH_SIZE = 16*8

bert_preprocessing = bert_preprocessing_model()

with strategy.scope():
    train_ds,train_size = load_dataset_from_memory(in_memory_ds,"train",ds_info,BATCH_SIZE,bert_preprocessing)
    valid_ds,valid_size = load_dataset_from_memory(in_memory_ds,"test",ds_info,BATCH_SIZE,bert_preprocessing)
    train_steps = train_size // BATCH_SIZE
    valid_steps = valid_size // BATCH_SIZE
    model = SentimentModel()
    optimizer = optimization.create_optimizer(
        init_lr=2e-4,
        num_train_steps=train_steps,
        num_warmup_steps=train_steps//8,
        optimizer_type="adamw",
        end_lr=1e-4
    )
    model.compile(
        loss="binary_crossentropy",
        optimizer=optimizer,
        metrics=["accuracy"],
        steps_per_execution=10
    )

history = model.fit(train_ds,epochs=10,validation_data=valid_ds,steps_per_epoch=train_steps,validation_steps=valid_steps)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [None]:
fig = go.Figure()
fig.add_traces(go.Scatter(y=history.history['loss'],mode="lines",name="Train Loss"))
fig.add_traces(go.Scatter(y=history.history['val_loss'],mode="lines",name="Validation Loss"))
fig.show()

In [None]:
fig = go.Figure()
fig.add_traces(go.Scatter(y=history.history['accuracy'],mode="lines",name="Train Accuracy"))
fig.add_traces(go.Scatter(y=history.history['val_accuracy'],mode="lines",name="Validation Accuracy"))
fig.show()