In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
print("PyTorch version:", torch.__version__)

In [None]:
def scaled_dot_product(q, k, v):
    # (bs, head, seq, hs // head)
    d_k = q.shape[-1]
    attn_score = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(d_k)
    # (bs, head, seq, seq)
    attn_probs = F.softmax(attn_score, dim=-1)
    attn_probs = F.dropout(attn_probs, 0.1)
    # (bs, head, seq, hs // head)
    attn = torch.matmul(attn_probs, v)
    return attn


class SelfAttention(nn.Module):
    def __init__(self, hidden_size, n_heads):
        super().__init__()
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.n_heads = n_heads

    def permute_for_scores(self, x):
        # x: (batch_size, seq_len, hidden_size)
        new_shape = x.shape[:-1] + (self.n_heads, -1)
        x = x.view(new_shape)
        # output: (bs, head, seq, hs // head)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        # hidden_states: (batch_size, seq_len, hidden_size)
        # qkv layers
        q = self.permute_for_scores(self.q_proj(hidden_states))
        k = self.permute_for_scores(self.k_proj(hidden_states))
        v = self.permute_for_scores(self.v_proj(hidden_states))
        # core attention
        output = scaled_dot_product(q, k, v)
        # output: (bs, seq, head, hs // head)
        output.permute(0, 2, 1, 3)
        output.view(output.shape[0], output.shape[1], -1)
        return output


class Projection(nn.Module):
    def __init__(self, hidden_size, p=0.1):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(p)
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layer_norm(hidden_states + input_tensor)
        return hidden_states


class Attention(nn.Module):
    def __init__(self, hidden_size, n_heads):
        super().__init__()
        self.self_attn = SelfAttention(hidden_size, n_heads)
        self.proj = Projection(hidden_size)

    def forward(self, hidden_states):
        self_output = self.self_attn(hidden_states)
        attention_output = self.proj(self_output, hidden_states)
        return attention_output

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden_size, intermediate_size, hidden_act):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.hidden_act = hidden_act
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, data):
        out = self.linear1(data)
        if self.hidden_act == "gelu":
            out = F.gelu(out)
        else:
            out = F.relu(out)
        out = self.linear2(out)
        return out

## TorchScript

* First generation of PyTorch compiler
* Can support both **training and inference**
* Out-of-the-box optimiztaion tool

PyTorch is a dynamic graph execution framework, so we need to firstly construct the computation graph in order to conduct further optimizations.

Just-in-Time (JIT) compilation.

* Tracing mode: `torch.jit.trace`
* Scripting mode: `torch.jit.script`

### Tracing Mode
Runs a model with certain inputs and "traces / records" all the operations that are executed into a graph.

We use the MLP example to illustrate.

In [None]:
device = f"cuda:{torch.cuda.device_count() - 1}"
inp = torch.rand((16, 512, 768)).to(device) # (bs, seq, hs)
mlp = MLP(768, 3072, "gelu").to(device)
traced_mlp = torch.jit.trace(mlp, (inp,))
print(traced_mlp)

The above is the **structural representation** that describes the module hierarchy. We can check the class type of the traced module.

In [None]:
print(type(traced_mlp), isinstance(traced_mlp, nn.Module))

We can print out the **graph representation** of the traced module, the intermediate representation (IR) mostly follows LLVM's convention.

* Graph: Similar to `llvm::Function`
* Block: Only dataflow is inside a block
* Node: Instruction
    * Analogous to `mlir::Operation`
    * Can have nested blocks inside
    * e.g., `prim::GetAttr`, `prim::CallMethod`, `prim::Constant`, `aten::gelu`
* Value: Input arguments / Output results
    * The edges in the graph
    * Single-static assignment (SSA) form: Each value has precisely one defining node
    * e.g., `%x: type` (statically typed!)

You can refer to the implementation file [ir.h](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.h) for more details.

In [None]:
print(traced_mlp.graph)

We can even print out the executable Python code.

In [None]:
print(traced_mlp.code)

### Scripting Mode

Parses the Python source code of the model, and compiles the code into a graph.
* a subset of Python grammar
* Has a Lexer and Parser that parse Python syntax directly
    * Useful to deploy to somewhere without Python environment (no need to link CPython)
    * It cannot catch up the latest Python grammar, maintainability
    * Limits the programmer

In [None]:
scripted_mlp = torch.jit.script(mlp)
print(scripted_mlp)

In [None]:
print(type(scripted_mlp), isinstance(scripted_mlp, nn.Module))

Control flow nodes: `prim::If` and `prim::Loop`
* Output of the if-statement serve a similar role to the $\Phi$ node in traditional SSA control-flow graphs
* Same as `mlir::affine::yield`

In [None]:
print(scripted_mlp.graph)

In [None]:
print(scripted_mlp.code)

### Operator fusion

NVFuser

https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py

**code quality is the cost of scriptability**

No transparency

Note: do NOT use timeit or time in Python standard library

In [None]:
import torch.utils.benchmark as benchmark
print(benchmark.Timer('mlp(inp)', globals={'mlp': mlp, 'inp': inp}, label='Vanilla').timeit(1000))
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(benchmark.Timer('traced_mlp(inp)', globals={'traced_mlp': traced_mlp, 'inp': inp.detach().clone()}, label='Traced').timeit(1000))
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(benchmark.Timer('scripted_mlp(inp)', globals={'scripted_mlp': scripted_mlp, 'inp': inp.detach().clone()}, label='Scripted').timeit(1000))
torch.cuda.empty_cache()
torch.cuda.synchronize()

In [None]:
class NewMLP(nn.Module):
    def __init__(self, hidden_size, intermediate_size, hidden_act):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(intermediate_size))
        self.hidden_act = hidden_act
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, data):
        out = self.linear1(data)
        out = out + self.bias
        if self.hidden_act == "gelu":
            out = F.gelu(out)
        else:
            out = F.relu(out)
        out = self.linear2(out)
        return out

In [None]:
mlp = NewMLP(768, 3072, "gelu").to(device)
traced_mlp = torch.jit.trace(mlp, (inp,))
scripted_mlp = torch.jit.script(mlp)
print(benchmark.Timer('mlp(inp)', globals={'mlp': mlp, 'inp': inp.detach().clone()}, label='Vanilla').timeit(1000))
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(benchmark.Timer('traced_mlp(inp)', globals={'traced_mlp': traced_mlp, 'inp': inp.detach().clone()}, label='Traced').timeit(1000))
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(benchmark.Timer('scripted_mlp(inp)', globals={'scripted_mlp': scripted_mlp, 'inp': inp.detach().clone()}, label='Scripted').timeit(1000))
torch.cuda.empty_cache()
torch.cuda.synchronize()

static control flow
does not related to data

### Limitation

[JIT should not force users to write ugly code](https://github.com/pytorch/pytorch/issues/48108)

Middle None argument

## torch.fx

MLSys paper (Horace He)

In [None]:
from torch import fx

In [None]:
print(fx.symbolic_trace(mlp))