In [1]:
!pip install torch
!pip install tabulate numpy pytest torchvision



In [2]:
import torch

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.version.cuda)

True
1
12.4


In [3]:
import torch
from torch.fx import symbolic_trace, GraphModule
import operator


class SqrtAssociativePass:
    """Implements the rewrite pattern: (A⊙√B)⊙(√B⊙C) => A⊙B⊙C"""

    def __init__(self):
        self.mul_patterns = {operator.mul, torch.mul, "mul"}

    def __call__(self, module):
        traced = symbolic_trace(module)
        graph = traced.graph

        # Keep track of nodes to delete
        nodes_to_delete = set()

        # Find all multiplication nodes
        for node in graph.nodes:
            is_mul_fn = node.op == "call_function" and node.target in self.mul_patterns
            is_mul_mth = node.op == "call_method" and node.target in (
                "mul", "__mul__")
            if not ((is_mul_fn or is_mul_mth) and len(node.args) == 2):
                continue

            # Check if this is a multiplication of two terms
            if len(node.args) != 2:
                continue

            lhs, rhs = node.args[0], node.args[1]

            # Check if lhs is a multiplication with sqrt (function or method)
            if isinstance(lhs, torch.fx.Node):
                is_lhs_mul_fn = lhs.op == "call_function" and lhs.target in self.mul_patterns
                is_lhs_mul_mth = lhs.op == "call_method" and lhs.target in (
                    "mul", "__mul__")
                if not ((is_lhs_mul_fn or is_lhs_mul_mth) and len(lhs.args) == 2):
                    continue

                lhs_lhs, lhs_rhs = lhs.args[0], lhs.args[1]

                # Check if lhs_rhs is a sqrt (function or method)
                is_sqrt_fn = lhs_rhs.op == "call_function" and lhs_rhs.target == torch.sqrt
                is_sqrt_mth = lhs_rhs.op == "call_method" and lhs_rhs.target == "sqrt"
                if not (is_sqrt_fn or is_sqrt_mth):
                    continue

                sqrt_arg = lhs_rhs.args[0]

                # Check if rhs is a multiplication with sqrt (fn or method)
                is_rhs_mul_fn = isinstance(
                    rhs, torch.fx.Node) and rhs.op == "call_function" and rhs.target in self.mul_patterns
                is_rhs_mul_mth = isinstance(
                    rhs, torch.fx.Node) and rhs.op == "call_method" and rhs.target in ("mul", "__mul__")
                if not ((is_rhs_mul_fn or is_rhs_mul_mth) and len(rhs.args) == 2):
                    continue

                rhs_lhs, rhs_rhs = rhs.args[0], rhs.args[1]

                # Check if rhs_lhs is the same sqrt (function or method)
                is_sqrt2_fn = rhs_lhs.op == "call_function" and rhs_lhs.target == torch.sqrt
                is_sqrt2_mth = rhs_lhs.op == "call_method" and rhs_lhs.target == "sqrt"
                if not ((is_sqrt2_fn or is_sqrt2_mth) and rhs_lhs.args[0] is sqrt_arg):
                    continue

                # Create new multiplication chain: A⊙B⊙C
                with graph.inserting_before(node):
                    # First multiply A and B
                    ab = graph.call_function(
                        torch.mul, args=(lhs_lhs, sqrt_arg), kwargs=node.kwargs)
                    # Then multiply with C
                    new_node = graph.call_function(
                        torch.mul, args=(ab, rhs_rhs), kwargs=node.kwargs)

                    # Replace the original node with the new one
                    node.replace_all_uses_with(new_node)

                    # Add nodes to delete set
                    nodes_to_delete.update(
                        [node, lhs, rhs, lhs_rhs, rhs_lhs])

        # Delete nodes in reverse order to avoid dependency issues
        for node in reversed(list(graph.nodes)):
            if node in nodes_to_delete:
                try:
                    graph.erase_node(node)
                except Exception as e:
                    # Skip if node is already deleted
                    pass

        graph.lint()
        new_mod = GraphModule(traced, graph)
        new_mod.recompile()
        return new_mod


# Test code
def test_sqrt_associative():
    class TestModule(torch.nn.Module):
        def forward(self, a, b, c):
            # Implement the (A⊙√B)⊙(√B⊙C) pattern
            sqrt_b = torch.sqrt(b)
            return (a * sqrt_b) * (sqrt_b * c)

    # Create test inputs
    a = torch.randn(2, 3) + 1  # Add 1 to ensure positive values
    b = torch.randn(2, 3) + 1  # Add 1 to ensure positive values
    c = torch.randn(2, 3)

    # Original module
    original_module = TestModule()
    original_output = original_module(a, b, c)

    # Apply optimization
    try:
        optimized_module = SqrtAssociativePass()(TestModule())
        optimized_output = optimized_module(a, b, c)

        # Manually compute expected optimized result
        expected_output = a * b * c

        # Verify results
        print("Original output:", original_output)
        print("Optimized output:", optimized_output)
        print("Expected output:", expected_output)
        print("Optimized matches original:", torch.allclose(
            original_output, optimized_output))
        print("Optimized matches expected:", torch.allclose(
            optimized_output, expected_output))

        # Print optimized graph
        print("\nOptimized graph:")
        optimized_module.graph.print_tabular()
    except Exception as e:
        print(f"Error during optimization: {e}")
        # Print original graph to assist debugging
        print("\nOriginal graph:")
        original_module = symbolic_trace(TestModule())
        original_module.graph.print_tabular()
        raise


if __name__ == "__main__":
    test_sqrt_associative()


Original output: tensor([[ 1.0714,     nan, -0.0535],
        [ 0.0869,  0.0046, -0.6433]])
Optimized output: tensor([[ 1.0714,  0.0604, -0.0535],
        [ 0.0869,  0.0046, -0.6433]])
Expected output: tensor([[ 1.0714,  0.0604, -0.0535],
        [ 0.0869,  0.0046, -0.6433]])
Optimized matches original: False
Optimized matches expected: True

Optimized graph:
opcode         name    target                                                  args        kwargs
-------------  ------  ------------------------------------------------------  ----------  --------
placeholder    a       a                                                       ()          {}
placeholder    b       b                                                       ()          {}
placeholder    c       c                                                       ()          {}
call_function  mul_3   <built-in method mul of type object at 0x7f235161ff00>  (a, b)      {}
call_function  mul_4   <built-in method mul of type object at 0x

In [4]:
import torch
import operator
from torch.fx import symbolic_trace


def swap_bitshift_reducesum(gm: torch.fx.GraphModule):
    """Rewrite Commutative: ReduceSum(BitShift(x)) → BitShift(ReduceSum(x)) bit shift."""
    BITSHIFT_FNS = {torch.bitwise_left_shift, operator.lshift}
    # Only left shift is included as it distributes over sum (e.g., sum(x << k) == sum(x) << k).
    # Right shift does NOT distribute over addition and can lead to incorrect rewrites.
    modified = False

    for node in list(gm.graph.nodes):
        is_sum = (
            (node.op == "call_method" and node.target == "sum") or
            (node.op == "call_function" and node.target in (torch.sum, sum))
        )

        if not is_sum or len(node.args) == 0:
            continue

        input_node = node.args[0]
        is_shift = (
            (input_node.op == "call_function" and input_node.target in BITSHIFT_FNS) or
            (input_node.op == "call_method" and input_node.target == "__lshift__")
        )
        if not is_shift:
            continue

        x, shift = input_node.args

        with gm.graph.inserting_before(node):
            new_sum = gm.graph.call_function(
                torch.sum,
                args=(x,),
                kwargs=node.kwargs
            )

            new_bitshift = gm.graph.call_function(
                torch.bitwise_left_shift,
                args=(new_sum, shift),
                kwargs=input_node.kwargs
            )

            node.replace_all_uses_with(new_bitshift)
            # Erase old nodes
            gm.graph.erase_node(node)
            gm.graph.erase_node(input_node)
            modified = True

    if modified:
        # Clean up any dead nodes and rebuild the module
        gm.graph.lint()
        new_gm = torch.fx.GraphModule(gm, gm.graph)
        new_gm.recompile()
        return new_gm
    else:
        return gm


def run_bitshift_swap():
    class BitShiftSwap(torch.nn.Module):
        def forward(self, x):
            # Original subgraph: ReduceSum(BitShift(A))
            t = torch.bitwise_left_shift(x, 1)  # BitShift(x,1)
            out = t.sum()                       # ReduceSum(t)
            return out

    gm = symbolic_trace(BitShiftSwap().eval())
    print("=== FX before rewrite ===")
    print(gm.graph)

    gm = swap_bitshift_reducesum(gm)
    print("=== FX after rewrite ===")
    print(gm.graph)

    compiled_gm = torch.compile(gm, backend="inductor")

    x = torch.randint(0, 16, (4, 4), dtype=torch.int32)

    out_uncompiled = gm(x)
    out = compiled_gm(x)

    print("input:\n", x)
    print("Uncompiled output:", out_uncompiled)
    print("Compiled output:", out)

    assert torch.equal(out_uncompiled, out), "Outputs do not match!"


if __name__ == "__main__":
    run_bitshift_swap()


=== FX before rewrite ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %bitwise_left_shift : [num_users=1] = call_function[target=torch.bitwise_left_shift](args = (%x, 1), kwargs = {})
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%bitwise_left_shift,), kwargs = {})
    return sum_1
=== FX after rewrite ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %sum_2 : [num_users=1] = call_function[target=torch.sum](args = (%x,), kwargs = {})
    %bitwise_left_shift_1 : [num_users=1] = call_function[target=torch.bitwise_left_shift](args = (%sum_2, 1), kwargs = {})
    return bitwise_left_shift_1
input:
 tensor([[ 1, 14,  6, 10],
        [ 0,  7, 13, 10],
        [12, 15, 11, 12],
        [ 8, 13,  8, 13]], dtype=torch.int32)
Uncompiled output: tensor(306)
Compiled output: tensor(306)


In [5]:
import torch
from torch.fx import symbolic_trace, GraphModule
import operator


class CommutativePass:
    """Implements the rewrite pattern: ReduceProd(Exp(A)) => Exp(ReduceSum(A))"""

    def __init__(self):
        self.reduce_prod_patterns = {torch.prod, "prod"}
        self.exp_patterns = {torch.exp, "exp"}

    def __call__(self, module):
        traced = symbolic_trace(module)
        graph = traced.graph

        # Keep track of nodes to delete
        nodes_to_delete = set()

        # Find all prod nodes (function and method)
        for node in graph.nodes:
            is_prod = (
                (node.op == "call_function" and node.target in self.reduce_prod_patterns) or (
                    node.op == "call_method" and node.target == "prod")
            )
            if not is_prod or len(node.args) < 1:
                continue

            # Get the argument of reduce_prod
            arg = node.args[0]

            # Check if the argument is an exp operation
            if not (
                isinstance(arg, torch.fx.Node) and (
                    (arg.op == "call_function" and arg.target in self.exp_patterns)
                    or (arg.op == "call_method" and arg.target == "exp")
                )
            ):
                continue

            # Get the argument of exp
            exp_arg = arg.args[0]

            # Extract only the reduction dimensions from prod
            reduce_kwargs = {}
            for key in ('dim', 'keepdim', 'dtype'):
                if key in node.kwargs:
                    reduce_kwargs[key] = node.kwargs[key]

            # Create new operations: Exp(ReduceSum(A))
            with graph.inserting_before(node):
                # Sum only over the same dimensions as the original prod
                reduce_sum = graph.call_function(
                    torch.sum, args=(exp_arg,), kwargs=reduce_kwargs)
                # Then compute Exp(ReduceSum(A))
                new_node = graph.call_function(
                    torch.exp, args=(reduce_sum,))

                # Replace the original node with the new one
                node.replace_all_uses_with(new_node)

                # Add nodes to delete set
                nodes_to_delete.update([node, arg])

        # Delete nodes in reverse order to avoid dependency issues
        for node in reversed(list(graph.nodes)):
            if node in nodes_to_delete:
                try:
                    graph.erase_node(node)
                except Exception as e:
                    # Skip if node is already deleted
                    pass

        graph.lint()
        new_module = GraphModule(traced, graph)
        new_module.recompile()
        return new_module


# Test code
def test_commutative():
    class TestModule(torch.nn.Module):
        def forward(self, x):
            # Implement the ReduceProd(Exp(A)) pattern
            exp_x = torch.exp(x)
            return torch.prod(exp_x)

    # Create test inputs
    x = torch.randn(2, 3)

    # Original module
    original_module = TestModule()
    original_output = original_module(x)

    # Apply optimization
    try:
        optimized_module = CommutativePass()(TestModule())
        optimized_output = optimized_module(x)

        # Manually compute expected optimized result
        expected_output = torch.exp(torch.sum(x))

        # Verify results
        print("Input x:", x)
        print("Original output:", original_output)
        print("Optimized output:", optimized_output)
        print("Expected output:", expected_output)
        print("Optimized matches original:", torch.allclose(
            original_output, expected_output))
        print("Optimized matches expected:", torch.allclose(
            optimized_output, expected_output))

        # Print optimized graph
        print("\nOptimized graph:")
        optimized_module.graph.print_tabular()
    except Exception as e:
        print(f"Error during optimization: {e}")
        # Print original graph to assist debugging
        print("\nOriginal graph:")
        original_module = symbolic_trace(TestModule())
        original_module.graph.print_tabular()
        raise


if __name__ == "__main__":
    test_commutative()


Input x: tensor([[ 0.6909, -0.7713, -0.5386],
        [-0.7573, -0.6858, -0.4379]])
Original output: tensor(0.0821)
Optimized output: tensor(0.0821)
Expected output: tensor(0.0821)
Optimized matches original: True
Optimized matches expected: True

Optimized graph:
opcode         name    target                                                  args      kwargs
-------------  ------  ------------------------------------------------------  --------  --------
placeholder    x       x                                                       ()        {}
call_function  sum_1   <built-in method sum of type object at 0x7f235161ff00>  (x,)      {}
call_function  exp_1   <built-in method exp of type object at 0x7f235161ff00>  (sum_1,)  {}
output         output  output                                                  (exp_1,)  {}


In [6]:
import torch
from torch.fx import symbolic_trace, GraphModule
import operator
import pytest


class DistributiveRulePass:
    """Implement distributive rule: A ⊙ C + A ⊙ B → A ⊙ (B + C)"""

    def __init__(self):
        # Possible representations for addition and multiplication
        self.add_patterns = {operator.add, torch.add, "add", "__add__"}
        self.mul_patterns = {operator.mul, torch.mul, "mul", "__mul__"}

    def _is_add_node(self, node):
        if not isinstance(node, torch.fx.Node):
            return False
        return (node.op == "call_function" and node.target in self.add_patterns) or \
               (node.op == "call_method" and node.target in self.add_patterns)

    def _is_mul_node(self, node):
        if not isinstance(node, torch.fx.Node):
            return False
        return (node.op == "call_function" and node.target in self.mul_patterns) or \
            (node.op == "call_method" and node.target in self.mul_patterns)

    def __call__(self, module):
        # Symbolically trace the module
        traced = symbolic_trace(module)
        graph = traced.graph

        # First collect all matching “A*C + A*B” nodes
        matches = []
        for node in graph.nodes:
            if not self._is_add_node(node) or len(node.args) != 2:
                continue

            lhs, rhs = node.args[0], node.args[1]

            # Both sides must be multiplication nodes
            if not (self._is_mul_node(lhs) and self._is_mul_node(rhs)):
                continue
            if len(lhs.args) != 2 or len(rhs.args) != 2:
                continue

            # Find common factor in any operand position
            # A*C + A*B  or  C*A + B*A  or  A*C + B*A  or  C*A + A*B
            A, B, C = None, None, None
            for potential_A in lhs.args:
                if potential_A in rhs.args:
                    A = potential_A
                    # Get remaining terms from both sides
                    C = lhs.args[0] if lhs.args[1] is A else lhs.args[1]
                    B = rhs.args[0] if rhs.args[1] is A else rhs.args[1]
                    break
            if A is None:
                continue

            matches.append((node, A, B, C, lhs, rhs))

        # Apply each matched transformation, then erase the old nodes
        nodes_to_delete = set()
        for add_node, A, B, C, lhs_node, rhs_node in matches:
            # A*(B+C)
            with graph.inserting_before(add_node):
                sum_bc = graph.call_function(torch.add, args=(B, C))
                fused = graph.call_function(torch.mul, args=(A, sum_bc))
                add_node.replace_all_uses_with(fused)
                nodes_to_delete.update({add_node, lhs_node, rhs_node})

        # Erase dead nodes
        for node in reversed(list(graph.nodes)):
            if node in nodes_to_delete and not node.users:
                graph.erase_node(node)

        graph.lint()

        new_mod = GraphModule(traced, graph)
        new_mod.recompile()
        return new_mod

# Test code


class TestModule(torch.nn.Module):
    def __init__(self, variation):
        super().__init__()
        self.variation = variation

    def forward(self, a, b, c):
        # Implement the A ⊙ C + A ⊙ B pattern here
        if self.variation == 0:
            return (c * a) + (b * a)
        elif self.variation == 1:
            return (a * c) + (b * a)
        else:
            return (c * a) + (a * b)


@pytest.mark.parametrize("variation", [0, 1, 2])
def test_distributive_rule(variation):
    # Create test inputs
    a = torch.randn(2, 3)
    b = torch.randn(2, 3)
    c = torch.randn(2, 3)

    # Original module
    model = TestModule(variation).eval()
    original_output = model(a, b, c)

    # Apply optimization
    try:
        optimized_module = DistributiveRulePass()(model)
        optimized_output = optimized_module(a, b, c)

        # Verify results
        print("Original output:", original_output)
        print("Optimized output:", optimized_output)
        print("Optimized matches original:", torch.allclose(
            original_output, optimized_output))

        # Print optimized graph
        print("\nOptimized graph:")
        optimized_module.graph.print_tabular()
    except Exception as e:
        print(f"Error during optimization: {e}")
        # Print original graph to assist debugging
        print("\nOriginal graph:")
        original_module = symbolic_trace(TestModule())
        original_module.graph.print_tabular()
        raise


if __name__ == "__main__":
    for v in [0, 1, 2]:
        test_distributive_rule(v)


Original output: tensor([[ 0.4573, -0.6412,  0.2246],
        [-0.2847, -1.4171,  0.0560]])
Optimized output: tensor([[ 0.4573, -0.6412,  0.2246],
        [-0.2847, -1.4171,  0.0560]])
Optimized matches original: True

Optimized graph:
opcode         name    target                                                  args        kwargs
-------------  ------  ------------------------------------------------------  ----------  --------
placeholder    a       a                                                       ()          {}
placeholder    b       b                                                       ()          {}
placeholder    c       c                                                       ()          {}
call_function  add_1   <built-in method add of type object at 0x7f235161ff00>  (b, c)      {}
call_function  mul_2   <built-in method mul of type object at 0x7f235161ff00>  (a, add_1)  {}
output         output  output                                                  (mul_2,)    {}
Or

In [7]:
import torch
from torch.fx import symbolic_trace, GraphModule
import operator
import pytest


class DistributiveRule2Pass:
    """Implement distributive rule 2: A + A ⊙ B → A ⊙ (B + 1)"""

    def __init__(self):
        # Possible representations for addition and multiplication
        self.add_patterns = {operator.add, torch.add, "add", "__add__"}
        self.mul_patterns = {operator.mul, torch.mul, "mul", "__mul__"}

    def _is_add_node(self, node):
        # Check if node is a torch.fx.Node first
        if not isinstance(node, torch.fx.Node):
            return False
        return (node.op == "call_function" and node.target in self.add_patterns) or \
               (node.op == "call_method" and node.target in self.add_patterns)

    def _is_mul_node(self, node):
        # Check if node is a torch.fx.Node first
        if not isinstance(node, torch.fx.Node):
            return False
        return (node.op == "call_function" and node.target in self.mul_patterns) or \
               (node.op == "call_method" and node.target in self.mul_patterns)

    def __call__(self, module):
        # Symbolically trace the module
        traced = symbolic_trace(module)
        graph = traced.graph

        # First collect all matching "A + A*B" nodes
        matches = []
        for node in graph.nodes:
            if not self._is_add_node(node) or len(node.args) != 2:
                continue

            lhs, rhs = node.args[0], node.args[1]

            # Check pattern: A + (A * B)
            if self._is_mul_node(rhs) and len(rhs.args) == 2:
                if lhs is rhs.args[0]:  # A + (A * B)
                    matches.append((node, lhs, rhs, rhs.args[1]))
                elif lhs is rhs.args[1]:  # A + (B * A)
                    matches.append((node, lhs, rhs, rhs.args[0]))

            # Check pattern: (A * B) + A
            if self._is_mul_node(lhs) and len(lhs.args) == 2:
                if rhs is lhs.args[0]:  # (A * B) + A
                    matches.append((node, rhs, lhs, lhs.args[1]))
                elif rhs is lhs.args[1]:  # (B * A) + A
                    matches.append((node, rhs, lhs, lhs.args[0]))

        # Apply each matched transformation, then erase the old nodes
        nodes_to_delete = set()
        for add_node, a_term, mul_node, b_term in matches:
            # A*(B+1)
            with graph.inserting_before(add_node):
                # First create a tensor of ones matching B's shape
                one_node = graph.call_function(torch.ones_like, args=(b_term,))

                # Then compute B + 1
                b_plus_one = graph.call_function(torch.add, args=(b_term, one_node))

                # Finally compute A * (B + 1)
                fused = graph.call_function(torch.mul, args=(a_term, b_plus_one))

                add_node.replace_all_uses_with(fused)
                nodes_to_delete.update({add_node, mul_node})

        # Erase dead nodes
        for node in reversed(list(graph.nodes)):
            if node in nodes_to_delete and not node.users:
                graph.erase_node(node)

        graph.lint()

        new_mod = GraphModule(traced, graph)
        new_mod.recompile()
        return new_mod


# Test code
class TestModule(torch.nn.Module):
    def __init__(self, variation):
        super().__init__()
        self.variation = variation

    def forward(self, a, b):
        # Implement the A + A*B pattern in different variations
        if self.variation == 0:
            return a + a * b  # A + (A * B)
        elif self.variation == 1:
            return a + b * a  # A + (B * A)
        elif self.variation == 2:
            return a * b + a  # (A * B) + A
        else:
            return b * a + a  # (B * A) + A


@pytest.mark.parametrize("variation", [0, 1, 2, 3])
def test_distributive_rule2(variation):
    # Create test inputs
    a = torch.randn(2, 3)
    b = torch.randn(2, 3)

    # Original module
    model = TestModule(variation).eval()
    original_output = model(a, b)

    # Apply optimization
    try:
        optimized_module = DistributiveRule2Pass()(model)
        optimized_output = optimized_module(a, b)

        # Manually compute expected optimized result
        expected_output = a * (b + 1)

        # Verify results
        print(f"Testing variation {variation}")
        print("Original output:", original_output)
        print("Optimized output:", optimized_output)
        print("Expected output:", expected_output)
        print("Optimized matches original:", torch.allclose(original_output, optimized_output))
        print("Optimized matches expected:", torch.allclose(optimized_output, expected_output))

        # Print optimized graph
        print("\nOptimized graph:")
        optimized_module.graph.print_tabular()
    except Exception as e:
        print(f"Error during optimization: {e}")
        # Print original graph to assist debugging
        print("\nOriginal graph:")
        traced = symbolic_trace(model)
        traced.graph.print_tabular()
        raise


if __name__ == "__main__":
    for v in range(4):
        test_distributive_rule2(v)
        print("\n" + "="*50 + "\n")

Testing variation 0
Original output: tensor([[-5.9558,  1.2854, -0.1099],
        [ 0.3256,  0.3261,  1.8915]])
Optimized output: tensor([[-5.9558,  1.2854, -0.1099],
        [ 0.3256,  0.3261,  1.8915]])
Expected output: tensor([[-5.9558,  1.2854, -0.1099],
        [ 0.3256,  0.3261,  1.8915]])
Optimized matches original: True
Optimized matches expected: True

Optimized graph:
opcode         name       target                                                        args            kwargs
-------------  ---------  ------------------------------------------------------------  --------------  --------
placeholder    a          a                                                             ()              {}
placeholder    b          b                                                             ()              {}
call_function  ones_like  <built-in method ones_like of type object at 0x7f235161ff00>  (b,)            {}
call_function  add_1      <built-in method add of type object at 0x7f23516

In [8]:
import torch
import operator
from torch.fx import GraphModule, symbolic_trace


def swap_recip_associative(gm: GraphModule):
    """Rewrite Recip(A)*Recip(A*B) → square(Recip(A)) * Recip(B)."""
    nodes_to_delete = set()
    modified = False

    for node in list(gm.graph.nodes):
        # 1) match outer multiplication
        if node.op != "call_function" or node.target not in (torch.mul, operator.mul):
            continue
        lhs, rhs = node.args
        if not (is_reciprocal(lhs) and is_reciprocal(rhs)):
            continue

        # 2) figure out which side is Recip(A) vs Recip(A*B)
        a_node = b_node = recip_a_node = mul_node = None
        arg1, arg2 = get_recip_arg(lhs), get_recip_arg(rhs)

        # case A*B on rhs
        if (isinstance(arg2, torch.fx.Node)
                and arg2.op == "call_function"
                and arg2.target in (torch.mul, operator.mul)
                and arg1 in arg2.args):
            recip_a_node = lhs
            mul_node = arg2
            a_node = arg1
            b_node = arg2.args[1] if arg2.args[0] is arg1 else arg2.args[0]

        # case A*B on lhs
        elif (isinstance(arg1, torch.fx.Node)
              and arg1.op == "call_function"
              and arg1.target in (torch.mul, operator.mul)
              and arg2 in arg1.args):
            recip_a_node = rhs
            mul_node = arg1
            a_node = arg2
            b_node = arg1.args[1] if arg1.args[0] is arg2 else arg1.args[0]
        else:
            continue  # no match

        # 3) insert the new ops before the old `node`
        with gm.graph.inserting_before(node):
            square_recip = gm.graph.call_function(
                torch.square, args=(recip_a_node,))
            recip_b = gm.graph.call_function(
                torch.reciprocal, args=(b_node,))
            new_mul = gm.graph.call_function(
                torch.mul, args=(square_recip, recip_b))

        # 4) redirect uses, mark old nodes for deletion
        node.replace_all_uses_with(new_mul)
        nodes_to_delete.update({node, mul_node, lhs, rhs})
        modified = True

    # 5) bulk‑erase dead nodes in reverse order
    if modified:
        for n in reversed(list(gm.graph.nodes)):
            if n in nodes_to_delete and not n.users:
                gm.graph.erase_node(n)
        gm.graph.lint()
        new_module = GraphModule(gm, gm.graph)
        new_module.recompile()
        return new_module

    return gm


def is_reciprocal(node):
    if node.op != "call_function":
        return False
    if node.target == torch.reciprocal:
        return True
    if node.target == operator.truediv:
        return is_constant_one(node.args[0])
    if node.target in (torch.pow, operator.pow):
        return len(node.args) >= 2 and float(node.args[1]) == -1.0
    return False


def get_recip_arg(node):
    if node.target == torch.reciprocal:
        return node.args[0]
    if node.target == operator.truediv:
        return node.args[1]
    if node.target in (torch.pow, operator.pow) and float(node.args[1]) == -1.0:
        return node.args[0]
    return None


def is_constant_one(node):
    return (node.op == "call_function"
            and node.target == torch.tensor
            and isinstance(node.args[0], (int, float))
            and float(node.args[0]) == 1.0)


def run_associative_swap():
    class RecipAssociativeModel(torch.nn.Module):
        def forward(self, A, B):
            # Implements: Recip(A) ⊙ Recip(A ⊙ B)
            recip_a = torch.reciprocal(A)  # 1/A
            a_mul_b = A * B                # A*B
            recip_a_mul_b = torch.reciprocal(a_mul_b)  # 1/(A*B)
            return recip_a * recip_a_mul_b  # (1/A) * (1/(A*B))

    gm = symbolic_trace(RecipAssociativeModel().eval())

    print("=== Original Graph ===")
    print(gm.graph)

    optimized = swap_recip_associative(gm)
    print("=== Rewritten Graph ===")
    print(optimized.graph)

    compiled_gm = torch.compile(optimized, backend='inductor')

    A = torch.tensor([2.0, 4.0, 5.0])
    B = torch.tensor([3.0, 5.0, 7.0])

    original_output = gm(A, B)
    optimized_output = optimized(A, B)
    compiled_output = compiled_gm(A, B)

    print("Original output:", original_output)
    print("Optimized output:", optimized_output)
    print("Compiled output:", compiled_output)

    assert torch.allclose(original_output, optimized_output,
                          rtol=1e-5), "Outputs do not match!"
    assert torch.allclose(original_output, compiled_output,
                          rtol=1e-5), "Outputs do not match!"


if __name__ == "__main__":
    run_associative_swap()


=== Original Graph ===
graph():
    %a : [num_users=2] = placeholder[target=A]
    %b : [num_users=1] = placeholder[target=B]
    %reciprocal : [num_users=1] = call_function[target=torch.reciprocal](args = (%a,), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%a, %b), kwargs = {})
    %reciprocal_1 : [num_users=1] = call_function[target=torch.reciprocal](args = (%mul,), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (%reciprocal, %reciprocal_1), kwargs = {})
    return mul_1
=== Rewritten Graph ===
graph():
    %a : [num_users=1] = placeholder[target=A]
    %b : [num_users=1] = placeholder[target=B]
    %reciprocal : [num_users=1] = call_function[target=torch.reciprocal](args = (%a,), kwargs = {})
    %square : [num_users=1] = call_function[target=torch.square](args = (%reciprocal,), kwargs = {})
    %reciprocal_2 : [num_users=1] = call_function[target=torch.reciprocal](args = (%b,), kwargs = {})
    %mul_2 : [num

In [9]:
import torch
import torch.nn as nn
from torch.fx import symbolic_trace, GraphModule
import operator

ELEMENTWISE_OPS = {
    torch.add, torch.sub, torch.mul, torch.div,
    torch.relu, torch.sigmoid, torch.tanh,
    torch.neg, torch.exp, torch.log,
    operator.add, operator.sub, operator.mul, operator.truediv
}

METHOD_OPS = {"relu", "sigmoid", "tanh", "add", "mul", "div", "sub"}


def is_elementwise(node):
    if getattr(node.target, "__name__", None) == "fused_function":
        return False
    return (node.op == "call_function" and node.target in ELEMENTWISE_OPS) or \
           (node.op == "call_method" and node.target in METHOD_OPS)


def find_elementwise_chains(graph):
    chains = []
    visited = set()

    for node in graph.nodes:
        if node in visited:
            continue
        if not is_elementwise(node):
            continue

        chain = [node]
        visited.add(node)
        current = node

        while True:
            users = list(current.users)
            if len(users) != 1:
                break

            next_node = users[0]
            if not is_elementwise(next_node) or next_node in visited:
                break

            only_depends_on_current = all(
                not (isinstance(arg, torch.fx.Node) and arg.op !=
                     "get_attr" and arg is not current)
                for arg in next_node.args
            )

            if not only_depends_on_current:
                break

            chain.append(next_node)
            visited.add(next_node)
            current = next_node

        if len(chain) >= 2:
            chains.append(chain)
    return chains


def generate_fused_fn(chain, gm):
    # Pre-capture tensor constants
    constants = {}

    # Check if this chain is safe to fuse
    for i, node in enumerate(chain):
        if node.op == "call_function" and len(node.args) >= 2:
            second_arg = node.args[1]
            if isinstance(second_arg, torch.fx.Node) and second_arg.op == "get_attr":
                constants[(i, 1)] = getattr(gm, second_arg.target)
            elif not isinstance(second_arg, torch.fx.Node):
                constants[(i, 1)] = second_arg
            else:
                return None  # No fusion if second arg is another variable node

    def fused_function(x):
        result = x
        for i, node in enumerate(chain):
            if node.op == "call_function":
                if len(node.args) == 1:  # unary op
                    result = node.target(result)
                elif len(node.args) >= 2:  # binary op
                    if (i, 1) in constants:
                        result = node.target(result, constants[(i, 1)])
                    else:
                        second_arg = node.args[1]
                        result = node.target(result, second_arg)

            elif node.op == "call_method":
                method = getattr(result, node.target)
                if len(node.args) == 0:
                    result = method()
                else:
                    method_args = node.args[1:]
                    result = method(*method_args)

        return result

    fused_function.__name__ = "fused_function"
    fused_function.flops_per_element = len(chain)
    fused_function.is_fused_function = True
    return fused_function


def fuse_elementwise_chains(gm: GraphModule):
    graph = gm.graph
    chains = find_elementwise_chains(graph)

    for chain in chains:
        # If the head node is no longer in the graph (erased by a prior fusion), skip
        if chain[0] not in graph.nodes:
            continue

        first = chain[0]
        last = chain[-1]
        input_val = first.args[0]

        fused_fn = generate_fused_fn(chain, gm)
        if fused_fn is None:
            continue

        with graph.inserting_before(first):
            fused = graph.call_function(fused_fn, args=(input_val,))

        last.replace_all_uses_with(fused)

        for node in reversed(chain):
            graph.erase_node(node)

    gm.recompile()
    return gm


class TestModel(nn.Module):
    def forward(self, x):
        return torch.sigmoid(torch.relu(torch.add(x, 1)))


if __name__ == "__main__":
    model = TestModel()
    traced = symbolic_trace(model)

    print("=== Before Fusion ===")
    print(traced.graph)

    optimized = fuse_elementwise_chains(traced)

    print("=== After Fusion ===")
    print(optimized.graph)

    x = torch.randn(3, 4)
    ref_out = model(x)
    opt_out = optimized(x)

    print("Ref Output:", ref_out)
    print("Fused Output:", opt_out)

    assert torch.allclose(ref_out, opt_out, rtol=1e-4), "Mismatch in outputs"

    compiled = torch.compile(optimized)
    compiled_out = compiled(x)
    print("Compiled Output:", compiled_out)
    assert torch.allclose(compiled_out, ref_out, rtol=1e-4)


=== Before Fusion ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %add : [num_users=1] = call_function[target=torch.add](args = (%x, 1), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%add,), kwargs = {})
    %sigmoid : [num_users=1] = call_function[target=torch.sigmoid](args = (%relu,), kwargs = {})
    return sigmoid
=== After Fusion ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %fused_function : [num_users=1] = call_function[target=__main__.fused_function](args = (%x,), kwargs = {})
    return fused_function
Ref Output: tensor([[0.6061, 0.5000, 0.8335, 0.6420],
        [0.6649, 0.5932, 0.5594, 0.9262],
        [0.7824, 0.7213, 0.8650, 0.8978]])
Fused Output: tensor([[0.6061, 0.5000, 0.8335, 0.6420],
        [0.6649, 0.5932, 0.5594, 0.9262],
        [0.7824, 0.7213, 0.8650, 0.8978]])
Compiled Output: tensor([[0.6061, 0.5000, 0.8335, 0.6420],
        [0.6649, 0.5932, 0.5594, 0.9262],
        [0.7824, 0.7213, 0.8650, 

In [10]:
import torch

def apply_all_rewrites(gm: torch.fx.GraphModule):
    # Commutative rules
    gm = swap_bitshift_reducesum(gm)
    gm = CommutativePass()(gm)

    # Associative rules
    gm = swap_recip_associative(gm)
    gm = SqrtAssociativePass()(gm)

    # Distributive rules
    gm = DistributiveRulePass()(gm)
    gm = DistributiveRule2Pass()(gm)

    return gm


def optimize_graph(gm: torch.fx.GraphModule):
    gm = apply_all_rewrites(gm)
    gm = fuse_elementwise_chains(gm)
    return gm


In [11]:
# ─── EVALUATION SECTION ────────────────────────────────────
import torch
import torch.nn as nn
import numpy as np
from torch.fx import symbolic_trace, GraphModule
from torch.fx.passes.shape_prop import ShapeProp
from typing import Callable, Dict, Tuple, List

# ─── Helpers ────────────────────────────────────────────────────────────────
def _numel_from_meta(node):
    return int(np.prod(node.meta["tensor_meta"].shape))

def _numel_from_meta_simple(shape: tuple) -> int:
    return int(np.prod(shape))

# ─── Zero‑cost ops ─────────────────────────────────────────────────────────
_ZERO_COST = {
    torch.ones_like,
    torch.zeros_like,
    torch.tensor,
    torch.flatten,    # view only
}

# ─── Recognize both Python and ATen sum/prod ────────────────────────────────
_SUM_FNS  = { torch.ops.aten.sum, torch.sum }
_PROD_FNS = { torch.ops.aten.prod, torch.prod }

# ─── ATen mapping: specific operators → flop functions ──────────────────────
ATEN_FLOP_MAP: Dict[Callable, Callable] = {
    # element‑wise
    torch.ops.aten.add.Tensor:  lambda n: _numel_from_meta(n),
    torch.ops.aten.sub.Tensor:  lambda n: _numel_from_meta(n),
    torch.ops.aten.mul.Tensor:  lambda n: _numel_from_meta(n),
    torch.ops.aten.div.Tensor:  lambda n: _numel_from_meta(n),
    torch.ops.aten.neg:         lambda n: _numel_from_meta(n),
    torch.ops.aten.exp:         lambda n: _numel_from_meta(n),
    torch.ops.aten.log:         lambda n: _numel_from_meta(n),
    torch.ops.aten.sqrt:        lambda n: _numel_from_meta(n),
    torch.ops.aten.sigmoid:     lambda n: 2 * _numel_from_meta(n),
    torch.ops.aten.tanh:        lambda n: 2 * _numel_from_meta(n),
    torch.ops.aten.square:      lambda n: _numel_from_meta(n),

    # constant factories
    torch.ops.aten.ones_like:   lambda n: 0,
    torch.ops.aten.zeros_like:  lambda n: 0,

    # matmul
    torch.ops.aten.matmul:      lambda n: 2
                                  * n.args[0].shape[0]
                                  * n.args[0].shape[1]
                                  * n.args[1].shape[-1],

    # reductions via ATen
    torch.ops.aten.sum:         lambda n: _numel_from_meta(n.args[0]),
    torch.ops.aten.prod:        lambda n: _numel_from_meta(n.args[0]),

    torch.bitwise_left_shift: lambda n: _numel_from_meta(n),
    torch.ops.aten.bitwise_left_shift: lambda n: _numel_from_meta(n),
}

# ─── nn.Module mapping → flop functions ────────────────────────────────────
MODULE_FLOP_MAP: Dict[type, Callable] = {
    nn.Conv1d:    lambda m, o: 2 * o[0] * o[1] * o[2] * m.in_channels * m.kernel_size[0],
    nn.Conv2d:    lambda m, o: 2 * o[0] * o[1] * o[2] * o[3]
                                     * m.in_channels
                                     * m.kernel_size[0]
                                     * m.kernel_size[1],
    nn.Conv3d:    lambda m, o: 2 * o[0] * o[1] * o[2] * o[3] * o[4]
                                     * m.in_channels
                                     * m.kernel_size[0]
                                     * m.kernel_size[1]
                                     * m.kernel_size[2],
    nn.Linear:    lambda m, o: 2 * o[0] * m.in_features * m.out_features,
    nn.BatchNorm1d: lambda m, o: 2 * _numel_from_meta_simple(o),
    nn.BatchNorm2d: lambda m, o: 2 * _numel_from_meta_simple(o),
    nn.BatchNorm3d: lambda m, o: 2 * _numel_from_meta_simple(o),
}

# ─── The FX‑based FLOP counter ─────────────────────────────────────────────
def fx_count_flops(
    model: torch.nn.Module,
    inputs: tuple,
    custom_aten: Dict[Callable, Callable] = None,
    custom_modules: Dict[type, Callable] = None,
) -> int:
    # 1) Trace & shape‑propagate
    if isinstance(model, GraphModule):
        gm = model
    else:
        gm = symbolic_trace(model)
    ShapeProp(gm).propagate(*inputs)

    aten_map = {**ATEN_FLOP_MAP,    **(custom_aten or {})}
    mod_map  = {**MODULE_FLOP_MAP, **(custom_modules or {})}

    total_flops = 0
    for node in gm.graph.nodes:
        # 0) skip zero‑cost
        if node.op == "call_function" and node.target in _ZERO_COST:
            continue

        # 1) fused‑function fast path
        if node.op == "call_function" and getattr(node.target, "is_fused_function", False):
            ne = _numel_from_meta(node)
            total_flops += node.target.flops_per_element * ne
            continue

        # 2) nn.Modules
        if node.op == "call_module":
            sub = gm.get_submodule(node.target)
            for T, fn in mod_map.items():
                if isinstance(sub, T):
                    total_flops += int(fn(sub, node.meta["tensor_meta"].shape))
                    break
            continue

        # 3) explicit ATen (elementwise, matmul, constant factories, aten.sum/prod)
        if node.op == "call_function" and node.target in aten_map:
            total_flops += int(aten_map[node.target](node))
            continue

        # 4) reductions via Python‐level sum/prod
        if node.op == "call_function" and node.target in _SUM_FNS:
            # cost = #elements of the input
            total_flops += int(np.prod(node.args[0].meta["tensor_meta"].shape))
            continue
        if node.op == "call_function" and node.target in _PROD_FNS:
            total_flops += int(np.prod(node.args[0].meta["tensor_meta"].shape))
            continue

        # 5/6) generic fallback: any other call_*(function|method) that has tensor_meta
        if (node.op in ("call_function","call_method") and
            "tensor_meta" in node.meta):
            total_flops += _numel_from_meta(node)
            continue

    return total_flops


In [12]:
import time

def profile_model(
    model: nn.Module, 
    dummy_input: torch.Tensor,  # Take pre-generated input tensor
    num_runs: int = 8,
    warmup: int = 3
) -> Dict[str, float]:
    device = dummy_input.device
    model = model.to(device).eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(dummy_input)
        torch.cuda.synchronize() if device.type == "cuda" else None

    # Latency measurement
    times = []
    for _ in range(num_runs + 3):
        if device.type == "cuda":
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            torch.cuda.synchronize()
            start.record()
            _ = model(dummy_input)
            end.record()
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))
        else:
            t0 = time.time()
            _ = model(dummy_input)
            times.append((time.time() - t0) * 1000)
    
    # Process times
    times = sorted(times)
    cut = int(len(times) * 0.1)
    avg_time = np.mean(times[cut:-cut])

    # Memory measurement
    peak_mem = 0.0
    if device.type == "cuda":
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        with torch.no_grad():
            _ = model(dummy_input)
            torch.cuda.synchronize()
        peak_mem = torch.cuda.max_memory_allocated() / (1024**2)

    # FLOP counting
    total_flops = 0
    try:
        gm = symbolic_trace(model)
        fx.passes.shape_prop.ShapeProp(gm).propagate(dummy_input)
        total_flops = fx_count_flops(gm, (dummy_input,)) / 1e9
    except Exception as e:
        print(f"FLOP counting failed: {str(e)}")

    return {
        'avg_time_ms': avg_time,
        'peak_mem_mb': peak_mem,
        'total_flops_g': total_flops
    }


In [13]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.fx as fx

# ─── models ──────────────────────────────────────────────────────────────────
class DeepElementwiseModel(nn.Module):
    def forward(self, x):
        for _ in range(40):
            x = torch.sigmoid(torch.relu(x + 1))
        return x.mean()
        
class EnhancedResNet(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(weights=None)
        self.features = nn.Sequential(
            base.conv1, base.bn1, base.relu, base.maxpool,
            base.layer1, base.layer2
        )
        self.pattern_conv = nn.Conv2d(128, 128, 3, padding=1)
        self.fc = nn.Linear(128, 10)

    def forward(self, x):
        x = self.features(x)
        for _ in range(30):
            a = x
            b = self.pattern_conv(x)
            c = self.pattern_conv(x)
            x = a*c + a*b
            x = torch.sigmoid(torch.relu(x + 1))
        return self.fc(x.mean(dim=[2, 3]))

class EnhancedVGG(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.vgg16(weights=None)
        self.features = base.features
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        x = self.features(x)
        # a short elementwise chain
        x = torch.tanh(torch.exp(x + 2))
        x = torch.sigmoid(x)
        x = torch.sqrt(x + 1)
        return self.fc(x.mean(dim=[2, 3]))

In [14]:
class BitShiftTestModel(nn.Module):
    def forward(self, x: torch.Tensor):
        shifted = torch.bitwise_left_shift(x, 1)
        return torch.sum(shifted)

In [15]:
class RewriteTriggerResNet(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = models.resnet18(weights=None)
        self.features = nn.Sequential(*list(backbone.children())[:-1])  # Remove FC
        self.head = nn.Linear(512, 10)

    def forward(self, image, B, C):
        # === Real path (ResNet on image input) ===
        x = self.features(image)  
        x = torch.flatten(x, 1)

        # === Rewrite patterns on vector-shaped inputs ===
        BIG_B = B.repeat(1, 64) 
        BIG_C = C.repeat(1, 64)
        
        for _ in range(100):  # Repeat 100x to amplify FLOP impact

            # 1. prod(exp(x)) → exp(sum(x))
            _ = torch.prod(torch.exp(BIG_B), dim=1)
    
            # 2. (B * sqrt(C)) * (sqrt(C) * B)
            s = torch.sqrt(BIG_C)
            _ = (BIG_B * s) * (s * BIG_B)
    
            # 3. sum(bitshift(x)) → bitshift(sum(x))
            shifted = torch.bitwise_left_shift(BIG_B.to(torch.int32), 2)
            summed_int = torch.sum(shifted, dim=1, keepdim=True)
            _ = summed_int.float()
    
            # 4. A * B + A * C → A * (B + C)
            _ = BIG_B * BIG_C + BIG_B * BIG_B
    
            # 5. A + A * B → A * (B + 1)
            _ = BIG_B + BIG_B * BIG_C
    
            # 6. (1/B) * (1/(B*C)) → (1/B)^2 * (1/C)
            _ = torch.reciprocal(BIG_B) * torch.reciprocal(BIG_B * BIG_C)

        return self.head(x)

class RewriteTriggerResNetWrapped(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base_model = base_model

    def forward(self, x):
        A = x[:, :, :, :224]  # [32,3,224,224]
        B = x[:, 0, 0, 224:224+8192]    # [32,8192]
        C = x[:, 0, 0, 224+8192:224+16384]  # [32,8192]
        return self.base_model(A, B, C)

In [16]:
class ComprehensiveRewriteModel(nn.Module):
    """
    A single‑module harness containing all of the algebraic rewrites:
      1) Bitshift→Sum
      2) Prod(Exp)→Exp(Sum)
      3) Recip‑Associative
      4) Sqrt‑Associative
      5) Distributive A*B + A*C → A*(B+C)
      6) Distributive2 A + A*B → A*(1+B)
    """
    def forward(self, x):
        # Build two helper tensors B and C from x
        B = x + 1.0
        C = x + 2.0

        # 1) Bitshift → Sum
        x_int   = x.to(torch.int32)
        t1      = torch.bitwise_left_shift(x_int, 1)
        sum_int = torch.sum(t1, dim=1, keepdim=True)
        p1      = sum_int.float()

        # 2) Prod(Exp) → Exp(Sum)
        t2 = torch.exp(B)
        p2 = torch.prod(t2, dim=1, keepdim=True)

        # 3) Recip‑Associative: Recip(A) * Recip(A * B) → (Recip(A))² * Recip(B)
        rA  = torch.reciprocal(B)
        rAB = torch.reciprocal(B * C)
        p3  = rA * rAB

        # 4) Sqrt‑Associative: (A*√B)*(√B*C) → A*B*C
        sB = torch.sqrt(B)
        m1 = x * sB
        m2 = sB * C
        p4 = m1 * m2

        # 5) Distributive: A*B + A*C → A*(B+C)
        d5 = x * B + x * C

        # 6) Distributive2: A + A*B → A*(1 + B)
        d6 = x + x * C

        # Combine everything and collapse to a scalar
        out = p1 + p2 + p3 + p4 + d5 + d6
        return out.mean()

In [17]:
import torch
import pandas as pd
from torch.fx import symbolic_trace

def evaluate_all_models_with_breakdown():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    summary = []

    model_configs = [
        {
            "name": "RewriteResNet",
            "model": RewriteTriggerResNetWrapped(RewriteTriggerResNet()).to(device),
            "make_input": lambda: torch.cat([
                torch.randn(32, 3, 224, 224, device=device),
                torch.randint(0, 16, (32, 1, 1, 16384), device=device).float()
                    .expand(-1, 3, 224, -1)
            ], dim=3),
            "passes": [
                ("Commutative",        CommutativePass),
                ("SqrtAssociative",    SqrtAssociativePass),
                ("BitShift→Sum",       swap_bitshift_reducesum),
                ("Distributive",       DistributiveRulePass),
                ("Distributive2",      DistributiveRule2Pass),
                ("Recip‑Assoc",        swap_recip_associative),
            ]
        },
        {
            "name": "ComprehensiveRewrite",
            "model": ComprehensiveRewriteModel().to(device),
            "make_input": lambda: torch.randn(32, 3, 224, 224, device=device),
            "passes": [
                ("Commutative",        CommutativePass),
                ("SqrtAssociative",    SqrtAssociativePass),
                ("BitShift→Sum",       swap_bitshift_reducesum),
                ("Distributive",       DistributiveRulePass),
                ("Distributive2",      DistributiveRule2Pass),
                ("Recip‑Assoc",        swap_recip_associative),
            ]
        },
        # summary‑only models
        {
            "name": "EnhancedResNet",
            "model": EnhancedResNet().to(device),
            "make_input": lambda: torch.randn(32, 3, 224, 224, device=device),
            "passes": None
        },
        {
            "name": "EnhancedVGG",
            "model": EnhancedVGG().to(device),
            "make_input": lambda: torch.randn(32, 3, 224, 224, device=device),
            "passes": None
        },
        {
            "name": "DeepElementwise",
            "model": DeepElementwiseModel().to(device),
            "make_input": lambda: torch.randn(32, 3, 224, 224, device=device),
            "passes": None
        },
        {
            "name": "BitShiftTest",
            "model": BitShiftTestModel().to(device),
            "make_input": lambda: torch.randint(0,16,(32,3,224,224),
                                                device=device, dtype=torch.int32),
            "passes": None
        },
    ]

    for cfg in model_configs:
        name    = cfg["name"]
        model   = cfg["model"].eval()
        dummy   = cfg["make_input"]()

        # 1) baseline
        gm0       = symbolic_trace(model)
        base_stat = profile_model(gm0, dummy)

        # 2) global optimize_graph
        gm1       = optimize_graph(symbolic_trace(model))
        opt_stat  = profile_model(gm1, dummy)

        if cfg["passes"] is None:
            # just gather for final summary
            Δf = 100 * (base_stat["total_flops_g"] - opt_stat["total_flops_g"]) / base_stat["total_flops_g"]
            Δt = 100 * (base_stat["avg_time_ms"]   - opt_stat["avg_time_ms"])   / base_stat["avg_time_ms"]
            Δm = 100 * (base_stat["peak_mem_mb"]   - opt_stat["peak_mem_mb"])   / base_stat["peak_mem_mb"]
            summary.append({
                "Model":        name,
                "Baseline FLOPs": base_stat["total_flops_g"],
                "Opt FLOPs":      opt_stat["total_flops_g"],
                "ΔFLOPs %":       f"{Δf:+.1f}%",
                "Baseline ms":    base_stat["avg_time_ms"],
                "Opt ms":         opt_stat["avg_time_ms"],
                "ΔLat %":         f"{Δt:+.1f}%",
                "Baseline MB":    base_stat["peak_mem_mb"],
                "Opt MB":         opt_stat["peak_mem_mb"],
                "ΔMem %":         f"{Δm:+.1f}%"
            })

        else:
            # detailed per‑pass breakdown
            print(f"\n--- {name} ---")
            print(f"Baseline  →  FLOPs {base_stat['total_flops_g']:.2f}G   "
                  f"Latency {base_stat['avg_time_ms']:.2f}ms   "
                  f"Mem {base_stat['peak_mem_mb']:.1f}MB\n")

            for pass_name, Pass in cfg["passes"]:
                gm_p   = symbolic_trace(model)
                gm_opt = Pass()(gm_p) if isinstance(Pass, type) else Pass(gm_p)
                pstat  = profile_model(gm_opt, dummy)

                Δf_p = 100 * (base_stat["total_flops_g"] - pstat["total_flops_g"]) / base_stat["total_flops_g"]
                Δt_p = 100 * (base_stat["avg_time_ms"]   - pstat["avg_time_ms"])   / base_stat["avg_time_ms"]
                Δm_p = 100 * (base_stat["peak_mem_mb"]   - pstat["peak_mem_mb"])   / base_stat["peak_mem_mb"]

                print(f"{pass_name:<18}"
                      f" FLOPs {pstat['total_flops_g']:.2f}G ({Δf_p:+.1f}%)  "
                      f"Time {pstat['avg_time_ms']:.2f}ms ({Δt_p:+.1f}%)  "
                      f"Mem {pstat['peak_mem_mb']:.1f}MB ({Δm_p:+.1f}%)")

    # 3) finally print the summary for the rest
    if summary:
        df = pd.DataFrame(summary).set_index("Model")
        print("\n=== Summary for the other models ===")
        display(df.round(2))


In [18]:
evaluate_all_models_with_breakdown()


--- RewriteResNet ---
Baseline  →  FLOPs 146.49G   Latency 195.10ms   Mem 2453.1MB

Commutative        FLOPs 144.81G (+1.1%)  Time 187.85ms (+3.7%)  Mem 2453.1MB (+0.0%)
SqrtAssociative    FLOPs 143.13G (+2.3%)  Time 175.59ms (+10.0%)  Mem 2453.1MB (+0.0%)
BitShift→Sum       FLOPs 144.81G (+1.1%)  Time 188.02ms (+3.6%)  Mem 2453.1MB (+0.0%)
Distributive       FLOPs 144.81G (+1.1%)  Time 187.91ms (+3.7%)  Mem 2453.1MB (+0.0%)
Distributive2      FLOPs 146.49G (+0.0%)  Time 199.65ms (-2.3%)  Mem 2453.1MB (+0.0%)
Recip‑Assoc        FLOPs 146.49G (+0.0%)  Time 191.95ms (+1.6%)  Mem 2453.1MB (+0.0%)

--- ComprehensiveRewrite ---
Baseline  →  FLOPs 0.12G   Latency 0.88ms   Mem 252.3MB

Commutative        FLOPs 0.12G (+2.7%)  Time 0.89ms (-1.0%)  Mem 252.3MB (+0.0%)
SqrtAssociative    FLOPs 0.11G (+8.1%)  Time 0.83ms (+5.4%)  Mem 252.3MB (+0.0%)
BitShift→Sum       FLOPs 0.12G (+2.7%)  Time 0.89ms (-1.4%)  Mem 252.3MB (+0.0%)
Distributive       FLOPs 0.11G (+4.1%)  Time 0.86ms (+2.4%)  Mem 252

Unnamed: 0_level_0,Baseline FLOPs,Opt FLOPs,ΔFLOPs %,Baseline ms,Opt ms,ΔLat %,Baseline MB,Opt MB,ΔMem %
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
EnhancedResNet,508.11,508.02,+0.0%,11.57,10.76,+7.0%,2390.6,2023.1,+15.4%
EnhancedVGG,982.19,982.19,+0.0%,15.27,15.29,-0.1%,3531.38,3531.38,+0.0%
DeepElementwise,0.58,0.58,+0.0%,2.13,2.08,+2.3%,166.18,166.18,+0.0%
BitShiftTest,0.01,0.0,+50.0%,0.12,0.1,+16.7%,184.56,166.18,+10.0%
