In [9]:
import torch
from IPython.display import Image

In [2]:
scaler = torch.cuda.amp.GradScaler()

## demo

In [13]:
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")

# with torch.autocast(device_type="cuda"):
with torch.cuda.amp.autocast():
    # torch.mm is on autocast's list of ops that should run in float16.
    # Inputs are float32, but the op runs in float16 and produces float16 output.
    # No manual casts are required.
    e_float16 = torch.mm(a_float32, b_float32)
    print('in autocast', e_float16.dtype, e_float16.device)
    # Also handles mixed input types
    f_float16 = torch.mm(d_float32, e_float16)
    print('in autocast', f_float16.dtype, e_float16.device)

# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float())
print('out autocast', g_float32.dtype, g_float32.device)

in autocast torch.float16 cuda:0
in autocast torch.float16 cuda:0
out autocast torch.float32 cuda:0


## basics

In [16]:
Image(url='../imgs/fp32-fp16.png', width=600)

- fp32 vs. fp16
- fp16 is fast and memory-efficient；
    - 更快的 compute throughout （8x）
    - 更高的 memory throughout (2x)
    - 更小的显存占用 (1/2x)
- fp32 offers precison and range benefits.
- 因此需要混合；
    - 需要 fp32 的场景：
        - reductions，exponentiation；
        - large + small：weight updates, reductions again;
            - 1+0.0001

In [21]:
# torch.float16
a = torch.cuda.HalfTensor(4096)
# 4096 * 16
a.fill_(16)
a.sum()

tensor(inf, device='cuda:0', dtype=torch.float16)

In [25]:
# torch.float32
b = torch.cuda.FloatTensor(4096)
# 4096 * 16
b.fill_(16)
b.sum()

tensor(65536., device='cuda:0')

In [26]:
para = torch.cuda.HalfTensor([1.])
update = torch.cuda.HalfTensor([.0001])
para + update

tensor([1.], device='cuda:0', dtype=torch.float16)

In [27]:
para = torch.cuda.FloatTensor([1.])
update = torch.cuda.FloatTensor([.0001])
para + update

tensor([1.0001], device='cuda:0')

In [28]:
Image(url='../imgs/amp_32_16.png', width=600)

## loss scaling

```
scaler = GradScaler()

# forward
with autocast():
    output = model(input)
    loss = loss_fn(output, target)

# backward
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
```

- 针对的是 loss（loss scaling）
- gradients 的两个版本，见下图
    - fp16 gradients
    - fp32 gradients

In [4]:
Image(url='https://blog.paperspace.com/content/images/2022/05/image-16.png', width=400)