In [1]:
import torch
import numpy as np  

In [2]:
torch.manual_seed(0)
x = torch.randn(10, 4, requires_grad=True)
W = torch.randn(4, 4, requires_grad=True)
y = torch.randn(10, 4, requires_grad=True)

In [15]:
z = torch.maximum(torch.matmul(x, W), torch.tensor(0)) - y
f = torch.trace(torch.matmul(z.T, z))
print(f)

tensor(99.9048, grad_fn=<TraceBackward>)


In [16]:
f.backward()

In [21]:
print(W.grad)
print(x.grad)
print(y.grad)

tensor([[ 18.2980,   2.7573,   2.3914,  -0.1974],
        [ 11.0817,   6.6428,   2.5163, -20.3225],
        [ -8.6662,   3.4506,  -1.8979,  -3.3608],
        [-21.1681,  -6.6739,  -1.0693,  27.0278]])
tensor([[  1.1002,   0.0860,   5.3377,   0.2788],
        [  0.9583,  10.4633, -13.5234, -16.3639],
        [ -0.8712,  -0.9272,  -0.7764,   2.0790],
        [ -1.4504,   5.6914,   0.7613,  -0.9693],
        [ -1.2892,  -3.4714,  -1.9788,   4.8091],
        [ -4.0523,  -4.3127,  -3.6114,   9.6703],
        [ -0.7312,  -0.7782,  -0.6516,   1.7449],
        [ -0.8191,  -0.8718,  -0.7300,   1.9547],
        [  1.0350,   2.9930,  -6.6743,  -7.5333],
        [ -2.4616,  -2.4243,  -2.1164,   5.7128]])
tensor([[ 2.8885e+00,  4.1639e+00,  3.4134e+00,  3.0501e+00],
        [-1.0589e+01, -2.7045e+00, -2.1849e+00, -1.7039e-01],
        [ 6.5523e-01, -1.5214e+00, -3.1982e+00, -1.5687e+00],
        [-1.5009e+00, -3.8551e+00,  4.9843e-01,  1.2764e+00],
        [-6.6077e-03, -1.0689e+00,  1.8791e+00, -4

In [28]:
def d_relu(x):
    temp = torch.zeros_like(x)
    temp[x>0] = 1
    temp[x<=0] = 0
    return temp

In [29]:
d_re = d_relu(torch.matmul(x, W))
print(d_re)

tensor([[1., 0., 0., 1.],
        [1., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 1., 1., 0.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 1., 1.]])


In [34]:
hand_dw = 2 * torch.matmul(x.T, z * d_re)
hand_dx = torch.matmul(2 * z * d_re, W.T)
hand_dy = -2 * z
print(hand_dw, hand_dx, hand_dy)

tensor([[ 18.2980,   2.7573,   2.3914,  -0.1974],
        [ 11.0817,   6.6428,   2.5163, -20.3225],
        [ -8.6662,   3.4506,  -1.8979,  -3.3608],
        [-21.1681,  -6.6739,  -1.0693,  27.0278]], grad_fn=<MulBackward0>) tensor([[  1.1002,   0.0860,   5.3377,   0.2788],
        [  0.9583,  10.4633, -13.5234, -16.3639],
        [ -0.8712,  -0.9272,  -0.7764,   2.0790],
        [ -1.4504,   5.6914,   0.7613,  -0.9693],
        [ -1.2892,  -3.4714,  -1.9788,   4.8091],
        [ -4.0523,  -4.3127,  -3.6114,   9.6703],
        [ -0.7312,  -0.7782,  -0.6516,   1.7449],
        [ -0.8191,  -0.8718,  -0.7300,   1.9547],
        [  1.0350,   2.9930,  -6.6743,  -7.5333],
        [ -2.4616,  -2.4243,  -2.1164,   5.7128]], grad_fn=<MmBackward>) tensor([[ 2.8885e+00,  4.1639e+00,  3.4134e+00,  3.0501e+00],
        [-1.0589e+01, -2.7045e+00, -2.1849e+00, -1.7039e-01],
        [ 6.5523e-01, -1.5214e+00, -3.1982e+00, -1.5687e+00],
        [-1.5009e+00, -3.8551e+00,  4.9843e-01,  1.2764e+00],
    

In [36]:
print(torch.equal(W.grad, hand_dw))
print(torch.equal(x.grad, hand_dx))
print(torch.equal(y.grad, hand_dy))

True
True
True
