In [22]:
# Initial variables
import torch
import torch.nn.functional as F

torch.manual_seed(0)

X = torch.randn(10, 4, requires_grad=True)
W = torch.randn(4, 4, requires_grad=True)
Y = torch.randn(10, 4, requires_grad=True)

In [23]:
# The unfold f function
f = torch.trace((F.relu(X @ W) - Y).t() @ (F.relu(X @ W) - Y))
# gradient backward
f.backward()

print('===================')
print(f'W torch gradient: \n{W.grad}')
print('===================')
print(f'X torch gradient: \n{X.grad}')
print('===================')
print(f'Y torch gradient: \n{Y.grad}')
print('===================')

W torch gradient: 
tensor([[ 18.2980,   2.7573,   2.3914,  -0.1974],
        [ 11.0817,   6.6428,   2.5163, -20.3225],
        [ -8.6662,   3.4506,  -1.8979,  -3.3608],
        [-21.1681,  -6.6739,  -1.0693,  27.0278]])
X torch gradient: 
tensor([[  1.1002,   0.0860,   5.3377,   0.2788],
        [  0.9583,  10.4633, -13.5234, -16.3639],
        [ -0.8712,  -0.9272,  -0.7764,   2.0790],
        [ -1.4504,   5.6914,   0.7613,  -0.9693],
        [ -1.2892,  -3.4714,  -1.9788,   4.8091],
        [ -4.0523,  -4.3127,  -3.6114,   9.6703],
        [ -0.7312,  -0.7782,  -0.6516,   1.7449],
        [ -0.8191,  -0.8718,  -0.7300,   1.9547],
        [  1.0350,   2.9930,  -6.6743,  -7.5333],
        [ -2.4616,  -2.4243,  -2.1164,   5.7128]])
Y torch gradient: 
tensor([[ 2.8885e+00,  4.1639e+00,  3.4134e+00,  3.0501e+00],
        [-1.0589e+01, -2.7045e+00, -2.1849e+00, -1.7039e-01],
        [ 6.5523e-01, -1.5214e+00, -3.1982e+00, -1.5687e+00],
        [-1.5009e+00, -3.8551e+00,  4.9843e-01,  1.2764

Homework content:

![](homework_4.png)

推导:

![](partial_f_partial_W.png)

![](partial_f_partial_X.png)

![](partial_f_partial_Y.png)


In [24]:
# The matrix partial derivative

grad_W = 2 * X.t() @ ((F.relu(X @ W) - Y) * ((X @ W) > 0))
grad_X = 2 * ((F.relu(X @ W) - Y) * ((X @ W > 0))) @ W.t()
grad_Y = -2 * (F.relu(X @ W) - Y)

print('===================')
print(f'W torch gradient: \n{W.grad} \nW matrix partial derivative results: \n{grad_W}')
print(f'match results: {torch.allclose(W.grad, grad_W)}')
print('===================')
print(f'X torch gradient: \n{X.grad} \nW matrix partial derivative results: \n{grad_X}')
print(f'match results: {torch.allclose(X.grad, grad_X)}')
print('===================')
print(f'Y torch gradient: \n{Y.grad} \nW matrix partial derivative results: \n{grad_Y}')
print(f'match results: {torch.allclose(Y.grad, grad_Y)}')
print('===================')

W torch gradient: 
tensor([[ 18.2980,   2.7573,   2.3914,  -0.1974],
        [ 11.0817,   6.6428,   2.5163, -20.3225],
        [ -8.6662,   3.4506,  -1.8979,  -3.3608],
        [-21.1681,  -6.6739,  -1.0693,  27.0278]]) 
W matrix partial derivative results: 
tensor([[ 18.2980,   2.7573,   2.3914,  -0.1974],
        [ 11.0817,   6.6428,   2.5163, -20.3225],
        [ -8.6662,   3.4506,  -1.8979,  -3.3608],
        [-21.1681,  -6.6739,  -1.0693,  27.0278]], grad_fn=<MmBackward>)
match results: True
X torch gradient: 
tensor([[  1.1002,   0.0860,   5.3377,   0.2788],
        [  0.9583,  10.4633, -13.5234, -16.3639],
        [ -0.8712,  -0.9272,  -0.7764,   2.0790],
        [ -1.4504,   5.6914,   0.7613,  -0.9693],
        [ -1.2892,  -3.4714,  -1.9788,   4.8091],
        [ -4.0523,  -4.3127,  -3.6114,   9.6703],
        [ -0.7312,  -0.7782,  -0.6516,   1.7449],
        [ -0.8191,  -0.8718,  -0.7300,   1.9547],
        [  1.0350,   2.9930,  -6.6743,  -7.5333],
        [ -2.4616,  -2.4243, 