In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [2]:
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  # print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')
  print(f'{s:20s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [3]:
X = torch.tensor([
    [1, 1, 2, 3],
    [1, 4, 5, 6],
    [1, 7, 8, 9]
], dtype=torch.float32)

y = torch.tensor([1, 0, 1]).reshape(-1, 1)

In [4]:
g = torch.Generator().manual_seed(2147483647)
theta = torch.randn((4, 1), generator=g, requires_grad=True)
theta.grad = None # i guess that i didn't zero the gradients when rerunning this
theta

tensor([[-0.9800],
        [-1.6578],
        [-0.0572],
        [-0.3409]], requires_grad=True)

In [5]:
theta.grad

In [6]:
# forward pass
theta.grad = None # forgot to flush gradients. things ended up weird b.c. of this. 
matmul = X @ theta
preds = 1 / (1 + torch.exp(-matmul))
branch1 = y * -torch.log(preds)
branch2 = (1 - y) * -torch.log(1 - preds)
losses = branch1 + branch2
risk = (losses).mean()

for t in [theta, matmul, preds, branch1, branch2, losses, risk]:
    t.retain_grad()
risk.backward()
risk

tensor(6.6358, grad_fn=<MeanBackward0>)

In [19]:
g_risk = torch.tensor(1.0)
g_losses = g_risk * (1/X.shape[0]) * torch.ones_like(losses)
g_branch1 = g_losses * 1
g_branch2 = g_losses * 1
g_preds = (
    (1/preds) * -1 * y * g_branch1 + 
    -1 * (1/(1-preds)) * -1 * (1-y) * g_branch2
)
g_matmul = g_preds * -1 * torch.exp(-matmul) * 1 * -(1 + torch.exp(-matmul)) ** -2
g_theta = X.T @ g_matmul

In [20]:
cmp('risk', g_risk, risk)
cmp('losses', g_losses, losses)
cmp('branch1', g_branch1, branch1)
cmp('branch2', g_branch2, branch2)
cmp('preds', g_preds, preds)
cmp('matmul', g_matmul, matmul)
cmp('theta', g_theta, theta)

risk                 | approximate: True  | maxdiff: 0.0
losses               | approximate: True  | maxdiff: 0.0
branch1              | approximate: True  | maxdiff: 0.0
branch2              | approximate: True  | maxdiff: 0.0
preds                | approximate: True  | maxdiff: 9.5367431640625e-07
matmul               | approximate: True  | maxdiff: 5.960464477539063e-08
theta                | approximate: True  | maxdiff: 2.384185791015625e-07
