# Automatic Mixed Precision (AMP)

## 1. What AMP Does

AMP enables **mixed-precision training** by automatically choosing lower precision (`float16` or `bfloat16`) where it is safe, and full precision (`float32`) where necessary.
This reduces memory use and often increases training speed while preserving model accuracy.

---

## 2. What `torch.amp.autocast` Does

`autocast` intercepts operations during the forward pass and:

* **casts them to lower precision** when stable
* **keeps them in float32** when precision is important (e.g., softmax, batch-norm)

This gives the benefits of mixed precision **without manually rewriting operations**.

---

## 3. Basic Syntax

```python
from torch.amp import autocast

with autocast(device_type='cuda', dtype=torch.float16):
    output = model(input)
    loss = criterion(output, target)
```

or, if you want PyTorch to choose the safe precision automatically:

```python
with autocast('cuda'):
    ...
```

---

## 4. When to Use Autocast

AMP is applied **only around the forward pass**:

* model(input)
* loss computation

The backward pass is **never** inside autocast.

When using float16, the backward pass must be protected by **GradScaler**.

---

## 5. Dynamic AMP Training Loop 

Below is the clean, correct version of your script with explanations integrated.

```python
import torch
from torch.amp import autocast, GradScaler
from torch.utils.data import TensorDataset, DataLoader

# Decide the safest dtype automatically
# Modern GPUs (A100/H100/RTX40xx) prefer bfloat16
if torch.cuda.is_bf16_supported():
    dtype = torch.bfloat16
    print("dtype is bf16")
else:
    dtype = torch.float16
    print("dtype is f16")

device = "cuda" if torch.cuda.is_available() else "cpu"

# Simple model
model = torch.nn.Linear(10, 2).to(device)

# Dummy dataset
X = torch.randn(40, 10)
Y = torch.randint(0, 2, (40,))     # CrossEntropyLoss expects integer labels

dataset = TensorDataset(X, Y)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# GradScaler is needed only when dtype=float16
use_scaler = (dtype == torch.float16)
scaler = GradScaler() if use_scaler else None

epochs = 10

for epoch in range(epochs):
    model.train()

    for x, y in loader:
        optimizer.zero_grad()
        x = x.to(device)
        y = y.to(device)

        # Forward pass in mixed precision
        with autocast(device_type='cuda', dtype=dtype):
            output = model(x)
            loss = criterion(output, y)

        # Backward pass
        if use_scaler:                     # float16 only
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:                               # bfloat16 or float32
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch} finished")
```

---

## 6. Why Use `GradScaler` Only for float16?

* `float16` has a **very limited exponent range**, making underflow/overflow common.
  `GradScaler` multiplies the loss by a large factor to keep gradients numerically stable.

* `bfloat16` has the **same exponent range as float32**, so gradients almost never underflow.
  Therefore **no scaler is needed**.

Summary:

| dtype      | Use GradScaler? | Reason                       |
| ---------- | --------------- | ---------------------------- |
| `float16`  | Yes             | narrow dynamic range         |
| `bfloat16` | No              | float32-level exponent range |
| `float32`  | No              | full precision               |

---

## 7. Updated Explanation of the Key Steps

| Step  | What Happens                                     | Why It Matters                                    |
| ----- | ------------------------------------------------ | ------------------------------------------------- |
| **1** | `with autocast(device_type='cuda', dtype=dtype)` | Enables mixed precision only for forward/loss.    |
| **2** | Scaled backward (float16 only)                   | Protects against NaNs in gradients.               |
| **3** | `scaler.step(optimizer)`                         | Applies unscaled gradients to parameters.         |
| **4** | `scaler.update()`                                | Adjusts scaling factor dynamically for stability. |
| **5** | In bf16 mode: normal backward                    | bf16 does not need scaling.                       |

---

## 8. Best Practices

1. Use `bfloat16` whenever available on your GPU.
2. Use `GradScaler` only when training in float16.
3. Never wrap the backward pass inside autocast.
4. For specific layers that must run in float32, disable autocast manually:

   ```python
   with autocast(enabled=False):
       output = numerically_sensitive_layer(x)
   ```

---

## 9. Inference Example (No Scaler Needed)

```python
model.eval()
with torch.no_grad(), autocast('cuda', dtype=dtype):
    output = model(x)
```

---

## 10. CPU Autocast (Mostly for bf16)

```python
with autocast(device_type='cpu', dtype=torch.bfloat16):
    output = model(x)
```

---