In [1]:
# bring the cwd to the main folder
import os
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)

In [2]:
import tensorflow as tf
from tensorflow import saved_model
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

In [3]:
from tf_agents.environments import tf_py_environment

In [4]:
print(tf.__version__)

2.8.0


In [5]:
AGENT_ID  = 'TFLite_conversion'


In [6]:
def load_policy(agent_id, agent_dir ='agent_checkpoints/'):
    policy_dir = os.path.join(os.getcwd(), agent_dir + agent_id)
    print(policy_dir)
    policy = saved_model.load(policy_dir)
    return policy

In [7]:
policy = load_policy(AGENT_ID)

/home/thinh/Sync/python-projects/echo_gym/echo_gym/bat_snake_env/agent_checkpoints/TFLite_conversion


In [11]:
def load_policy_lite(agent_id, agent_dir ='agent_checkpoints/'):
    policy_dir = os.path.join(os.getcwd(), agent_dir + agent_id)
    print(policy_dir)
    converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir) # path to the SavedModel directory
    converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops
    ]
    policy_lite = converter.convert()
    return policy_lite

In [12]:
policy_lite = load_policy_lite(AGENT_ID)

/home/thinh/Sync/python-projects/echo_gym/echo_gym/bat_snake_env/agent_checkpoints/TFLite_conversion


In [21]:
open('lite_policy.tflite', 'wb').write(policy_lite)

224456

In [13]:
policy.signatures

_SignatureMap({'action': <ConcreteFunction signature_wrapper(*, 0/step_type, 0/observation, 0/reward, 0/discount) at 0x7FD6FF7094F0>, 'get_initial_state': <ConcreteFunction signature_wrapper(*, batch_size) at 0x7FD6FC5452B0>, 'get_train_step': <ConcreteFunction signature_wrapper() at 0x7FD6FC581370>, 'get_metadata': <ConcreteFunction signature_wrapper() at 0x7FD6FF6E5550>})

In [23]:
interpreter = tf.lite.Interpreter(model_content=policy_lite)
interpreter.allocate_tensors()

In [31]:
interpreter.get_input_details()

[{'name': 'action_0/step_type:0',
  'index': 0,
  'shape': array([1], dtype=int32),
  'shape_signature': array([-1], dtype=int32),
  'dtype': numpy.int32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'action_0/discount:0',
  'index': 1,
  'shape': array([1], dtype=int32),
  'shape_signature': array([-1], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'action_0/observation:0',
  'index': 2,
  'shape': array([  1, 100], dtype=int32),
  'shape_signature': array([ -1, 100], dtype=int32),
  'dtype': numpy.float64,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dty

In [32]:
interpreter.get_output_details()

[{'name': 'StatefulPartitionedCall:0',
  'index': 37,
  'shape': array([1], dtype=int32),
  'shape_signature': array([-1], dtype=int32),
  'dtype': numpy.int32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}}]

In [38]:
input = interpreter.get_input_details()[2]

In [40]:
output = interpreter.get_output_details()[0]

In [70]:
input_data = tf.convert_to_tensor(np.zeros((1,100), dtype=np.float64))

In [71]:
input_data

<tf.Tensor: shape=(1, 100), dtype=float64, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.]])>

In [72]:
interpreter.set_tensor(input['index'], input_data)

In [73]:
interpreter.invoke()

In [74]:
interpreter.get_tensor(output['index'])

array([0], dtype=int32)