# 1. 非混合精度训练

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
from tqdm import tqdm
import random
import triton
import triton.language as tl
from TritonAdam import TritonAdamW
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
torch.cuda.empty_cache()
triton.__version__

  from .autonotebook import tqdm as notebook_tqdm


'3.2.0'

## 加载模型
- 加载fp32的模型进行测试

In [9]:
model_path = '/mnt/workspace/mdy/models/Qwen2.5-0.5B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).cuda()
iters = 100
for p in model.parameters():
    break


## torch Adam

### 非Fused版本

In [3]:
optimizer = AdamW(model.parameters(), fused=False)
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
print(p) # 刷新再跑，p应该差不多
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)

100%|██████████| 100/100 [00:09<00:00, 10.86it/s]


Parameter containing:
tensor([[-0.1078,  0.1398,  0.0985,  ..., -0.0773, -0.0879, -0.0940],
        [ 0.0785, -0.0957,  0.0860,  ..., -0.1042,  0.1048,  0.0868],
        [-0.0911,  0.0809, -0.0942,  ...,  0.0930, -0.1123,  0.0945],
        ...,
        [ 0.1392, -0.1389,  0.1346,  ..., -0.1390, -0.1394,  0.1495],
        [ 0.1392, -0.1389,  0.1346,  ..., -0.1390, -0.1394,  0.1495],
        [ 0.1392, -0.1389,  0.1346,  ..., -0.1390, -0.1394,  0.1495]],
       device='cuda:0', requires_grad=True)
25.367843627929688


### Fused版本

In [6]:
optimizer = AdamW(model.parameters(), fused=True)
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
print(p) # 刷新再跑，p应该差不多
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)

100%|██████████| 100/100 [00:06<00:00, 14.88it/s]


Parameter containing:
tensor([[-0.1075,  0.1400,  0.0989,  ..., -0.0775, -0.0883, -0.0943],
        [ 0.0791, -0.0955,  0.0848,  ..., -0.1041,  0.1041,  0.0858],
        [-0.0751,  0.0809, -0.0941,  ...,  0.0931, -0.1123,  0.0944],
        ...,
        [ 0.1393, -0.1390,  0.1346,  ..., -0.1388, -0.1392,  0.1495],
        [ 0.1393, -0.1390,  0.1346,  ..., -0.1388, -0.1392,  0.1495],
        [ 0.1393, -0.1390,  0.1346,  ..., -0.1388, -0.1392,  0.1495]],
       device='cuda:0', requires_grad=True)
8.984747886657715


## Triton Adam

### 全部fp32

In [8]:
optimizer = TritonAdamW(model.parameters())
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
print(p) # 刷新再跑，p应该差不多
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)

  0%|          | 0/100 [00:00<?, ?it/s]

finish_custom_init, p_dtype: torch.float32, master_p_dtype: None


100%|██████████| 100/100 [00:07<00:00, 12.86it/s]


Parameter containing:
tensor([[-0.1081,  0.1399,  0.0991,  ..., -0.0780, -0.0882, -0.0937],
        [ 0.0789, -0.0957,  0.0858,  ..., -0.1044,  0.1044,  0.0871],
        [-0.0793,  0.0809, -0.0939,  ...,  0.0929, -0.1125,  0.0944],
        ...,
        [ 0.1391, -0.1389,  0.1348,  ..., -0.1391, -0.1395,  0.1495],
        [ 0.1391, -0.1389,  0.1348,  ..., -0.1391, -0.1395,  0.1495],
        [ 0.1391, -0.1389,  0.1348,  ..., -0.1391, -0.1395,  0.1495]],
       device='cuda:0', requires_grad=True)
8.832446098327637


### 1阶2阶动量为bf16

In [10]:
optimizer = TritonAdamW(model.parameters(), exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16)
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
print(p) # 刷新再跑，p应该差不多
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)


  3%|▎         | 3/100 [00:00<00:08, 12.00it/s]

finish_custom_init, p_dtype: torch.float32, master_p_dtype: None


  5%|▌         | 5/100 [00:00<00:07, 13.55it/s]

100%|██████████| 100/100 [00:06<00:00, 15.30it/s]


Parameter containing:
tensor([[-0.1063,  0.1401,  0.0998,  ..., -0.0788, -0.0872, -0.0942],
        [ 0.0785, -0.0951,  0.0865,  ..., -0.1053,  0.1029,  0.0856],
        [-0.0803,  0.0811, -0.0938,  ...,  0.0920, -0.1115,  0.0943],
        ...,
        [ 0.1394, -0.1390,  0.1343,  ..., -0.1392, -0.1395,  0.1501],
        [ 0.1394, -0.1390,  0.1343,  ..., -0.1392, -0.1395,  0.1501],
        [ 0.1394, -0.1390,  0.1343,  ..., -0.1392, -0.1395,  0.1501]],
       device='cuda:0', requires_grad=True)
7.0299811363220215


# 2. 混合精度训练

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import random
import triton
import triton.language as tl
from TritonAdam import TritonAdamW
from apex.optimizers import FusedAdam as ApexFusedAdam
from transformer_engine.pytorch.optimizers import FusedAdam as TEFusedAdam
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
torch.cuda.empty_cache()
import torch._inductor.runtime.hints

  from .autonotebook import tqdm as notebook_tqdm


## 加载模型
- 加载bf16的模型

In [3]:
model_path = '/mnt/workspace/mdy/models/Qwen2.5-0.5B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()
iters = 100
for p in model.parameters():
    break


## Apex
- apex只有标准的混合精度训练，fp32的master weight和1，2阶动量，bf16/fp16的model weight和grad

In [4]:
optimizer = ApexFusedAdam(model.parameters(), capturable=True, master_weights=True)
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
    break
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


13.046334266662598


## Transformer Engine
- 最新的te，支持多种精度（可以点进去看下），比如1，2阶动量支持fp16，int8之类的，但都需要进行scale，但是不支持bf16

### 标准版

In [5]:
optimizer = TEFusedAdam(model.parameters(), master_weights=True)
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
    break
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)


  0%|          | 0/100 [00:00<?, ?it/s]


9.840949058532715


### fp16的1，2阶动量

In [6]:
optimizer = TEFusedAdam(model.parameters(), 
                        exp_avg_dtype=torch.float16,
                        exp_avg_sq_dtype=torch.float16,
                        master_weights=True)
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
    break
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


64.45989990234375


### fp16的1，2阶动量 + fp32的grad

In [7]:
optimizer = TEFusedAdam(model.parameters(), 
                        exp_avg_dtype=torch.float16,
                        exp_avg_sq_dtype=torch.float16,
                        use_decoupled_grad=True,
                        master_weights=True)
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
    break

# grad必须和param的精度是一样的，如果是bf16的p，使用fp32的g，那么就需要使用其它属性进行存储
for p in model.parameters():
    p.decoupled_grad = p.grad.float()
    p.grad = None
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)


  0%|          | 0/100 [00:00<?, ?it/s]


64.99529266357422


## Triton Adam
- 基本就是对着TE中的进行写的，接口基本都差不多，无缝衔接Megatron框架
- 目前支持master weight是fp32，model weight bf16， grad fp32 和 bf16都可以，1，2阶动量bf16或者fp32。无多余功能，基本满足训练需求

### 标准版

In [8]:
optimizer = TritonAdamW(model.parameters(), master_weights=True)
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
    break
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)


  0%|          | 0/100 [00:00<?, ?it/s]

finish_custom_init, p_dtype: torch.bfloat16, master_p_dtype: torch.float32


  0%|          | 0/100 [00:01<?, ?it/s]


9.474020957946777


### bf16的1，2阶动量

In [9]:
optimizer = TritonAdamW(model.parameters(), 
                        exp_avg_dtype=torch.bfloat16,
                        exp_avg_sq_dtype=torch.bfloat16,
                        master_weights=True)
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
    break
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)


  0%|          | 0/100 [00:00<?, ?it/s]

finish_custom_init, p_dtype: torch.bfloat16, master_p_dtype: torch.float32


  0%|          | 0/100 [00:00<?, ?it/s]


7.561221122741699


### bf16的1，2阶动量 + fp32的grad
- 这个就是deepseekv3中的配置
- 在megatron中，它会使用一个fp32的grad buffer去存储梯度
- 所有micro batch的梯度都加到这个buffer中，通过hook实现，下面是伪代码
- grad_buffer += p.grad
- p.grad = None
- 当所有micro batch都计算完后
- optimizer.model_p.decoupled_grad = grad_buffer

In [10]:
optimizer = TritonAdamW(model.parameters(), 
                        exp_avg_dtype=torch.bfloat16,
                        exp_avg_sq_dtype=torch.bfloat16,
                        use_decoupled_grad=True,
                        master_weights=True)
torch.cuda.empty_cache()
inp_ids = torch.arange(128).reshape(4,-1).cuda()
for _ in tqdm(range(iters)):
    out = model(inp_ids)
    out.logits.mean().backward()
    optimizer.step()
    break

# grad必须和param的精度是一样的，如果是bf16的p，使用fp32的g，那么就需要使用其它属性进行存储
for p in model.parameters():
    p.decoupled_grad = p.grad.float()
    p.grad = None
ms = triton.testing.do_bench(lambda: optimizer.step(), rep=1000)
print(ms)


  0%|          | 0/100 [00:00<?, ?it/s]


finish_custom_init, p_dtype: torch.bfloat16, master_p_dtype: torch.float32
7.5503644943237305
