In [259]:
from scipy.sparse.linalg import svds
from scipy.linalg import svd
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

In [260]:
def random_matrix(n, p=0):
    r_matrix = np.random.uniform(0, 1, (n, n))
    for ind in np.random.permutation(range(n*n))[:int(n*n*p)]:
        r_matrix[ind // n, ind % n] = 0
    return r_matrix

In [261]:
def truncated_svd(A, k):
    if not 0 < k < A.shape[0] + 1:
        raise ValueError("k should between 0 and A.shape[0] + 1")
    if k == A.shape[0]:
        U, s, V = svd(A)
    else:
        U, s, V = svds(A, k)
        U, s, V = U[::, ::-1], s[::-1], V[::-1]
    return U, s, V

In [262]:
def calculate_singular_values(matrix):
    return svd(matrix)[1]

def plot_singular_values(values):
    plt.plot(values)

def power_of_two(x):
    while x % 2:
        x //= 2
    return x != 1

In [263]:
class Node:
    def __init__(self, lp=None, rp=None, ld=None, rd=None):
        self.lp = lp
        self.rp = rp
        self.ld = ld
        self.rd = rd

    def eval(self, length):
        matrix = np.zeros((length, length))
        length //= 2
        matrix[:length, :length] = self.lp.eval(length)
        matrix[:length, length:] = self.rp.eval(length)
        matrix[length:, :length] = self.ld.eval(length)
        matrix[length:, length:] = self.rd.eval(length)
        return matrix


class Leaf:
    def __init__(self, U=None, V=None, s=None, zeros=False):
        self.U = U
        self.s = s
        self.V = V
        self.zeros = zeros

    def eval(self, length):
        if self.zeros:
            return 0
        return self.U @ np.diag(self.s) @ self.V

In [264]:
eps = 1e-12

In [265]:
def compress(matrix, min_value, max_rank, left, up, length):
    if length == 1:
        return Leaf(U=matrix, s=np.array([1]), V=np.array([1]))
    else:
        if length <= max_rank + 1:
            max_rank = length - 1
        U, s, V  = truncated_svd(matrix, k=max_rank + 1)
        if np.abs(s[-1]) < min_value + eps:
            s_values = s[np.abs(s) >= min_value + eps]
            k = s_values.shape[0]
            if k == 0:
                return Leaf(zeros=True)
            return Leaf(U=U[::, :k], s=s_values, V=V[:k])
        length //= 2
        node = Node(
            lp=compress(matrix[:length, :length], min_value, max_rank, left, up, length),
            rp=compress(matrix[:length, length:], min_value, max_rank, left + length, up, length),
            ld=compress(matrix[length:, :length], min_value, max_rank, left, up + length, length),
            rd=compress(matrix[length:, length:], min_value, max_rank, left + length, up+length, length)
        )
        return node


def compress_matrix(matrix, min_value, max_rank):
    if max_rank > matrix.shape[0] - 1:
        raise ValueError("Maximum rank should be strictly less than matrix dimension")
    if matrix.shape[0] != matrix.shape[1]:
        raise ValueError("Matrix should be square")
    if not power_of_two(matrix.shape[0]):
        raise ValueError("Matrix dimension should be power of two")
    n, _ = matrix.shape
    return compress(matrix, abs(min_value), max_rank, 0, 0, matrix.shape[0])


In [266]:
mat = random_matrix(64, 0.5)
v = calculate_singular_values(mat)
# plot_singular_values(v)
print(mat)

[[0.42904098 0.         0.55331774 ... 0.41800577 0.9807372  0.        ]
 [0.15114335 0.73254286 0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.62041932 0.        ]
 ...
 [0.32542939 0.00957363 0.41554258 ... 0.9104641  0.28185133 0.        ]
 [0.87371543 0.         0.         ... 0.86924712 0.         0.78143102]
 [0.         0.         0.99032384 ... 0.96723787 0.28824327 0.09878646]]


In [267]:
image_mat = np.full((64, 64), 255)
tree_root = compress_matrix(mat, v[-1], 1)
image = Image.fromarray(image_mat)
image.show()

In [268]:
mat2 = tree_root.eval(64)
print(mat2)

[[0.42904098 0.         0.55331774 ... 0.41762454 0.9807372  0.        ]
 [0.15114335 0.73254286 0.         ... 0.01261793 0.         0.        ]
 [0.         0.         0.         ... 0.         0.62041932 0.        ]
 ...
 [0.32542939 0.00957363 0.41554258 ... 0.9104641  0.28185133 0.        ]
 [0.87371543 0.         0.         ... 0.86924712 0.         0.78143102]
 [0.         0.         0.99032384 ... 0.96723787 0.28824327 0.09878646]]


In [269]:
a, b, c = svds(mat, k=10)
print(a.shape)
print(b.shape)
print(c.shape)

(64, 10)
(10,)
(10, 64)


In [270]:
print(a @ np.diag(b) @ c)

[[ 0.35249905  0.04384438  0.43138162 ...  0.43186511  0.61336161
   0.23833659]
 [-0.09797805  0.63287735  0.410886   ...  0.00599594 -0.11831209
   0.27894651]
 [ 0.19498072  0.0173474   0.21684499 ...  0.26927657  0.47923087
   0.26371855]
 ...
 [ 0.41716158  0.18053695  0.43751597 ...  0.72698849  0.67179983
   0.28881499]
 [ 0.18054536  0.2051022   0.04409298 ...  0.72995069  0.23291675
   0.33550911]
 [-0.02033188 -0.10187654  0.66299714 ...  0.74386544  0.55609905
  -0.00937125]]
