Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Dispatcher to remove unnecessary indirection #6349

Merged
merged 13 commits into from
Nov 25, 2020
Merged
6 changes: 2 additions & 4 deletions docs/source/developer/repomap.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ Dispatching
- :ghfile:`numba/core/dispatcher.py` - Dispatcher objects are compiled functions
produced by ``@jit``. A dispatcher has different implementations
for different type signatures.
- :ghfile:`numba/_dispatcher.{h,c}` - C interface to C++ dispatcher
implementation
- :ghfile:`numba/_dispatcherimpl.cpp` - C++ dispatcher implementation (for
speed on common data types)
- :ghfile:`numba/_dispatcher.cpp` - C++ dispatcher implementation (for speed on
common data types)


Compiler Pipeline
Expand Down
168 changes: 128 additions & 40 deletions numba/_dispatcher.c → numba/_dispatcher.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "_pymodule.h"

#include <string.h>
#include <time.h>
#include <assert.h>
#include <cstring>
#include <ctime>
#include <cassert>
#include <vector>

#include "_dispatcher.h"
#include "_typeof.h"
#include "frameobject.h"
#include "core/typeconv/typeconv.hpp"

/*
* The following call_trace and call_trace_protected functions
Expand Down Expand Up @@ -97,46 +98,129 @@ else \
} \
}

typedef std::vector<Type> TypeTable;
typedef std::vector<PyObject*> Functions;

typedef struct DispatcherObject{
/* The Dispatcher class is the base class of all dispatchers in the CPU and
CUDA targets. Its main responsibilities are:

- Resolving the best overload to call for a given set of arguments, and
- Calling the resolved overload.

This logic is implemented within this class for efficiency (lookup of the
appropriate overload needs to be fast) and ease of implementation (calling
directly into a compiled function using a function pointer is easier within
the C++ code where the overload has been resolved). */
class Dispatcher {
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
public:
PyObject_HEAD
/* Holds borrowed references to PyCFunction objects */
dispatcher_t *dispatcher;
char can_compile; /* Can auto compile */
char can_fallback; /* Can fallback */
/* Whether compilation of new overloads is permitted */
char can_compile;
/* Whether fallback to object mode is permitted */
char can_fallback;
/* Whether types must match exactly when resolving overloads.
If not, conversions (e.g. float32 -> float64) are permitted when
searching for a match. */
char exact_match_required;
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
/* Borrowed reference */
PyObject *fallbackdef;
/* Whether to fold named arguments and default values (false for lifted loops)*/
/* Whether to fold named arguments and default values
(false for lifted loops) */
int fold_args;
/* Whether the last positional argument is a stararg */
int has_stararg;
/* Tuple of argument names */
PyObject *argnames;
/* Tuple of default values */
PyObject *defargs;
} DispatcherObject;
/* Number of arguments to function */
int argct;
/* Used for selecting overloaded function implementations */
TypeManager *tm;
/* An array of overloads */
Functions functions;
/* A flattened array of argument types to all overloads
* (invariant: sizeof(overloads) == argct * sizeof(functions)) */
TypeTable overloads;
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved

/* Add a new overload. Parameters:

- args: An array of Type objects, one for each parameter
- callable: The callable implementing this overload. */
void addDefinition(Type args[], PyObject *callable) {
overloads.reserve(argct + overloads.size());
for (int i=0; i<argct; ++i) {
overloads.push_back(args[i]);
}
functions.push_back(callable);
}

/* Given a list of types, find the overloads that have a matching signature.
Returns the best match, as well as the number of matches found.

Parameters:

- sig: an array of Type objects, one for each parameter.
- matches: the number of matches found (mutated by this function).
- allow_unsafe: whether to match overloads that would require an unsafe
cast.
- exact_match_required: Whether all arguments types must match the
overload's types exactly. When false,
overloads that would require a type conversion
can also be matched. */
PyObject* resolve(Type sig[], int &matches, bool allow_unsafe,
bool exact_match_required) const {
const int ovct = functions.size();
int selected;
matches = 0;
if (0 == ovct) {
// No overloads registered
return NULL;
}
if (argct == 0) {
// Nullary function: trivial match on first overload
matches = 1;
selected = 0;
}
Comment on lines +180 to +184
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this. If there are multiple definitions, it should be an ambiguous resolution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could there be multiple definitions with no arguments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok.

btw, there shouldn't be multiple definitions for no arguments. It will have to be a mistake.

else {
matches = tm->selectOverload(sig, &overloads[0], selected, argct,
ovct, allow_unsafe,
exact_match_required);
}
if (matches == 1) {
return functions[selected];
}
return NULL;
}

/* Remove all overloads */
void clear() {
functions.clear();
overloads.clear();
}

};


static int
Dispatcher_traverse(DispatcherObject *self, visitproc visit, void *arg)
Dispatcher_traverse(Dispatcher *self, visitproc visit, void *arg)
{
Py_VISIT(self->defargs);
return 0;
}

static void
Dispatcher_dealloc(DispatcherObject *self)
Dispatcher_dealloc(Dispatcher *self)
{
Py_XDECREF(self->argnames);
Py_XDECREF(self->defargs);
dispatcher_del(self->dispatcher);
self->clear();
Py_TYPE(self)->tp_free((PyObject*)self);
}


static int
Dispatcher_init(DispatcherObject *self, PyObject *args, PyObject *kwds)
Dispatcher_init(Dispatcher *self, PyObject *args, PyObject *kwds)
{
PyObject *tmaddrobj;
void *tmaddr;
Expand All @@ -158,7 +242,8 @@ Dispatcher_init(DispatcherObject *self, PyObject *args, PyObject *kwds)
Py_INCREF(self->argnames);
Py_INCREF(self->defargs);
tmaddr = PyLong_AsVoidPtr(tmaddrobj);
self->dispatcher = dispatcher_new(tmaddr, argct);
self->tm = static_cast<TypeManager*>(tmaddr);
self->argct = argct;
self->can_compile = 1;
self->can_fallback = can_fallback;
self->fallbackdef = NULL;
Expand All @@ -168,15 +253,15 @@ Dispatcher_init(DispatcherObject *self, PyObject *args, PyObject *kwds)
}

static PyObject *
Dispatcher_clear(DispatcherObject *self, PyObject *args)
Dispatcher_clear(Dispatcher *self, PyObject *args)
{
dispatcher_clear(self->dispatcher);
self->clear();
Py_RETURN_NONE;
}

static
PyObject*
Dispatcher_Insert(DispatcherObject *self, PyObject *args)
Dispatcher_Insert(Dispatcher *self, PyObject *args)
{
PyObject *sigtup, *cfunc;
int i, sigsz;
Expand All @@ -194,22 +279,22 @@ Dispatcher_Insert(DispatcherObject *self, PyObject *args)
}

sigsz = PySequence_Fast_GET_SIZE(sigtup);
sig = malloc(sigsz * sizeof(int));
sig = new int[sigsz];

for (i = 0; i < sigsz; ++i) {
sig[i] = PyLong_AsLong(PySequence_Fast_GET_ITEM(sigtup, i));
}

/* The reference to cfunc is borrowed; this only works because the
derived Python class also stores an (owned) reference to cfunc. */
dispatcher_add_defn(self->dispatcher, sig, (void*) cfunc);
self->addDefinition(sig, cfunc);

/* Add pure python fallback */
if (!self->fallbackdef && objectmode){
self->fallbackdef = cfunc;
}

free(sig);
delete[] sig;

Py_RETURN_NONE;
}
Expand Down Expand Up @@ -277,7 +362,7 @@ int search_new_conversions(PyObject *dispatcher, PyObject *args, PyObject *kws)

/* A custom, fast, inlinable version of PyCFunction_Call() */
static PyObject *
call_cfunc(DispatcherObject *self, PyObject *cfunc, PyObject *args, PyObject *kws, PyObject *locals)
call_cfunc(Dispatcher *self, PyObject *cfunc, PyObject *args, PyObject *kws, PyObject *locals)
{
PyCFunctionWithKeywords fn;
PyThreadState *tstate;
Expand Down Expand Up @@ -342,7 +427,7 @@ call_cfunc(DispatcherObject *self, PyObject *cfunc, PyObject *args, PyObject *kw

static
PyObject*
compile_and_invoke(DispatcherObject *self, PyObject *args, PyObject *kws, PyObject *locals)
compile_and_invoke(Dispatcher *self, PyObject *args, PyObject *kws, PyObject *locals)
{
/* Compile a new one */
PyObject *cfa, *cfunc, *retval;
Expand Down Expand Up @@ -371,7 +456,7 @@ compile_and_invoke(DispatcherObject *self, PyObject *args, PyObject *kws, PyObje
}

static int
find_named_args(DispatcherObject *self, PyObject **pargs, PyObject **pkws)
find_named_args(Dispatcher *self, PyObject **pargs, PyObject **pkws)
{
PyObject *oldargs = *pargs, *newargs;
PyObject *kws = *pkws;
Expand Down Expand Up @@ -485,7 +570,7 @@ find_named_args(DispatcherObject *self, PyObject **pargs, PyObject **pkws)
}

static PyObject*
Dispatcher_call(DispatcherObject *self, PyObject *args, PyObject *kws)
Dispatcher_call(Dispatcher *self, PyObject *args, PyObject *kws)
{
PyObject *tmptype, *retval = NULL;
int *tys = NULL;
Expand All @@ -496,6 +581,11 @@ Dispatcher_call(DispatcherObject *self, PyObject *args, PyObject *kws)
PyObject *cfunc;
PyThreadState *ts = PyThreadState_Get();
PyObject *locals = NULL;

/* If compilation is enabled, ensure that an exact match is found and if
* not compile one */
int exact_match_required = self->can_compile ? 1 : self->exact_match_required;

if (ts->use_tracing && ts->c_profilefunc) {
locals = PyEval_GetLocals();
if (locals == NULL) {
Expand All @@ -515,7 +605,7 @@ Dispatcher_call(DispatcherObject *self, PyObject *args, PyObject *kws)
if (argct < (Py_ssize_t) (sizeof(prealloc) / sizeof(int)))
tys = prealloc;
else
tys = malloc(argct * sizeof(int));
tys = new int[argct];

for (i = 0; i < argct; ++i) {
tmptype = PySequence_Fast_GET_ITEM(args, i);
Expand All @@ -530,14 +620,13 @@ Dispatcher_call(DispatcherObject *self, PyObject *args, PyObject *kws)
}
}

/* If compilation is enabled, ensure that an exact match is found and if
* not compile one */
int exact_match_required = self->can_compile ? 1 : self->exact_match_required;

/* We only allow unsafe conversions if compilation of new specializations
has been disabled. */
cfunc = dispatcher_resolve(self->dispatcher, tys, &matches,
!self->can_compile, exact_match_required);
has been disabled.

Note that the number of matches is returned in matches by resolve, which
accepts it as a reference. */
cfunc = self->resolve(tys, matches, !self->can_compile,
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
exact_match_required);

if (matches == 0 && !self->can_compile) {
/*
Expand All @@ -552,9 +641,8 @@ Dispatcher_call(DispatcherObject *self, PyObject *args, PyObject *kws)
}
if (res > 0) {
/* Retry with the newly registered conversions */
cfunc = dispatcher_resolve(self->dispatcher, tys, &matches,
!self->can_compile,
exact_match_required);
cfunc = self->resolve(tys, matches, !self->can_compile,
exact_match_required);
}
}

Expand Down Expand Up @@ -584,7 +672,7 @@ Dispatcher_call(DispatcherObject *self, PyObject *args, PyObject *kws)

CLEANUP:
if (tys != prealloc)
free(tys);
delete[] tys;
Py_DECREF(args);

return retval;
Expand All @@ -598,15 +686,15 @@ static PyMethodDef Dispatcher_methods[] = {
};

static PyMemberDef Dispatcher_members[] = {
{"_can_compile", T_BOOL, offsetof(DispatcherObject, can_compile), 0, NULL },
{(char*)"_can_compile", T_BOOL, offsetof(Dispatcher, can_compile), 0, NULL },
{NULL} /* Sentinel */
};


static PyTypeObject DispatcherType = {
PyVarObject_HEAD_INIT(NULL, 0)
"_dispatcher.Dispatcher", /* tp_name */
sizeof(DispatcherObject), /* tp_basicsize */
sizeof(Dispatcher), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)Dispatcher_dealloc, /* tp_dealloc */
0, /* tp_print */
Expand Down
33 changes: 0 additions & 33 deletions numba/_dispatcher.h

This file was deleted.