In [6]:
import tensorflow as tf

model = tf.keras.models.load_model("receiver_model_UMi.h5", compile=False)

print(model.summary())

Model: "DeepRx_real"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 Y_real_imag (InputLayer)    [(None, 14, 96, 2)]          0         []                            
                                                                                                  
 Hr_real_imag (InputLayer)   [(None, 14, 96, 2)]          0         []                            
                                                                                                  
 concatenate (Concatenate)   (None, 14, 96, 4)            0         ['Y_real_imag[0][0]',         
                                                                     'Hr_real_imag[0][0]']        
                                                                                                  
 conv_in (Conv2D)            (None, 14, 96, 64)           2368      ['concatenate[0][0]'

In [10]:
for i, inp in enumerate(model.inputs):
    print(f"Input {i}: name={inp.name}, shape={inp.shape}")


Input 0: name=Y_real_imag, shape=(None, 14, 96, 2)
Input 1: name=Hr_real_imag, shape=(None, 14, 96, 2)


In [11]:
import tensorflow as tf
import tf2onnx

# Load Keras model (without compiling since we only care about inference)
model = tf.keras.models.load_model("receiver_model_UMi.h5", compile=False)

# Check the model inputs
print("Model inputs:", model.inputs)

# Create TensorSpecs for *both* inputs
spec = (
    tf.TensorSpec(model.inputs[0].shape, tf.float32, name=model.inputs[0].name.split(":")[0]),
    tf.TensorSpec(model.inputs[1].shape, tf.float32, name=model.inputs[1].name.split(":")[0]),
)

# Convert to ONNX
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13)

# Save ONNX file
with open("receiver_model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

print("✅ Conversion complete: receiver_model.onnx")


Model inputs: [<KerasTensor: shape=(None, 14, 96, 2) dtype=float32 (created by layer 'Y_real_imag')>, <KerasTensor: shape=(None, 14, 96, 2) dtype=float32 (created by layer 'Hr_real_imag')>]


2025-09-27 11:18:23.043778: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 2
2025-09-27 11:18:23.043866: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2025-09-27 11:18:23.193716: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2025-09-27 11:18:23.328115: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 2
2025-09-27 11:18:23.328201: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2025-09-27 11:18:23.328617: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GP

✅ Conversion complete: receiver_model.onnx


In [12]:
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

# Builder + network + parser
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)

# Load ONNX model
with open("receiver_model.onnx", "rb") as f:
    if not parser.parse(f.read()):
        for i in range(parser.num_errors):
            print(parser.get_error(i))
        raise RuntimeError("Failed to parse ONNX model")

# Builder config
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)  # optional

# Create optimization profile for dynamic input shapes
profile = builder.create_optimization_profile()
for i in range(network.num_inputs):
    inp = network.get_input(i)
    # Replace these with your actual min/opt/max shapes
    profile.set_shape(inp.name, min=(1, 14, 96, 2),
                                   opt=(1, 14, 96, 2),
                                   max=(8, 14, 96, 2))

config.add_optimization_profile(profile)

# Build engine
engine_bytes = builder.build_serialized_network(network, config)
if engine_bytes is None:
    raise RuntimeError("Engine build failed. Check input shapes or ONNX opset.")

# Save TRT engine
with open("receiver_model.trt", "wb") as f:
    f.write(engine_bytes)

print("✅ TRT engine created successfully")


✅ TRT engine created successfully
