Skip to content

Commit

Permalink
some progress on sparse sparse but no dice
Browse files Browse the repository at this point in the history
  • Loading branch information
twhughes committed Dec 14, 2019
1 parent 1703471 commit 264b414
Showing 1 changed file with 79 additions and 47 deletions.
126 changes: 79 additions & 47 deletions ceviche/primitives.py
Expand Up @@ -32,6 +32,20 @@ def make_sparse(entries, indices, N):
coo = sp.coo_matrix((entries, indices), shape=shape, dtype=np.complex128)
return coo.tocsc()

def make_sparse_MxN(entries, indices, shape):
"""Construct a sparse csc matrix
Args:
entries: numpy array with shape (M,) giving values for non-zero
matrix entries.
indices: numpy array with shape (2, M) giving x and y indices for
non-zero matrix entries.
shape: shape of matrix
Returns:
sparse, complex, matrix with specified values
"""
coo = sp.coo_matrix((entries, indices), shape=shape, dtype=np.complex128)
return coo.tocsc()

def transpose_indices(indices):
# returns the transposed indices for transpose sparse matrix creation
return np.flip(indices, axis=0)
Expand Down Expand Up @@ -146,13 +160,27 @@ def sp_mult(entries, indices, x):
A = make_sparse(entries, indices, N=x.size)
return A.dot(x)

def grad_sp_mult_entries_reverse(ans, entries, indices, x):
i, j = indices
# def grad_sp_mult_entries_reverse(ans, entries, indices, x):
# i, j = indices
# def vjp(v):
# return v[i] * x[j]
# return vjp

def grad_sp_mult_entries_reverse(b, entries, indices, x):
entries_1 = np.ones(entries.shape)
num_k = entries.size
ik, jk = indices
indices_BT = np.vstack((np.arange(num_k), jk))
BT = make_sparse_MxN(entries_1, indices_BT, shape=(num_k, x.size))
BTx = BT.dot(x)
def vjp(v):
return v[i] * x[j]
indices_AT = np.vstack((np.arange(num_k), ik))
AT = make_sparse_MxN(entries_1, indices_AT, shape=(num_k, v.size))
ATv = AT.dot(v)
return BTx * ATv
return vjp

def grad_sp_mult_x_reverse(ans, entries, indices, x):
def grad_sp_mult_x_reverse(b, entries, indices, x):
indices_T = transpose_indices(indices)
def vjp(v):
return sp_mult(entries, indices_T, v)
Expand Down Expand Up @@ -216,56 +244,60 @@ def grad_sp_solve_x_forward(g, x, entries, indices, b):
""" ==========================Sparse Matrix-Sparse Matrix Multiplication ========================== """

@ag.primitive
def spsp_mult(entries_a, indices_a, entries_b, indices_b, N):
""" Multiply a sparse matrix (A) by a sparse matrix (B)
def spsp_mult(entries_a, indices_a, entries_x, indices_x, N):
""" Multiply a sparse matrix (A) by a sparse matrix (X) A @ X = B
Args:
entries_a: numpy array with shape (num_non_zeros,) giving values for non-zero
matrix entries into A.
indices_a: numpy array with shape (2, num_non_zeros) giving x and y indices for
non-zero matrix entries into A.
entries_b: numpy array with shape (num_non_zeros,) giving values for non-zero
matrix entries into B.
indices_b: numpy array with shape (2, num_non_zeros) giving x and y indices for
non-zero matrix entries into B.
entries_x: numpy array with shape (num_non_zeros,) giving values for non-zero
matrix entries into X.
indices_x: numpy array with shape (2, num_non_zeros) giving x and y indices for
non-zero matrix entries into X.
N: all matrices are assumed of shape (N, N) (need to specify because no dense vector supplied)
Returns:
entries_c: numpy array with shape (num_non_zeros,) giving values for non-zero
matrix entries into the result C.
indices_c: numpy array with shape (2, num_non_zeros) giving x and y indices for
non-zero matrix entries into the result C.
entries_b: numpy array with shape (num_non_zeros,) giving values for non-zero
matrix entries into the result B.
indices_b: numpy array with shape (2, num_non_zeros) giving i, j indices for
non-zero matrix entries into the result B.
"""
A = make_sparse(entries_a, indices_a, N=N)
B = make_sparse(entries_b, indices_b, N=N)
C = A.dot(B)
entries_c, indices_c = get_entries_indices(C)
return entries_c, indices_c

def grad_spsp_mult_entries_a_reverse(ans, entries_a, indices_a, entries_b, indices_b, N):
X = make_sparse(entries_x, indices_x, N=N)
B = A.dot(X)
entries_b, indices_b = get_entries_indices(B)
return entries_b, indices_b

def grad_spsp_mult_entries_a_reverse(b_out, entries_a, indices_a, entries_x, indices_x, N):
entries_1 = np.ones(entries_a.shape)
num_k = entries_a.size
ik, jk = indices_a
def vjp(v):
entries_v, indices_v = v
V = make_sparse(entries_v, indices_v, N).todense()
B = make_sparse(entries_b, indices_b, N).todense()
V_z = V[ik, indices_v[1]]
B_z = B[jk, indices_b[1]]
V_B = np.multiply(B_z, V_z)
return V_B.flatten()
indices_BT = np.vstack((np.arange(num_k), jk))
BT = make_sparse_MxN(entries_1, indices_BT, shape=(num_k, N))
X = make_sparse(entries_x, indices_x, N)
BTX = BT.dot(X)
def vjp(V):
indices_AT = np.vstack((np.arange(num_k), ik))
AT = make_sparse_MxN(entries_1, indices_AT, shape=(num_k, N))
entries_v, indices_v = V
indices_v = np.vstack((np.arange(N), np.arange(N)))
print(entries_v, indices_v)
V = sp.diags(entries_v, shape=(N,N))
print(V.todense())
ATV = AT.dot(V)
ATV_BTX = BTX.T.multiply(ATV.T)
print(ATV_BTX)
return ATV_BTX.sum(axis=0)
return vjp

def grad_spsp_mult_entries_a_reverse(ans, entries_a, indices_a, entries_b, indices_b, N):
# why you no work?
ik, jk = indices_a
def vjp(v):
entries_v, indices_v = v
return entries_v[ik] * entries_b[jk]
return vjp

ag.extend.defvjp(spsp_mult, grad_spsp_mult_entries_a_reverse, None, None)

def grad_spsp_mult_entries_a_forward(g, ans, entries_a, indices_a, entries_b, indices_b, N):
# out = spsp_mult(g, iandices_a, entries_b, indices_b, N)
# entries_out, indices_out = out
return spsp_mult(g, indices_a, entries_b, indices_b, N)
def grad_spsp_mult_entries_a_forward(g, out_b, entries_a, indices_a, entries_x, indices_x, N):
return spsp_mult(g, indices_a, entries_x, indices_x, N)

# def grad_sp_mult_x_forward(g, b, entries, indices, x):
# return sp_mult(entries, indices, g)

ag.extend.defjvp(spsp_mult, grad_spsp_mult_entries_a_forward, None, None)

Expand Down Expand Up @@ -392,8 +424,8 @@ def vjp(v):

## Setup

N = 5 # size of matrix dimensions. matrix shape = (N, N)
M = N**2 # number of non-zeros (make it dense for numerical stability)
N = 4 # size of matrix dimensions. matrix shape = (N, N)
M = N**2- 1 # number of non-zeros (make it dense for numerical stability)

# these are the default values used within the test functions
indices_const = make_rand_indeces(N, M)
Expand All @@ -408,25 +440,25 @@ def out_fn(output_vector):
def fn_spsp_entries(entries):
# sparse matrix multiplication (Ax = b) as a function of matrix entries 'A(entries)'
entries_c, indices_c = spsp_mult(entries, indices_const, entries_const, indices_const, N=N)
return out_fn(entries_c)
# return out_fn(entries_c)
x = sp_solve(entries_c, indices_c, b_const)
return out_fn(x)

entries = make_rand_complex(M)

# doesnt pass yet
grad_rev = ceviche.jacobian(fn_spsp_entries, mode='reverse')(entries)[0]
grad_true = grad_num(fn_spsp_entries, entries)

# doesnt pass yet
np.testing.assert_almost_equal(grad_rev, grad_true, decimal=DECIMAL)

# Testing Gradients of 'Sparse-Sparse Multiply entries Forward-mode'

grad_for = ceviche.jacobian(fn_spsp_entries, mode='forward')(entries)[0]
grad_true = grad_num(fn_spsp_entries, entries)
# grad_for = ceviche.jacobian(fn_spsp_entries, mode='forward')(entries)[0]
# grad_true = grad_num(fn_spsp_entries, entries)

# print(grad_for, grad_true)
# doesnt pass for more complicated functions
np.testing.assert_almost_equal(grad_for, grad_true, decimal=DECIMAL)
# np.testing.assert_almost_equal(grad_for, grad_true, decimal=DECIMAL)

## TESTS SPARSE MATRX CREATION

Expand Down

0 comments on commit 264b414

Please sign in to comment.