In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
M = torch.randn(2,2)

In [None]:
Q = M @ M.T

In [None]:
def u(a1, a2):
    return (a1.T + a2.T) @ (-Q) @ (a1.T+a2) + torch.tensor([0.,1.]).T@a1

In [None]:
a2 = torch.tensor([2., .5], requires_grad=False)

In [None]:
inners = []
for e in range(10000):
    a1 = torch.rand(2, requires_grad=True)
    b1 = torch.rand(2, requires_grad=True)
    u1 = u(a1, a2)
    u2 = u(b1, a2)
    a1.grad = None
    b1.grad = None
    u1.backward()
    u2.backward()
    ga = a1.grad
    gb = b1.grad
    inners.append((ga - gb).T @ (a1 - b1))

In [None]:
inners = []
for e in range(10000):
    a1 = torch.rand(2, requires_grad=True)
    b1 = torch.rand(2, requires_grad=True)
    u1 = u(a2, a1)
    u2 = u(a2, b1)
    a1.grad = None
    b1.grad = None
    u1.backward()
    u2.backward()
    ga = a1.grad
    gb = b1.grad
    inners.append((ga - gb).T @ (a1 - b1))
torch.tensor(inners).max(), torch.tensor(inners).mean()

In [None]:
class NN(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        
        self.fc = nn.Linear(n_in, 10)
        self.out = nn.Linear(10, n_out, bias=False)
        
    def forward(self, v):
        v = F.relu(self.fc(v))
        b = F.relu(self.out(v))
        return b

In [None]:
inners1 = []
inners2 = []
inners = []

def u1(v1, a, b):
    return -torch.dot(a,a)*v1 +b.sum()

def u2(v2, a, b):
    return -torch.dot(b,b)*v2 +a.sum()


dist = torch.distributions.Uniform(low=0, high=10)
# setup identical neural nets, but with different params
pi11 = NN(1, 1)
pi12 = NN(1, 1)
pi21 = NN(1, 1)
pi22 = NN(1, 1)


for e in range(1000):
    #reset grads:
    for model in [pi11,pi12,pi21,pi22]:
        for p in model.parameters():
            p.grad = None
    # choose random valuation
    v1, v2 = dist.sample([2,1])

    a1 = pi11(v1)
    a2 = pi12(v1)
    b1 = pi21(v2)
    b2 = pi22(v2)

    u11 = u1(v1,a1, b1.detach())
    u12 = u1(v1,a2, b2.detach())
    u21 = u2(v2,a1.detach(), b1)
    u22 = u2(v2,a2.detach(), b2)

    u11.backward()
    u12.backward()
    u21.backward()
    u22.backward()
    
    def last_layer_params(model):
        return list(model.children())[-1].parameters()

    #     t11 = nn.utils.parameters_to_vector(pi11.parameters()).detach()
    #     t12 = nn.utils.parameters_to_vector(pi12.parameters()).detach()
    #     t21 = nn.utils.parameters_to_vector(pi21.parameters()).detach()
    #     t22 = nn.utils.parameters_to_vector(pi22.parameters()).detach()
    #     g11 = torch.cat(tuple(p.grad.flatten() for p in pi11.parameters()))
    #     g12 = torch.cat(tuple(p.grad.flatten() for p in pi12.parameters()))
    #     g21 = torch.cat(tuple(p.grad.flatten() for p in pi21.parameters()))
    #     g22 = torch.cat(tuple(p.grad.flatten() for p in pi22.parameters()))

    t11 = nn.utils.parameters_to_vector(last_layer_params(pi11)).detach()
    t12 = nn.utils.parameters_to_vector(last_layer_params(pi12)).detach()
    t21 = nn.utils.parameters_to_vector(last_layer_params(pi21)).detach()
    t22 = nn.utils.parameters_to_vector(last_layer_params(pi22)).detach()
    g11 = torch.cat(tuple(p.grad.flatten() for p in last_layer_params(pi11)))
    g12 = torch.cat(tuple(p.grad.flatten() for p in last_layer_params(pi12)))
    g21 = torch.cat(tuple(p.grad.flatten() for p in last_layer_params(pi21)))
    g22 = torch.cat(tuple(p.grad.flatten() for p in last_layer_params(pi22)))

    t1 = torch.cat((t11, t21))
    t2 = torch.cat((t12, t22))
    

    
    g1 = torch.cat((g11, g21))
    g2 = torch.cat((g12, g22))
    
    inners1.append(torch.dot(g12 - g11, t12 - t11))
    inners2.append(torch.dot(g22 - g21, t22 - t21))
    inners.append(torch.dot(g2 - g1, t2 - t1))

    
inners1 = torch.tensor(inners1)
inners2 = torch.tensor(inners2)
inners = torch.tensor(inners)

print('inner products, min, mean, max')
[print(torch.tensor([l.min(), l.mean(), l.max()])) for l in [inners1, inners2, inners]]


In [None]:
g21.T @(t22 - t21)

In [None]:
pi22(v1) - pi21(v1)

In [None]:
nn.utils.parameters_to_vector(pi11.parameters()).detach()

In [None]:
t11

In [None]:
g11

In [None]:
last_layer_params(p11)

In [None]:
(pi11.children())[-1].parameters()

In [None]:
inners2[0]

In [None]:
inners[0]

In [None]:
check[0]

In [None]:
g12 - g11, g22 - g21, g2 - g1

In [None]:
t12 - t11, t22 - t21, t2 - t1