In [9]:
import torch
from IPython.display import Image

In [2]:
scaler = torch.cuda.amp.GradScaler()

## demo

In [13]:
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")

# with torch.autocast(device_type="cuda"):
with torch.cuda.amp.autocast():
    # torch.mm is on autocast's list of ops that should run in float16.
    # Inputs are float32, but the op runs in float16 and produces float16 output.
    # No manual casts are required.
    e_float16 = torch.mm(a_float32, b_float32)
    print('in autocast', e_float16.dtype, e_float16.device)
    # Also handles mixed input types
    f_float16 = torch.mm(d_float32, e_float16)
    print('in autocast', f_float16.dtype, e_float16.device)

# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float())
print('out autocast', g_float32.dtype, g_float32.device)

in autocast torch.float16 cuda:0
in autocast torch.float16 cuda:0
out autocast torch.float32 cuda:0


## loss scaling

```
# forward
with torch.cuda.amp.autocast():
    logits = model(batch_x)
    loss = criterion(logits, batch_y)

# backward
optimizer.zero_grad()
# loss scaling
scaler.scale(loss).backward()
```

- 针对的是 loss（loss scaling）
- gradients 的两个版本，见下图
    - fp16 gradients
    - fp32 gradients

In [4]:
Image(url='https://blog.paperspace.com/content/images/2022/05/image-16.png', width=400)