Permalink
Browse files

Refactor unit tests.

  • Loading branch information...
1 parent c7899c3 commit c966a7ff241b76f1df46178d73cb5ee9763b6d30 @bjornedstrom committed Oct 3, 2012
Showing with 5,484 additions and 41 deletions.
  1. +46 −0 test/generate_tests.py
  2. +0 −41 test/test.py
  3. +5,438 −0 test/test_vectors.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+
+# This will generate a test suite.
+
+def generate():
+ FILES = [
+ ('test/data/ShortMsgKAT_224.txt', 'sha3.SHA3224'),
+ ('test/data/ShortMsgKAT_256.txt', 'sha3.SHA3256'),
+ ('test/data/ShortMsgKAT_384.txt', 'sha3.SHA3384'),
+ ('test/data/ShortMsgKAT_512.txt', 'sha3.SHA3512'),
+ ('test/data/LongMsgKAT_224.txt', 'sha3.SHA3224'),
+ ]
+
+ print """
+# This file generated by generate_tests.py
+
+import sha3
+import unittest
+
+class SHA3Tests(unittest.TestCase):
+"""
+
+ for path, instance_str in FILES:
+ contents = file(path).read().split('Len = ')
+ for test in contents:
+ lines = test.split('\n')
+ if lines and len(lines) and not lines[0].startswith('#'):
+ length = int(lines[0])
+ if length % 8 == 0 and length != 0:
+ msg = lines[1].split(' = ')[-1].lower()
+ md = lines[2].split(' = ')[-1].lower()
+
+ print """ def test_%s_%s(self):
+ inst = %s()
+ inst.update(%r.decode('hex'))
+ assert inst.hexdigest() == %r
+""" % (path.split('/')[-1].split('.')[0], length, instance_str, msg, md)
+
+ print """
+
+if __name__ == '__main__':
+ unittest.main()
+"""
+
+if __name__ == '__main__':
+ generate()
View
@@ -1,41 +0,0 @@
-import sha3
-import unittest
-
-class SHA3Tests(unittest.TestCase):
-
- FILES = [
- ('test/data/ShortMsgKAT_224.txt', sha3.SHA3224),
- ('test/data/ShortMsgKAT_256.txt', sha3.SHA3256),
- ('test/data/ShortMsgKAT_384.txt', sha3.SHA3384),
- ('test/data/ShortMsgKAT_512.txt', sha3.SHA3512),
- ('test/data/LongMsgKAT_224.txt', sha3.SHA3224),
- ]
-
- def test_from_files(self):
- num_tests = 0
- for path, instance in self.FILES:
- print path
- contents = file(path).read().split('Len = ')
- for test in contents:
- lines = test.split('\n')
- if lines and len(lines) and not lines[0].startswith('#'):
- length = int(lines[0])
- if length % 8 == 0 and length != 0:
- msg = lines[1].split(' = ')[-1]
- md = lines[2].split(' = ')[-1]
-
- h = instance()
- h.update(msg.decode('hex'))
- try:
- assert h.hexdigest().upper() == md
- num_tests += 1
- except:
- print path
- print test
- print (msg.decode('hex'), h.hexdigest().upper(), md)
- raise
- print 'Ran %d tests.' % (num_tests,)
-
-
-if __name__ == '__main__':
- unittest.main()
Oops, something went wrong.

0 comments on commit c966a7f

Please sign in to comment.