<a href="https://colab.research.google.com/github/lain13/keras_basic/blob/master/TextVectorization_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [31]:
import tensorflow as tf
import pickle
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import preprocessing
from tensorflow.keras import utils
from tensorflow.keras import models
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization


In [32]:
text_dataset = tf.data.Dataset.from_tensor_slices([
                                                   "this is some clean text", 
                                                   "some more text", 
                                                   "even some more text"]) 
# Fit a TextVectorization layer
vectorizer = TextVectorization(max_tokens=10, output_mode='tf-idf',ngrams=None)    
vectorizer.adapt(text_dataset.batch(1024))

In [33]:
# Vector for word "this"
vectorizer(['this'])

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.        , 0.        , 0.        , 0.        , 0.91629076,
        0.        , 0.        , 0.        , 0.        , 0.        ]],
      dtype=float32)>

In [34]:
# Pickle the config and weights
pickle.dump({'config': vectorizer.get_config(),
             'weights': vectorizer.get_weights()}
            , open("tv_layer.pkl", "wb"))

print ("*"*10)

**********


In [35]:
# Later you can unpickle and use 
# `config` to create object and 
# `weights` to load the trained weights. 

from_disk = pickle.load(open("tv_layer.pkl", "rb"))
new_v = TextVectorization.from_config(from_disk['config'])
# You have to call `adapt` with some dummy data (BUG in Keras)
new_v.adapt(tf.data.Dataset.from_tensor_slices(["xyz"]))
new_v.set_weights(from_disk['weights'])

# Lets see the Vector for word "this"
print (new_v(["this"]))

tf.Tensor(
[[0.         0.         0.         0.         0.91629076 0.
  0.         0.         0.         0.        ]], shape=(1, 10), dtype=float32)


In [8]:
import tensorflow as tf
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

data = [
    "The sky is blue.",
    "Grass is green.",
    "Hunter2 is my password.",
]


In [14]:
# Create vectorizer.
text_dataset = tf.data.Dataset.from_tensor_slices(data)
vectorizer = TextVectorization(
    max_tokens=100000, output_mode='tf-idf', ngrams=None, name='output_layer'
)
vectorizer.adapt(text_dataset.batch(1024))


In [15]:
input = layers.Input(shape=(1, ), dtype=tf.string, name='input_layer')
output = vectorizer(input)
model = models.Model(input, output)
model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_layer (InputLayer)     [(None, 1)]               0         
_________________________________________________________________
output_layer (TextVectorizat (None, 100000)            0         
Total params: 100,000
Trainable params: 0
Non-trainable params: 100,000
_________________________________________________________________


In [17]:
# Save.
filepath = "vectorizer-model"
model.save(filepath, save_format="tf")

INFO:tensorflow:Assets written to: vectorizer-model/assets


In [23]:
# Load.
loaded_model = tf.keras.models.load_model(filepath)
loaded_vectorizer = loaded_model.get_layer('input_layer')



In [21]:
dir(loaded_model)

['_TF_MODULE_IGNORED_PROPERTIES',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_activity_regularizer',
 '_add_trackable',
 '_add_variable_with_custom_getter',
 '_assert_compile_was_called',
 '_assert_weights_created',
 '_auto_track_sub_layers',
 '_autocast',
 '_autographed_call',
 '_base_model_initialized',
 '_build_input_shape',
 '_call_accepts_kwargs',
 '_call_arg_was_passed',
 '_call_fn_arg_defaults',
 '_call_fn_arg_positions',
 '_call_fn_args',
 '_call_full_argspec',
 '_callable_losses',
 '_cast_single_input',
 '_check_call_args',
 '_checkpoint_dependencies',
 '_clear_losses',
 '_compile_was_called',
 '_compiled_tr