-
Notifications
You must be signed in to change notification settings - Fork 76
/
test_seclists.py
147 lines (133 loc) · 5.66 KB
/
test_seclists.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import operator
import unittest
from mpyc.runtime import mpc
from mpyc.seclists import seclist, secindex
class Arithmetic(unittest.TestCase):
@classmethod
def setUpClass(cls):
mpc.logging(False)
def test_secfld(self):
secfld = mpc.SecFld(101)
s = seclist([], secfld)
self.assertEqual(s, [])
s = seclist(sectype=secfld)
self.assertEqual(s, [])
s.append(False)
s.append(secfld(100))
s[1:2] = (1,)
s += [2, 3]
s.reverse()
s = [5, 4] + s
s.reverse()
s = s + [6, 7]
del s[0]
s.remove(4)
s[5] = 9
del s[2:4]
self.assertEqual(mpc.run(mpc.output(list(s))), [1, 2, 6, 9])
mpc.peek(list(s))
secfld2 = mpc.SecFld()
self.assertRaises(TypeError, seclist, [secfld(1)], secfld2)
self.assertRaises(ValueError, seclist, [])
self.assertRaises(TypeError, operator.add, seclist([secfld(1)]), seclist([secfld2(1)]))
def test_secint(self):
secint = mpc.SecInt()
s = seclist([], secint)
self.assertEqual(s, [])
s = seclist(sectype=secint)
self.assertEqual(s, [])
s.append(False)
s.sort()
s.append(secint(7))
s[0] = secint(13)
self.assertEqual(mpc.run(mpc.output(list(s))), [13, 7])
i = [secint(0), secint(1)]
s[i] = 5
self.assertEqual(mpc.run(mpc.output(s[1])), 5)
i0 = [secint(1), secint(0)]
i1 = [secint(0), secint(1)]
s[i0], s[i1] = s[i1], s[i0]
self.assertEqual(mpc.run(mpc.output(list(s))), [5, 13])
s[i0], s[i1] = s[i1], s[i0]
self.assertEqual(mpc.run(mpc.output(list(s))), [13, 5])
s[0], s[1] = s[1], s[0]
self.assertEqual(mpc.run(mpc.output(list(s))), [5, 13])
s.append(secint(8)) # s = [5, 13, 8]
s.reverse() # s = [8, 13, 5]
s.insert(secint(0), 9) # s = [9, 8, 13, 5]
del s[secint(1)] # s = [9, 13, 5]
s.pop(secint(2)) # s = [9, 13]
s.insert(0, 99) # s = [99, 9, 13]
s.pop(0) # s = [9, 13]
self.assertRaises(ValueError, s.remove, secint(11))
s *= 2 # s = [9, 13, 9, 13]
s.remove(9) # s = [13, 9, 13]
s[0:1] = [] # s = [9, 13]
s = 1 * s + s * 0 # s = [9, 13]
self.assertEqual(mpc.run(mpc.output(list(s))), [9, 13])
self.assertEqual(mpc.run(mpc.output(s[secint(1)])), 13)
s[secint(1)] = secint(21)
self.assertEqual(mpc.run(mpc.output(s[1])), 21)
self.assertRaises(IndexError, s.insert, [secint(1), secint(0)], 42)
self.assertRaises(IndexError, s.pop, [secint(1)])
self.assertRaises(IndexError, s.__getitem__, [secint(1)])
self.assertRaises(TypeError, s.__setitem__, [secint(1)], 42.5)
self.assertRaises(TypeError, s.__setitem__, slice(0, 2), seclist([0], mpc.SecFxp()))
self.assertRaises(IndexError, s.__setitem__, [secint(1)], 42)
self.assertRaises(IndexError, s.__delitem__, [secint(1)])
s = seclist([0]*7, secint)
for a in [secint(3)]*3 + [secint(4)]*4:
s[a] += 1
self.assertEqual(mpc.run(mpc.output(list(s))), [0, 0, 0, 3, 4, 0, 0])
with self.assertRaises(NotImplementedError):
0 in s
self.assertTrue(mpc.run(mpc.output(s.contains(0))))
self.assertFalse(mpc.run(mpc.output(s.contains(9))))
self.assertEqual(mpc.run(mpc.output(s.count(0))), 5)
self.assertEqual(mpc.run(mpc.output(s.find(3))), 3)
self.assertEqual(mpc.run(mpc.output(s.index(4))), 4)
self.assertRaises(ValueError, s.index, 9)
self.assertEqual(mpc.run(mpc.output(seclist([], secint).find(9))), -1)
self.assertRaises(ValueError, seclist([], secint).index, 0)
s.sort(lambda a: -a**2, reverse=True)
s.sort()
self.assertEqual(mpc.run(mpc.output(list(s))), 5*[0] + [3, 4])
self.assertFalse(mpc.run(mpc.output(s < s)))
self.assertTrue(mpc.run(mpc.output(s <= s)))
self.assertTrue(mpc.run(mpc.output(s == s)))
self.assertFalse(mpc.run(mpc.output(s > s)))
self.assertTrue(mpc.run(mpc.output(s >= s)))
self.assertFalse(mpc.run(mpc.output(s != s)))
self.assertFalse(mpc.run(mpc.output(s < [])))
self.assertFalse(mpc.run(mpc.output(s <= [])))
self.assertTrue(mpc.run(mpc.output(s >= [])))
self.assertTrue(mpc.run(mpc.output(s > [])))
self.assertFalse(mpc.run(mpc.output(s < s[:-1])))
self.assertTrue(mpc.run(mpc.output(s[:-1] < s)))
self.assertTrue(mpc.run(mpc.output(s[:-1] != s)))
t = s.copy()
t[-1] += 1
self.assertTrue(mpc.run(mpc.output(s < t)))
t[1] -= 1
self.assertFalse(mpc.run(mpc.output(s < t)))
self.assertFalse(mpc.run(mpc.output(s[:-1] <= t)))
s = seclist([1, 2, 3, 4], secint)
t = mpc.run(mpc.transfer(s, senders=0))
self.assertTrue(mpc.run(mpc.output(s == t)))
def test_secfxp(self):
secfxp = mpc.SecFxp()
s = seclist([5, -3, 2, 5, 5], secfxp)
self.assertFalse(mpc.run(mpc.output(s < s)))
t = s[:]
t[-1] += 1
self.assertTrue(mpc.run(mpc.output(s < t)))
s = [[1, 0], [0, 1], [0, 0], [1, 1]]
ss = mpc.sorted([[secfxp(a) for a in _] for _ in s], key=seclist)
self.assertEqual([mpc.run(mpc.output(_)) for _ in ss], sorted(s))
def test_secindex(self):
secint = mpc.SecInt()
i = secindex([secint(0), secint(0), secint(1), secint(0)])
j = secindex([secint(0), secint(1), secint(0)])
k = i + j
if __name__ == "__main__":
unittest.main()