In [1]:
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_hub as hub

**Load Dataset**

In [3]:
# Load Dataset
df_train = pd.read_csv('/content/train.csv')
df_test = pd.read_csv('/content/test.csv')

# Separate dataset into data and label
train_data = df_train.loc[:,'report'].values
train_label = df_train.loc[:,'label'].values
test_data = df_test.loc[:,'report'].values
test_label = df_test.loc[:,'label'].values

# Change label to categorical
train_label = tf.keras.utils.to_categorical(train_label,5)
test_label = tf.keras.utils.to_categorical(test_label,5)

**Load pretrained word embbedings**

In [4]:
# load pretrained word embeddings
model = "https://tfhub.dev/google/nnlm-id-dim128/2"
hub_layer = hub.KerasLayer(model, input_shape=[], dtype=tf.string, trainable=True)

**Sequence Model**

In [5]:
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(32, activation='relu'))
model.add(tf.keras.layers.Dense(5, activation='softmax'))
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy', metrics=['accuracy']
              )
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer (KerasLayer)     (None, 128)               112461824 
_________________________________________________________________
dense (Dense)                (None, 32)                4128      
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 165       
Total params: 112,466,117
Trainable params: 112,466,117
Non-trainable params: 0
_________________________________________________________________


**Training**

In [7]:
num_epochs = 10
history = model.fit(train_data, train_label, epochs=num_epochs, validation_data=(test_data, test_label), verbose = 1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


**Save Model**

In [22]:
tf.saved_model.save(model, 'classification_model') #saving model


FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.



FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.


INFO:tensorflow:Assets written to: classification_model/assets


INFO:tensorflow:Assets written to: classification_model/assets


In [23]:
!zip -r /content/classification_model.zip /content/classification_model

  adding: content/classification_model/ (stored 0%)
  adding: content/classification_model/assets/ (stored 0%)
  adding: content/classification_model/assets/tokens.txt (deflated 51%)
  adding: content/classification_model/assets/.ipynb_checkpoints/ (stored 0%)
  adding: content/classification_model/.ipynb_checkpoints/ (stored 0%)
  adding: content/classification_model/saved_model.pb (deflated 85%)
  adding: content/classification_model/variables/ (stored 0%)
  adding: content/classification_model/variables/variables.index (deflated 58%)
  adding: content/classification_model/variables/.ipynb_checkpoints/ (stored 0%)
  adding: content/classification_model/variables/variables.data-00000-of-00001 (deflated 53%)


In [25]:
from google.colab import files
files.download("/content/classification_model.zip") # download file from google colab

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

**Converting to tflite**

In [24]:
converter = tf.lite.TFLiteConverter.from_saved_model('classification_model')
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]

tflite_model = converter.convert()

In [26]:
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)