

中国软件技术大会

# BytelR: 迈向端到端的AI编译

字节跳动机器学习系统工程师刘渊强

2023.12

- 1. What is BytelR and advantages of BytelR
- 2. Design and technical details
- 3. LLM training example and performance

### What is BytelR





BYTEIR is our solution for framework-to-hardware compilation

# Al Compilation: NN graph to HW

#### Al Workloads

**AI** Compilation

**Graph Compiler (Frontend)** 

Internediate Representation (IR)

Codegen (Backend)

**Al Hardware** 





# BytelR Architecture



### Advantages of BytelR

1. Embrace open source, upstream first

Contribute to Ilvm, tensorflow, pytorch, torch-mlir, onnx-mlir, stablehlo

- 2. Well support for PyTorch, both inference and training
- 3. Friendly to new hardware (ASIC/NPU)
- 4. Flexible, extensible, high performance

1. What is BytelR and advantages of BytelR

2. Design and technical details

3. LLM training example and performance

### BytelR Torch Frontend – Torch MLIR



Github: https://github.com/llvm/torch-mlir

### Torch MLIR Lowering

#### torch

```
class Module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 20)
    def forward(self, x):
        return self.linear(x)
```

#### torch dialect

```
func.func @forward(%arg0: !torch.vtensor<[20, 10], f32>, %arg1: !torch.vtensor<[10, 20], f32>) - > !torch.vtensor<[20, 20], f32> {
    %0 = torch.aten.mm %arg0,
%arg1: !torch.vtensor<[20, 10],
f32>, !torch.vtensor<[10, 20], f32> - > !torch.vtensor<[20, 20], f32> return %0: !torch.vtensor<[20, 20], f32> }
```

#### torch dialect

```
func.func @forward(%arg0: !torch.tensor,
    %arg1: !torch.tensor) -> !torch.tensor {
        %0 = torch.aten.mm %arg0,
    %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor
        return %0 : !torch.tensor
}
```

torch.jit.scirpt

MLIR importer

alue semantilo

Lowering

Sharelathbe

#### stablehlo dialect

```
func.func @forward(%arg0: tensor<20x10xf32>,
    %arg1: tensor<10x20xf32>) ->
    tensor<20x20xf32> {
        %0 = "stablehlo.dot"(%arg0, %arg1):
    tensor<20x10xf32>, tensor<10x20xf32> ->
    tensor<20x20xf32>
        return %0: tensor<20x20xf32>
}
```



# BytelR Torch Frontend



Able to provide coarse-grained operators

#### Corner Cases

```
def forward(self, x):
x += 1
return x
```

Success, but no effect on input

```
def forward(self, x):
    y = torch.as_strided(x, size, stride)
    return y
```

```
def forward(self, x, y):
    x += y
    z = x.view(-1, 4)
    x += 1
    return x, z
```

Failed, really overwrite same memory

Don't support dynamic if/for/while

Failed, no abstraction of pointer/storage

# BytelR Compiler Overview

NVVM/LLVM

CUDA

Stablehlo/Mhlo Opt. passes Lowering Lowering Lowering Linalg (tensor) CAT\* Opt. passes ByRE **AITemplate Bufferization** Serialization backend for **Nvidia GPU** Linalg (memref) Opt. passes Lowering Opt. passes scf/affine/vec/... Hybrid Upstream **GPU Lowering C** Emitter Lowering BytelR **GPU LLVM** C IR or file **CUDA Emitter** Lowering

I ByteDance字节跳动

### Integrating AlTemplate

We introduce CAT (Composable Algebra Template) dialect



#### Mhlo-2-CAT Passes

#### Mhlo-2-CAT:

- Convert Mhlo ops to CAT ops (one CAT op corresponds to one AIT op)
- Eleminate redundant transpose/permuate ops



### Linalg Tiling and Fusion

```
func.func @fuse element(%arg0 ..., %arg1 ...)
Mhlo op:
               %0 = mhlo.some elemwise binary 1(%arg0, %arg1)
               %1 = mhlo.some elemwise binary 2(%0, %arg1)
               return %1
                              Linalg transformation
           func.func @fuse element(%arg0 ..., %arg1 ...)
Linalg op:
               %1 = linalg.generic {indexing maps = ...,
                       iterator types = ...} ins(%arg0, %arg1) outs... {
               ^bb0(%in0, %in1, %out)
                   %2 = arith.some elemwise binary 1(%in0, %in1)
                   %3 = arith.some elemwise binary 2(%2, %in1)
                   linalg.yield %3 ...
               return %1
```

### **ByetIR's Linalg Extension**

#### More ops:

- Alias, Diag, Scan, Scatter, Softmax, TopK
- support transformations of extended ops

#### Enhanced fusion transformations:

- producer-consumer & input-sharing fusion
- tiling along reduction axis correction
- intermediates as outputs within fusion
- intermediate tensor dim simplification
- map ops to generic ops conversion
- •

#### Other introduced transformations:

- Collapse dims transformation
- Fuse operands transformation

•



#### Benefits:

- Extreme IO-bound op fusion
- Lower overhead for fused ops (Exploit GPU DRAM bandwidth)





# BytelR Runtime (BRT) Overview



#### **BRT Interface for Hardware**

Provider

A collection of op implementation

• e.g., mm, d2h/h2d memcpy

Work Queue

Abstraction(like CUDAStream) for execution order

Allocator

Memory Allocate/Free

1. What is BytelR and advantages of BytelR

2. Design and technical details

3. LLM training example and performance

# BytelR LLM Training Compilation Pipeline



# Optimization in FW/BW Partition



One case of our partition strategy

### BytelR PyTorch Compile Example

```
from byteir import byteir compile fx
model = make model (model name)
# 1, compile with byteir
optimized model = torch.compile(model, backend=byteir compile fx)
# 2, execution as usual
data = make data(optimized model, model name, device)
model.zero grad(set to none=True)
with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16):
    # forward compile
    loss = compute loss(optimized model, data)
    # backward compile
    loss.backward()
```

### BytelR Runtime Example

```
import brt
session = brt.Session()
session.load(byre model path)
req = session.new request context(torch.cuda.current stream())
inputs, outputs = [], []
# init input/output data
for offset in session.get input arg offsets():
    inputs.append(torch.randn(session.get static shape(offset),
              dtype=dtype, device="cuda"))
    req.bind arg(offset, inputs[-1].data ptr())
req.finish io binding()
req.run()
req.sync()
```

#### Performance

- Flash Attention 2
- Elementwise Fusion
- AlTemplate
- Reduce Codegen



#### **Conclusion & Future Work**

We introduce BytelR: a framework-to-hardware compiler solution

- Friendly to PyTorch and GPU/ASIC
- PyTorch 2.0 training/inference demo on LLM

#### Future Work:

- Distributed support
- TensorCore MMA Codegen

Website: https://byteir.ai

Github: https://github.com/bytedance/byteir

# THANKS

https://byteir.ai

### Appendix 1: Reduce Op Optimization (Fusion)

Optimization 1: Fusing reduce op with producer ops



#### Appendix 2: Reduce Op Optimization (Tiling)

Optimization 2: Parallelizing reduce dimension



We use utilize our LinalgExt transformations to achieve best tiling efficiency