In [43]:
torch.manual_seed(42)


dimension = 1000
M = torch.randn([dimension,dimension])
M += M.T.clone()
M = M/2

v = torch.randn([dimension,])
v_normalized = v / torch.norm(v, p=2)


In [61]:
import gpytorch
import torch
lanczos_iters = 2

def Hess_Vec_Orig(M,v):
#     v_normalized = v / torch.norm(v, p=2)
    return torch.matmul(M,v)

def Hess_Vec(M):
    def matvec(v):
#         v_normalized = v / torch.norm(v, p=2)
        return torch.matmul(M, v)
    return matvec

P = M.shape[0]

# Create the closure by calling Hess_Vec with matrix M
matvec_closure = Hess_Vec(M)

# Perform Lanczos tridiagonalization using the closure
Q, T_gpy = gpytorch.utils.lanczos.lanczos_tridiag(
    matvec_closure,
    max_iter=lanczos_iters,
    dtype=torch.float32,
    device='cpu',
    matrix_shape=(P, P)
)


T = torch.zeros([lanczos_iters, lanczos_iters])
# r = torch.randn([dimension,])
r = v[:]
q_old = torch.zeros_like(r)  # Ensure q_old is a vector of the same size as r
b = torch.norm(r, p=2)
u_list = []
for i in range(lanczos_iters):
    q = r / b
    u_list.append(q)
    u = Hess_Vec_Orig(M, q) - b * q_old
    alpha = torch.dot(u, q)
    T[i, i] = alpha
    r = u - alpha * q
    
    for j in range(len(u_list)):
        r -= torch.dot(r, u_list[j]) * u_list[j]
    
    b = torch.norm(r, p=2)
    if i < lanczos_iters - 1:
        T[i, i+1] = b
        T[i+1, i] = b
    q_old = q
    
    if b < 1e-6:
        break




In [62]:
print(T_gpy)
torch.linalg.eigh(T_gpy)

tensor([[ 0.3130, 23.6814],
        [23.6814, -1.4521]])


torch.return_types.linalg_eigh(
eigenvalues=tensor([-24.2674,  23.1283]),
eigenvectors=tensor([[-0.6938, -0.7202],
        [ 0.7202, -0.6938]]))

In [63]:
print(T)
torch.linalg.eigh(T)

tensor([[-0.6388, 22.2502],
        [22.2502, -1.2837]])


torch.return_types.linalg_eigh(
eigenvalues=tensor([-23.2138,  21.2912]),
eigenvectors=tensor([[-0.7020, -0.7122],
        [ 0.7122, -0.7020]]))

In [8]:
for i in range(lanczos_iters):
    print(i)
    for j in range(i+1,lanczos_iters):
        print(torch.dot(u_list[i],u_list[j]))

0
tensor(-5.5879e-09)
tensor(1.1176e-08)
tensor(3.7253e-09)
1
tensor(9.3132e-09)
tensor(4.6566e-09)
2
tensor(9.3132e-10)
3


In [12]:
v = torch.randn([dimension,])
assert torch.allclose(Hess_Vec_Orig(M, v), matvec_closure(v))


In [13]:
torch.allclose(Hess_Vec_Orig(M, v), matvec_closure(v))

True