In [1]:
import torch
import time
import torch.nn.functional as F

In [2]:
def add_new_col_to_inverse(X, old_inverse, new_col, N=None, n=None):
    B = old_inverse
    v = new_col
    u1 = torch.matmul(X.T, v)
    u2 = torch.matmul(B, u1)
    F22inv = 1. / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
    u3 = F22inv * u2
    F11inv = B + F22inv * u2 * u2.T
    if N is not None:
        Bnew = torch.cat([
            torch.cat([N - F11inv, -u3], 1),
            torch.cat([-u3.T, n - F22inv], 1)
        ], 0)
    else:
        Bnew = torch.cat([
            torch.cat([F11inv, -u3], 1),
            torch.cat([-u3.T, F22inv], 1)
        ], 0)
    return Bnew

In [3]:
L = 2
U = 3
dim = 5
a = torch.randn(L, dim)
b = torch.randn(U, dim)
a, b = [F.normalize(i, 1) for i in [a, b]]
m_ab = torch.mm(a, b.T)
m_ba = m_ab.T
m_bb = torch.mm(b, b.T)
m_ba_bb = torch.cat([m_ba, m_bb], 1)
p_ab = torch.softmax(m_ab, 1)

In [4]:
exp_m_ba_bb = torch.exp(m_ba_bb)
normaliser = exp_m_ba_bb.sum(1, keepdim=True)
exp_m_ba_bb / normaliser

tensor([[0.2019, 0.1529, 0.2330, 0.2122, 0.2001],
        [0.1795, 0.1780, 0.2086, 0.2177, 0.2162],
        [0.1745, 0.1889, 0.1975, 0.2170, 0.2220]])

In [5]:
Tbar_uu, Tbar_ul = torch.softmax(m_ba_bb, 1)[:, 2:], torch.softmax(m_ba_bb, 1)[:, :2]
Tbar_uu_unnorm = exp_m_ba_bb[:, 2:]

In [6]:
N = torch.diag(1 / normaliser[:, 0])
Nm1 = torch.diag(normaliser[:, 0])

In [7]:
middle_alternative = torch.mm(torch.inverse(Nm1 - Tbar_uu_unnorm), Nm1)

In [8]:
middle = torch.inverse(torch.eye(3) - Tbar_uu)

In [9]:
p_aba = torch.matmul(torch.matmul(p_ab, middle), Tbar_ul)
p_aba

tensor([[0.5175, 0.4825],
        [0.5154, 0.4846]])

In [10]:
torch.matmul(torch.matmul(p_ab, middle_alternative), Tbar_ul)

tensor([[0.5175, 0.4825],
        [0.5154, 0.4846]])

In [23]:
X = torch.randn([dim, U]) * 0.3

In [24]:
xnew = torch.randn([dim, 1]) * 0.3
Xnew = torch.cat([X, xnew], 1)
XtXnew = torch.mm(Xnew.T, Xnew)

In [53]:
nm1 = torch.FloatTensor([0.7])
Nm1_new = torch.diag(torch.cat([torch.diagonal(Nm1), nm1], 0))
XtX = torch.mm(X.T, X)

In [54]:
torch.inverse(Nm1 - XtX)

tensor([[0.1913, 0.0022, 0.0124],
        [0.0022, 0.1869, 0.0019],
        [0.0124, 0.0019, 0.1951]])

In [152]:
def add_new_col_to_inverse_new(X, old_inverse, new_col, N=None, n=None):
    B = old_inverse
    v = new_col
    # Decomposition of diagonal matrix N
    Nlong = torch.cat([torch.sqrt(N), torch.zeros([X.shape[0] - N.shape[0], N.shape[1]])], 0)
    X = Nlong - X
    
    
    # d = 1. / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
    middle = torch.matmul(torch.matmul(Nlong + X, (B)), (Nlong + X).T)
#     print(middle.shape, N.shape, B.shape)
    v = v + n
    other = torch.matmul(torch.matmul(v.T, middle), v)
    d = 1. / (torch.matmul(v.T, v) - other)
    return d
    u3 = d * u2
    F11inv = B + d * u2 * u2.T
    Bnew = torch.cat([
        torch.cat([F11inv, -u3], 1),
        torch.cat([-u3.T, d], 1)
    ], 0)
    return Bnew

add_new_col_to_inverse_new(X, torch.inverse(Nm1 - XtX), xnew, N=Nm1, n=nm1)

tensor([[-0.1303]])

In [127]:
nm1

tensor([0.7000])

In [163]:
def add_new_col_to_inverse_new(X, old_inverse, new_col, N=None, n=None):
    B = old_inverse
    v = new_col
    # Decomposition of diagonal matrix N
    Nlong = torch.cat([torch.sqrt(N), torch.zeros([X.shape[0] - N.shape[0], N.shape[1]])], 0)
    X = (Nlong - X)
    v = n - v
    u1 = torch.matmul(X.T, v)
    u2 = torch.matmul(B, u1)
    d = 1. / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
    return n - d

print(add_new_col_to_inverse_new(X, torch.inverse(Nm1 - XtX), xnew, N=Nm1, n=nm1))

tensor([[5.2744]])


In [189]:
%timeit torch.inverse(Nm1 + torch.matmul(X.T, X))

15.2 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [188]:
import numpy as np
def woodbury(A, U, V, k):
    A_inv = torch.diag(1. / torch.diag(A))  # Fast matrix inversion of a diagonal.
    B_inv = torch.inverse(torch.eye(k) + V @ A_inv @ U)
    return A_inv - (A_inv @ U @ B_inv @ V @ A_inv)

%timeit woodbury(Nm1, X.T, X, X.shape[0])

64.3 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [182]:
# Can possible use Woodbury identity..
Am1 = torch.inverse(Nm1)
u = X.T
c = torch.eye(X.shape[0])
v = X

left = torch.mm(Am1, u)
middle = (c + torch.mm(torch.mm(v, Am1), u))
right = torch.mm(v, Am1)

Am1 - torch.mm(torch.mm(left, middle), right)

tensor([[ 0.1611, -0.0021, -0.0120],
        [-0.0021,  0.1594, -0.0018],
        [-0.0120, -0.0018,  0.1530]])

In [173]:
o = torch.mm(N1, X.T)
p = torch.eye(X.shape[0]) + torch.matmul(X, X.T)
o = torch.mm(o, p)
o = torch.mm(o, X)
N1 - o

tensor([[ 0.0475, -0.0217, -0.1195],
        [-0.0214,  0.0674, -0.0202],
        [-0.1180, -0.0203, -0.0151]])

In [167]:
torch.inverse(Nm1_new - XtXnew)

tensor([[0.1920, 0.0023, 0.0131, 0.0324],
        [0.0023, 0.1869, 0.0021, 0.0082],
        [0.0131, 0.0021, 0.1960, 0.0387],
        [0.0324, 0.0082, 0.0387, 1.5992]])

In [394]:
import torch
import time
import torch.nn.functional as F


n, p = 128, 4000
col = 0

a = torch.randn(100, 16) * 0.5
b = torch.randn(500, 16) * 0.5

match_ab = torch.matmul(a, torch.t(b))
match_ba = torch.t(match_ab)
match_bb = torch.matmul(b, torch.t(b))

p_ab = F.softmax(match_ab, dim=1)
p_bb = F.softmax(match_bb, dim=1)
p_ba = F.softmax(match_ba, dim=1)

In [395]:
match_ba_bb = torch.cat([match_ba, match_bb], dim=1)
p_ba_bb = torch.clamp(F.softmax(match_ba_bb, dim=1), min=1e-8)

In [396]:
N = a.shape[0]
M = b.shape[0]
Tbar_ul, Tbar_uu = p_ba_bb[:, :N], p_ba_bb[:, N:]

In [430]:
middle = torch.inverse(torch.eye(Tbar_uu.shape[1]) - Tbar_uu)# - torch.eye(Tbar_uu.shape[1])

In [433]:
middle = torch.eye(Tbar_uu.shape[1])
for i in range(1, 3):
    middle += torch.matrix_power(Tbar_uu, i)
middle /= Tbar_uu.sum(1, keepdim=True)

In [434]:
p_aba = torch.matmul(torch.matmul(p_ab, middle), Tbar_ul)

In [432]:
p_aba

tensor([[0.0115, 0.0085, 0.0108,  ..., 0.0093, 0.0088, 0.0090],
        [0.0093, 0.0101, 0.0103,  ..., 0.0093, 0.0085, 0.0092],
        [0.0094, 0.0082, 0.0150,  ..., 0.0095, 0.0084, 0.0092],
        ...,
        [0.0094, 0.0086, 0.0111,  ..., 0.0118, 0.0089, 0.0093],
        [0.0094, 0.0083, 0.0103,  ..., 0.0094, 0.0102, 0.0089],
        [0.0089, 0.0083, 0.0103,  ..., 0.0090, 0.0081, 0.0117]])

In [437]:
p_aba / p_aba.sum(1, keepdim=True)

tensor([[0.0151, 0.0088, 0.0112,  ..., 0.0096, 0.0092, 0.0085],
        [0.0094, 0.0128, 0.0100,  ..., 0.0097, 0.0083, 0.0090],
        [0.0097, 0.0081, 0.0217,  ..., 0.0103, 0.0081, 0.0089],
        ...,
        [0.0096, 0.0090, 0.0118,  ..., 0.0157, 0.0092, 0.0091],
        [0.0098, 0.0083, 0.0100,  ..., 0.0099, 0.0126, 0.0081],
        [0.0084, 0.0082, 0.0101,  ..., 0.0090, 0.0074, 0.0153]])

In [165]:
match_ab = torch.matmul(a, torch.t(b))
p_ab = F.softmax(match_ab, dim=1)

match_ba = torch.t(match_ab)
match_bb = torch.matmul(b, torch.t(b))
p_bb = torch.softmax(match_bb, dim=1)
p_ba = F.softmax(match_ba, dim=1)

match_ba_bb = torch.cat([match_ba, match_bb], dim=1)
p_ba_bb = torch.clamp(F.softmax(match_ba_bb, dim=1), min=1e-8)
N = a.shape[0]
M = b.shape[0]
Tbar_ul, Tbar_uu = p_ba_bb[:, :N], p_ba_bb[:, N:]
I = torch.eye(M)
I = I.cuda() if Tbar_uu.is_cuda else I
# middle = torch.inverse(I - Tbar_uu + 1e-8)
middle = torch.inverse(match_bb + 1e-8)
# middle = torch.inverse(I - p_bb + 1e-8)
p_aba = torch.matmul(torch.matmul(p_ab, middle), Tbar_ul)
p_aba /= p_aba.sum(1, keepdim=True)
# p_aba = torch.matmul(torch.matmul(p_ab, p_bb), p_ba)

In [162]:
p_aba.sum(1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])

In [163]:
middle

tensor([[ 1.2201, -0.3730,  0.6471,  ..., -0.3973, -0.3934, -0.2369],
        [-2.3164,  0.3392, -4.5347,  ...,  0.9172,  2.8041, -0.4114],
        [ 3.0603,  0.5993,  3.7068,  ..., -0.6661, -5.3733,  2.5896],
        ...,
        [ 2.4616,  0.9285,  3.8963,  ..., -0.7809, -5.4913,  1.0392],
        [-1.4368, -0.3912, -3.5948,  ...,  0.0509,  3.4498, -0.8072],
        [-1.6878, -0.0854, -2.4238,  ...,  0.5706,  0.5472, -1.4204]])

In [128]:
p_aba

tensor([[6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        ...,
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888]])

In [123]:
p_bb

tensor([[0.0003, 0.0002, 0.0003,  ..., 0.0003, 0.0002, 0.0002],
        [0.0002, 0.0003, 0.0003,  ..., 0.0002, 0.0003, 0.0003],
        [0.0003, 0.0003, 0.0003,  ..., 0.0002, 0.0002, 0.0003],
        ...,
        [0.0003, 0.0002, 0.0002,  ..., 0.0003, 0.0002, 0.0003],
        [0.0002, 0.0003, 0.0002,  ..., 0.0002, 0.0003, 0.0002],
        [0.0002, 0.0003, 0.0003,  ..., 0.0003, 0.0002, 0.0003]])

In [None]:
p_ab

In [113]:
middle

tensor([[1.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 1.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 1.0100,  ..., 0.0100, 0.0100, 0.0100],
        ...,
        [0.0100, 0.0100, 0.0100,  ..., 1.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 1.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 1.0100]])

In [114]:
Tbar_ul

tensor([[0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        ...,
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002]])

In [115]:
p_aba

tensor([[0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        ...,
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100]])

In [96]:
(p_aba / p_aba.sum(1, keepdim=True))

tensor([[0.5052, 0.4948],
        [0.5018, 0.4982]])

In [55]:
match_ab = torch.matmul(a, torch.t(b))
p_ab = F.softmax(match_ab, dim=1)
p_ba = F.softmax(torch.t(match_ab), dim=1)
p_aba = torch.matmul(p_ab, p_ba)

In [56]:
p_aba.sum(1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])

In [139]:
X = torch.randn([n, p]) * 0.5
# X = torch.softmax(X, 0)
B = torch.matmul(X.T, X)
B = B.cuda()
B = B.cpu()
now = time.time()
B = torch.inverse(B)
time.time() - now

NameError: name 'n' is not defined

In [7]:
new_col = v = torch.randn([n, 1]) * 0.5
# new_col = v = torch.softmax(torch.randn([n, 1]) * 0.5, 0)

In [8]:
def add_new_col_to_inverse(X, old_inverse, new_col):
    B = old_inverse
    v = new_col
    
    u1 = torch.matmul(X.T, v)
    u2 = torch.matmul(B, u1)
    F22inv = 1. / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
    u3 = F22inv * u2
    F11inv = B + F22inv * u2 * u2.T
    Bnew = torch.cat([
        torch.cat([F11inv, -u3], 1),
        torch.cat([-u3.T, F22inv], 1)
    ], 0)
    return Bnew

In [9]:
now = time.time()
add_new_col_to_inverse(B, v)
time.time() - now

0.338165283203125

In [29]:
Xnew = torch.cat([X, v], 1)
torch.inverse(torch.matmul(Xnew.T, Xnew))

tensor([[ 1.2506, -0.6256,  0.1035,  1.0205],
        [-0.6256,  0.8223,  0.1486, -0.5141],
        [ 0.1035,  0.1486,  0.4402,  0.3598],
        [ 1.0205, -0.5141,  0.3598,  2.0012]])