Skip to content

Commit

Permalink
Enable TLS for more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
tshead2 committed Mar 11, 2022
1 parent 66a6651 commit 524ab42
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions features/steps/additive_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def operation(communicator):
protocol = cicada.additive.AdditiveProtocol(communicator)

for i in range(count):
SocketCommunicator.run(world_size=context.players, fn=operation)
SocketCommunicator.run(world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@when(u'secret sharing the same value for {count} sessions')
Expand All @@ -49,7 +49,7 @@ def operation(communicator):

context.shares = []
for i in range(count):
context.shares.append(SocketCommunicator.run(world_size=context.players, fn=operation))
context.shares.append(SocketCommunicator.run(world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted))
context.shares = numpy.array(context.shares, dtype=numpy.object)


Expand All @@ -62,7 +62,7 @@ def operation(communicator, count):
shares = [protocol.share(src=0, secret=protocol.encoder.encode(numpy.array(5)), shape=()) for i in range(count)]
return numpy.array([int(share.storage) for share in shares], dtype=numpy.object)

context.shares = numpy.column_stack(SocketCommunicator.run(world_size=context.players, fn=operation, args=(count,)))
context.shares = numpy.column_stack(SocketCommunicator.run(world_size=context.players, fn=operation, args=(count,), identities=context.identities, trusted=context.trusted))


@then(u'the shares should never be repeated')
Expand Down Expand Up @@ -94,7 +94,7 @@ def operation(communicator, secret, player, local):
protocol.encoder.inplace_add(share.storage, protocol.encoder.encode(local))
return protocol.encoder.decode(protocol.reveal(share))

context.results = SocketCommunicator.run(world_size=context.players, fn=operation, args=(context.secret, player, context.local))
context.results = SocketCommunicator.run(world_size=context.players, fn=operation, args=(context.secret, player, context.local), identities=context.identities, trusted=context.trusted)


@when(u'player {} performs local in-place subtraction on the shared secret')
Expand All @@ -108,7 +108,7 @@ def operation(communicator, secret, player, local):
protocol.encoder.inplace_subtract(share.storage, protocol.encoder.encode(local))
return protocol.encoder.decode(protocol.reveal(share))

context.results = SocketCommunicator.run(world_size=context.players, fn=operation, args=(context.secret, player, context.local))
context.results = SocketCommunicator.run(world_size=context.players, fn=operation, args=(context.secret, player, context.local), identities=context.identities, trusted=context.trusted)


@then(u'the group should return {} to within {} digits')
Expand Down Expand Up @@ -145,7 +145,7 @@ def operation(communicator, a, b):
c = protocol.public_private_add(a, b)

return protocol.encoder.decode(protocol.reveal(c))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation private-private addition')
Expand All @@ -160,7 +160,7 @@ def operation(communicator, a, b):
c = protocol.add(a, b)

return protocol.encoder.decode(protocol.reveal(c))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation private-private untruncated multiplication')
Expand All @@ -177,7 +177,7 @@ def operation(communicator, a, b):

logging.debug(f"Comm {communicator.name!r} player {communicator.rank} reveal")
return protocol.encoder.decode(protocol.reveal(c))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation private-private xor')
Expand All @@ -191,7 +191,7 @@ def operation(communicator, a, b):
b = protocol.share(src=1, secret=b, shape=b.shape)
c = protocol.logical_xor(a, b)
return protocol.reveal(c)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation private-private or')
Expand All @@ -205,7 +205,7 @@ def operation(communicator, a, b):
b = protocol.share(src=1, secret=b, shape=b.shape)
c = protocol.logical_or(a, b)
return protocol.reveal(c)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation max')
Expand All @@ -220,7 +220,7 @@ def operation(communicator, a, b):
c_share = protocol.max(a_share, b_share)

return protocol.encoder.decode(protocol.reveal(c_share))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation min')
Expand All @@ -235,7 +235,7 @@ def operation(communicator, a, b):
c_share = protocol.min(a_share, b_share)

return protocol.encoder.decode(protocol.reveal(c_share))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation private-private multiplication')
Expand All @@ -250,7 +250,7 @@ def operation(communicator, a, b):
c = protocol.untruncated_multiply(a, b)
c = protocol.truncate(c)
return protocol.encoder.decode(protocol.reveal(c))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation private-private equality')
Expand All @@ -264,7 +264,7 @@ def operation(communicator, a, b):
b = protocol.share(src=1, secret=b, shape=b.shape)
c = protocol.equal(a, b)
return protocol.reveal(c)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation private-public modulus')
Expand All @@ -277,7 +277,7 @@ def operation(communicator, a, b):
b = numpy.array(b)
c = protocol.private_public_mod(a, b)
return protocol.encoder.decode(protocol.reveal(c))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'operands {a} and {b}')
Expand Down Expand Up @@ -309,7 +309,7 @@ def operation(communicator, a):
a_share = protocol.share(src=0, secret=protocol.encoder.encode(a), shape=a.shape)
b_share = protocol.floor(a_share)
return protocol.encoder.decode(protocol.reveal(b_share))
context.unary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.unary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@when(u'the unary operation is executed {count} times')
Expand All @@ -334,7 +334,7 @@ def operation(communicator, secret):

for index in range(count):
secret = numpy.array(numpy.random.uniform(-100000, 100000))
results = SocketCommunicator.run(world_size=context.players, fn=operation, args=(secret,))
results = SocketCommunicator.run(world_size=context.players, fn=operation, args=(secret,), identities=context.identities, trusted=context.trusted)
for result in results:
numpy.testing.assert_almost_equal(secret, result, decimal=4)

Expand All @@ -353,7 +353,7 @@ def operation(communicator, bits, src, seed):
secret = protocol.reveal(secret_share)
return bits, secret

result = SocketCommunicator.run(world_size=context.players, fn=operation, args=(bits, src, seed))
result = SocketCommunicator.run(world_size=context.players, fn=operation, args=(bits, src, seed), identities=context.identities, trusted=context.trusted)
for bits, secret in result:
test.assert_equal(secret, numpy.sum(numpy.power(2, numpy.arange(len(bits))[::-1]) * bits))

Expand All @@ -368,7 +368,7 @@ def operation(communicator, a):
b_share = protocol.multiplicative_inverse(a_share)
one_share = protocol.untruncated_multiply(a_share, b_share)
return protocol.reveal(one_share)
context.unary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.unary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation less')
Expand All @@ -382,7 +382,7 @@ def operation(communicator, a, b):
b = protocol.share(src=0, secret=b, shape=b.shape)
c = protocol.less(a, b)
return protocol.reveal(c)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'unary operation relu')
Expand All @@ -394,7 +394,7 @@ def operation(communicator, a):
a_share = protocol.share(src=0, secret=protocol.encoder.encode(a), shape=a.shape)
relu_share = protocol.relu(a_share)
return protocol.encoder.decode(protocol.reveal(relu_share))
context.unary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.unary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)

@given(u'unary operation zigmoid')
def step_impl(context):
Expand All @@ -405,7 +405,7 @@ def operation(communicator, a):
a_share = protocol.share(src=0, secret=protocol.encoder.encode(a), shape=a.shape)
zigmoid_share = protocol.zigmoid(a_share)
return protocol.encoder.decode(protocol.reveal(zigmoid_share))
context.unary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.unary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)

@when(u'player {} performs private public subtraction on the shared secret')
def step_impl(context, player):
Expand All @@ -417,7 +417,7 @@ def operation(communicator, secret, player, local):
result = protocol.private_public_subtract(share, protocol.encoder.encode(local))
return protocol.encoder.decode(protocol.reveal(result))

context.results = SocketCommunicator.run(world_size=context.players, fn=operation, args=(context.secret, player, context.local))
context.results = SocketCommunicator.run(world_size=context.players, fn=operation, args=(context.secret, player, context.local), identities=context.identities, trusted=context.trusted)


@given(u'binary operation logical_and')
Expand All @@ -431,7 +431,7 @@ def operation(communicator, a, b):
b = protocol.share(src=0, secret=b, shape=b.shape)
c = protocol.logical_and(a, b)
return protocol.reveal(c)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)


@given(u'binary operation private_public_power')
Expand All @@ -444,7 +444,7 @@ def operation(communicator, a, b):
b = numpy.array(b)
c = protocol.private_public_power(a, b)
return protocol.encoder.decode(protocol.reveal(c))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)



Expand All @@ -460,4 +460,4 @@ def operation(communicator, a, b):
c = protocol.untruncated_private_divide(a, b)
c = protocol.truncate(c)
return protocol.encoder.decode(protocol.reveal(c))
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation)
context.binary_operation = functools.partial(SocketCommunicator.run, world_size=context.players, fn=operation, identities=context.identities, trusted=context.trusted)

0 comments on commit 524ab42

Please sign in to comment.