In [1]:
import torch
import time
import torch.nn.functional as F

In [2]:
def add_new_col_to_inverse(X, old_inverse, new_col, N=None, n=None):
    B = old_inverse
    v = new_col
    u1 = torch.matmul(X.T, v)
    u2 = torch.matmul(B, u1)
    F22inv = 1. / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
    u3 = F22inv * u2
    F11inv = B + F22inv * u2 * u2.T
    if N is not None:
        Bnew = torch.cat([
            torch.cat([N - F11inv, -u3], 1),
            torch.cat([-u3.T, n - F22inv], 1)
        ], 0)
    else:
        Bnew = torch.cat([
            torch.cat([F11inv, -u3], 1),
            torch.cat([-u3.T, F22inv], 1)
        ], 0)
    return Bnew

In [3]:
L = 2
U = 3
dim = 5
a = torch.randn(L, dim)
b = torch.randn(U, dim)
a, b = [F.normalize(i, 1) for i in [a, b]]
m_ab = torch.mm(a, b.T)
m_ba = m_ab.T
m_bb = torch.mm(b, b.T)
m_ba_bb = torch.cat([m_ba, m_bb], 1)
p_ab = torch.softmax(m_ab, 1)

In [4]:
exp_m_ba_bb = torch.exp(m_ba_bb)
normaliser = exp_m_ba_bb.sum(1, keepdim=True)
exp_m_ba_bb / normaliser

tensor([[0.2018, 0.2061, 0.2307, 0.2076, 0.1539],
        [0.2065, 0.2021, 0.2111, 0.2300, 0.1504],
        [0.1915, 0.1992, 0.1769, 0.1700, 0.2624]])

In [5]:
Tbar_uu, Tbar_ul = torch.softmax(m_ba_bb, 1)[:, 2:], torch.softmax(m_ba_bb, 1)[:, :2]
Tbar_uu_unnorm = exp_m_ba_bb[:, 2:]

In [6]:
N = torch.diag(1 / normaliser[:, 0])
Nm1 = torch.diag(normaliser[:, 0])

In [7]:
middle_alternative = torch.mm(torch.inverse(Nm1 - Tbar_uu_unnorm), Nm1)

In [8]:
middle = torch.inverse(torch.eye(3) - Tbar_uu)

In [9]:
p_aba = torch.matmul(torch.matmul(p_ab, middle), Tbar_ul)
p_aba

tensor([[0.4971, 0.5029],
        [0.4971, 0.5029]])

In [10]:
torch.matmul(torch.matmul(p_ab, middle_alternative), Tbar_ul)

tensor([[0.4971, 0.5029],
        [0.4971, 0.5029]])

In [29]:
X = torch.randn([dim, U]) * 0.3

In [30]:
xnew = torch.randn([dim, 1]) * 0.3
Xnew = torch.cat([X, xnew], 1)
XtXnew = torch.mm(Xnew.T, Xnew)

In [31]:
nm1 = torch.FloatTensor([0.7])
Nm1_new = torch.diag(torch.cat([torch.diagonal(Nm1), nm1], 0))
XtX = torch.mm(X.T, X)

In [32]:
torch.inverse(Nm1 - XtX)

tensor([[ 0.2041, -0.0121, -0.0103],
        [-0.0121,  0.1979,  0.0036],
        [-0.0103,  0.0036,  0.2216]])

In [18]:
X = torch.randn([128, 4096]) * 0.1
# X = torch.randn([4, 129])
X = F.normalize(X)
XtX = torch.matmul(X.T, X)
# XtXnew = torch.matmul(X[:, 1:].T, X[:, 1:])
eXtX = torch.exp(XtX)
eXtX -= eXtX.min()
# eXtXnew = torch.exp(XtXnew)

In [24]:
def torch_inv():
    return torch.inverse(eXtX)

%timeit torch_inv()

579 ms ± 7.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [31]:
def torch_inv2():
    u = torch.cholesky(eXtX)
    return torch.cholesky_inverse(u)
%timeit torch_inv2()

470 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [34]:
%timeit u = torch.cholesky(eXtX)

180 ms ± 7.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
torch.mm(u, u.T)

tensor([[0.0413, 0.0183, 0.0127,  ..., 0.0157, 0.0157, 0.0126],
        [0.0183, 0.0516, 0.0125,  ..., 0.0111, 0.0166, 0.0144],
        [0.0127, 0.0125, 0.0454,  ..., 0.0168, 0.0142, 0.0184],
        ...,
        [0.0157, 0.0111, 0.0168,  ..., 0.0465, 0.0145, 0.0170],
        [0.0157, 0.0166, 0.0142,  ..., 0.0145, 0.0436, 0.0137],
        [0.0126, 0.0144, 0.0184,  ..., 0.0170, 0.0137, 0.0439]])

In [26]:

torch.potrf(eXtX)

AttributeError: module 'torch' has no attribute 'potrf'

In [25]:
import scipy.linalg

M = eXtX.numpy()
def scipy_inv(M):
    return np.linalg.inv(M)
    zz , _ = scipy.linalg.lapack.dpotrf(M, False, False)
    inv_M , info = scipy.linalg.lapack.dpotri(zz)
    return inv_M
%timeit scipy_inv(M)

2.94 s ± 90.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
inv_M

array([[ 5.16016609e+03, -7.17094981e+01,  3.67227918e+01, ...,
        -1.39174757e+01,  9.63197227e+00,  3.24160249e+01],
       [ 1.82854533e-02,  2.65514458e+03, -9.94787787e+00, ...,
        -3.58465863e+01,  1.19178624e+01,  2.36879707e+01],
       [ 1.26858354e-02,  1.24610066e-02,  3.53307662e+03, ...,
         8.30143124e+01,  5.84684897e+01, -1.16068985e+02],
       ...,
       [ 1.57354474e-02,  1.11050010e-02,  1.68003440e-02, ...,
         3.70626661e+03,  1.43761181e+01,  2.37045821e+01],
       [ 1.57182813e-02,  1.65951848e-02,  1.41780972e-02, ...,
         1.44798160e-02,  4.38297233e+03,  2.60644095e+01],
       [ 1.25547051e-02,  1.44284368e-02,  1.84476972e-02, ...,
         1.70407891e-02,  1.36840343e-02,  4.21725737e+03]])

In [22]:
import numpy as np

np.linalg.cholesky(M)

array([[ 2.0313248e-01,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 9.0017378e-02,  2.0856513e-01,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 6.2451042e-02,  3.2792285e-02,  2.0104480e-01, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       ...,
       [ 7.7463962e-02,  1.9811068e-02,  5.6270979e-02, ...,
         1.6426392e-02,  0.0000000e+00,  0.0000000e+00],
       [ 7.7379458e-02,  4.6171136e-02,  3.8954563e-02, ...,
        -5.3331343e-05,  1.5105100e-02,  0.0000000e+00],
       [ 6.1805502e-02,  4.2504072e-02,  6.5627530e-02, ...,
        -9.2000715e-05, -9.3355819e-05,  1.5398731e-02]], dtype=float32)

In [29]:
inv_M

array([[ 3.51934044e+22,  5.62912448e+22,  3.32529420e+22, ...,
        -1.55015399e+05,  1.90540316e+05, -6.83050750e+05],
       [-7.57587026e-04,  9.00368835e+22,  5.31875047e+22, ...,
        -2.48000540e+05,  3.04861960e+05, -1.09240252e+06],
       [ 1.96077349e-03,  6.69133267e-04,  3.14194748e+22, ...,
        -1.46437006e+05,  1.79968692e+05, -6.45486137e+05],
       ...,
       [-1.49789630e-04, -3.62291816e-03,  6.04315009e-03, ...,
         7.83422954e+02, -3.26315742e+01, -1.37612101e+02],
       [-1.92725344e-03, -3.70545220e-03, -4.79730405e-03, ...,
         1.21067581e-03,  8.82375437e+02,  2.54072718e+01],
       [-1.81561953e-03, -4.84478893e-03, -1.69411406e-03, ...,
         4.36846633e-03, -7.56252208e-04,  1.13137123e+03]])

In [31]:
zz

array([[ 1.74134039e-01, -4.35059699e-03,  1.12601390e-02, ...,
        -1.49789630e-04, -1.92725344e-03, -1.81561953e-03],
       [-7.57587026e-04,  1.72435779e-01,  4.16457418e-03, ...,
        -3.62291816e-03, -3.70545220e-03, -4.84478893e-03],
       [ 1.96077349e-03,  6.69133267e-04,  1.72395117e-01, ...,
         6.04315009e-03, -4.79730405e-03, -1.69411406e-03],
       ...,
       [-1.49789630e-04, -3.62291816e-03,  6.04315009e-03, ...,
         3.61386612e-02,  1.21067581e-03,  4.36846633e-03],
       [-1.92725344e-03, -3.70545220e-03, -4.79730405e-03, ...,
         1.21067581e-03,  3.36754769e-02, -7.56252208e-04],
       [-1.81561953e-03, -4.84478893e-03, -1.69411406e-03, ...,
         4.36846633e-03, -7.56252208e-04,  2.97301728e-02]])

In [258]:
X = torch.randn([4096, 128]) * 0.05
def old(X):
    return torch.inverse(torch.exp(torch.matmul(X, X.T)))

%timeit old(X)

705 ms ± 76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [259]:
old(X).shape

torch.Size([4096, 4096])

In [None]:
def new(X):
    
    symeig = torch.symeig(torch.exp(torch.matmul(X, X.T)), True)

In [260]:
torch.symeig(torch.exp(torch.matmul(X, X.T)), True)

torch.return_types.symeig(
eigenvalues=tensor([8.9242e-03, 9.1426e-03, 9.1826e-03,  ..., 1.4072e+01, 1.4283e+01,
        4.0980e+03]),
eigenvectors=tensor([[-1.2142e-02,  1.0688e-02, -4.1807e-03,  ...,  1.0539e-02,
          1.1228e-02,  1.5622e-02],
        [ 5.4051e-03, -1.8964e-03,  3.7369e-03,  ...,  1.0931e-02,
         -2.3277e-02,  1.5624e-02],
        [ 3.8326e-03, -6.3312e-03, -7.8157e-03,  ..., -2.2687e-02,
         -2.0749e-02,  1.5607e-02],
        ...,
        [ 3.8040e-03, -2.9918e-03, -8.4959e-03,  ...,  3.2905e-04,
          1.9703e-02,  1.5632e-02],
        [-2.8663e-03, -1.0480e-03,  7.3053e-03,  ..., -6.7421e-03,
         -2.5815e-02,  1.5624e-02],
        [ 1.3776e-02,  7.8041e-03, -7.0944e-05,  ..., -1.7609e-02,
         -1.6254e-02,  1.5626e-02]]))

In [261]:
%timeit torch.symeig(torch.exp(torch.matmul(X, X.T)), True)

10.9 s ± 314 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
torch.inverse()

In [None]:
c = 1.0 / (1.0 + torch.mm(u.T, torch.mm(old_inv, v.T)))
new_inv = 

In [245]:
u.shape

torch.Size([128])

In [209]:
eXtX

tensor([[1.0224, 0.9655, 1.0029,  ..., 0.9941, 1.0127, 0.9819],
        [0.9655, 1.0879, 1.0113,  ..., 1.0133, 0.9849, 1.0195],
        [1.0029, 1.0113, 1.0152,  ..., 0.9897, 1.0153, 0.9953],
        ...,
        [0.9941, 1.0133, 0.9897,  ..., 1.0278, 0.9684, 0.9949],
        [1.0127, 0.9849, 1.0153,  ..., 0.9684, 1.0607, 1.0039],
        [0.9819, 1.0195, 0.9953,  ..., 0.9949, 1.0039, 1.0234]])

In [210]:
vals, vecs = torch.symeig(eXtX, True)

In [213]:
XtX0 = torch.matmul(X[:, :-1].T, X[:, :-1])
eXtX0 = torch.exp(XtX0)
vals0, vecs0 = torch.symeig(eXtX0, True)

In [193]:
goal = torch.inverse(torch.mm(torch.mm(vecs, torch.diag(vals)), vecs.T))
goal[-1, -1]

tensor(2566537.)

In [194]:
known = torch.inverse(torch.mm(torch.mm(vecs0, torch.diag(vals0)), vecs0.T))

In [None]:
v

In [None]:
# %timeit torch.inverse(eXtX)

In [116]:
torch.all(XtX == XtX.T)

tensor(True)

In [132]:
torch.mm(torch.mm(vecs, torch.diag(vals)), vecs.T)

tensor([[1.0348, 0.9936, 0.9993,  ..., 1.0011, 0.9995, 0.9995],
        [0.9936, 1.0373, 0.9996,  ..., 1.0034, 0.9957, 1.0048],
        [0.9993, 0.9996, 1.0306,  ..., 1.0016, 0.9989, 1.0001],
        ...,
        [1.0011, 1.0034, 1.0016,  ..., 1.0285, 0.9987, 0.9958],
        [0.9995, 0.9957, 0.9989,  ..., 0.9987, 1.0331, 0.9967],
        [0.9995, 1.0048, 1.0001,  ..., 0.9958, 0.9967, 1.0325]])

In [137]:
torch.mm(torch.mm(vecs, torch.diag(vals)), vecs.T)

tensor([[1.0348, 0.9936, 0.9993,  ..., 1.0011, 0.9995, 0.9995],
        [0.9936, 1.0373, 0.9996,  ..., 1.0034, 0.9957, 1.0048],
        [0.9993, 0.9996, 1.0306,  ..., 1.0016, 0.9989, 1.0001],
        ...,
        [1.0011, 1.0034, 1.0016,  ..., 1.0285, 0.9987, 0.9958],
        [0.9995, 0.9957, 0.9989,  ..., 0.9987, 1.0331, 0.9967],
        [0.9995, 1.0048, 1.0001,  ..., 0.9958, 0.9967, 1.0325]])

In [50]:
f = torch.exp(XtX - XtX.min())
f / f.sum(1, keepdim=True)

tensor([[0.5380, 0.2225, 0.2395],
        [0.2297, 0.4388, 0.3315],
        [0.2511, 0.3366, 0.4123]])

In [51]:
XtX - XtX.min()

tensor([[0.8831, 0.0000, 0.0738],
        [0.0000, 0.6472, 0.3668],
        [0.0738, 0.3668, 0.5696]])

In [49]:
torch.softmax(XtX, 1)

tensor([[0.5380, 0.2225, 0.2395],
        [0.2297, 0.4388, 0.3315],
        [0.2511, 0.3366, 0.4123]])

In [46]:
torch.inverse(torch.exp(XtX))

tensor([[ 0.7726, -0.1278, -0.3662],
        [-0.1278,  1.8606, -1.4413],
        [-0.3662, -1.4413,  2.1615]])

In [37]:
(torch.inverse(XtX))

tensor([[4.3831, 3.1722, 2.7911],
        [3.1722, 5.3071, 1.2520],
        [2.7911, 1.2520, 5.6468]])

In [152]:
def add_new_col_to_inverse_new(X, old_inverse, new_col, N=None, n=None):
    B = old_inverse
    v = new_col
    # Decomposition of diagonal matrix N
    Nlong = torch.cat([torch.sqrt(N), torch.zeros([X.shape[0] - N.shape[0], N.shape[1]])], 0)
    X = Nlong - X
    
    
    # d = 1. / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
    middle = torch.matmul(torch.matmul(Nlong + X, (B)), (Nlong + X).T)
#     print(middle.shape, N.shape, B.shape)
    v = v + n
    other = torch.matmul(torch.matmul(v.T, middle), v)
    d = 1. / (torch.matmul(v.T, v) - other)
    return d
    u3 = d * u2
    F11inv = B + d * u2 * u2.T
    Bnew = torch.cat([
        torch.cat([F11inv, -u3], 1),
        torch.cat([-u3.T, d], 1)
    ], 0)
    return Bnew

add_new_col_to_inverse_new(X, torch.inverse(Nm1 - XtX), xnew, N=Nm1, n=nm1)

tensor([[-0.1303]])

In [127]:
nm1

tensor([0.7000])

In [163]:
def add_new_col_to_inverse_new(X, old_inverse, new_col, N=None, n=None):
    B = old_inverse
    v = new_col
    # Decomposition of diagonal matrix N
    Nlong = torch.cat([torch.sqrt(N), torch.zeros([X.shape[0] - N.shape[0], N.shape[1]])], 0)
    X = (Nlong - X)
    v = n - v
    u1 = torch.matmul(X.T, v)
    u2 = torch.matmul(B, u1)
    d = 1. / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
    return n - d

print(add_new_col_to_inverse_new(X, torch.inverse(Nm1 - XtX), xnew, N=Nm1, n=nm1))

tensor([[5.2744]])


In [189]:
%timeit torch.inverse(Nm1 + torch.matmul(X.T, X))

15.2 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [188]:
import numpy as np
def woodbury(A, U, V, k):
    A_inv = torch.diag(1. / torch.diag(A))  # Fast matrix inversion of a diagonal.
    B_inv = torch.inverse(torch.eye(k) + V @ A_inv @ U)
    return A_inv - (A_inv @ U @ B_inv @ V @ A_inv)

%timeit woodbury(Nm1, X.T, X, X.shape[0])

64.3 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [182]:
# Can possible use Woodbury identity..
Am1 = torch.inverse(Nm1)
u = X.T
c = torch.eye(X.shape[0])
v = X

left = torch.mm(Am1, u)
middle = (c + torch.mm(torch.mm(v, Am1), u))
right = torch.mm(v, Am1)

Am1 - torch.mm(torch.mm(left, middle), right)

tensor([[ 0.1611, -0.0021, -0.0120],
        [-0.0021,  0.1594, -0.0018],
        [-0.0120, -0.0018,  0.1530]])

In [173]:
o = torch.mm(N1, X.T)
p = torch.eye(X.shape[0]) + torch.matmul(X, X.T)
o = torch.mm(o, p)
o = torch.mm(o, X)
N1 - o

tensor([[ 0.0475, -0.0217, -0.1195],
        [-0.0214,  0.0674, -0.0202],
        [-0.1180, -0.0203, -0.0151]])

In [167]:
torch.inverse(Nm1_new - XtXnew)

tensor([[0.1920, 0.0023, 0.0131, 0.0324],
        [0.0023, 0.1869, 0.0021, 0.0082],
        [0.0131, 0.0021, 0.1960, 0.0387],
        [0.0324, 0.0082, 0.0387, 1.5992]])

In [394]:
import torch
import time
import torch.nn.functional as F


n, p = 128, 4000
col = 0

a = torch.randn(100, 16) * 0.5
b = torch.randn(500, 16) * 0.5

match_ab = torch.matmul(a, torch.t(b))
match_ba = torch.t(match_ab)
match_bb = torch.matmul(b, torch.t(b))

p_ab = F.softmax(match_ab, dim=1)
p_bb = F.softmax(match_bb, dim=1)
p_ba = F.softmax(match_ba, dim=1)

In [395]:
match_ba_bb = torch.cat([match_ba, match_bb], dim=1)
p_ba_bb = torch.clamp(F.softmax(match_ba_bb, dim=1), min=1e-8)

In [396]:
N = a.shape[0]
M = b.shape[0]
Tbar_ul, Tbar_uu = p_ba_bb[:, :N], p_ba_bb[:, N:]

In [430]:
middle = torch.inverse(torch.eye(Tbar_uu.shape[1]) - Tbar_uu)# - torch.eye(Tbar_uu.shape[1])

In [433]:
middle = torch.eye(Tbar_uu.shape[1])
for i in range(1, 3):
    middle += torch.matrix_power(Tbar_uu, i)
middle /= Tbar_uu.sum(1, keepdim=True)

In [434]:
p_aba = torch.matmul(torch.matmul(p_ab, middle), Tbar_ul)

In [432]:
p_aba

tensor([[0.0115, 0.0085, 0.0108,  ..., 0.0093, 0.0088, 0.0090],
        [0.0093, 0.0101, 0.0103,  ..., 0.0093, 0.0085, 0.0092],
        [0.0094, 0.0082, 0.0150,  ..., 0.0095, 0.0084, 0.0092],
        ...,
        [0.0094, 0.0086, 0.0111,  ..., 0.0118, 0.0089, 0.0093],
        [0.0094, 0.0083, 0.0103,  ..., 0.0094, 0.0102, 0.0089],
        [0.0089, 0.0083, 0.0103,  ..., 0.0090, 0.0081, 0.0117]])

In [437]:
p_aba / p_aba.sum(1, keepdim=True)

tensor([[0.0151, 0.0088, 0.0112,  ..., 0.0096, 0.0092, 0.0085],
        [0.0094, 0.0128, 0.0100,  ..., 0.0097, 0.0083, 0.0090],
        [0.0097, 0.0081, 0.0217,  ..., 0.0103, 0.0081, 0.0089],
        ...,
        [0.0096, 0.0090, 0.0118,  ..., 0.0157, 0.0092, 0.0091],
        [0.0098, 0.0083, 0.0100,  ..., 0.0099, 0.0126, 0.0081],
        [0.0084, 0.0082, 0.0101,  ..., 0.0090, 0.0074, 0.0153]])

In [165]:
match_ab = torch.matmul(a, torch.t(b))
p_ab = F.softmax(match_ab, dim=1)

match_ba = torch.t(match_ab)
match_bb = torch.matmul(b, torch.t(b))
p_bb = torch.softmax(match_bb, dim=1)
p_ba = F.softmax(match_ba, dim=1)

match_ba_bb = torch.cat([match_ba, match_bb], dim=1)
p_ba_bb = torch.clamp(F.softmax(match_ba_bb, dim=1), min=1e-8)
N = a.shape[0]
M = b.shape[0]
Tbar_ul, Tbar_uu = p_ba_bb[:, :N], p_ba_bb[:, N:]
I = torch.eye(M)
I = I.cuda() if Tbar_uu.is_cuda else I
# middle = torch.inverse(I - Tbar_uu + 1e-8)
middle = torch.inverse(match_bb + 1e-8)
# middle = torch.inverse(I - p_bb + 1e-8)
p_aba = torch.matmul(torch.matmul(p_ab, middle), Tbar_ul)
p_aba /= p_aba.sum(1, keepdim=True)
# p_aba = torch.matmul(torch.matmul(p_ab, p_bb), p_ba)

In [162]:
p_aba.sum(1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])

In [163]:
middle

tensor([[ 1.2201, -0.3730,  0.6471,  ..., -0.3973, -0.3934, -0.2369],
        [-2.3164,  0.3392, -4.5347,  ...,  0.9172,  2.8041, -0.4114],
        [ 3.0603,  0.5993,  3.7068,  ..., -0.6661, -5.3733,  2.5896],
        ...,
        [ 2.4616,  0.9285,  3.8963,  ..., -0.7809, -5.4913,  1.0392],
        [-1.4368, -0.3912, -3.5948,  ...,  0.0509,  3.4498, -0.8072],
        [-1.6878, -0.0854, -2.4238,  ...,  0.5706,  0.5472, -1.4204]])

In [128]:
p_aba

tensor([[6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        ...,
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888],
        [6.0889, 6.0888, 6.0887,  ..., 6.0888, 6.0890, 6.0888]])

In [123]:
p_bb

tensor([[0.0003, 0.0002, 0.0003,  ..., 0.0003, 0.0002, 0.0002],
        [0.0002, 0.0003, 0.0003,  ..., 0.0002, 0.0003, 0.0003],
        [0.0003, 0.0003, 0.0003,  ..., 0.0002, 0.0002, 0.0003],
        ...,
        [0.0003, 0.0002, 0.0002,  ..., 0.0003, 0.0002, 0.0003],
        [0.0002, 0.0003, 0.0002,  ..., 0.0002, 0.0003, 0.0002],
        [0.0002, 0.0003, 0.0003,  ..., 0.0003, 0.0002, 0.0003]])

In [None]:
p_ab

In [113]:
middle

tensor([[1.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 1.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 1.0100,  ..., 0.0100, 0.0100, 0.0100],
        ...,
        [0.0100, 0.0100, 0.0100,  ..., 1.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 1.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 1.0100]])

In [114]:
Tbar_ul

tensor([[0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        ...,
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
        [0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002]])

In [115]:
p_aba

tensor([[0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        ...,
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100],
        [0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100]])

In [96]:
(p_aba / p_aba.sum(1, keepdim=True))

tensor([[0.5052, 0.4948],
        [0.5018, 0.4982]])

In [55]:
match_ab = torch.matmul(a, torch.t(b))
p_ab = F.softmax(match_ab, dim=1)
p_ba = F.softmax(torch.t(match_ab), dim=1)
p_aba = torch.matmul(p_ab, p_ba)

In [56]:
p_aba.sum(1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])

In [139]:
X = torch.randn([n, p]) * 0.5
# X = torch.softmax(X, 0)
B = torch.matmul(X.T, X)
B = B.cuda()
B = B.cpu()
now = time.time()
B = torch.inverse(B)
time.time() - now

NameError: name 'n' is not defined

In [7]:
new_col = v = torch.randn([n, 1]) * 0.5
# new_col = v = torch.softmax(torch.randn([n, 1]) * 0.5, 0)

In [8]:
def add_new_col_to_inverse(X, old_inverse, new_col):
    B = old_inverse
    v = new_col
    
    u1 = torch.matmul(X.T, v)
    u2 = torch.matmul(B, u1)
    F22inv = 1. / (torch.matmul(v.T, v) - torch.matmul(u1.T, u2))
    u3 = F22inv * u2
    F11inv = B + F22inv * u2 * u2.T
    Bnew = torch.cat([
        torch.cat([F11inv, -u3], 1),
        torch.cat([-u3.T, F22inv], 1)
    ], 0)
    return Bnew

In [9]:
now = time.time()
add_new_col_to_inverse(B, v)
time.time() - now

0.338165283203125

In [29]:
Xnew = torch.cat([X, v], 1)
torch.inverse(torch.matmul(Xnew.T, Xnew))

tensor([[ 1.2506, -0.6256,  0.1035,  1.0205],
        [-0.6256,  0.8223,  0.1486, -0.5141],
        [ 0.1035,  0.1486,  0.4402,  0.3598],
        [ 1.0205, -0.5141,  0.3598,  2.0012]])