In [None]:
import torch
import torch.nn as nn

batch_size, hidden_size =4, 8

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.z1 = None
        self.w1 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.w2 = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, x):
        z1 = self.w1(x)
        print(z1)
        self.z1 = z1.clone().detach()
        z2 = self.w2(z1)
        return z2

In [None]:
model = Net()

def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params / 1e9:.1f}G\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(model))

In [2]:
from torch.optim import SGD

fp32_model= Net().to("cuda")
lr = 1e-0
optimizer = SGD(fp32_model.parameters(), lr=lr)
# print(lr)  #1.0

In [3]:
fp32_model.w1.weight

Parameter containing:
tensor([[-0.1316,  0.0437,  0.3298, -0.1907,  0.2609,  0.2583,  0.1089,  0.3062],
        [-0.0842,  0.2763, -0.3430, -0.1335, -0.0554,  0.2907,  0.1734,  0.0947],
        [ 0.0371, -0.0304,  0.1951, -0.3194, -0.3333, -0.0392, -0.0669, -0.1214],
        [-0.3199,  0.2473, -0.0599,  0.0925, -0.2743,  0.1447,  0.1477, -0.0938],
        [ 0.0281,  0.3230,  0.2939,  0.2553, -0.0854,  0.2849,  0.2011, -0.2587],
        [-0.1827, -0.0532, -0.0823,  0.2022,  0.0876,  0.1245,  0.3236,  0.2002],
        [ 0.1658,  0.2610,  0.0143, -0.1296,  0.1480,  0.2264,  0.0457,  0.3042],
        [-0.0736,  0.0768, -0.2696, -0.2455, -0.2354,  0.1226, -0.2516,  0.0230]],
       device='cuda:0', requires_grad=True)

In [13]:
fp32_model.w2.weight

Parameter containing:
tensor([[-4.8533, -1.6335, -3.1606, -2.5019, -2.7972,  1.6775, -2.6429, -2.4407]],
       device='cuda:0', requires_grad=True)

In [5]:
import torch

# example input sizes
#batch_size, hidden_size =4, 8

# create dummy data (bsz=4, hid=256)
x = torch.randn(batch_size,hidden_size, dtype=torch.float, device="cuda") 

# do forward
z2 = fp32_model(x)

# check dtypr of output logits
f"logits type = {z2.dtype}"

tensor([[-0.6157,  1.1022, -0.7334,  0.6155,  0.1529,  0.6115,  0.2666,  0.0737],
        [-0.8183, -0.5411, -0.4788, -0.5367, -0.6478,  0.1682, -0.5512, -0.3209],
        [ 0.6724, -0.3931,  0.2992,  0.0016,  0.3233,  0.2226,  0.0595, -0.6364],
        [-0.6500, -0.2179, -0.4946, -0.4452,  0.0377,  0.3786, -0.2514, -0.6484]],
       device='cuda:0', grad_fn=<MmBackward0>)


'logits type = torch.float32'

In [6]:
# craete dummy data (bsz=4)
#y = torch.tensor([[1.9], [9.5], [0.9], [1.2]], dtype=torch.half, device="cuda") #batch_size =4
y = torch.tensor([[1.9], [9.5], [0.9], [1.2]], dtype=torch.float32, device="cuda") #batch_size =4
#y = torch.tensor([[1.9]], dtype=torch.float32, device="cuda")
#y = torch.tensor([[1.9], [0.5]], dtype=torch.float32, device="cuda")
# compute mean square error loss
L = torch.nn.functional.mse_loss(z2, y)

# check dtype of loss
f"loss type = {L.dtype}"

'loss type = torch.float32'

In [7]:
print(L)
print(z2)
print(y)
loss = torch.sum((z2-y)**2/batch_size)
print(loss)

tensor(28.0975, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor([[-0.0857],
        [-0.7735],
        [ 0.3088],
        [-0.3974]], device='cuda:0', grad_fn=<MmBackward0>)
tensor([[1.9000],
        [9.5000],
        [0.9000],
        [1.2000]], device='cuda:0')
tensor(28.0975, device='cuda:0', grad_fn=<SumBackward0>)


In [8]:
L.backward()
w2_weight = fp32_model.w2.weight.clone().detach()
w1_weight = fp32_model.w1.weight.clone().detach()
print(f'before: {fp32_model.w2.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w2.weight}\n')

before: Parameter containing:
tensor([[ 0.2819,  0.3416,  0.3337, -0.0009,  0.2527, -0.1617,  0.1073, -0.1597]],
       device='cuda:0', requires_grad=True)

after: Parameter containing:
tensor([[-4.8533, -1.6335, -3.1606, -2.5019, -2.7972,  1.6775, -2.6429, -2.4407]],
       device='cuda:0', requires_grad=True)



In [9]:
DL_Dz2= 2 * (z2 - y) / batch_size # DL/Dz2  (BWD-activation: layer2), in case of MSE
print(DL_Dz2) # [4,1] [output_size=1, batch_size] DL/Dz2
Dz2_Dw2 = fp32_model.z1.clone().detach() #Dz2/Dw2
print(Dz2_Dw2.shape) #[4,8] [batch_size, hidden_size] 
print(Dz2_Dw2)
#DL_Dw2 = DL_Dz2.T * Dz2_Dw2
DL_Dw2 = torch.matmul(DL_Dz2.T, Dz2_Dw2) #[1,4] * [4,8] batch_size, hidden_size
print(DL_Dw2.shape) # [1,8] [output_size=1, hidden_size]
print(DL_Dw2)
print(w2_weight) # [hidden_size=8 , output_size=1]
print(w2_weight - lr * DL_Dw2)

tensor([[-0.9929],
        [-5.1368],
        [-0.2956],
        [-0.7987]], device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([4, 8])
tensor([[-0.6157,  1.1022, -0.7334,  0.6155,  0.1529,  0.6115,  0.2666,  0.0737],
        [-0.8183, -0.5411, -0.4788, -0.5367, -0.6478,  0.1682, -0.5512, -0.3209],
        [ 0.6724, -0.3931,  0.2992,  0.0016,  0.3233,  0.2226,  0.0595, -0.6364],
        [-0.6500, -0.2179, -0.4946, -0.4452,  0.0377,  0.3786, -0.2514, -0.6484]],
       device='cuda:0')
torch.Size([1, 8])
tensor([[ 5.1351,  1.9751,  3.4944,  2.5010,  3.0499, -1.8392,  2.7502,  2.2809]],
       device='cuda:0', grad_fn=<MmBackward0>)
tensor([[ 0.2819,  0.3416,  0.3337, -0.0009,  0.2527, -0.1617,  0.1073, -0.1597]],
       device='cuda:0')
tensor([[-4.8533, -1.6335, -3.1606, -2.5019, -2.7972,  1.6775, -2.6429, -2.4407]],
       device='cuda:0', grad_fn=<SubBackward0>)


In [10]:
fp32_model.w1.weight # w1 = [hidden_size, hidden_size] [8,8]

Parameter containing:
tensor([[ 1.1200, -1.7049, -1.4670,  1.9101,  0.4296, -1.2090,  0.7812,  0.4647],
        [ 1.4326, -1.8428, -2.5207,  2.4125,  0.1490, -1.4875,  0.9882,  0.2869],
        [ 1.5189, -2.1006, -1.9323,  2.1679, -0.1337, -1.7764,  0.7291,  0.0664],
        [-0.3240,  0.2529, -0.0541,  0.0857, -0.2748,  0.1494,  0.1455, -0.0943],
        [ 1.1499, -1.2443, -1.3167,  2.1383,  0.0658, -1.0303,  0.8037, -0.1166],
        [-0.9005,  0.9496,  0.9482, -1.0026, -0.0091,  0.9660, -0.0619,  0.1092],
        [ 0.6421, -0.4044, -0.6695,  0.6699,  0.2122, -0.3320,  0.3015,  0.3645],
        [-0.7828,  1.0677,  0.7487, -1.4360, -0.3309,  0.9541, -0.6326, -0.0668]],
       device='cuda:0', requires_grad=True)

In [11]:
DL_Dz2= 2 * (z2 - y) / batch_size # DL/Dz2  (BWD-activation: layer2), in case of MSE
print(DL_Dz2) # [4,1] [batch_size=4, output_size=1]
print(w2_weight.shape) # [1,8]
temp = torch.matmul(DL_Dz2, w2_weight) #DL/Dz2 * w2
print(temp.shape) # [4,8]
print(x.shape) # [4,8]
DL_Dw1 = torch.matmul(temp.T, x) # [8,4] * [4,8] = [8,8]
print(DL_Dw1.shape) #[8,8]
print(w1_weight - lr * DL_Dw1)

tensor([[-0.9929],
        [-5.1368],
        [-0.2956],
        [-0.7987]], device='cuda:0', grad_fn=<DivBackward0>)
torch.Size([1, 8])
torch.Size([4, 8])
torch.Size([4, 8])
torch.Size([8, 8])
tensor([[ 1.1200, -1.7049, -1.4670,  1.9101,  0.4296, -1.2090,  0.7812,  0.4647],
        [ 1.4326, -1.8428, -2.5207,  2.4125,  0.1490, -1.4875,  0.9882,  0.2869],
        [ 1.5189, -2.1006, -1.9323,  2.1679, -0.1337, -1.7764,  0.7291,  0.0664],
        [-0.3240,  0.2529, -0.0541,  0.0857, -0.2748,  0.1494,  0.1455, -0.0943],
        [ 1.1499, -1.2443, -1.3167,  2.1383,  0.0658, -1.0303,  0.8037, -0.1166],
        [-0.9005,  0.9496,  0.9482, -1.0026, -0.0091,  0.9660, -0.0619,  0.1092],
        [ 0.6421, -0.4044, -0.6695,  0.6699,  0.2122, -0.3320,  0.3015,  0.3645],
        [-0.7828,  1.0677,  0.7487, -1.4360, -0.3309,  0.9541, -0.6326, -0.0668]],
       device='cuda:0', grad_fn=<SubBackward0>)


In [1]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(512, 512, bias=False)
        self.w2 = nn.Linear(512, 1, bias=False)
    
    def forward(self, x):
        z1 = self.w1(x)
        z2 = self.w2(z1)
        return z2

from torch.optim import SGD

fp32_model= Net().to("cuda")
optimizer = SGD(fp32_model.parameters(), lr=1e-2)


### Float2Half

In [2]:
fp16_model = Net().half().to("cuda")
fp16_model.load_state_dict(fp32_model.state_dict())

<All keys matched successfully>

### Forward

In [3]:
import torch

# example input sizes
batch_size, hidden_size = 4, 512

# create dummy data (bsz=4, hid=256)
x = torch.randn(batch_size,hidden_size, dtype=torch.half, device="cuda") 

# do forward
z2 = fp16_model(x)

# check dtypr of output logits
f"logits type = {z2.dtype}"

'logits type = torch.float16'

In [4]:
# craete dummy data (bsz=4)
y = torch.tensor([[1.9], [9.5], [0.9], [1.2]], dtype=torch.half, device="cuda")

# compute mean square error loss
L = torch.nn.functional.mse_loss(z2, y)

# check dtype of loss
f"loss type = {L.dtype}"

'loss type = torch.float16'

### Backward

In [5]:
# loss scaling
L *= 1024

# do backward
L.backward()

### Update Weight

In [6]:
print(f'before: {fp32_model.w1.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w1.weight}\n')

before: Parameter containing:
tensor([[ 0.0014, -0.0054, -0.0113,  ...,  0.0319, -0.0107, -0.0092],
        [-0.0171,  0.0104, -0.0103,  ..., -0.0259, -0.0431, -0.0075],
        [-0.0423, -0.0418,  0.0213,  ..., -0.0251,  0.0348,  0.0121],
        ...,
        [ 0.0361, -0.0298, -0.0226,  ..., -0.0069, -0.0387,  0.0304],
        [ 0.0351,  0.0207,  0.0058,  ..., -0.0041, -0.0299,  0.0108],
        [ 0.0368, -0.0269,  0.0004,  ..., -0.0361, -0.0273, -0.0195]],
       device='cuda:0', requires_grad=True)

after: Parameter containing:
tensor([[ 0.0014, -0.0054, -0.0113,  ...,  0.0319, -0.0107, -0.0092],
        [-0.0171,  0.0104, -0.0103,  ..., -0.0259, -0.0431, -0.0075],
        [-0.0423, -0.0418,  0.0213,  ..., -0.0251,  0.0348,  0.0121],
        ...,
        [ 0.0361, -0.0298, -0.0226,  ..., -0.0069, -0.0387,  0.0304],
        [ 0.0351,  0.0207,  0.0058,  ..., -0.0041, -0.0299,  0.0108],
        [ 0.0368, -0.0269,  0.0004,  ..., -0.0361, -0.0273, -0.0195]],
       device='cuda:0', requ

In [7]:
#print(f'before: {fp16_model.w1.weight}\n')
#optimizer.step()
#print(f'after: {fp16_model.w1.weight}\n')

In [8]:
# copy gradient to FP32 model
fp32_model.w1.weight.grad = fp16_model.w1.weight.grad.float()
fp32_model.w2.weight.grad = fp16_model.w2.weight.grad.float()

In [9]:
print(f'before: {fp32_model.w1.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w1.weight}\n')

before: Parameter containing:
tensor([[ 0.0014, -0.0054, -0.0113,  ...,  0.0319, -0.0107, -0.0092],
        [-0.0171,  0.0104, -0.0103,  ..., -0.0259, -0.0431, -0.0075],
        [-0.0423, -0.0418,  0.0213,  ..., -0.0251,  0.0348,  0.0121],
        ...,
        [ 0.0361, -0.0298, -0.0226,  ..., -0.0069, -0.0387,  0.0304],
        [ 0.0351,  0.0207,  0.0058,  ..., -0.0041, -0.0299,  0.0108],
        [ 0.0368, -0.0269,  0.0004,  ..., -0.0361, -0.0273, -0.0195]],
       device='cuda:0', requires_grad=True)

after: Parameter containing:
tensor([[-0.1877,  0.5187, -0.0540,  ...,  0.0834, -0.7445,  0.2320],
        [ 0.4882, -1.3909,  0.1039,  ..., -0.1635,  1.9181, -0.6525],
        [-0.2927,  0.6520, -0.0352,  ...,  0.0430, -0.9364,  0.3314],
        ...,
        [-1.0620,  3.0127, -0.2704,  ...,  0.2918, -4.3012,  1.4316],
        [ 0.2854, -0.6731,  0.0623,  ..., -0.0722,  0.9413, -0.3086],
        [ 0.8568, -2.3006,  0.1856,  ..., -0.2593,  3.1552, -1.0664]],
       device='cuda:0', requ

In [15]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(512, 512, bias=False)
        self.w2 = nn.Linear(512, 1, bias=False)
    
    def forward(self, x):
        z1 = self.w1(x)
        z2 = self.w2(z1)
        return z2

from torch.optim import SGD

fp32_model= Net().to("cuda")
optimizer = SGD(fp32_model.parameters(), lr=1e-2)
#optimizer = SGD(fp32_model.parameters(), lr=1e-0)

### Float2Half
fp16_model = Net().half().to("cuda")
fp16_model.load_state_dict(fp32_model.state_dict())

### Forward
import torch

# example input sizes
batch_size, hidden_size = 4, 512

# create dummy data (bsz=4, hid=256)
x = torch.randn(batch_size,hidden_size, dtype=torch.half, device="cuda") 

# do forward
z2 = fp16_model(x)

# check dtypr of output logits
f"logits type = {z2.dtype}"


# craete dummy data (bsz=4)
y = torch.tensor([[1.9], [9.5], [0.9], [1.2]], dtype=torch.half, device="cuda")

# compute mean square error loss
L = torch.nn.functional.mse_loss(z2, y)

# check dtype of loss
f"loss type = {L.dtype}"

### Backward
# loss scaling
#L *= 1024

# do backward
L.backward()

print(f'fp32 grad: {fp32_model.w1.weight.grad}\n')
### Update Weight
print(f'before: {fp32_model.w1.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w1.weight}\n')
print(f'fp32 grad: {fp32_model.w1.weight.grad}\n')


print(f'f16 grad: {fp16_model.w1.weight.grad}\n')

# copy gradient to FP32 model
fp32_model.w1.weight.grad = fp16_model.w1.weight.grad.float()
fp32_model.w2.weight.grad = fp16_model.w2.weight.grad.float()

print(f'before: {fp32_model.w1.weight}\n')
optimizer.step()
print(f'after: {fp32_model.w1.weight}\n')

"""
print(f'before: {fp16_model.w1.weight}\n')
print(fp16_model.w1.weight.grad)
optimizer.step()
print(fp16_model.w1.weight.grad)
print(f'after: {fp16_model.w1.weight}\n')
"""

fp32 grad: None

before: Parameter containing:
tensor([[-0.0324,  0.0434, -0.0429,  ...,  0.0384, -0.0291, -0.0340],
        [ 0.0010, -0.0153,  0.0097,  ...,  0.0374, -0.0176, -0.0270],
        [-0.0366, -0.0296,  0.0203,  ..., -0.0179, -0.0196, -0.0239],
        ...,
        [-0.0436, -0.0131,  0.0017,  ..., -0.0360,  0.0416, -0.0143],
        [-0.0170,  0.0044, -0.0296,  ..., -0.0230,  0.0356, -0.0095],
        [ 0.0061,  0.0173,  0.0191,  ..., -0.0371,  0.0421,  0.0111]],
       device='cuda:0', requires_grad=True)

after: Parameter containing:
tensor([[-0.0324,  0.0434, -0.0429,  ...,  0.0384, -0.0291, -0.0340],
        [ 0.0010, -0.0153,  0.0097,  ...,  0.0374, -0.0176, -0.0270],
        [-0.0366, -0.0296,  0.0203,  ..., -0.0179, -0.0196, -0.0239],
        ...,
        [-0.0436, -0.0131,  0.0017,  ..., -0.0360,  0.0416, -0.0143],
        [-0.0170,  0.0044, -0.0296,  ..., -0.0230,  0.0356, -0.0095],
        [ 0.0061,  0.0173,  0.0191,  ..., -0.0371,  0.0421,  0.0111]],
       devi

"\nprint(f'before: {fp16_model.w1.weight}\n')\nprint(fp16_model.w1.weight.grad)\noptimizer.step()\nprint(fp16_model.w1.weight.grad)\nprint(f'after: {fp16_model.w1.weight}\n')\n"