In [1]:
# Install Model maker
!pip install tflite-model-maker



In [3]:
# Imports and check that we are using TF2.x
import numpy as np
import os
import pandas as pd

from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.text_classifier import DataLoader

import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')

In [25]:
# Download the dataset as a CSV and store as data_file
import os
data_file = os.path.abspath('spam_dataset.csv')

In [24]:
data_file.head()

AttributeError: 'str' object has no attribute 'head'

In [26]:
# Use a model spec from model maker. Options are 'mobilebert_classifier', 'bert_classifier' and 'average_word_vec'
# The first 2 use the BERT model, which is accurate, but larger and slower to train
# Average Word Vec is kinda like transfer learning where there are pre-trained word weights
# and dictionaries
spec = model_spec.get('average_word_vec')
spec.num_words = 2000
spec.seq_len = 20
spec.wordvec_dim = 7

In [29]:
# Load the CSV using DataLoader.from_csv to make the training_data
data = DataLoader.from_csv(
      filename=data_file,
      text_column='text', 
      label_column='label', 
      model_spec=spec,
      delimiter=',',
      shuffle=True,
      is_training=True)

train_data, test_data = data.split(0.9)

In [30]:
# Build the model
model = text_classifier.create(train_data, model_spec=spec, epochs=50, validation_data=test_data)

Epoch 2/2
Epoch 3/3
Epoch 4/4
Epoch 5/5
Epoch 6/6
Epoch 7/7
Epoch 8/8
Epoch 9/9
Epoch 10/10
Epoch 11/11
Epoch 12/12
Epoch 13/13
Epoch 14/14
Epoch 15/15
Epoch 16/16
Epoch 17/17
Epoch 18/18
Epoch 19/19
Epoch 20/20
Epoch 21/21
Epoch 22/22
Epoch 23/23
Epoch 24/24
Epoch 25/25
Epoch 26/26
Epoch 27/27
Epoch 28/28
Epoch 29/29
Epoch 30/30
Epoch 31/31
Epoch 32/32
Epoch 33/33
Epoch 34/34
Epoch 35/35
Epoch 36/36
Epoch 37/37
Epoch 38/38
Epoch 39/39
Epoch 40/40
Epoch 41/41
Epoch 42/42
Epoch 43/43
Epoch 44/44
Epoch 45/45
Epoch 46/46
Epoch 47/47
Epoch 48/48
Epoch 49/49
Epoch 50/50


In [31]:
loss, accuracy = model.evaluate(train_data)



In [32]:
# This will export to TFLite format with the model only. 
# if you see a .json file in this directory, it is NOT the JSON model for TFJS
# See below for how to generate that.
# Please note that if you run this cell to create the tflite model then the 
# export to TFJS will fail. You'll need to rerun the model training first
model.export(export_dir='/mm_spam')
# If you want the labels and the vocab, for example for iOS, you can use this
model.export(export_dir='/mm_spam/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])

# You can find your files in colab by clicking the 'folder' tab to the left of
# this code window, and then navigating 'up' a directory to find the root
# directory listing -- and from there you should see /mm_spam/

In [33]:
# Use this section for export to TFJS
# Please note that if you run the above cell to create the tflite model then the 
# export to TFJS will fail. You'll need to rerun the model training first
model.export(export_dir="/mm_js/", export_format=[ExportFormat.TFJS, ExportFormat.LABEL, ExportFormat.VOCAB])

AttributeError: 'TextClassifier' object has no attribute 'get_output_details'

In [34]:
# Optional extra
# You can use this cell to export details for projector.tensorflow.org
# Where you can explore the embeddings that were learned for this dataset
embeddings = model.model.layers[0]
weights = embeddings.get_weights()[0]
tokenizer = model.model_spec.vocab

import io

out_v = io.open('vecs.tsv', 'w', encoding='utf-8')
out_m = io.open('meta.tsv', 'w', encoding='utf-8')
for word in tokenizer:
  #word = tokenizer.decode([word_num])
  value = tokenizer[word]
  embeddings = weights[value]
  out_m.write(word + "\n")
  out_v.write('\t'.join([str(x) for x in embeddings]) + "\n")
out_v.close()
out_m.close()


try:
  from google.colab import files
except ImportError:
  pass
else:
  files.download('vecs.tsv')
  files.download('meta.tsv')

In [35]:
get_input_details()

NameError: name 'get_input_details' is not defined