In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
m = 10
d = 3
s = 4
k = 5
gamma = 0

In [9]:
X = torch.rand((m,d),requires_grad=True)
y = torch.rand((1,d),requires_grad=True)
U = torch.rand((m,s),requires_grad=True)
W = torch.rand((d,d),requires_grad=True)
X1 = X@W
y1 = y@W
X1.retain_grad()
y1.retain_grad()

z = X1.matmul(y1.T)
z.retain_grad()

z1,H = z.topk(k=k,dim=0)
z1 = z1
z1.retain_grad()
H1 = F.one_hot(H.squeeze(dim=-1),num_classes=m).float()

fz1 = nn.ReLU()(z1-gamma)
fz1.retain_grad()

z3 = fz1.expand(k,s).float().requires_grad_()
z3.retain_grad()

U = torch.rand((m,s),requires_grad=True)
R = H1.matmul(U)
R.retain_grad()

R1 = R*z3
R1.retain_grad()

loss = (R1**2).sum()
loss.backward()

In [10]:
dR1 = 2*R1
dR1, R1.grad, dR1 == R1.grad

(tensor([[3.6804, 4.6247, 4.9408, 3.7853],
         [5.1414, 4.8188, 0.2205, 3.6477],
         [3.6515, 3.9962, 3.0998, 3.3756],
         [1.5739, 4.1258, 0.2877, 3.1390],
         [2.6154, 0.2699, 0.0204, 1.0984]], grad_fn=<MulBackward0>),
 tensor([[3.6804, 4.6247, 4.9408, 3.7853],
         [5.1414, 4.8188, 0.2205, 3.6477],
         [3.6515, 3.9962, 3.0998, 3.3756],
         [1.5739, 4.1258, 0.2877, 3.1390],
         [2.6154, 0.2699, 0.0204, 1.0984]]),
 tensor([[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]]))

In [11]:
dR = dR1.view(1,k*s)@(z3.reshape(k*s).diag())
dR = dR.view(k,s)
dR, R.grad, dR == R.grad

(tensor([[10.3588, 13.0167, 13.9062, 10.6540],
         [13.8420, 12.9733,  0.5935,  9.8206],
         [ 9.4193, 10.3087,  7.9963,  8.7077],
         [ 4.0429, 10.5980,  0.7391,  8.0633],
         [ 6.2631,  0.6463,  0.0487,  2.6303]], grad_fn=<ViewBackward>),
 tensor([[10.3588, 13.0167, 13.9062, 10.6540],
         [13.8420, 12.9733,  0.5935,  9.8206],
         [ 9.4193, 10.3087,  7.9963,  8.7077],
         [ 4.0429, 10.5980,  0.7391,  8.0633],
         [ 6.2631,  0.6463,  0.0487,  2.6303]]),
 tensor([[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]]))

In [12]:
dz3 = dR1.view(1,k*s)@(R.reshape(k*s).diag())
dz3 = dz3.view(k,s)
dz3, z3.grad, dz3 == z3.grad

(tensor([[2.4063e+00, 3.7996e+00, 4.3366e+00, 2.5454e+00],
         [4.9094e+00, 4.3125e+00, 9.0263e-03, 2.4712e+00],
         [2.5843e+00, 3.0954e+00, 1.8624e+00, 2.2086e+00],
         [4.8217e-01, 3.3133e+00, 1.6114e-02, 1.9180e+00],
         [1.4283e+00, 1.5207e-02, 8.6469e-05, 2.5190e-01]],
        grad_fn=<ViewBackward>),
 tensor([[2.4063e+00, 3.7996e+00, 4.3366e+00, 2.5454e+00],
         [4.9094e+00, 4.3125e+00, 9.0263e-03, 2.4712e+00],
         [2.5843e+00, 3.0954e+00, 1.8624e+00, 2.2086e+00],
         [4.8217e-01, 3.3133e+00, 1.6114e-02, 1.9180e+00],
         [1.4283e+00, 1.5207e-02, 8.6469e-05, 2.5190e-01]]),
 tensor([[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]]))

In [13]:
dfz1 = dz3.sum(dim=-1,keepdim=True)
dfz1, fz1.grad, dfz1 == fz1.grad

(tensor([[13.0879],
         [11.7021],
         [ 9.7507],
         [ 5.7295],
         [ 1.6954]], grad_fn=<SumBackward1>),
 tensor([[13.0879],
         [11.7021],
         [ 9.7507],
         [ 5.7295],
         [ 1.6954]]),
 tensor([[True],
         [True],
         [True],
         [True],
         [True]]))

In [14]:
dz1 = dfz1 * z1.masked_fill(z1 < gamma,0).masked_fill(z1>=gamma,1)
dz1, z1.grad, dz1 == z1.grad

(tensor([[13.0879],
         [11.7021],
         [ 9.7507],
         [ 5.7295],
         [ 1.6954]], grad_fn=<MulBackward0>),
 tensor([[13.0879],
         [11.7021],
         [ 9.7507],
         [ 5.7295],
         [ 1.6954]]),
 tensor([[True],
         [True],
         [True],
         [True],
         [True]]))

In [15]:
dz = (H1.T)@dz1
dz, z.grad, dz==z.grad

(tensor([[ 0.0000],
         [ 9.7507],
         [ 5.7295],
         [13.0879],
         [ 0.0000],
         [11.7021],
         [ 0.0000],
         [ 1.6954],
         [ 0.0000],
         [ 0.0000]], grad_fn=<MmBackward>),
 tensor([[ 0.0000],
         [ 9.7507],
         [ 5.7295],
         [13.0879],
         [ 0.0000],
         [11.7021],
         [ 0.0000],
         [ 1.6954],
         [ 0.0000],
         [ 0.0000]]),
 tensor([[True],
         [True],
         [True],
         [True],
         [True],
         [True],
         [True],
         [True],
         [True],
         [True]]))

In [16]:
dX1 = dz@y1
dX1, X1.grad, dX1 == X1.grad

(tensor([[ 0.0000,  0.0000,  0.0000],
         [ 6.1413,  9.0546,  7.1276],
         [ 3.6086,  5.3205,  4.1882],
         [ 8.2431, 12.1535,  9.5670],
         [ 0.0000,  0.0000,  0.0000],
         [ 7.3703, 10.8666,  8.5540],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.0678,  1.5744,  1.2393],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]], grad_fn=<MmBackward>),
 tensor([[ 0.0000,  0.0000,  0.0000],
         [ 6.1413,  9.0546,  7.1276],
         [ 3.6086,  5.3205,  4.1882],
         [ 8.2431, 12.1535,  9.5670],
         [ 0.0000,  0.0000,  0.0000],
         [ 7.3703, 10.8666,  8.5540],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.0678,  1.5744,  1.2393],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]),
 tensor([[True, True, True],
         [True, True, True],
         [True, True, True],
         [True, True, True],
         [True, True, True],
         [True, True, True],
         [True, True, True],
         [T

In [17]:
dy1 = (X1.T@dz).T
dy1, y1.grad, dy1==y1.grad

(tensor([[41.7457, 58.3268, 43.5270]], grad_fn=<PermuteBackward>),
 tensor([[41.7457, 58.3268, 43.5270]]),
 tensor([[True, True, True]]))

In [18]:
dW = X.T@dX1 + y.T@dy1
dW, W.grad, dW == W.grad

(tensor([[51.2413, 73.0053, 55.5842],
         [20.5465, 29.6704, 22.8945],
         [35.6151, 51.3017, 39.4887]], grad_fn=<AddBackward0>),
 tensor([[51.2413, 73.0053, 55.5842],
         [20.5465, 29.6704, 22.8945],
         [35.6151, 51.3017, 39.4887]]),
 tensor([[True, True, True],
         [True, True, True],
         [True, True, True]]))