In [78]:
import os
import pathlib
import numpy as np
import tensorflow as tf
import pickle
!pip install tflite



In [166]:
tflite_models_dir = pathlib.Path("../../saved_models")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_quant_file = os.path.join(tflite_models_dir, "ssd_mobilenetV2_fpnlite_UINT8_AP24.tflite")
# tflite_model_quant_file = os.path.join(tflite_models_dir, "lenet5_int8.tflite")

Extract the list of operator names with their corresponding operator indexes from schema.fbs file of Tensorflow

In [143]:
import re

# Path to the schema.fbs file
schema_fbs_file = os.path.join(os.getcwd(), "schema.fbs")

# Function to extract builtin operators from schema.fbs
def extract_builtin_operators(schema_fbs_file):
    builtin_operators = {}

    with open(schema_fbs_file, "r") as f:
        content = f.read()

        # Regex pattern to find enum definitions
        enum_pattern = r'enum BuiltinOperator(.*?)}'
        enums = re.findall(enum_pattern, content, re.DOTALL)

        # Extract enum values and names
        for enum in enums:
            enum_values = re.findall(r'(\w+)\s*=\s*(-?\d+)', enum)
            for enum_name, enum_value in enum_values:
                builtin_operators[int(enum_value)] = enum_name

    return builtin_operators

# Extract builtin operators
builtin_operators = extract_builtin_operators(schema_fbs_file)

# Print the dictionary
print("Builtin Operators Dictionary:")
for enum_value, enum_name in builtin_operators.items():
    print(f"{enum_value}: {enum_name}")

Builtin Operators Dictionary:
0: ADD
1: AVERAGE_POOL_2D
2: CONCATENATION
3: CONV_2D
4: DEPTHWISE_CONV_2D
5: DEPTH_TO_SPACE
6: DEQUANTIZE
7: EMBEDDING_LOOKUP
8: FLOOR
9: FULLY_CONNECTED
10: HASHTABLE_LOOKUP
11: L2_NORMALIZATION
12: L2_POOL_2D
13: LOCAL_RESPONSE_NORMALIZATION
14: LOGISTIC
15: LSH_PROJECTION
16: LSTM
17: MAX_POOL_2D
18: MUL
19: RELU
20: RELU_N1_TO_1
21: RELU6
22: RESHAPE
23: RESIZE_BILINEAR
24: RNN
25: SOFTMAX
26: SPACE_TO_DEPTH
27: SVDF
28: TANH
29: CONCAT_EMBEDDINGS
30: SKIP_GRAM
31: CALL
32: CUSTOM
33: EMBEDDING_LOOKUP_SPARSE
34: PAD
35: UNIDIRECTIONAL_SEQUENCE_RNN
36: GATHER
37: BATCH_TO_SPACE_ND
38: SPACE_TO_BATCH_ND
39: TRANSPOSE
40: MEAN
41: SUB
42: DIV
43: SQUEEZE
44: UNIDIRECTIONAL_SEQUENCE_LSTM
45: STRIDED_SLICE
46: BIDIRECTIONAL_SEQUENCE_RNN
47: EXP
48: TOPK_V2
49: SPLIT
50: LOG_SOFTMAX
51: DELEGATE
52: BIDIRECTIONAL_SEQUENCE_LSTM
53: CAST
54: PRELU
55: MAXIMUM
56: ARG_MAX
57: MINIMUM
58: LESS
59: NEG
60: PADV2
61: GREATER
62: GREATER_EQUAL
63: LESS_EQUAL
64: S

In [172]:
from tflite.Model import Model
from tflite.TensorType import TensorType

def extract_tensor_values(tensor, model):
    buffer_idx = tensor.Buffer()
    buffer = model.Buffers(buffer_idx)
    buffer_data = buffer.DataAsNumpy()
    tensor_type = tensor.Type()
    data_length = buffer.DataLength()

    # Convert buffer data to a NumPy array based on the datatype
    if tensor_type == TensorType.FLOAT32:
        buffer_data = np.empty((data_length,), dtype=np.float32)
        for i in range(data_length):
            buffer_data[i] = buffer.Data(i)
        tensor_values = np.frombuffer(buffer_data, dtype=np.float32)
    elif tensor_type == TensorType.UINT8:
        # print("uint8")
        buffer_data = np.empty((data_length,), dtype=np.uint8)
        for i in range(data_length):
            buffer_data[i] = buffer.Data(i)
        tensor_values = np.frombuffer(buffer_data, dtype=np.uint8)
    elif tensor_type == TensorType.INT8:
        # print("int8")
        buffer_data = np.empty((data_length,), dtype=np.uint8)
        for i in range(data_length):
            buffer_data[i] = buffer.Data(i)
        # print(buffer_data)
        tensor_values = np.frombuffer(buffer_data.tobytes(), dtype=np.int8)
    elif tensor_type == TensorType.INT32:
        # print("int32")
        buffer_data = np.empty((data_length,), dtype=np.uint8)
        for i in range(data_length):
            buffer_data[i] = buffer.Data(i)
        # print(buffer_data)
        tensor_values = np.frombuffer(buffer_data.tobytes(), dtype=np.int32)
    else:
        raise ValueError("Unsupported datatype")

    return tensor_values

def extract_quantization_params(tensor):
    tensor_metadata = tensor.Quantization()
    scale = tensor_metadata.ScaleAsNumpy()
    zero_point = tensor_metadata.ZeroPointAsNumpy()

    return {'scale': scale, 'zero_point': zero_point}

def extract_operator_params(subgraph, op, op_idx):
    op_name = builtin_operators[op_code]
    print(op_idx, op_name)
    # Extract inputs and outputs
    # print(op.InputsLength(), op.OutputsLength())
    input_tensors = [subgraph.Tensors(op.Inputs(j)) for j in range(0, op.InputsLength())]
    output_tensors = [subgraph.Tensors(op.Outputs(j)) for j in range(0, op.OutputsLength())]
    # Extract properties of input and output tensors
    for tensor in input_tensors + output_tensors:
        tensor_data = extract_tensor_values(tensor, model)
        # Print shape and data of the tensor
        print("shape:", tensor.ShapeAsNumpy())
        print("data:", tensor_data)
        quantization_params = extract_quantization_params(tensor)
        print("scale:", quantization_params['scale'])
        print("zero_point:", quantization_params['zero_point'])

def extract_model_params(model_file):
    # Parse the model's flatbuffers
    with open(model_file, "rb") as f:
        model_buf = f.read()

    model = Model.GetRootAsModel(model_buf, 0)

    # Get all operator codes
    op_codes = []
    for i in range(model.OperatorCodesLength()):
        op_codes.append(model.OperatorCodes(i).BuiltinCode())

    # Access operators and buffers
    for subgraph_idx in range(model.SubgraphsLength()):
        subgraph = model.Subgraphs(subgraph_idx)
        
        # Extract operator details
        for op_idx in range(subgraph.OperatorsLength()):
            op = subgraph.Operators(op_idx)
            op_code_idx = op.OpcodeIndex()
            op_code = model.OperatorCodes(op_code_idx).BuiltinCode()

            if op_code in op_codes:
                extract_operator_params(subgraph, op, op_idx)

0 QUANTIZE
shape: [  1 640 640   3]
data: []
scale: [0.00784314]
zero_point: [127]
shape: [  1 640 640   3]
data: []
scale: [0.00784314]
zero_point: [-1]
1 CONV_2D
shape: [  1 640 640   3]
data: []
scale: [0.00784314]
zero_point: [-1]
shape: [32  3  3  3]
data: [  32   12   13   94  105   46  -24  -28   -4   78   76   42  120  127
   54  -42  -48  -14  -52  -89  -38  -69 -113  -65  -50  -55  -31   96
  -31  -49  123  -17  -69   48  -24  -16  109   -9  -44  127   -5  -68
   49  -18  -17   14  -17    8   16  -21  -10  -16   -9   22   89  127
   10   76  116  -13   59   72  -27   45   81  -31   33   75  -48   39
   50  -35   -1   17  -43   -9   12  -55    2  -16  -35   66  -56  -50
   86  -70  -65   51  -41  -42   99  -69  -69  127  -79  -76   57  -71
  -71   74  -42  -39   95  -48  -36   44  -28  -32   68  -38 -115   74
  -33 -125   48  -43 -127   65  -11  -97   52  -23 -114   34  -26 -110
   77   13  -68   38  -24 -103   55    4  -73  -92  -29 -117 -110  -47
 -127  -21   32  -56 -105  -

In [None]:
if __name__ == "__main__":
    extract_model_params(tflite_model_quant_file)