# Basic sparse example

In [None]:
import numpy as np
from scipy.linalg import circulant
from scipy.sparse import csr_matrix

In [None]:
x = (np.arange(4) * 2 + 1).reshape((2, 2))
x = np.pad(x, 1, 'constant')
n_rows_in_x, n_columns_in_x = x.shape
x = x.reshape((-1))
n_elements_in_x = len(x)
x

In [None]:
w = (np.arange(4) + 10).reshape((2, 2))
n_rows_in_w, n_columns_in_w = w.shape
w

In [None]:
f = np.zeros((n_rows_in_x, n_columns_in_x))
f[:n_rows_in_w, :n_columns_in_w] = w
f = f.reshape(-1)
f

In [None]:
w = circulant(f).T
w = csr_matrix(w)
w.toarray()

In [None]:
stride = 1
n_rows_in_output = int(np.floor((n_rows_in_x - n_rows_in_w) / stride + 1))
n_columns_in_output = int(np.floor((n_columns_in_x - n_columns_in_w) / stride + 1))

indices = np.zeros(w.shape[0], dtype=bool)
for index_row, index_start in enumerate(range(0, indices.shape[0], n_columns_in_x)):
    if index_row >= n_rows_in_output:
        break
    index_end = index_start + n_columns_in_output
    indices[index_start:index_end] = [True] * n_columns_in_output

if n_rows_in_output * n_columns_in_output != np.sum(indices):
    raise Exception('Sum of indices should match the values in output')
indices

In [None]:
w = w[indices, :]
w.toarray()

In [None]:
goal = [[10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 13]]
np.sum(w.toarray() - goal != 0)

# Real data

In [None]:
from typing import Optional

import numpy as np
from scipy.linalg import circulant
from scipy.sparse import csr_matrix

In [None]:
def load(x: str) -> np.ndarray:
    x = np.load('{}.npy'.format(x))
    x = x[0, 0, :, :]
    print(x.shape)
    return x

x = load('x')
y = load('y')
w = load('w')

x = np.pad(x, 1, 'constant')
n_rows_in_x, n_columns_in_x = x.shape
x = x.reshape((-1))
n_rows_in_w, n_columns_in_w = w.shape

In [None]:
f = np.zeros((n_rows_in_x, n_columns_in_x))
f[:n_rows_in_w, :n_columns_in_w] = w
f = f.reshape(-1)

In [None]:
w = circulant(f).T
w = csr_matrix(w)
w.shape

In [None]:
stride = 1
n_rows_in_output = int(np.floor((n_rows_in_x - n_rows_in_w) / stride + 1))
n_columns_in_output = int(np.floor((n_columns_in_x - n_columns_in_w) / stride + 1))

indices = np.zeros(w.shape[0], dtype=bool)
for index_row, index_start in enumerate(range(0, indices.shape[0], n_columns_in_x)):
    if index_row >= n_rows_in_output:
        break
    index_end = index_start + n_columns_in_output
    indices[index_start:index_end] = [True] * n_columns_in_output

if n_rows_in_output * n_columns_in_output != np.sum(indices):
    raise Exception('Sum of indices should match the values in output')
    
w = w[indices, :]
w.shape