In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [14]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approx: {str(app):5s} | maxdiff: {maxdiff}')

In [66]:
torch.manual_seed(42)
in_size = 2
out_size = 3
layer = nn.Linear(in_size, out_size, bias=True)
x = F.normalize(torch.randn(10, in_size))
x.requires_grad = True
threshold = 2.0

parameters = [layer]

In [67]:
# Positive
actv = layer(x)
y = F.relu(actv)
y_square = y.square()
y_sum = y_square.sum(dim=1) # sum each vector
y_mean = y_sum.mean() # mean over batch
y_corr = y_mean - threshold
log_p = torch.log(1 + torch.exp(y_corr))

vals = [x, actv, y, y_square, y_sum, y_mean, y_corr, log_p]
for p in parameters:
    p.zero_grad()
for v in vals:
    v.retain_grad()

log_p.backward()

In [68]:
dy_corr = torch.sigmoid(y_corr)
dy_mean = dy_corr * 1.0
dy_sum = dy_mean / y_sum.shape[0]
dy_square = dy_sum
dy = dy_square * 2.0 * y
dactv = dy * actv.sign()
dweight = dactv.t() @ x
dbias = dactv.sum(dim=0)

cmp('dy_corr', dy_corr, y_corr)
cmp('dy_mean', dy_mean, y_mean)
cmp('dy_sum', dy_sum, y_sum)
cmp('dy_square', dy_square, y_square)
cmp('dy', dy, y)
cmp('dactv', dactv, actv)
cmp('dweight', dweight, layer.weight)
cmp('dbias', dbias, layer.bias)

dy_corr         | exact: True  | approx: True  | maxdiff: 0.0
dy_mean         | exact: True  | approx: True  | maxdiff: 0.0
dy_sum          | exact: True  | approx: True  | maxdiff: 0.0
dy_square       | exact: True  | approx: True  | maxdiff: 0.0
dy              | exact: True  | approx: True  | maxdiff: 0.0
dactv           | exact: True  | approx: True  | maxdiff: 0.0
dweight         | exact: True  | approx: True  | maxdiff: 0.0
dbias           | exact: True  | approx: True  | maxdiff: 0.0


In [71]:
dlogp = torch.sigmoid(y_corr)
dweight = (dlogp * 2 * y).t() @ x / x.shape[0]
dbias = (dlogp * 2 * y).mean(dim=0)

cmp('dweight', dweight, layer.weight)
cmp('dbias', dbias, layer.bias)

dweight         | exact: False | approx: True  | maxdiff: 1.4901161193847656e-08
dbias           | exact: False | approx: True  | maxdiff: 1.4901161193847656e-08


In [73]:
# Negative
actv = layer(x)
y = F.relu(actv)
y_square = y.square()
y_sum = y_square.sum(dim=1) # sum each vector
y_mean = y_sum.mean() # mean over batch
y_corr = threshold - y_mean
log_p = torch.log(1 + torch.exp(y_corr))

vals = [x, actv, y, y_square, y_sum, y_mean, y_corr, log_p]
for p in parameters:
    p.zero_grad()
for v in vals:
    v.retain_grad()

log_p.backward()

In [74]:
dy_corr = torch.sigmoid(y_corr)
dy_mean = dy_corr * -1.0
dy_sum = dy_mean / y_sum.shape[0]
dy_square = dy_sum
dy = dy_square * 2.0 * y
dactv = dy * actv.sign()
dweight = dactv.t() @ x
dbias = dactv.sum(dim=0)

cmp('dy_corr', dy_corr, y_corr)
cmp('dy_mean', dy_mean, y_mean)
cmp('dy_sum', dy_sum, y_sum)
cmp('dy_square', dy_square, y_square)
cmp('dy', dy, y)
cmp('dactv', dactv, actv)
cmp('dweight', dweight, layer.weight)
cmp('dbias', dbias, layer.bias)

dy_corr         | exact: True  | approx: True  | maxdiff: 0.0
dy_mean         | exact: True  | approx: True  | maxdiff: 0.0
dy_sum          | exact: True  | approx: True  | maxdiff: 0.0
dy_square       | exact: True  | approx: True  | maxdiff: 0.0
dy              | exact: True  | approx: True  | maxdiff: 0.0
dactv           | exact: True  | approx: True  | maxdiff: 0.0
dweight         | exact: True  | approx: True  | maxdiff: 0.0
dbias           | exact: True  | approx: True  | maxdiff: 0.0


In [75]:
dlogp = torch.sigmoid(y_corr)
dweight = (-dlogp * 2 * y).t() @ x / x.shape[0]
dbias = (-dlogp * 2 * y).mean(dim=0)
cmp('dweight', dweight, layer.weight)
cmp('dbias', dbias, layer.bias)

dweight         | exact: False | approx: True  | maxdiff: 5.960464477539063e-08
dbias           | exact: True  | approx: True  | maxdiff: 0.0
