In [7]:
import matplotlib.pyplot as plt
import torch
from rbi.utils.nets import IndependentGaussianNet
from rbi.loss.loss_fn import NLLLoss
from rbi.defenses.regularized_loss import GaussianNoiseJacobiRegularizer
from rbi.utils.fisher_info import fisher_info
from rbi.utils.autograd_tools import batch_jacobian

In [8]:
means = torch.randn((1000, 1))*2
p_true = torch.distributions.Normal(means.squeeze(), 2)
observations = means + torch.randn(1000, 10)*0.5

$$ p(X, \mu) = p(X|\mu)p(\mu)$$

In [13]:
net = torch.nn.Sequential(torch.nn.Linear(10, 50), torch.nn.ReLU(), torch.nn.Linear(50,50), torch.nn.ReLU(), torch.nn.Linear(50,1))

In [15]:
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
for i in range(1000):
    optim.zero_grad()
    pred_mean = net(observations)
    p_impl = torch.distributions.Normal(means, 0.5)
    loss = -p_impl.log_prob(pred_mean).mean()
    if (i % 100) == 0:
        print(loss.detach())
    loss.backward()
    optim.step()

tensor(0.2623)
tensor(0.2596)
tensor(0.2568)
tensor(0.2538)
tensor(0.2512)
tensor(0.2489)
tensor(0.2471)
tensor(0.2456)
tensor(0.2444)
tensor(0.2433)


In [76]:
mean_star = torch.randn((1,1))*2
p_true = torch.distributions.Normal(mean_star.squeeze(), 0.5)
observations = mean_star + torch.randn(10000, 10)*0.5

In [77]:
observations = observations.requires_grad_(True)
bias = net(observations)
diff = torch.autograd.grad(bias.sum(), observations)[0].mean(0)
var_lower_bound = diff@diff.T*1/4
variance = net(observations).var()
print("Bias: ",torch.mean(bias-mean_star).detach(), "Variance: ", variance.detach(), "CRLB: ", var_lower_bound)

Bias:  tensor(0.0154) Variance:  tensor(0.0427) CRLB:  tensor(0.0263)


In [89]:
Js = torch.autograd.grad(net(observations).sum(), observations)[0].unsqueeze(-1)

torch.Size([10000, 10])

In [104]:
(torch.transpose(Js, -2,-1)@Js).mean()*4

tensor(1.5291)

tensor(-16357.8281, grad_fn=<SumBackward0>) tensor(0.0354, grad_fn=<VarBackward0>) tensor(0.0273)


In [331]:
J_expected = batch_jacobian(lambda x: net.predict(x), p_true.sample((1000,10))).mean(0)

In [332]:
var_lower_bound = J_expected@torch.transpose(J_expected, -2,-1)*4

In [333]:
var_lower_bound

tensor([[0.5530]])

In [334]:
net.predict(p_true.sample((1000,10))).var()

tensor(0.8371, grad_fn=<VarBackward0>)

In [250]:
net2

In [256]:
covariance_of_estiamtor = torch.transpose(Js, -2,-1)@torch.linalg.inv(Fs)@Js

In [265]:
cov_est = (torch.linalg.inv(Fs)@Js)@torch.transpose(Js, -2,-1)

In [272]:
mean_variance_lower_bound = cov_est[:, 0][:,0]

In [273]:
mean_variance_lower_bound

tensor([9.9893e-04, 2.5532e-05, 1.6243e-04, 1.5652e-03, 9.9026e-05, 3.3554e-04,
        6.8057e-04, 1.9034e-04, 3.9373e-04, 1.4017e-03, 3.0827e-05, 2.3596e-04,
        1.1210e-03, 1.2128e-02, 2.3221e-03, 4.4122e-04, 2.9795e-05, 2.0009e-04,
        1.4027e-04, 2.9821e-03, 8.3756e-04, 3.2888e-04, 6.4394e-04, 8.2891e-04,
        1.5898e-04, 5.5301e-04, 1.5125e-03, 3.9884e-03, 1.0122e-04, 2.9633e-04,
        1.4927e-03, 7.2020e-05, 2.6287e-05, 2.2541e-03, 6.5329e-04, 7.4049e-05,
        3.3679e-04, 2.9090e-03, 7.5915e-05, 1.5355e-04, 1.1225e-03, 1.5683e-04,
        6.0192e-04, 1.0075e-04, 4.3179e-04, 9.8452e-04, 2.1295e-04, 2.3030e-03,
        6.7949e-05, 4.6616e-04, 1.9059e-04, 1.0526e-04, 6.9441e-04, 3.8035e-04,
        3.0109e-04, 1.2808e-04, 7.1613e-05, 1.8134e-03, 3.1026e-03, 2.8814e-04,
        2.0272e-03, 1.0387e-04, 4.5939e-04, 3.6789e-04, 3.3912e-04, 6.3587e-04,
        1.1583e-04, 8.2043e-05, 1.3256e-04, 2.8748e-04, 5.7719e-04, 1.5080e-03,
        1.0621e-03, 1.4752e-03, 1.2676e-

In [243]:
net(observations).mean.var()

tensor(0.0457, grad_fn=<VarBackward0>)

In [53]:
torch.mean((net(observations).base_dist.loc - observations.mean(-1).unsqueeze(-1))**2)

tensor(0.0128, grad_fn=<MeanBackward0>)

In [68]:
net2 = IndependentGaussianNet(10, 1)
loss_fn = NLLLoss(net2)
defense = GaussianNoiseJacobiRegularizer(net, loss_fn, 0.2)
defense.activate()

In [69]:
optim = torch.optim.Adam(net2.parameters(), lr=1e-3)
for i in range(1000):
    optim.zero_grad()
    loss = loss_fn(observations, means)
    if (i % 100) == 0:
        print(loss.detach())
    loss.backward()
    optim.step()

tensor(4.4554)
tensor(1.3074)
tensor(1.2321)
tensor(1.2172)
tensor(1.2070)
tensor(1.1984)
tensor(1.1916)
tensor(1.1857)
tensor(1.1810)
tensor(1.1766)


In [219]:
F = defense._compute_fisher(observations[0].unsqueeze(0), net2(observations[0].unsqueeze(0))).detach() + torch.eye(10)*0.001

In [220]:
F_inv = torch.linalg.inv(F)

In [221]:
torch.linalg.eigvalsh(F_inv)

tensor([[7.2797e-02, 9.0554e-01, 9.9991e+02, 9.9994e+02, 9.9998e+02, 9.9998e+02,
         1.0000e+03, 1.0000e+03, 1.0001e+03, 1.0003e+03]])

In [222]:
p = torch.distributions.MultivariateNormal(observations[0], F_inv)

In [227]:
p.sample((10,)).mean(-1)

tensor([[ 0.5746],
        [13.2651],
        [ 1.9085],
        [ 3.7335],
        [ 9.9578],
        [ 0.5435],
        [14.9603],
        [-1.0185],
        [-0.3865],
        [-0.0564]])

In [203]:
net2(observations[1].unsqueeze(0) + 0.1*torch.randn(100,1)).mean

tensor([[-1.7794],
        [-1.7057],
        [-1.9992],
        [-1.6891],
        [-1.8192],
        [-1.7018],
        [-1.7421],
        [-1.9387],
        [-1.6105],
        [-1.7220],
        [-1.9977],
        [-1.6809],
        [-1.7099],
        [-1.6032],
        [-1.8177],
        [-1.8952],
        [-1.7103],
        [-1.7752],
        [-1.8190],
        [-1.7495],
        [-1.8155],
        [-1.7258],
        [-1.6928],
        [-1.7232],
        [-1.7981],
        [-1.7462],
        [-1.8113],
        [-1.6597],
        [-1.6997],
        [-1.8801],
        [-1.7440],
        [-1.7890],
        [-1.5049],
        [-1.8259],
        [-1.7430],
        [-1.7593],
        [-1.6150],
        [-1.6529],
        [-1.7349],
        [-1.8002],
        [-1.5405],
        [-1.6300],
        [-1.9168],
        [-1.8031],
        [-1.5336],
        [-1.7707],
        [-1.7355],
        [-1.8818],
        [-1.6828],
        [-1.8101],
        [-1.9117],
        [-1.6224],
        [-1.

In [168]:
net2(p.sample((100,))).mean

tensor([[-1.7549],
        [-1.5716],
        [-1.5987],
        [-1.2383],
        [-0.8455],
        [-1.7040],
        [-1.1621],
        [-1.6986],
        [-1.4179],
        [-1.4974],
        [-1.8474],
        [-2.1858],
        [-0.8035],
        [-1.5460],
        [-1.7395],
        [-1.9392],
        [-1.5967],
        [-1.3721],
        [-1.3601],
        [-1.6426],
        [-1.1538],
        [-1.6184],
        [-1.9061],
        [-1.5908],
        [-1.4998],
        [-1.7900],
        [-2.1121],
        [-1.7119],
        [-1.4837],
        [-1.3873],
        [-1.7578],
        [-1.6242],
        [-1.6461],
        [-1.6029],
        [-1.9264],
        [-1.6585],
        [-1.4569],
        [-1.4690],
        [-1.5868],
        [-1.4912],
        [-2.0929],
        [-1.4760],
        [-1.8047],
        [-1.5554],
        [-1.7501],
        [-1.5341],
        [-1.7682],
        [-2.0359],
        [-1.6556],
        [-2.0150],
        [-1.8584],
        [-1.8833],
        [-1.

In [65]:
torch.mean((net2(observations).base_dist.loc - observations.median(-1).values.unsqueeze(-1))**2)

tensor(0.0284, grad_fn=<MeanBackward0>)

In [67]:
torch.mean((net2(observations).base_dist.loc - observations.mean(-1).unsqueeze(-1))**2)

tensor(0.0127, grad_fn=<MeanBackward0>)

In [38]:
observations.median(-1)

torch.return_types.median(
values=tensor([-4.4001e+00, -3.2258e+00, -2.1351e+00, -2.3956e+00, -2.6880e+00,
         1.4055e+00,  2.0184e+00, -3.2590e+00, -2.3928e-01,  3.8684e+00,
         1.1583e+00,  4.7608e-01, -2.6399e-01,  1.6707e+00, -2.7737e+00,
        -3.0504e+00,  6.3489e-01,  2.7195e+00,  6.5361e+00,  1.4531e+00,
        -4.3600e-01, -2.3690e+00, -1.7195e+00, -1.2741e+00, -2.0122e+00,
        -1.5859e+00,  3.9849e-01,  2.2415e+00,  1.2183e+00, -3.0741e+00,
         1.4684e+00,  5.6325e-01,  1.8900e-01,  1.4461e+00,  3.4787e-01,
         2.2897e+00,  1.8708e+00, -5.5868e-01, -2.2019e+00,  1.2163e+00,
        -2.3448e+00,  1.4460e+00, -1.0969e+00,  1.2090e+00, -9.8263e-01,
         1.8689e-01, -2.3600e+00,  2.8903e+00,  1.2310e+00,  1.1050e+00,
         1.9051e+00, -4.9865e-01,  8.2530e-01, -2.6378e-01,  5.9296e-01,
         1.5881e+00,  3.2970e+00, -1.9274e+00, -1.2162e+00, -5.6231e-01,
         1.7530e+00, -2.2519e-01, -1.2262e+00,  7.7451e-01, -3.5798e-01,
         1.1902e+