-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
args.py
77 lines (61 loc) · 1.93 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Hints to wrap Kernel arguments to indicate how to manage host-device
memory transfers before & after the kernel call.
"""
import abc
from numba.core.typing.typeof import typeof, Purpose
class ArgHint(metaclass=abc.ABCMeta):
def __init__(self, value):
self.value = value
@abc.abstractmethod
def to_device(self, retr, stream=0):
"""
:param stream: a stream to use when copying data
:param retr:
a list of clean-up work to do after the kernel's been run.
Append 0-arg lambdas to it!
:return: a value (usually an `DeviceNDArray`) to be passed to
the kernel
"""
pass
@property
def _numba_type_(self):
return typeof(self.value, Purpose.argument)
class In(ArgHint):
def to_device(self, retr, stream=0):
from .cudadrv.devicearray import auto_device
devary, _ = auto_device(
self.value,
stream=stream)
# A dummy writeback functor to keep devary alive until the kernel
# is called.
retr.append(lambda: devary)
return devary
class Out(ArgHint):
def to_device(self, retr, stream=0):
from .cudadrv.devicearray import auto_device
devary, conv = auto_device(
self.value,
copy=False,
stream=stream)
if conv:
retr.append(lambda: devary.copy_to_host(self.value, stream=stream))
return devary
class InOut(ArgHint):
def to_device(self, retr, stream=0):
from .cudadrv.devicearray import auto_device
devary, conv = auto_device(
self.value,
stream=stream)
if conv:
retr.append(lambda: devary.copy_to_host(self.value, stream=stream))
return devary
def wrap_arg(value, default=InOut):
return value if isinstance(value, ArgHint) else default(value)
__all__ = [
'In',
'Out',
'InOut',
'ArgHint',
'wrap_arg',
]