From a94aa2d9aa58a7c2df289588eb4f16d83725ce8f Mon Sep 17 00:00:00 2001 From: Mark Florisson Date: Fri, 5 Apr 2013 15:27:18 +0100 Subject: [PATCH] Add test for hash-based vtable creation --- numba/exttypes/tests/test_vtables.py | 79 +++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/numba/exttypes/tests/test_vtables.py b/numba/exttypes/tests/test_vtables.py index bd2c32ace9d..35c996b94bb 100644 --- a/numba/exttypes/tests/test_vtables.py +++ b/numba/exttypes/tests/test_vtables.py @@ -1 +1,78 @@ -__author__ = 'mark' +# -*- coding: utf-8 -*- +from __future__ import print_function, division, absolute_import + +import numba as nb +from numba import * +from numba.minivect.minitypes import FunctionType +from numba.exttypes import virtual +from numba.exttypes import ordering +from numba.exttypes import methodtable +from numba.exttypes.signatures import Method +from numba.testing.test_support import parametrize, main + +class py_class(object): + pass + +def myfunc1(a): + pass + +def myfunc2(a, b): + pass + +def myfunc3(a, b, c): + pass + +types = list(nb.numeric) + [object_] + +array_types = [t[:] for t in types] +array_types += [t[:, :] for t in types] +array_types += [t[:, :, :] for t in types] + +all_types = types + array_types + +def method(func, name, sig): + return Method(func, name, sig, False, False) + +make_methods1 = lambda: [ + method(myfunc1, 'method', FunctionType(argtype, [argtype])) + for argtype in all_types] + +make_methods2 = lambda: [ + method(myfunc2, 'method', FunctionType(argtype1, [argtype1, argtype2])) + for argtype1 in all_types + for argtype2 in all_types] + + +def make_table(methods): + table = methodtable.VTabType(py_class, []) + table.create_method_ordering() + + for i, method in enumerate(make_methods1()): + key = method.name, method.signature.args + method.lfunc_pointer = i + table.specialized_methods[key] = method + + assert len(methods) == len(table.specialized_methods) + + return table + +def make_hashtable(methods): + table = make_table(methods) + hashtable = virtual.build_hashing_vtab(table) + return hashtable + +#------------------------------------------------------------------------ +# Tests +#------------------------------------------------------------------------ + +@parametrize(make_methods1()) +def test_specializations(methods): + hashtable = make_hashtable(methods) + print(hashtable) + + for i, method in enumerate(methods): + key = virtual.sep201_signature_string(method.signature, method.name) + assert hashtable.find_method(key), (i, method, key) + +if __name__ == '__main__': + main()