-
Notifications
You must be signed in to change notification settings - Fork 274
/
test_distributions.py
111 lines (89 loc) · 4.07 KB
/
test_distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import crypten
import torch
from test.multiprocess_test_case import MultiProcessTestCase
class TestDistributions(object):
"""
This class tests accuracy of distributions provided by random sampling in crypten.
"""
def _check_distribution(
self, func, expected_mean, expected_variance, lb=None, ub=None
):
"""
Checks that the function `func` returns a distribution with the expected
size, mean, and variance.
Arguments:
func - A function that takes a size and returns a random sample as a CrypTensor
expected_mean - The expected mean for the distribution returned by function `func`
expected_variance - The expected variance for the distribution returned by function `func
lb - An expected lower bound on samples from the given distribution. Use None if -Inf.
ub - An expected uppder bound on samples from the given distribution. Use None if +Inf.
"""
name = func.__name__
for size in [(10000,), (1000, 10), (101, 11, 11)]:
sample = func(size)
self.assertTrue(
sample.size() == size, "Incorrect size for %s distribution" % name
)
plain_sample = sample.get_plain_text().float()
mean = plain_sample.mean()
var = plain_sample.var()
self.assertTrue(
math.isclose(mean, expected_mean, rel_tol=1e-1, abs_tol=1e-1),
"incorrect variance for %s distribution: %f" % (name, mean),
)
self.assertTrue(
math.isclose(var, expected_variance, rel_tol=1e-1, abs_tol=1e-1),
"incorrect variance for %s distribution: %f" % (name, var),
)
if lb is not None:
self.assertTrue(
plain_sample.ge(lb).all(),
"Sample detected below lower bound for %s distribution" % name,
)
if ub is not None:
self.assertTrue(
plain_sample.le(ub).all(),
"Sample detected below lower bound for %s distribution" % name,
)
def test_uniform(self):
self._check_distribution(crypten.rand, 0.5, 0.083333, lb=0, ub=1)
def test_normal(self):
self._check_distribution(crypten.randn, 0, 1)
def test_bernoulli(self):
for p in [0.25 * i for i in range(5)]:
def bernoulli(*size):
x = crypten.cryptensor(p * torch.ones(*size))
return x.bernoulli()
self._check_distribution(bernoulli, p, p * (1 - p), lb=0, ub=1)
# Assert all values are in discrete set {0, 1}
tensor = bernoulli((1000,)).get_plain_text()
self.assertTrue(
((tensor == 0) + (tensor == 1)).all(), "Invalid Bernoulli values"
)
# Run all unit tests with both TFP and TTP providers
class TestTFP(MultiProcessTestCase, TestDistributions):
def setUp(self):
self._original_provider = crypten.mpc.get_default_provider()
crypten.CrypTensor.set_grad_enabled(False)
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
super(TestTFP, self).setUp()
def tearDown(self):
crypten.mpc.set_default_provider(self._original_provider)
crypten.CrypTensor.set_grad_enabled(True)
super(TestTFP, self).tearDown()
class TestTTP(MultiProcessTestCase, TestDistributions):
def setUp(self):
self._original_provider = crypten.mpc.get_default_provider()
crypten.CrypTensor.set_grad_enabled(False)
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedThirdParty)
super(TestTTP, self).setUp()
def tearDown(self):
crypten.mpc.set_default_provider(self._original_provider)
crypten.CrypTensor.set_grad_enabled(True)
super(TestTTP, self).tearDown()