# Basic sparse example

In [None]:
import numpy as np
from scipy.linalg import circulant
from scipy.sparse import csr_matrix

In [None]:
x = (np.arange(4) * 2 + 1).reshape((2, 2))
x = np.pad(x, 1, 'constant')
n_rows_in_x, n_columns_in_x = x.shape
x = x.reshape((-1))
n_elements_in_x = len(x)
x

In [None]:
w = (np.arange(4) + 10).reshape((2, 2))
n_rows_in_w, n_columns_in_w = w.shape
w

In [None]:
f = np.zeros((n_rows_in_x, n_columns_in_x))
f[:n_rows_in_w, :n_columns_in_w] = w
f = f.reshape(-1)
f

In [None]:
w = circulant(f).T
w = csr_matrix(w)
w.toarray()

In [None]:
stride = 1
n_rows_in_output = int(np.floor((n_rows_in_x - n_rows_in_w) / stride + 1))
n_columns_in_output = int(np.floor((n_columns_in_x - n_columns_in_w) / stride + 1))

indices = np.zeros(w.shape[0], dtype=bool)
for index_row, index_start in enumerate(range(0, indices.shape[0], n_columns_in_x)):
    if index_row >= n_rows_in_output:
        break
    index_end = index_start + n_columns_in_output
    indices[index_start:index_end] = [True] * n_columns_in_output

if n_rows_in_output * n_columns_in_output != np.sum(indices):
    raise Exception('Sum of indices should match the values in output')
indices

In [None]:
w = w[indices, :]
w.toarray()

In [None]:
goal = [[10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13]]
np.sum(w.toarray() - goal != 0)

# Real data

In [9]:
from typing import Optional
from timeit import default_timer

import numpy as np
from scipy.linalg import circulant
from scipy.sparse import csr_matrix
from tqdm import tqdm

In [2]:
def load(x: str) -> np.ndarray:
    x = np.load('{}.npy'.format(x))
    x = x[0, 0, :, :]
    print(x.shape)
    return x

x = load('x')
y = load('y')
w = load('w')

x = np.pad(x, 1, 'constant')
n_rows_in_x, n_columns_in_x = x.shape
x = x.reshape((-1))
n_rows_in_w, n_columns_in_w = w.shape

(120, 214)
(120, 214)
(3, 3)


In [3]:
time_start = default_timer()
f = np.zeros((n_rows_in_x, n_columns_in_x))
f[:n_rows_in_w, :n_columns_in_w] = w
f = f.reshape(-1)
print(default_timer() - time_start)

0.0002920829865615815


In [4]:
time_start = default_timer()
w = circulant(f).T
w = csr_matrix(w)
print(w.shape)
print(default_timer() - time_start)

(26352, 26352)
33.8711567989958


In [5]:
time_start = default_timer()
stride = 1
n_rows_in_output = int(np.floor((n_rows_in_x - n_rows_in_w) / stride + 1))
n_columns_in_output = int(np.floor((n_columns_in_x - n_columns_in_w) / stride + 1))

indices = np.zeros(w.shape[0], dtype=bool)
for index_row, index_start in enumerate(range(0, indices.shape[0], n_columns_in_x)):
    if index_row >= n_rows_in_output:
        break
    index_end = index_start + n_columns_in_output
    indices[index_start:index_end] = [True] * n_columns_in_output

if n_rows_in_output * n_columns_in_output != np.sum(indices):
    raise Exception('Sum of indices should match the values in output')
    
w = w[indices, :]
print(w.shape)
print(default_timer() - time_start)

(25680, 26352)
0.010819596005603671


In [6]:
x = x.reshape((-1, 1))
y = y.reshape((-1, 1))

In [18]:
time_start = default_timer()
original_w = w
refined_w = original_w.copy()

for i in tqdm(range(y.shape[0])):
    pass
    # w_tf, num_iter_tf = _trim_layer(X=X, y=Y[i, :], rho=5, alpha=1.8, lmbda=4)
    # refined_w[:, i] = w_tf

print('number of non-zero values in the original weight matrix = ', np.sum(original_w != 0))
print('number of non-zero values in the refined weight matrix = ', np.sum(refined_w != 0))
print(default_timer() - time_start)

100%|██████████| 25680/25680 [00:00<00:00, 2719461.88it/s]

number of non-zero values in the original weight matrix =  231120
number of non-zero values in the refined weight matrix =  231120
0.03951365299872123



