In [3]:
import torch
import torch.nn as nn
import os

In [4]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()

        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features

        self.layers = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, out_features)
        )

    def forward(self, x):
        return self.layers(x)

In [6]:
def train_without_amp():
    in_features = os.environ['in_features']
    out_features = os.environ['out_features']
    hidden_features = os.environ['hidden_features']

    X = torch.randn((1000, in_features))
    y = torch.randint(0, out_features, (1000,))
    dataloader = torch.utils.data.DataLoader(list(zip(X, y)), batch_size=32, shuffle=True)

    model = Model(in_features, hidden_features, out_features)
    fp32_master_weights = [p.to(torch.float32).detach() for p in model.parameters()]
    model = model.half()

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(fp32_master_weights, lr=1e-3)

    n_epochs = os.environ['n_epochs']

    model.train()

    for epoch in range(n_epochs):
        for _, (batch, label) in enumerate(dataloader):
            optimizer.zero_grad()
            model.zero_grad()

            batch = batch.half()
            label = label.half()

            output = model(batch)
            loss = criterion(output)
            scaled_loss = loss * 8192. 
            scaled_loss.backward()

            for fp16_params, fp32_params in zip(model.parameters, fp32_master_weights):
                if fp32_params.grad is None:
                    fp32_params.grad = nn.Parameter(torch.empty_like(fp32_params))
                fp32_params.grad.data.copy_(fp16_params.grad.data)

            for fp32_params in fp32_master_weights:
                fp32_params.grad.data = fp32_params.grad.data / 8192. 

            optimizer.step()

            for fp16_params, fp32_params in zip(model.parameters, fp32_master_weights):
                fp16_params.data.copy_(fp32_params.data.half())

In [7]:
from torch.amp import GradScaler

In [8]:
def train_with_amp():
    in_features = os.environ['in_features']
    out_features = os.environ['out_features']
    hidden_features = os.environ['hidden_features']

    X = torch.randn((1000, in_features))
    y = torch.randint(0, out_features, (1000,))
    dataloader = torch.utils.data.DataLoader(list(zip(X, y)), batch_size=32, shuffle=True)

    model = Model(in_features, hidden_features, out_features)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = GradScaler()

    n_epochs = os.environ['n_epochs']
    model.train()

    for epoch in range(n_epochs):
        for _, (batch, label) in enumerate(dataloader):
            optimizer.zero_grad()

            with torch.amp.autocast(dtype=torch.float16):
                output = model(batch)
                loss = criterion(output, label)

            scaler.scale(loss)
            scaler.step(optimizer)
            scaler.update()