Backward process of max operation

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(3, 4)
        self.layer2 = nn.Linear(4, 4)
        self.layer3 = nn.Linear(4, 2)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return self.layer3(x)

In [3]:
policy_net = DQN().to(device)
policy_net.eval()

DQN(
  (layer1): Linear(in_features=3, out_features=4, bias=True)
  (layer2): Linear(in_features=4, out_features=4, bias=True)
  (layer3): Linear(in_features=4, out_features=2, bias=True)
)

In [4]:
print(summary(policy_net, (1,3)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1, 4]              16
            Linear-2                 [-1, 1, 4]              20
            Linear-3                 [-1, 1, 2]              10
Total params: 46
Trainable params: 46
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
None


In [5]:
optimizer = optim.AdamW(policy_net.parameters(), lr=0.1, amsgrad=True)

In [7]:
input_tensor = torch.tensor([0.1,0.2,0.3])
output = policy_net(input_tensor)
y_hat = policy_net(input_tensor).max()

y_gt = torch.tensor(0.5)

criterion = nn.MSELoss()
loss = criterion(y_gt, y_hat)

print("before")
for param in policy_net.parameters():
  print(param.data)

loss.backward()
optimizer.step()

print("after")
for param in policy_net.parameters():
  print(param.data)


print("backward gradient")
print("layer1")
print(policy_net.layer1.weight.grad) 
print(policy_net.layer1.bias.grad)
print("layer2")
print(policy_net.layer2.weight.grad) 
print(policy_net.layer2.bias.grad)
print("layer3")
print(policy_net.layer3.weight.grad) 
print(policy_net.layer3.bias.grad)

before
tensor([[ 0.3824,  0.4544,  0.3511],
        [-0.3611, -0.3277,  0.5049],
        [ 0.0302, -0.0337,  0.4506],
        [-0.2599,  0.3161,  0.4811]])
tensor([ 0.2709, -0.3419,  0.1466, -0.1668])
tensor([[-0.1680,  0.1856, -0.1609,  0.2015],
        [ 0.1425,  0.1235,  0.2672, -0.3600],
        [-0.2259, -0.0553,  0.1915, -0.0637],
        [-0.4439, -0.2527,  0.1881,  0.2743]])
tensor([ 0.0328, -0.0654, -0.2330, -0.4830])
tensor([[ 0.1031, -0.2284,  0.1467, -0.3482],
        [ 0.4068,  0.3217,  0.0539, -0.2354]])
tensor([ 0.3963, -0.2409])
after
tensor([[ 0.2820,  0.3539,  0.2507],
        [-0.2609, -0.2274,  0.6043],
        [-0.0698, -0.1335,  0.3502],
        [-0.1596,  0.4158,  0.5806]])
tensor([ 0.1706, -0.2416,  0.0466, -0.0667])
tensor([[-0.0679,  0.0854, -0.0608,  0.1013],
        [ 0.0425,  0.2233,  0.1670, -0.2597],
        [-0.1258, -0.1552,  0.2912, -0.1637],
        [-0.5432, -0.1527,  0.0881,  0.3740]])
tensor([ 0.1326, -0.1652, -0.1329, -0.5821])
tensor([[ 0.0031, -