# PyTorch Tutorial (Cornell ECE 6980)

Authors: *Hongzheng Chen*, *Zhanqiu Hu*

In [None]:
! which python
! python --version
! nvidia-smi

In [None]:
! pip3 install torch numpy slapo tabulate transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
print("PyTorch version:", torch.__version__)
%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")

## PyTorch Basics

Construct the model in a hierarchical way.

### Self-Attention

$$\mathrm{CoreAttention} \left(Q, K, V\right) = \mathrm{softmax} \left( \frac{QK^\mathrm{T}}{\sqrt{d_k}} \right) \cdot V$$


<div>
<img src="https://lilianweng.github.io/posts/2020-04-07-the-transformer-family/transformer.png" width="80%"/>
</div>

In [None]:
def scaled_dot_product(q, k, v):
    # (bs, head, seq, hs // head)
    d_k = q.shape[-1]
    attn_score = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(d_k)
    # (bs, head, seq, seq)
    attn_probs = F.softmax(attn_score, dim=-1)
    attn_probs = F.dropout(attn_probs, 0.1)
    # (bs, head, seq, hs // head)
    attn = torch.matmul(attn_probs, v)
    return attn

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_size, n_heads):
        super().__init__()
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.n_heads = n_heads

    def permute_for_scores(self, x):
        # x: (batch_size, seq_len, hidden_size)
        new_shape = x.shape[:-1] + (self.n_heads, -1)
        x = x.view(new_shape)
        # output: (bs, head, seq, hs // head)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        # hidden_states: (batch_size, seq_len, hidden_size)
        # qkv layers
        q = self.permute_for_scores(self.q_proj(hidden_states))
        k = self.permute_for_scores(self.k_proj(hidden_states))
        v = self.permute_for_scores(self.v_proj(hidden_states))
        # core attention
        output = scaled_dot_product(q, k, v)
        # output: (bs, seq, head, hs // head)
        output.permute(0, 2, 1, 3)
        output.view(output.shape[0], output.shape[1], -1)
        return output

### Attention Layer

In [None]:
class Projection(nn.Module):
    def __init__(self, intermediate_size, hidden_size, p=0.1):
        super().__init__()
        self.dense = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(p)
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layer_norm(hidden_states + input_tensor)
        return hidden_states


class Attention(nn.Module):
    def __init__(self, hidden_size, intermediate_size, n_heads):
        super().__init__()
        self.self_attn = SelfAttention(hidden_size, n_heads)
        self.proj = Projection(hidden_size, hidden_size)

    def forward(self, hidden_states):
        self_output = self.self_attn(hidden_states)
        attention_output = self.proj(self_output, hidden_states)
        return attention_output

In [None]:
class FFN(nn.Module):
    """Feed forward network (FFN) with GELU activation"""
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.activation = nn.GELU()
        self.projection = Projection(intermediate_size, hidden_size)

    def forward(self, data):
        out = self.linear1(data)
        out = self.activation(out)
        out = self.projection(out)
        return out


class TransformerLayer(nn.Module):
    def __init__(self, hidden_size, intermediate_size, n_heads):
        super().__init__()
        self.attention = Attention(hidden_size, intermediate_size, n_heads)
        self.ffn = FFN(hidden_size, intermediate_size)

    def forward(self, hidden_states):
        attention_output = self.attention(hidden_states)
        ffn_output = self.ffn(attention_output)
        return ffn_output

In [None]:
transformer_layer = TransformerLayer(hidden_size=768, intermediate_size=3072, n_heads=12)
print(transformer_layer)

In [None]:
from transformers import AutoConfig, BertLMHeadModel

config = AutoConfig.from_pretrained("bert-base-uncased")
bert_model = BertLMHeadModel(config)
print(config)
print(bert_model)

In [None]:
print(bert_model.bert.encoder.layer[0])

### Training

In [None]:
def train(model, device="cuda", bs=8, seq_length=512, steps=40):
    # data preparation
    input_ids = torch.ones(bs, seq_length, dtype=torch.long, device=device)
    attention_mask = torch.ones(bs, seq_length, dtype=torch.float32, device=device)
    token_type_ids = torch.ones(bs, seq_length, dtype=torch.long, device=device)
    labels = input_ids.clone()
    # model preparation
    model.to(device)
    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    # training loop
    for step in range(steps):
        inputs = (input_ids, attention_mask, token_type_ids)
        loss = model(*inputs, labels=labels).loss
        loss.backward()
        optimizer.step()
        writer.add_scalar("Loss/train", loss, step)

        if step % 10 == 0:
            print(f"step {step} loss: {loss.item()}")

In [None]:
%tensorboard --logdir logs
train(bert_model)

## TorchScript

PyTorch uses **dynamic graph** representation (**eager mode** / define-by-run), which means the graph is built on-the-fly.


> 💡 **Graph mode** / define-and-run: TensorFlow, Caffe

![](https://github.com/pytorch/pytorch/raw/master/docs/source/_static/img/dynamic_graph.gif)

We need some ways to capture the dynamic graph into a static graph so that we can conduct more optimizations.

### Just-in-Time (JIT) compilation

![](https://d3i71xaburhd42.cloudfront.net/e99921410790e1876a6089d039a960a8ea3b3f66/3-Figure1-1.png)

TorchScript
* First generation of PyTorch compiler
* Can support both **training and inference**
* Out-of-the-box optimiztaion tool

Two different modes:
* Tracing mode: `torch.jit.trace`
* Scripting mode: `torch.jit.script`

### Tracing Mode

Runs a model with certain inputs and "traces / records" all the operations that are executed into a graph.

We use the MLP example to illustrate.

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden_size, intermediate_size, hidden_act):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.hidden_act = hidden_act
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, data):
        out = self.linear1(data)
        if self.hidden_act == "gelu":
            out = F.gelu(out)
        else:
            out = F.relu(out)
        out = self.linear2(out)
        return out

In [None]:
device = f"cuda:{torch.cuda.device_count() - 1}"
inp = torch.rand((16, 512, 768)).to(device) # (bs, seq, hs)
mlp = MLP(768, 3072, "gelu").to(device)
traced_mlp = torch.jit.trace(mlp, (inp,))
print(traced_mlp)

The above is the **structural representation** that describes the module hierarchy. We can check the class type of the traced module.

In [None]:
print(type(traced_mlp), isinstance(traced_mlp, nn.Module))

We can print out the **graph representation** of the traced module, the intermediate representation (IR) mostly follows LLVM's convention.

* Graph: Similar to `llvm::Function`
* Block: Only dataflow is inside a block
* Node: Instruction
    * Analogous to `mlir::Operation`
    * Can have nested blocks inside
    * e.g., `prim::GetAttr`, `prim::CallMethod`, `prim::Constant`, `aten::gelu`
* Value: Input arguments / Output results
    * The edges in the graph
    * Single-static assignment (SSA) form: Each value has precisely one defining node
    * e.g., `%x: type` (statically typed!)

You can refer to the implementation file [ir.h](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.h) for more details.

In [None]:
print(traced_mlp.graph)

We can even print out the executable Python code from the TorchScript IR.

In [None]:
print(traced_mlp.code)

### Scripting Mode

Parses the Python source code of the model, and compiles the code into a graph.
* A subset of Python grammar
* Has a Lexer and Parser that parse Python syntax directly
    * Useful to deploy to somewhere without Python environment (no need to link CPython)
    * It cannot catch up with the latest Python grammar, poor maintainability
    * Limits what can apply in the program

In [None]:
scripted_mlp = torch.jit.script(mlp)
print(scripted_mlp)

Or you can use the Python decorator
```python
# decorate a class
@torch.jit.script
class MLP(nn.Module):
    ...

# decorate a function
@torch.jit.script
def foo(x):
    ...
```

In [None]:
print(type(scripted_mlp), isinstance(scripted_mlp, nn.Module))

Control flow nodes: `prim::If` and `prim::Loop`
* Output of the if-statement serve a similar role to the $\Phi$ node in traditional SSA control-flow graphs
* Same as `mlir::affine::yield`

In [None]:
print(scripted_mlp.graph)

In [None]:
print(scripted_mlp.code)

### Case Study: Operator fusion

$g(f(x_1, \cdots, x_n)) = (g\circ f)(x_1, \cdots, x_n)$

* Reduce kernel launch overheads
* Keep intermediate results in register instead of writing back to memory
* TorchScript incorporates [NVFuser](https://pytorch.org/blog/introducing-nvfuser-a-deep-learning-compiler-for-pytorch/) as the backend fusion framework which is by default enabled

We want to fuse the linear bias add and the GELU operation, since both are element-wise operations.

> 💡 Do NOT use the standard Python library `timeit` or `time` to benchmark the PyTorch execution time on GPU. Otherwise `torch.cuda.synchronize()` is needed.

In [None]:
import torch.utils.benchmark as benchmark
torch.cuda.empty_cache() # clear cache
print(benchmark.Timer('mlp(inp)', globals={'mlp': mlp, 'inp': inp.detach().clone()}, label='Vanilla').timeit(1000))
torch.cuda.empty_cache()
print(benchmark.Timer('traced_mlp(inp)', globals={'traced_mlp': traced_mlp, 'inp': inp.detach().clone()}, label='Traced').timeit(1000))
torch.cuda.empty_cache()
print(benchmark.Timer('scripted_mlp(inp)', globals={'scripted_mlp': scripted_mlp, 'inp': inp.detach().clone()}, label='Scripted').timeit(1000))

Another attempt below:

In [None]:
# Need to modify somewhere...
class NewMLP(nn.Module):
    def __init__(self, hidden_size, intermediate_size, hidden_act):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.hidden_act = hidden_act
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, data):
        raise NotImplementedError
        out = self.linear1(data)
        if self.hidden_act == "gelu":
            out = F.gelu(out)
        else:
            out = F.relu(out)
        out = self.linear2(out)
        return out

In [None]:
mlp = NewMLP(768, 3072, "gelu").to(device)
traced_mlp = torch.jit.trace(mlp, (inp,))
scripted_mlp = torch.jit.script(mlp)
torch.cuda.empty_cache()
print(benchmark.Timer('mlp(inp)', globals={'mlp': mlp, 'inp': inp.detach().clone()}, label='Vanilla').timeit(1000))
torch.cuda.empty_cache()
print(benchmark.Timer('traced_mlp(inp)', globals={'traced_mlp': traced_mlp, 'inp': inp.detach().clone()}, label='Traced').timeit(1000))
torch.cuda.empty_cache()
print(benchmark.Timer('scripted_mlp(inp)', globals={'scripted_mlp': scripted_mlp, 'inp': inp.detach().clone()}, label='Scripted').timeit(1000))

In [None]:
print(traced_mlp.graph)

In [None]:
print(traced_mlp.graph_for(inp))

In [None]:
print(scripted_mlp.graph_for(inp))

> 💡 Prefer scripting a whole module rather than scripting a function, since scripting function only includes the forward pass.
>
> Check the implementation of Megatron-LM fused kernel: https://github.com/NVIDIA/Megatron-LM/blob/master/megatron/model/fused_bias_gelu.py

### Limitation

> [JIT should not force users to write ugly code](https://github.com/pytorch/pytorch/issues/48108)

* Generalization problem:
    * Dynamic control flow: It is depended on the input data of the forward function
    * Capture variables as constants (e.g., Dropout)
* Only use basic syntax of Python: no/few custom structures, no builtins, no inheritance, no `Union`, no `**kwargs`, no `lambda`, no dynamic types, etc.

1. Dynamic control flow

In [None]:
def f(x):
    return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
m = torch.jit.trace(f, torch.tensor(3))
print(m.code)

2. Coding style not supported

In [None]:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neo/modeling_gpt_neo.py#L269-L277
def foo(hidden_state, layer_past=None, attention_mask=None):
    # do something
    # ...
    return hidden_state, layer_past, attention_mask

In [None]:
traced_foo = torch.jit.trace(foo, (inp, None, inp))

In [None]:
scripted_foo = torch.jit.script(foo)
print(scripted_foo.graph)

In [None]:
scripted_foo(inp, None, inp)

### Takeaway

* While optimization is done by a push button, code quality is the cost of scriptability and tracability.
* No transparency on the optimizations. Compiler passes make code complicated and hard to debug.

<!-- https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py -->

## torch.fx

> James K. Reed, Zachary DeVito, Horace He, Ansley Ussery, Jason Ansel, *[Torch.fx: Practical Program Capture and Transformation for Deep Learning in Python](https://arxiv.org/abs/2112.08429)*, MLSys, 2022.

In [None]:
from torch import fx

### Design principles

* Prefer making program capture and transformation easy for typical models at the cost of working for all possible programs. **Avoid complexity to support longtail**, esoteric use cases.
* Work with tools and concepts that ML practitioners are already familiar with such as Python data structures and the publicly documented operators in PyTorch. (**Fully Pythonic**)
* Make the process of program capture **highly configurable** so users can implement their own solutions for long-tail uses. Allowing users to make one-off configurations is simpler than handling the general case.

> PyTorch is primarily used as an **eager execution** framework and program capture is only used for some specific transforms; It does not need to work for an entire program.
> * [TorchDynamo](https://pytorch.org/docs/master/dynamo/): Only capture those can be captured and leave the rest to the Python native runtime

### Symbolic tracing

Use **abstract values (Proxy)** rather than example inputs.

The static control flow is directly eliminated.

In [None]:
fx_traced_mlp = fx.symbolic_trace(mlp)
print(fx_traced_mlp)

To print out the graph IR, use `.graph`

In [None]:
print(fx_traced_mlp.graph)

`%name: [#users=x] = <node_type>[target=mod_or_func_name](args = (%x1,), kwargs = {...})`

| Node type | Description |
| :--: | :-- |
| placeholder | input |
| call_module | call a sub-`nn.Module` |
| call_function | call a Python or PyTorch internal function (e.g., `operator.xxx`, `nn.functional.xxx`) |
| call_method | call a class method |
| get_attr | get a class attribute (e.g., parameter) |
| output | return |

* No primitive operations
* `args` and `kwargs` support immediate values that are natively supported in Python
* IR is much simpler

![](https://d3i71xaburhd42.cloudfront.net/febc8c8018372f96867a7a56dc1b52cd682596c0/9-Figure5-1.png)

### Graph Traversal

In [None]:
for node in fx_traced_mlp.graph.nodes:
    print(node, node.op, node.target, node.args, node.kwargs)

In [None]:
fx_traced_mlp.graph.print_tabular()

### Graph Manipulation

#### Replace a function

In [None]:
for node in fx_traced_mlp.graph.nodes:
    if node.op == 'call_function' and node.target == F.gelu:
        node.target = F.relu
print(fx_traced_mlp.graph)

#### Replace a module

In [None]:
for node in fx_traced_mlp.graph.nodes:
    if node.op == 'call_module' and node.target == 'linear2': # string match
        fx_traced_mlp.register_module('new_linear2', nn.Linear(3072, 3072, bias=False).to(device)) # be careful with the device
        node.target = 'new_linear2'
        break
fx_traced_mlp.delete_all_unused_submodules()
print(fx_traced_mlp.graph)
# Need to recompile after modifying the graph
fx_traced_mlp.graph.lint()
fx_traced_mlp.recompile()
print(fx_traced_mlp)

#### Insert a node

An incorrect implementation:

In [None]:
for node in fx_traced_mlp.graph.nodes:
    if node.op == 'call_module' and node.target == 'linear2':
        with fx_traced_mlp.graph.inserting_after(node):
            new_node = fx_traced_mlp.graph.call_function(F.relu, args=(node,))
            node.replace_all_uses_with(new_node)
        break

In [None]:
for node in fx_traced_mlp.graph.nodes:
    if node.op == 'output':
        with fx_traced_mlp.graph.inserting_before(node):
            new_node = fx_traced_mlp.graph.call_function(F.relu, args=(node.args[0],))
            node.args = (new_node,)
        break
fx_traced_mlp.graph.lint()
fx_traced_mlp.recompile()
print(fx_traced_mlp)

Run code as usual

In [None]:
fx_traced_mlp(inp)

#### Shape Propagation

In [None]:
# https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py
from torch.fx.passes.shape_prop import ShapeProp

ShapeProp(fx_traced_mlp).propagate(inp)
for node in fx_traced_mlp.graph.nodes:
    print(node, node.meta['tensor_meta'].shape)

#### Visualization

Need `pydot` to be installed.

```python
from torch.fx.passes.graph_drawer import FxGraphDrawer

g = FxGraphDrawer(fx_traced_mlp, "MLP")
g.get_main_dot_graph().create_svg()
```

Check torch.fx [codebase](https://github.com/pytorch/pytorch/tree/master/torch/fx) to see more use cases. Also see fx [tutorial](https://pytorch.org/docs/stable/fx.html).

### Limitation

* Dynamic control flow

In [None]:
def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)

* Non-torch functions
    * Use `wrap` to specify *leaf functions* that you do not want to trace into
    * Similarly, [`Tracer`](https://github.com/pytorch/pytorch/blob/master/torch/fx/_symbolic_trace.py#L376) can be customized to have some *leaf_modules*.

In [None]:
from math import sqrt

def normalize(x):
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = fx.symbolic_trace(normalize)

In [None]:
torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)

* Full graph capturing

In [None]:
from transformers import AutoConfig, BertLMHeadModel
config = AutoConfig.from_pretrained('bert-base-uncased')
model = BertLMHeadModel(config)
print(model)

In [None]:
import torch.fx as fx
fx.symbolic_trace(model)

## Slapo

### Challenges of both methods

Compiler optimizations:
* [C1] Programmability: All or nothing. Follow specific coding styles.
* [C2] Debuggability: Hard to reason about the optimizations in a flattened optimized graph.

Manual optimizations:
* [C3] Generality: Need to modify the model definition or even rewrite the model.
* [C4] Tunability: Tune for different configurations.

### <u>S</u>chedule <u>LA</u>nguage for <u>P</u>rogressive <u>O</u>ptimization
1. Decouple model schedule from definition [C3]
2. Auto-tuner and auto-scheduler [C4]
3. Progressive optimization with a "trace-by-need" approach [C1]
4. Structure-preserved scheduling [C2]

In [None]:
import slapo

sch = slapo.create_schedule(model)

In [None]:
print(sch["bert.encoder.layer.0.attention"].mod)

In [None]:
subsch = sch["bert.encoder.layer.0.intermediate"]
print(subsch.mod)

### Operator Fusion (Dataflow Graph Transformation)

In [None]:
subsch["dense"].decompose()
print(subsch.mod)

In [None]:
subsch.trace(flatten=True)
print(subsch.mod)

In [None]:
def fusion_pattern(bias, output):
    return F.gelu(bias + output)
subgraph = subsch.find(fusion_pattern)
print(subgraph)

In [None]:
from slapo.pattern import call_module

# fuzzy matching
def fusion_pattern(bias, output):
    out = bias + output
    out = call_module(r"intermediate_.*", out)
    return out

subgraph = subsch.find(fusion_pattern)
print(subgraph)

In [None]:
subsch.fuse(subgraph)
print(subsch.mod)

**Exercise**: Given the following subschedule, try to find and fuse the pattern: BiasAdd+Dropout+ResidualAdd+LayerNorm

In [None]:
subsch_out = subsch["bert.encoder.layer.0.output"]
# Your implementation
# ...

### Module replacement (Structural Transformation)

In [None]:
new_linear = nn.Linear(3072, 3072, bias=False).to(device)
subsch["dense"].replace(new_linear)
print(subsch.mod)

Quantization can also be achieved by module replacement.

Check more primitives on Slapo's [documentation](https://awslabs.github.io/slapo/) webpage.

## PyTorch 2.0

Introduce the `torch.compile()` method to unify all the compilation techniques in the PyTorch ecosystem. See more on the [website](https://pytorch.org/get-started/pytorch-2.0/).

* TorchDynamo as frontend: Only capture those can be captured and leave the rest to the Python native runtime
* torch.fx as the mid-end IR
* TorchInductor as the backend (+Triton, OpenMP, TVM, etc.)