In [2]:
import numpy as np
from scipy.linalg.interpolative import interp_decomp


In [9]:
A = np.array([[0,1,2],[3,4,5],[6,7,8]])
i = np.array([0,1,2])

In [10]:
A[2,2]
A[i,i]
i[2]

2

In [11]:
print(A[i[:,None],i[None,:]])
print(A)

[[0 1 2]
 [3 4 5]
 [6 7 8]]
[[0 1 2]
 [3 4 5]
 [6 7 8]]


In [12]:
A = np.array([[0,1,2],[3,4,5],[6,7,8]])

### Check that Interpolative Decomposition works with complex matrices

In [3]:
# Define a 2x5 random complex matrix
rows, cols = 5, 2
real_part = np.random.rand(rows, cols)  # Random real part
imag_part = np.random.rand(rows, cols)  # Random imaginary part

# Combine them into a complex matrix
m1 = real_part + 1j * imag_part

real_part = np.random.rand(cols, rows)  # Random real part
imag_part = np.random.rand(cols, rows)  # Random imaginary part

m2 = real_part + 1j * imag_part

m_tot = m1 @ m2

In [4]:
# Use scipy's implementation of the interpolative decomposition
# Instead of the matrix cross interpolation M = C @ P^-1 @ R
# it factorizes as M = A @ P with A = M[:, idx]
def interpolative_decomposition(M, eps_or_k=1e-5, k_min=2):
    r = min(M.shape)
    if r <= k_min:
        k = r
        idx, proj = interp_decomp(M, eps_or_k=k) #eps_or_k = precision of decomposition
    elif isinstance(eps_or_k, int): #checks if eps is an integer
        k = min(r, eps_or_k)
        idx, proj = interp_decomp(M, eps_or_k=k)
    else:
        k, idx,  proj = interp_decomp(M, eps_or_k=eps_or_k)
        if k <= k_min:
            k = min(r, k_min) #is it not enough to put k = k_min? 
                              #r>k_min otherwise first condition would have been true
            idx, proj = interp_decomp(M, eps_or_k=k)
    A = M[:, idx[:k]]
    P = np.concatenate([np.eye(k), proj], axis=1)[:, np.argsort(idx)]
    return A, P, k, idx[:k]

# k is the 'compressed' rank = number of pivot columns
# idx is the array with entries the indeces of the pivot columns
# proj = matrix R s.t. M[:,idx[:k]]*R = M[:,idx[k:]] 
# P = matrix s.t.  M[:,idx[:k]]*P = M (approximated)

In [7]:
A, P, k, idx = interpolative_decomposition(m_tot, eps_or_k=2)

m_interp = A @ P

print(m_interp)
print()
print(m_tot)

[[-0.41928128+1.00080127j  0.0652775 +1.58916245j -0.64497361+1.98811855j
  -0.56438473+1.3876732j  -0.29940375+1.62778386j]
 [ 0.12338775+0.5492222j   0.1098847 +0.66315231j  0.18806878+1.25949579j
  -0.02697813+0.74233348j -0.00183318+0.92444097j]
 [-0.24059206+0.20502703j -0.2957317 +0.52886219j -0.40654617+0.52867641j
  -0.38824251+0.33687345j -0.35155227+0.51929688j]
 [-0.04395127+0.2300687j  -0.11904514+0.36899942j -0.10077391+0.57565565j
  -0.16465109+0.32690823j -0.1678344 +0.46106j   ]
 [-0.24899779+0.98599264j  0.39391027+1.62171128j -0.19176679+1.92695023j
  -0.27947528+1.45218008j  0.10804916+1.65898575j]]

[[-0.41928128+1.00080127j  0.0652775 +1.58916245j -0.64497361+1.98811855j
  -0.56438473+1.3876732j  -0.29940375+1.62778386j]
 [ 0.12338775+0.5492222j   0.1098847 +0.66315231j  0.18806878+1.25949579j
  -0.02697813+0.74233348j -0.00183318+0.92444097j]
 [-0.24059206+0.20502703j -0.2957317 +0.52886219j -0.40654617+0.52867641j
  -0.38824251+0.33687345j -0.35155227+0.51929688j