<a href="https://colab.research.google.com/github/mohitraosatya/enhanced-cpu-fallback-ttbuda-demo/blob/main/partial_cpu_fallback_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math

class MockOp:
    """
    Represents a single op (layer or node) in the computational graph.
    """
    def __init__(self, name, op_type, shape):
        self.name = name
        self.op_type = op_type
        self.shape = shape  # e.g., (batch, features)
        self.device = None  # 'CPU' or 'MockDevice' will be assigned later

    def __repr__(self):
        return f"<Op {self.name} type={self.op_type} shape={self.shape} device={self.device}>"

class MockGraph:
    """
    Represents the entire "model" as a list of Ops, each feeding into the next.
    In reality, you might have a DAG, but here it's linear for simplicity.
    """
    def __init__(self, ops):
        self.ops = ops

    def __repr__(self):
        return "\n".join([repr(op) for op in self.ops])

class MockDevice:
    """
    Hypothetical Tenstorrent device that supports only certain op types.
    """
    def __init__(self, name, supported_ops):
        self.name = name
        self.supported_ops = supported_ops  # set or list of supported op_type strings

    def can_run(self, op_type):
        return op_type in self.supported_ops

    def run_op(self, op):
        """
        Pretend to run the op on device. We'll just print a message.
        Real code would invoke TT-Buda or TT-Metal calls here.
        """
        print(f"[{self.name}] Running {op.name} (type={op.op_type}) on device...")

class CPUDevice:
    """
    The CPU fallback device. We'll treat it as always able to run any op.
    """
    def __init__(self):
        self.name = "CPU"

    def run_op(self, op):
        print(f"[CPU] Fallback for {op.name} (type={op.op_type}). Running on CPU...")

def partition_graph(graph, device, cpu):
    """
    Assign each op to 'device' if supported, otherwise assign to CPU.
    Return a list of sub-graphs (or partitions) that can be run in sequence.
    For simplicity, we'll create consecutive ops that share the same device
    as a single subgraph.
    """
    partitions = []
    current_partition = []
    current_device = None

    for op in graph.ops:
        if device.can_run(op.op_type):
            # This op can run on the device
            op.device = device.name
            # If the current partition is for CPU, we start a new partition
            if current_device != device.name:
                if current_partition:
                    partitions.append((current_device, current_partition))
                current_partition = []
                current_device = device.name
            current_partition.append(op)
        else:
            # Must fallback to CPU
            op.device = cpu.name
            if current_device != cpu.name:
                if current_partition:
                    partitions.append((current_device, current_partition))
                current_partition = []
                current_device = cpu.name
            current_partition.append(op)

    # Append the last partition if it exists
    if current_partition:
        partitions.append((current_device, current_partition))

    return partitions

def run_partitioned_graph(partitions, device, cpu):
    """
    Execute each partition in sequence. If the partition is device-based, run on device.
    If CPU, run on CPU. We'll pretend to pass "data" from one partition to the next.
    """
    data_buffer = None  # pretend data from previous partition

    for dev_name, ops_list in partitions:
        # In reality, you'd push data_buffer to device if needed,
        # or run concurrency steps. We'll just simulate logs:
        if dev_name == device.name:
            print(f"\n--- Running partition on {device.name} with {len(ops_list)} ops ---")
            for op in ops_list:
                device.run_op(op)
            # The result is "data_buffer" for the next partition
            data_buffer = "DeviceOutput"
        else:
            print(f"\n--- Running partition on CPU with {len(ops_list)} ops ---")
            for op in ops_list:
                cpu.run_op(op)
            data_buffer = "CPUOutput"

    print("\nAll partitions completed.")
    return data_buffer

# -----------------------
# Example usage (the "main" flow):
# -----------------------

# 1. Define a mock list of ops in the model
ops = [
    MockOp("MatMul1", "Matmul", (32, 64)),
    MockOp("LayerNorm1", "LayerNorm", (32, 64)),
    MockOp("Unsupported1", "WeirdOp", (32, 64)),  # Not supported by the mock device
    MockOp("Softmax1", "Softmax", (32, 64)),
    MockOp("Unsupported2", "CustomAttentionOp", (32, 64)), # Another not supported
    MockOp("MatMul2", "Matmul", (32, 128)),
]

# 2. Create a mock "device" with partial support
device_supported_ops = {"Matmul", "Softmax", "LayerNorm"}  # e.g. "WeirdOp" isn't supported
mock_device = MockDevice("MockTenstorrentDevice", device_supported_ops)

# 3. Create a CPU fallback device
cpu_device = CPUDevice()

# 4. Build the graph
graph = MockGraph(ops)
print("Initial Graph:\n", graph, "\n")

# 5. Partition the graph
partitions = partition_graph(graph, mock_device, cpu_device)
print("Partitions (device, ops_list):")
for dev_name, subg in partitions:
    print(f"  {dev_name} => {[op.name for op in subg]}")

# 6. Execute the partitioned graph
run_partitioned_graph(partitions, mock_device, cpu_device)


Initial Graph:
 <Op MatMul1 type=Matmul shape=(32, 64) device=None>
<Op LayerNorm1 type=LayerNorm shape=(32, 64) device=None>
<Op Unsupported1 type=WeirdOp shape=(32, 64) device=None>
<Op Softmax1 type=Softmax shape=(32, 64) device=None>
<Op Unsupported2 type=CustomAttentionOp shape=(32, 64) device=None>
<Op MatMul2 type=Matmul shape=(32, 128) device=None> 

Partitions (device, ops_list):
  MockTenstorrentDevice => ['MatMul1', 'LayerNorm1']
  CPU => ['Unsupported1']
  MockTenstorrentDevice => ['Softmax1']
  CPU => ['Unsupported2']
  MockTenstorrentDevice => ['MatMul2']

--- Running partition on MockTenstorrentDevice with 2 ops ---
[MockTenstorrentDevice] Running MatMul1 (type=Matmul) on device...
[MockTenstorrentDevice] Running LayerNorm1 (type=LayerNorm) on device...

--- Running partition on CPU with 1 ops ---
[CPU] Fallback for Unsupported1 (type=WeirdOp). Running on CPU...

--- Running partition on MockTenstorrentDevice with 1 ops ---
[MockTenstorrentDevice] Running Softmax1 (type=

'DeviceOutput'