## PyTorch dispatcher tutorial

**What is the PyTorch Dispatcher?**

At its core, the PyTorch dispatcher is a central routing mechanism within the PyTorch framework. Think of it like a sophisticated traffic controller for operator calls. When you execute a PyTorch operation like torch.add(a, b), the dispatcher's job is to figure out exactly which piece of code (called a kernel) should handle this operation based on the properties of the input tensors (a and b) and the current context.

Its primary role is to decouple the definition of an operation (what torch.add means conceptually) from its various implementations (how to perform addition on CPU tensors, CUDA tensors, how to handle automatic differentiation, etc.). It acts as the central hub that manages and invokes the correct kernel for a given set of inputs and system state.

The dispatcher is crucial for handling cross-cutting concerns – features that apply across many different operators. These include:

- **Device Type**: Running the operation on CPU, CUDA, MPS, XLA, etc.
- **Data Type**: Handling different dtypes like float32, float16, int64.
- **Autograd**: Enabling automatic differentiation by tracking computation graphs.
- **Other Features**: Supporting things like TorchScript tracing, quantization, functionalization, and more.

**Why is it Needed?**

Without the dispatcher, we'd need a massive conditional block:
```c++
Tensor add(const Tensor& a, const Tensor& b) {
    if (a.device().type() == kCPU && b.device().type() == kCPU) {
        if (requires_grad(a) || requires_grad(b)) {
            // Call CPU autograd addition kernel
        } else {
            // Call plain CPU addition kernel
        }
    }
    ...
}
```

**Key Concepts**

- **Operator**: A fundamental operation in PyTorch, usually exposed as a function in the torch namespace (e.g., torch.add, torch.matmul, torch.relu). Operators have a defined schema that specifies their name, inputs, and outputs.
- **Kernel**: A specific C++ function that implements an operator for a particular dispatch key. For example, there's a CPU kernel for torch.add, a CUDA kernel for torch.add, an Autograd kernel, etc.
- **Dispatch Key**: An enum value (c10::DispatchKey) representing a specific context, feature, or backend. Keys are used to tag tensors and direct the dispatcher. Examples include CPU, CUDA, Autograd, QuantizedCPU, CompositeImplicitAutograd. Dispatch keys form a hierarchy or set, allowing the dispatcher to select the most appropriate kernel.
- **Operator Schema**: A formal definition of an operator's signature, including its name (with overload name if applicable), arguments (name and type), and return values (name and type). Example: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor. Schemas ensure type safety and consistency across different kernels.

## Dispatch keys

Dispatch Keys are the fundamental identifiers used by the dispatcher to select the appropriate kernel. These keys are defined in `c10::DispatchKey` enum located in [c10/core/DispatchKey.h](https://github.com/pytorch/pytorch/blob/f252f9df5e0fa5d942218108cc5983bb72182086/c10/core/DispatchKey.h#L136) This enum establishes a prioritized ordering, where keys corresponding to more specialized or wrapping functionalities generally have higher priority.

## Registering Kernels with the Dispatcher

Kernels are the specific implementations of operators. To make the dispatcher aware of these implementations, you need to register them. This is primarily done using the `TORCH_LIBRARY` and `TORCH_LIBRARY_IMPL` macros in C++. These macros are the standard way to register operators and kernels from C++. They are typically used within C++ files that are compiled as part of PyTorch itself or as a C++ extension.

`TORCH_LIBRARY(ns, m)`: Defines a library of operators under a specific namespace (ns). Common namespaces include `aten` (for standard PyTorch operators) or custom namespaces for extensions. Inside the TORCH_LIBRARY block, you use methods like m.def() to define the operator schema.

```c++
#include <torch/library.h>
#include <ATen/core/dispatch/DispatchKey.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/Tensor.h>

// TODO: Fix Scalar a and Scalar b.
// Define the schema for 'xla_ops::ax_by'
TORCH_LIBRARY(xla_ops, m) {
    // Schema: xla_ops::ax_by(Tensor self, Tensor other, Scalar a, Scalar b) -> Tensor
    m.def("ax_by(Tensor self, Tensor other, Scalar a, Scalar b) -> Tensor");
}
```

`TORCH_LIBRARY_IMPL(ns, key, m)`: Registers kernel implementations for operators within the namespace `ns` for a specific dispatch key. Inside the `TORCH_LIBRARY_IMPL` block, you use `m.impl()` to link an operator name (matching a schema defined in TORCH_LIBRARY) to a specific C++ kernel function. `key` is a `c10::DispatchKey` enum value (e.g., CPU, CUDA, Autograd)


```c++
#include <ATen/ops/add.h>

// TODO: Fix Scalar a and Scalar b and implementation.
at::Tensor xla_ax_by(const at::Tensor& self, const at::Tensor& other, Scalar a, Scalar b) {
    std::cout << "Executing xla_ops::ax_by Kernel!" << std::endl;
    return at::add(a*self, b*other);
}

// Register the kernel function for the XLA dispatch key
TORCH_LIBRARY_IMPL(xla_ops, XLA, m) {
    // Link "ax_by" (matching the schema) to our C++ function xla_ax_by
    // The dispatcher ensures the function signature matches the schema.
    m.impl("ax_by", &xla_ax_by);
}
```



To enable logs for PyTorch dispatcher traces showing which kernels are called, build pytorch with:

```bash
export CFLAGS="-DHAS_TORCH_SHOW_DISPATCH_TRACE"
python setup.py bdist_wheel
python setup.py develop
```

In [1]:
%env TORCH_SHOW_DISPATCH_TRACE=1

env: TORCH_SHOW_DISPATCH_TRACE=1


In [6]:
import torch
import torch_xla
import torch_xla.runtime
import time

 [call] op=[aten::ones], key=[BackendSelect]
  [redispatch] op=[aten::ones], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::fill_.Scalar], key=[CPU]
 [call] op=[aten::ones], key=[BackendSelect]
  [redispatch] op=[aten::ones], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::fill_.Scalar], key=[CPU]
 [call] op=[aten::ones], key=[BackendSelect]
  [redispatch] op=[aten::ones], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::fill_.Scalar], key=[CPU]
 [call] op=[aten::ones], key=[BackendSelect]
  [redispatch] op=[aten::ones], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::fill_.Scalar]

In [3]:
t = torch.randn(4,4)

 [call] op=[aten::randn], key=[BackendSelect]
  [redispatch] op=[aten::randn], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::normal_], key=[CPU]


In [7]:
x = t.to('xla')

 [call] op=[aten::to.dtype_layout], key=[AutogradCPU]
  [call] op=[aten::_to_copy], key=[AutogradCPU]
   [redispatch] op=[aten::_to_copy], key=[BackendSelect]
    [redispatch] op=[aten::_to_copy], key=[XLA]
     [call] op=[aten::to.dtype_layout], key=[BackendSelect]
      [redispatch] op=[aten::to.dtype_layout], key=[CPU]
     [call] op=[aten::to.dtype_layout], key=[BackendSelect]
      [redispatch] op=[aten::to.dtype_layout], key=[CPU]


In [9]:
y = x + 1

 [call] op=[aten::add.Tensor], key=[AutogradXLA]
  [redispatch] op=[aten::add.Tensor], key=[Functionalize]
   [callBoxed] op=[aten::add.Tensor], key=[XLA]
    [call] op=[aten::result_type.Tensor], key=[XLA]
    [call] op=[aten::result_type.Tensor], key=[XLA]
    [call] op=[aten::to.dtype], key=[CPU]
     [call] op=[aten::_to_copy], key=[BackendSelect]
      [redispatch] op=[aten::_to_copy], key=[CPU]
       [call] op=[aten::empty_strided], key=[BackendSelect]
        [redispatch] op=[aten::empty_strided], key=[CPU]
       [call] op=[aten::copy_], key=[CPU]
    [call] op=[aten::item], key=[CPU]
     [call] op=[aten::_local_scalar_dense], key=[CPU]


## Registering kernels with dispatcher keys

## Boxed vs Unboxed kernels

## Debugging dispatcher

| Tool/Technique  | Description | Use Case |
|---|---|---|
| `torch._C._dispatch_dump("namespace::op_name")` | Prints a list of all kernels registered for a specific operator overload, including the dispatch key they are registered for and their source code location. | Verify that a specific kernel (e.g., custom CUDA kernel) is actually registered for the correct operator and dispatch key. See competing kernels. |
| `torch._C._dispatch_dump_table("namespace::op_name")` | Shows the computed dispatch table for an operator. For each dispatch key, it indicates which registered kernel would be selected based on priority rules. | Understand kernel priorities, identify which kernel should run for a given key, and see if fallbacks are being used. |

In [11]:
print(torch._C._dispatch_dump("aten::add.Tensor"))

name: aten::add.Tensor
schema: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
debug: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
alias analysis kind: FROM_SCHEMA
MkldnnCPU: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterMkldnnCPU_0.cpp:171 :: (Tensor _0, Tensor _1, Scalar _2) -> Tensor _0 [ boxed unboxed ]
Named: registered at /home/bbahl_google_com/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:11 :: (none) [ fallthrough boxed ]
ZeroTensor: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterZeroTensor_0.cpp:123 :: (Tensor _0, Tensor _1, Scalar _2) -> Tensor _0 [ boxed unboxed ]
Tracer: registered at /home/bbahl_google_com/pytorch/torch/csrc/autograd/generated/TraceType_2.cpp:17827 :: (Tensor _0, Tensor _1, Scalar _2) -> Tensor _0 [ boxed unboxed ]
FuncTorchBatched: registered at /home/bbahl_google_com/pytorch/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp:352 :: (Ten

In [13]:
print(torch._C._dispatch_dump_table("aten::add.Tensor"))

Undefined: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1380 [default backend kernel]
CPU: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterCPU_0.cpp:1316 [kernel]
CUDA: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1380 [default backend kernel]
HIP: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1380 [default backend kernel]
XLA: registered at bazel-out/k8-opt/bin/torch_xla/csrc/RegisterXLA.cpp:5050 [kernel]
MPS: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1380 [default backend kernel]
IPU: registered at /home/bbahl_google_com/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1380 [default backend kernel]
XPU: registered at /home/bbahl_google_com