# Python Benchmark: Fusion vs Non-Fusion

In [1]:
import tensorflow as tf
import time

In [2]:
# Inputs
x = tf.random.normal([1024, 1024], dtype=tf.float32)
scale = tf.random.normal([1024], dtype=tf.float32)
bias = tf.random.normal([1024], dtype=tf.float32)

# Non-fused version
@tf.function
def non_fused(x, scale, bias):
    x = tf.multiply(x, scale)
    x = tf.add(x, bias)
    return tf.nn.relu(x)

# Fused version (using XLA)
@tf.function(jit_compile=True)
def fused(x, scale, bias):
    return tf.nn.relu(tf.multiply(x, scale) + bias)

# Warm-up
non_fused(x, scale, bias)
fused(x, scale, bias)

# Benchmark
import time

start = time.time()
for _ in range(100):
    non_fused(x, scale, bias)
non_fused_time = time.time() - start

start = time.time()
for _ in range(100):
    fused(x, scale, bias)
fused_time = time.time() - start

print(f"Non-Fused Time: {non_fused_time:.4f}s")
print(f"Fused (XLA) Time: {fused_time:.4f}s")
print(f"Speedup: {non_fused_time / fused_time:.2f}×")


Non-Fused Time: 0.0895s
Fused (XLA) Time: 0.0383s
Speedup: 2.34×


# Fusion Visualization & Confirmation

In [3]:
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
# import tensorflow as tf

In [4]:
# import tensorflow as tf

@tf.function(jit_compile=True)  # Enables XLA operator fusion
def fused_conv_swish(x, weights, bias):
    x = tf.nn.conv2d(x, weights, strides=1, padding='SAME')
    x = tf.nn.bias_add(x, bias)
    return x * tf.nn.sigmoid(x)  # Swish


In [5]:
# Convert model to graph
f = tf.function(fused_conv_swish).get_concrete_function(
    tf.TensorSpec(shape=[1, 64, 64, 3], dtype=tf.float32),
    tf.TensorSpec(shape=[3, 3, 3, 64], dtype=tf.float32),
    tf.TensorSpec(shape=[64], dtype=tf.float32)
)
graph_def = convert_variables_to_constants_v2(f).graph.as_graph_def()

# Print fused operations
for node in graph_def.node:
    print(node.op)

Placeholder
Placeholder
Placeholder
PartitionedCall
Identity
