In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approx: {str(app):5s} | maxdiff: {maxdiff}')

In [4]:
torch.manual_seed(42)
in_size = 2
out_size = 3
layer = nn.Linear(in_size, out_size, bias=False)
x = torch.randn(10, in_size)
x.requires_grad = True
target = torch.randn(1) * 0.1

parameters = [layer]

In [5]:
# Positive
x_norm = F.normalize(x, dim=1)
y = layer(x_norm)
out = F.relu(y)
out_square = out.square()
out_mean = out_square.mean()
out_norm = out_mean - target
loss = torch.log(1 + torch.exp(out_norm))

vals = [x_norm, y, out, out_square, out_mean, out_norm, loss]
for p in parameters:
    p.zero_grad()
for v in vals:
    v.retain_grad()

loss.backward()

In [8]:
dout_norm = F.sigmoid(out_norm)
dout_mean = dout_norm
dout_square = dout_mean  / out_square.numel()
dout = dout_square * 2 * out
dy = dout * y.sign()

dweight = (out * out.sign() * F.sigmoid(out.square().mean() - target)).t() @ x_norm * 2/out.numel()

cmp('out_norm', dout_norm, out_norm)
cmp('out_mean', dout_mean, out_mean)
cmp('out_square', dout_square, out_square)
cmp('out', dout, out)
cmp('y', dy, y)
cmp('weight', dweight, layer.weight)

out_norm        | exact: True  | approx: True  | maxdiff: 0.0
out_mean        | exact: True  | approx: True  | maxdiff: 0.0
out_square      | exact: True  | approx: True  | maxdiff: 0.0
out             | exact: True  | approx: True  | maxdiff: 0.0
y               | exact: True  | approx: True  | maxdiff: 0.0
weight          | exact: False | approx: True  | maxdiff: 3.725290298461914e-09


In [9]:
# Negative
x_norm = F.normalize(x, dim=1)
y = layer(x_norm)
out = F.relu(y)
out_square = out.square()
out_mean = out_square.mean()
out_norm = target - out_mean
loss = torch.log(1 + torch.exp(out_norm))

vals = [x_norm, y, out, out_square, out_mean, out_norm, loss]
for p in parameters:
    p.zero_grad()
for v in vals:
    v.retain_grad()

loss.backward()

In [18]:
dout_norm = F.sigmoid(out_norm)
dout_mean = -dout_norm
dout_square = dout_mean  / out_square.numel()
dout = dout_square * 2 * out
dy = dout * y.sign()
dweight = dy.t() @ x_norm
dweight = (out * -F.sigmoid(target-out.square().mean())).t() @ x_norm * 2/out.numel()
# dweight = (out * out.sign() * F.sigmoid(out.square().mean() - target)).t() @ x_norm * 2/out.numel()

cmp('out_norm', dout_norm, out_norm)
cmp('out_mean', dout_mean, out_mean)
cmp('out_square', dout_square, out_square)
cmp('out', dout, out)
cmp('y', dy, y)
cmp('weight', dweight, layer.weight)

out_norm        | exact: True  | approx: True  | maxdiff: 0.0
out_mean        | exact: True  | approx: True  | maxdiff: 0.0
out_square      | exact: True  | approx: True  | maxdiff: 0.0
out             | exact: True  | approx: True  | maxdiff: 0.0
y               | exact: True  | approx: True  | maxdiff: 0.0
weight          | exact: False | approx: True  | maxdiff: 1.1175870895385742e-08
