In [6]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
device

device(type='cuda')

In [8]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(10, 16*1024)
        self.fc2 = nn.Linear(16*1024, 16*1024)
        self.fc3 = nn.Linear(16* 1024, 1)

    def forward(self, x):
        return self.fc3(self.fc2(self.fc1(x)))
    

In [9]:
model = SimpleModel().to(device)

In [None]:
master_params = [param.data.clone().float() for param in model.parameters()]

In [None]:
model.half()

In [None]:
class MasterParams(nn.Module):
    def __init__(self, master_params):
        super().__init__()

        for i, param in enumerate(master_params):
            self.register_parameter(f"param_{i}", nn.Parameter(param))



In [None]:
master_model = MasterParams(master_params)


In [None]:
optimizer = torch.optim.SGD(master_model.parameters(), lr=1e-3)

In [10]:
inputs = torch.randn(1024+512*3, 10, device=device)
targets = torch.randn(1024+512*3, 1, device=device)

In [None]:
S = 128.0

In [None]:
def train():
    for epoch in range(10):
    
        # copy weights from master model into fp16, with half precision
        for p_master, p_model in zip(master_model.parameters(), model.parameters()):
            p_model.data.copy_(p_master.data.float())
    
        optimizer.zero_grad()
    
        # fp16 forward
        outputs = model(inputs.half())
        loss = nn.MSELoss()(outputs, targets.half())
        scaled_loss = S * loss 
    
        scaled_loss.backward()
    
        # unscale grads
        for p in model.parameters():
            if p.grad is not None:
                p.grad.data.div_(S)
    
        for p_master, p_model in zip(master_model.parameters(), model.parameters()):
            if p_model.grad is not None:
                p_master.grad = p_model.grad.float()
        optimizer.step() # on unscaled grads
    
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")
    
    for p_master, p_model in zip(master_model.parameters(), model.parameters()):
        p_model.data.copy_(p_master.data.float())
    

In [None]:
%timeit -n 1 -r 2 train()

Epoch 1, Loss: 0.994140625
Epoch 2, Loss: 0.9755859375
Epoch 3, Loss: 0.974609375
Epoch 4, Loss: 0.974609375
Epoch 5, Loss: 0.974609375
Epoch 6, Loss: 0.974609375
Epoch 7, Loss: 0.974609375
Epoch 8, Loss: 0.974609375
Epoch 9, Loss: 0.974609375
Epoch 10, Loss: 0.974609375
Epoch 1, Loss: 0.974609375
Epoch 2, Loss: 0.974609375
Epoch 3, Loss: 0.974609375
Epoch 4, Loss: 0.974609375
Epoch 5, Loss: 0.974609375
Epoch 6, Loss: 0.974609375
Epoch 7, Loss: 0.974609375
Epoch 8, Loss: 0.974609375
Epoch 9, Loss: 0.974609375
Epoch 10, Loss: 0.974609375


2.12 s ± 1.61 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)


In [11]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [12]:
def train_full():
    for epoch in range(10):
    
        # # copy weights from master model into fp16, with half precision
        # for p_master, p_model in zip(master_model.parameters(), model.parameters()):
        #     p_model.data.copy_(p_master.data.float())
    
        optimizer.zero_grad()
    
        # fp16 forward
        outputs = model(inputs)
        loss = nn.MSELoss()(outputs, targets)
        # scaled_loss = S * loss 
    
        # scaled_loss.backward()
        loss.backward()
    
        # unscale grads
        # for p in model.parameters():
        #     if p.grad is not None:
        #         p.grad.data.div_(S)
    
        # for p_master, p_model in zip(master_model.parameters(), model.parameters()):
        #     if p_model.grad is not None:
        #         p_master.grad = p_model.grad.data.float()
        optimizer.step() # on unscaled grads
    
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")
    
    # for p_master, p_model in zip(master_model.parameters(), model.parameters()):
    #     p_model.data.copy_(p_master.data.float())
    

In [13]:
%timeit -n 1 -r 2 train_full()

Epoch 1, Loss: 1.0338482856750488
Epoch 2, Loss: 0.9935110211372375
Epoch 3, Loss: 0.9907287955284119
Epoch 4, Loss: 0.9905089735984802
Epoch 5, Loss: 0.9904901385307312
Epoch 6, Loss: 0.9904884696006775
Epoch 7, Loss: 0.9904883503913879
Epoch 8, Loss: 0.9904882311820984
Epoch 9, Loss: 0.9904882311820984
Epoch 10, Loss: 0.9904882311820984
Epoch 1, Loss: 0.9904883503913879
Epoch 2, Loss: 0.9904883503913879
Epoch 3, Loss: 0.9904883503913879
Epoch 4, Loss: 0.9904883503913879
Epoch 5, Loss: 0.9904883503913879
Epoch 6, Loss: 0.9904883503913879
Epoch 7, Loss: 0.9904882311820984
Epoch 8, Loss: 0.9904882311820984
Epoch 9, Loss: 0.9904882311820984
Epoch 10, Loss: 0.9904882311820984
11.2 s ± 116 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)


Epoch 1, Loss: 1.0338482856750488
Epoch 2, Loss: 0.9935110211372375
Epoch 3, Loss: 0.9907287955284119
Epoch 4, Loss: 0.9905089735984802
Epoch 5, Loss: 0.9904901385307312
Epoch 6, Loss: 0.9904884696006775
Epoch 7, Loss: 0.9904883503913879
Epoch 8, Loss: 0.9904882311820984
Epoch 9, Loss: 0.9904882311820984
Epoch 10, Loss: 0.9904882311820984
Epoch 1, Loss: 0.9904883503913879
Epoch 2, Loss: 0.9904883503913879
Epoch 3, Loss: 0.9904883503913879
Epoch 4, Loss: 0.9904883503913879
Epoch 5, Loss: 0.9904883503913879
Epoch 6, Loss: 0.9904883503913879
Epoch 7, Loss: 0.9904882311820984
Epoch 8, Loss: 0.9904882311820984
Epoch 9, Loss: 0.9904882311820984
Epoch 10, Loss: 0.9904882311820984


11.2 s ± 116 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)