In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class NormalizeInput(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x / (x.norm(dim=1, keepdim=True) + 1e-8)

In [3]:
model = nn.Sequential(
            NormalizeInput(),

            nn.Linear(36, 64),
            nn.ReLU6(),

            nn.Linear(64, 32),
            nn.LeakyReLU(0.1),

            nn.Linear(32, 2),
        )

In [4]:
with open("./MLP-ONLINE-PICO-1749415744.pth", "rb") as f:
    model.load_state_dict(torch.load(f))

In [5]:
import tensorflow as tf

model_tf = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation=tf.nn.relu6, name='linear1', input_shape=(36,)),
    tf.keras.layers.Dense(32, name='linear2'),
    tf.keras.layers.LeakyReLU(alpha=0.1),
    tf.keras.layers.Dense(2, name='linear3'),
])

# Triggers layer builds
with tf.device('/CPU:0'):
    model_tf(tf.zeros((1, 36)))

2025-06-09 00:22:39.897564: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749421359.966873   36997 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749421359.980938   36997 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1749421360.088846   36997 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749421360.088974   36997 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749421360.088977   36997 computation_placer.cc:177] computation placer alr

In [6]:
import torch
import numpy as np

state_dict = torch.load("MLP-ONLINE-PICO-1749415744.pth", map_location="cpu")

layer_names = ['linear1', 'linear2', 'linear3']
pt_layer_keys = ['1', '3', '5']

for tf_name, pt_idx in zip(layer_names, pt_layer_keys):
    W = state_dict[f'{pt_idx}.weight'].numpy().T  # [in, out] for TF
    b = state_dict[f'{pt_idx}.bias'].numpy()
    model_tf.get_layer(tf_name).set_weights([W, b])

model_tf.trainable = False

In [7]:
# Test if they are equivalent
x = np.random.rand(1, 36).astype(np.float32)
x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-8)

pt_output = model(torch.tensor(x)).detach().numpy()
with tf.device('/CPU:0'):
    tf_output = model_tf(x)

print("PyTorch output:", pt_output)
print("TensorFlow output:", tf_output)

# assert np.allclose(pt_output, tf_output, atol=1e-6), "Outputs are not close enough!"

PyTorch output: [[1480.7955 1873.3892]]
TensorFlow output: tf.Tensor([[1480.7955 1873.3893]], shape=(1, 2), dtype=float32)


In [8]:
def representative_dataset():
    for _ in range(100):
        data = tf.random.uniform(shape=(1, 36), dtype=tf.float32)
        yield [data]

In [9]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_tf)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset

# Ensure full integer quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

tflite_model = converter.convert()

# Save
with open("model_int8.tflite", "wb") as f:
    f.write(tflite_model)

INFO:tensorflow:Assets written to: /tmp/tmptw3cy93v/assets


INFO:tensorflow:Assets written to: /tmp/tmptw3cy93v/assets


Saved artifact at '/tmp/tmptw3cy93v'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 36), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 2), dtype=tf.float32, name=None)
Captures:
  140037757143760: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140037757144912: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140037757143568: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140037757142416: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140037757145488: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140037757142608: TensorSpec(shape=(), dtype=tf.resource, name=None)


W0000 00:00:1749421367.156449   36997 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1749421367.156491   36997 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-06-09 00:22:47.157244: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmptw3cy93v
2025-06-09 00:22:47.162733: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-06-09 00:22:47.162751: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmptw3cy93v
I0000 00:00:1749421367.178296   36997 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-06-09 00:22:47.178994: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-06-09 00:22:47.220788: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmptw3cy93v
2025-06-09 00:22:47.228907: I tensorflow/cc/saved_model/loader.cc:471] SavedModel 