In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import cholesky

from datasets import load_dataset
from energies import LowRankEnergy
from sampling import AdaptiveSampler


In [2]:
def get_components_naive(G, indices):
    L = cholesky(G[np.ix_(indices, indices)], lower=True)
    W = np.linalg.inv(G[np.ix_(indices, indices)]) @ G[indices,:]
    f = np.diag(np.linalg.inv(G[np.ix_(indices, indices)])).real
    d = np.diag(G).real - np.sum(W.conj() * G[indices,:], axis=0).real
    return L, W, f, d


def check_components(energy, t=None, verbose=True, nround=4):
    if t is not None:
        indices = energy.indices[:t] + energy.indices[t+1:]
        indices_t = list(range(t)) + list(range(t+1, len(energy.indices)))
    else:
        indices = energy.indices[:]
        indices_t = list(range(len(energy.indices)))

    L_naive, W_naive, f_naive, d_naive = get_components_naive(energy.G, indices)
    matchL = np.allclose(energy.L[np.ix_(indices_t, indices_t)], L_naive)
    matchW = np.allclose(energy.W[indices_t, :], W_naive)
    matchf = np.allclose(energy.f[indices_t], f_naive)
    matchd = np.allclose(energy.d, d_naive)
    print("all match: ", matchL and matchW and matchf and matchd)
    
    print("L close:", matchL)
    if not matchL and verbose:
        print("Computed L:\n", np.round(energy.L[np.ix_(indices_t, indices_t)], nround))
        print("Naive L:\n", np.round(L_naive, nround))
    print("W close:", matchW)
    if not matchW and verbose:
        print("Computed W:\n", np.round(energy.W[indices_t, :], nround))
        print("Naive W:\n", np.round(W_naive, nround))
        print("full self.W:\n", np.round(energy.W, nround))
    print("f close:", matchf)
    if not matchf and verbose:
        print("Computed f:\n", np.round(energy.f[indices_t],nround))
        print("Naive f:\n", np.round(f_naive, nround))
        print("full self.f:\n", np.round(energy.f, nround))
    print("d close:", matchd)
    if not matchd and verbose:
        print("Computed d:\n", np.round(energy.d, nround))
        print("Naive d:\n", np.round(d_naive, nround))
    
    if matchL and matchW and matchf and matchd and (t is not None) and verbose:
        print("All components match naive computation!")
        print("L @ t:\n", np.round(energy.L[t, :], nround)) 
        print("W @ t:\n", np.round(energy.W[t, :], nround))
        print("f @ t:\n", np.round(energy.f[t], nround))
        print("d @ t:\n", np.round(energy.d[t], nround))
    return 


## Interpolative Update/Downdate Unit Tests

Algorithms 9.4, 9.5

In [3]:
rand_state = np.random.RandomState(42)

m, n = 40,70
k = 20
X = rand_state.randn(m, n) + 1j*rand_state.randn(m, n)
energy = LowRankEnergy(X, p=2)
sampler = AdaptiveSampler(energy, seed=42)

sampler.build_phase(k, method="search") # build with k points

check_components(energy)

all match:  True
L close: True
W close: True
f close: True
d close: True


In [4]:
print("initial indices: ", energy.indices)
for t in range(len(energy.indices)):
    print("----- Testing index t =", t, " -----")
    # downdate 1 point
    print("Check downdate at t =", t)
    energy.downdate(t) 
    check_components(energy, t=t, verbose=False)
    print("indices after downdate:  ", energy.indices)
    print() 

    # update with a new point
    i = n // 2
    while i in energy.indices:
        i += 1
    print("Check update at t =", t, " with i =", i)
    energy.update(t, i)
    check_components(energy, t=None, verbose=False)
    print("indices after update:  ", energy.indices)
    print()

initial indices:  [np.int64(56), np.int64(18), np.int64(21), np.int64(67), np.int64(40), np.int64(28), np.int64(43), np.int64(63), np.int64(16), np.int64(4), np.int64(0), np.int64(3), np.int64(35), np.int64(37), np.int64(9), np.int64(64), np.int64(24), np.int64(60), np.int64(69), np.int64(8)]
----- Testing index t = 0  -----
Check downdate at t = 0
all match:  True
L close: True
W close: True
f close: True
d close: True
indices after downdate:   [np.int64(56), np.int64(18), np.int64(21), np.int64(67), np.int64(40), np.int64(28), np.int64(43), np.int64(63), np.int64(16), np.int64(4), np.int64(0), np.int64(3), np.int64(35), np.int64(37), np.int64(9), np.int64(64), np.int64(24), np.int64(60), np.int64(69), np.int64(8)]

Check update at t = 0  with i = 36
t, f[t]: 0 1e-15
all match:  True
L close: True
W close: True
f close: True
d close: True
indices after update:   [36, np.int64(18), np.int64(21), np.int64(67), np.int64(40), np.int64(28), np.int64(43), np.int64(63), np.int64(16), np.int6

## Swap Tests

In [5]:
rand_state = np.random.RandomState(42)

m, n = 10,90
k = 5
X = rand_state.randn(m, n) + 1j*rand_state.randn(m, n)
energy = LowRankEnergy(X, p=2)

sampler = AdaptiveSampler(energy, seed=42)
sampler.build_phase(k, method="search") # build with k points

In [6]:
sampler.swap_phase("search", debug=False)

In [7]:
sampler.swap_phase("sampling", debug=True)

Best energy found during sampling swap: 26.996541784469144
Energy at end of sampling swap: 29.100993954668592
Re-initializing Energy object to best found indices...
