In [27]:
import numpy as np

N = 10
J = 8
alpha = np.random.rand(J//2) + 1.0j * np.random.rand(J//2)
alpha = np.append(alpha, np.conj(alpha))
beta = np.random.rand(J//2) + 1.0j * np.random.rand(J//2)
beta = np.append(beta, np.conj(beta))
t = np.sort(np.random.uniform(0, 100, N))
y0 = np.sin(t)
u = alpha*np.exp(-beta * t[:, None])
v = np.exp(beta * t[:, None])
diag = 0.01 + np.sum(alpha) + np.zeros(N)

K = np.sum(alpha*np.exp(-beta*np.abs(t[:, None] - t[None, :])[:, :, None]), axis=-1)
K[np.diag_indices_from(K)] = diag

K0 = np.tril(np.dot(u, v.T), -1) + np.triu(np.dot(v, u.T), 1)
K0[np.diag_indices_from(K0)] = diag
print("Semiseparable error: {0}".format(np.max(np.abs(K - K0))))

# Cholesky method
dt = np.diff(t)
phi = np.exp(-beta * dt[:, None])
D = np.empty(N, dtype=alpha.dtype)
X = np.empty((N, J), dtype=alpha.dtype)

# Explicit first step
D[0] = np.sqrt(diag[0])
X[0] = 1.0 / D[0]
S = X[0][:, None] * X[0][None, :]

# Then the rest
for n in range(1, N):
    St = phi[n-1][:, None] * phi[n-1][None, :] * S
    D[n] = np.sqrt(diag[n] - np.sum(alpha[None, :] * alpha[:, None] * St))
    X[n] = (1.0 - np.sum(alpha[None, :] * St, axis=1)) / D[n]
    S = St + X[n][:, None] * X[n][None, :]

# Check factorization
L = np.tril(np.dot(u, (v*X).T), -1)
L[np.diag_indices_from(L)] = D

print("Cholesky error: {0}".format(np.max(np.abs(np.dot(L, L.T) - K))))
print(2*np.sum(np.log(D)).real, np.linalg.slogdet(K))

Semiseparable error: 3.552713688986415e-15
Cholesky error: 3.552918894473127e-15
14.699555675 ((1+0j), 14.699555674959717)


In [28]:
y = np.array(y0)
z = np.empty(N, dtype=alpha.dtype)
z[0] = y[0] / D[0]
f = 0.0
for n in range(1, N):
    f = phi[n-1] * (f + alpha * X[n-1] * z[n-1]) 
    z[n] = (y[n] - np.sum(f)) / D[n]
print("Forward sub error: {0}".format(np.max(np.abs(z - np.linalg.solve(L, y)))))

y = np.array(z)
z = np.empty(N, dtype=alpha.dtype)
z[-1] = y[-1] / D[-1]
f = 0.0
for n in range(N-2, -1, -1):
    f = phi[n] * (f + alpha * z[n+1]) 
    z[n] = (y[n] - np.sum(f * X[n])) / D[n]
print("Backward sub error: {0}".format(np.max(np.abs(z - np.linalg.solve(L.T, y)))))

print("Full solve error: {0}".format(np.max(np.abs(np.linalg.solve(K, y0) - z))))

Forward sub error: 1.6657865655285256e-16
Backward sub error: 2.498002032016893e-16
Full solve error: 1.666274353858095e-16
