# 1. pytorch native amp

In [None]:
"""
pytorch native amp:
    source: https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html 
    -- auto mp: some op use float16 or bfloat16, some still 32 bits

Explanation:
0. Without amp/autocast ..., the default precision is 32 bits
1. Inside the `torch.autocast`, will do amp that automatically cast some operation to less bits
2. Usually used with gradient scalor that helps prevent gradients with small magnitudes from 
    flushing to zero (“underflowing”) when training with mixed precision.
"""

import torch

use_amp = True
model = make_model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
scaler = torch.amp.GradScaler("cuda" ,enabled=use_amp)  # can also use other scalor, eg from timm

for epoch in range(0): 
    for input, target in zip(data, targets):

        # (1) `device_type` param must be specified for `torch.autocastm` but not for `torch.cuda.amp.autocast`
        # (2) dtype param is the the target low-precision dtype; default is None, then go to device’s default autocast dtype
        # (3) No functional diff between `torch.autocast` (new, work on GPU, CPU, ...) 
        #       and `torch.cuda.amp.autocast` (old, CUDA-only) -> then don't need to specify device type
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = model(input)   # (usually?) output.dtype is torch.float16 
            loss = loss_fn(output, target)  # (usually?) loss.dtype is torch.float32  

        # Backward passes under autocast are not recommended.
        # Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()
        # ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.
        # If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)
        # Updates the scale for next iteration.
        scaler.update()
        optimizer.zero_grad()

In [None]:
"""
Another eg:
    `with torch.autocast(enabled=False)` doesn't mean skip operation, but means turn off
"""

with torch.autocast(enabled=True):
    'op1'   # with autocast
    with torch.autocast(enabled=False):
        'op2' # without autocast
    'op3'   # with autocast
'op4'   # without autocast
