Skip to content

Commit

Permalink
prototyped matvec for shamir
Browse files Browse the repository at this point in the history
  • Loading branch information
kgoss1729 committed Dec 20, 2023
1 parent 84f332e commit f15e819
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions cicada/shamir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,97 @@ def _lsb(self, operand):
return ShamirArrayShare(result.storage.reshape(operand.storage.shape))


def matvec(self, lhs, rhs, *, encoding=None):
"""Privacy-preserving product of a matrix and a vector.
Note
----
This is a collective operation that *must* be called
by all players that are members of :attr:`communicator`.
Parameters
----------
lhs: :class:`ShamirArrayShare` or :class:`numpy.ndarray`, required
Secret shared or public value to be multiplied.
rhs: :class:`ShamirArrayShare` or :class:`numpy.ndarray`, required
Secret shared or public value to be multiplied.
encoding: :class:`object`, optional
Encodes public operands and determines the number of bits to shift
right the results. The protocol's :attr:`encoding` is used by
default if :any:`None`.
Returns
-------
result: :class:`ShamirArrayShare`
Secret-shared product of `lhs` and `rhs`.
"""
encoding = self._require_encoding(encoding)

# Private-private matrix-vector multiplication.
if isinstance(lhs, ShamirArrayShare) and isinstance(rhs, ShamirArrayShare):
lshape = lhs.storage.shape
rshape = rhs.storage.shape
if lshape[1] != rshape[0]:
raise ValueError("Incompatible shapes of operands for this operation: got {lshape} and {rshape}.")
if len(lshape) != 2:
raise ValueError("Incompatible shapes of operands for this operation: got {lshape} and lhs must be 2d.")
result = ShamirArrayShare(numpy.zeros((lshape[0],)))
for j in range(lshape[0]):
x = lhs.storage[j]
y = rhs.storage
z = numpy.dot(x, y)
xy = numpy.array((z) % self.field.order, dtype=self.field.dtype)
lc = self._lagrange_coef()
dubdeg = numpy.zeros((len(lc),)+xy.shape, dtype=self.field.dtype)
for i, src in enumerate(self.communicator.ranks):
dubdeg[i]=self.share(src=src, secret=xy, shape=xy.shape, encoding=Identity()).storage #transpose
sharray = numpy.zeros(xy.shape, dtype=self.field.dtype)
for i in range(len(self.communicator.ranks)):
sharray = numpy.array((sharray + dubdeg[i]*lc[i]) % self.field.order, dtype=self.field.dtype)
result.storage[j] = self.right_shift(ShamirArrayShare(sharray), bits=encoding.precision).storage
return result

# Private-public matrix-vector multiplication.
if isinstance(lhs, ShamirArrayShare) and isinstance(rhs, numpy.ndarray):
lshape = lhs.storage.shape
rshape = rhs.shape
if lshape[1] != rshape[0]:
raise ValueError("Incompatible shapes of operands for this operation: got {lshape} and {rshape}.")
if len(lshape) != 2:
raise ValueError("Incompatible shapes of operands for this operation: got {lshape} and lhs must be 2d.")
result = ShamirArrayShare(numpy.zeros((lshape[0],)))
for i in range(lshape[0]):
x = lhs.storage[i]
y = encoding.encode(rhs, self.field)
z = numpy.dot(x, y)
xy = numpy.array((z) % self.field.order, dtype=self.field.dtype)
result.storage[i] = self.right_shift(ShamirArrayShare(xy), bits=encoding.precision).storage
return result

# Public-private matrix-vector multiplication.
if isinstance(lhs, numpy.ndarray) and isinstance(rhs, ShamirArrayShare):
result = self.field_multiply(encoding.encode(lhs, self.field), rhs)
result = self.right_shift(result, bits=encoding.precision)
lshape = lhs.shape
rshape = rhs.storage.shape
if lshape[1] != rshape[0]:
raise ValueError("Incompatible shapes of operands for this operation: got {lshape} and {rshape}.")
if len(lshape) != 2:
raise ValueError("Incompatible shapes of operands for this operation: got {lshape} and lhs must be 2d.")
result = ShamirArrayShare(numpy.zeros((lshape[0],)))
for i in range(lshape[0]):
x = encoding.encode(lhs[i], self.field)
y = rhs.storage
z = numpy.dot(x, y)
xy = numpy.array((z) % self.field.order, dtype=self.field.dtype)
result.storage[i] = self.right_shift(ShamirArrayShare(xy), bits=encoding.precision).storage
return result

raise NotImplementedError(f"Privacy-preserving multiplication not implemented for the given types: {type(lhs)} and {type(rhs)}.") # pragma: no cover



def maximum(self, lhs, rhs):
"""Privacy-preserving elementwise maximum of secret shared arrays.
Expand Down

0 comments on commit f15e819

Please sign in to comment.