In [6]:
import torch
import torch.nn as nn
from contextlib import contextmanager
import torch, functools
import torch.nn.functional as F


In [3]:
class ToyModel(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 10, bias=False)
        self.ln = nn.LayerNorm(10)
        self.fc2 = nn.Linear(10, out_features, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.ln(x)
        x = self.fc2(x)
        return x

In [4]:
def add_dtype_hooks(model):
    """Attach a hook to every leaf module that prints input/output dtypes."""
    hooks = []
    for name, mod in model.named_modules():
        if len(list(mod.children())) == 0:                                # leaf
            def _hook(mod, inp, out, n=name):
                in_types  = [t.dtype for t in inp if torch.is_tensor(t)]
                out_types = out.dtype if torch.is_tensor(out) else [
                             t.dtype for t in out if torch.is_tensor(t)]
                print(f"{n:>30}  in={in_types}  out={out_types}")
            hooks.append(mod.register_forward_hook(_hook))
    return hooks

In [9]:
in_features = 5
out_features = 4
device = "mps" # or cuda
model = ToyModel(in_features, out_features).to(device)
hooks = add_dtype_hooks(model)
dtype = torch.float16
x = torch.randn(1, in_features).to(device)
targets = torch.randn(1, out_features).to(device)
with torch.autocast(device_type=device, dtype=dtype):
    logits = model(x)
    loss   = F.cross_entropy(logits, targets)

print("logits:", logits.dtype)   
print("loss  :", loss.dtype)     
loss.backward()
print("grad fc1:", model.fc1.weight.grad.dtype)  

                           fc1  in=[torch.float32]  out=torch.float16
                          relu  in=[torch.float16]  out=torch.float16
                            ln  in=[torch.float16]  out=torch.float32
                           fc2  in=[torch.float32]  out=torch.float16
logits: torch.float16
loss  : torch.float32
grad fc1: torch.float32
