In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class NormalizeInput(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x / (x.norm(dim=1, keepdim=True) + 1e-8)

In [3]:
model = nn.Sequential(
            NormalizeInput(),

            nn.Linear(36, 64),
            nn.ReLU6(),

            nn.Linear(64, 32),
            nn.LeakyReLU(0.1),

            nn.Linear(32, 2),
        )

In [4]:
with open("./MLP-ONLINE-PICO-1749673828.pth", "rb") as f:
    model.load_state_dict(torch.load(f))

In [5]:
import tensorflow as tf

model_tf = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation=tf.nn.relu6, name='linear1', input_shape=(36,)),
    tf.keras.layers.Dense(32, name='linear2'),
    tf.keras.layers.LeakyReLU(alpha=0.1),
    tf.keras.layers.Dense(2, name='linear3'),
])

# Triggers layer builds
with tf.device('/CPU:0'):
    model_tf(tf.zeros((1, 36)))

2025-06-12 00:20:07.383115: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749680407.402097   21091 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749680407.407067   21091 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1749680407.421921   21091 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749680407.421943   21091 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749680407.421945   21091 computation_placer.cc:177] computation placer alr

In [6]:
import torch
import numpy as np

state_dict = torch.load("MLP-ONLINE-PICO-1749673828.pth", map_location="cpu")

layer_names = ['linear1', 'linear2', 'linear3']
pt_layer_keys = ['1', '3', '5']

for tf_name, pt_idx in zip(layer_names, pt_layer_keys):
    W = state_dict[f'{pt_idx}.weight'].numpy().T  # [in, out] for TF
    b = state_dict[f'{pt_idx}.bias'].numpy()
    model_tf.get_layer(tf_name).set_weights([W, b])

model_tf.trainable = False

In [15]:
# Test if they are equivalent
x = np.random.rand(1, 36).astype(np.float32)
x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-8)

pt_output = model(torch.tensor(x)).detach().numpy()
with tf.device('/CPU:0'):
    tf_output = model_tf(x)

print("PyTorch output:", pt_output)
print("TensorFlow output:", tf_output)

# assert np.allclose(pt_output, tf_output, atol=1e-6), "Outputs are not close enough!"

PyTorch output: [[1663.2979 1690.6519]]
TensorFlow output: tf.Tensor([[1663.298 1690.652]], shape=(1, 2), dtype=float32)


In [8]:
def representative_dataset():
    for _ in range(1000):
        data = tf.random.uniform(shape=(1, 36), dtype=tf.float32)
        yield [data]

In [9]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_tf)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset

# Ensure full integer quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

tflite_model = converter.convert()

# Save
with open("model_int8.tflite", "wb") as f:
    f.write(tflite_model)

INFO:tensorflow:Assets written to: /tmp/tmp15yxf14e/assets


INFO:tensorflow:Assets written to: /tmp/tmp15yxf14e/assets


Saved artifact at '/tmp/tmp15yxf14e'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 36), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 2), dtype=tf.float32, name=None)
Captures:
  140272447568784: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140272447569744: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140272447564176: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140272447567440: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140272447568592: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140272447567632: TensorSpec(shape=(), dtype=tf.resource, name=None)


W0000 00:00:1749680411.501413   21091 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1749680411.501454   21091 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-06-12 00:20:11.501944: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmp15yxf14e
2025-06-12 00:20:11.502408: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-06-12 00:20:11.502417: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmp15yxf14e
I0000 00:00:1749680411.506818   21091 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-06-12 00:20:11.507455: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-06-12 00:20:11.527868: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmp15yxf14e
2025-06-12 00:20:11.534307: I tensorflow/cc/saved_model/loader.cc:471] SavedModel 

In [12]:
np.linalg.norm(np.array([    0.000000, 0.000000, 0.021984, 0.080698, 0.271321, 0.303353,
    0.000000, 0.000000, 0.029773, 0.124845, 0.451808, 0.644560,
    0.000000, 0.000000, 0.023496, 0.083553, 0.239094, 0.340424,
    0.000000, 0.000000, 0.000000, 0.034788, 0.063870, 0.069226,
    0.000000, 0.000000, 0.000000, 0.014055, 0.000000, 0.020920,
    0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000]))

np.float64(0.9999994879548689)