# Python Benchmark: Fusion vs Non-Fusion

In [1]:
import tensorflow as tf
import time

In [3]:
# Inputs
x = tf.random.normal([1024, 1024], dtype=tf.float32)
scale = tf.random.normal([1024], dtype=tf.float32)
bias = tf.random.normal([1024], dtype=tf.float32)

# Non-fused version
@tf.function
def non_fused(x, scale, bias):
    x = tf.multiply(x, scale)
    x = tf.add(x, bias)
    return tf.nn.relu(x)

# Fused version (using XLA)
@tf.function(jit_compile=True)
def fused(x, scale, bias):
    return tf.nn.relu(tf.multiply(x, scale) + bias)

# Warm-up
# non_fused(x, scale, bias)
# fused(x, scale, bias)

# Benchmark
import time

start = time.time()
for _ in range(100):
    non_fused(x, scale, bias)
non_fused_time = time.time() - start

start = time.time()
for _ in range(100):
    fused(x, scale, bias)
fused_time = time.time() - start

print(f"Non-Fused Time: {non_fused_time:.4f}s")
print(f"Fused (XLA) Time: {fused_time:.4f}s")
print(f"Speedup: {non_fused_time / fused_time:.2f}×")


Non-Fused Time: 0.1240s
Fused (XLA) Time: 0.1039s
Speedup: 1.19×
