-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
_backend.py
107 lines (83 loc) · 2.95 KB
/
_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def _convert_arrays(array, func):
# Converts array or arrays
if isinstance(array, (list, tuple)):
# The same object encountered multiple times in the container is
# converted into the same object.
d = {}
ret = []
for arr in array:
if arr is None:
ret.append(None)
else:
arr2 = d.get(id(arr))
if arr2 is None:
arr2 = func(arr)
d[id(arr)] = arr2
ret.append(arr2)
return type(array)(ret)
else:
return func(array)
class _DummyContext(object):
def __enter__(self):
pass
def __exit__(self, typ, value, traceback):
pass
_dummy_context = _DummyContext()
# TODO(niboshi): Write more detailed description about interface/usage.
class Device(object):
"""A base class of unified devices.
Chainer has the following concrete implementations:
- :class:`chainer.backend.CpuDevice`
- :class:`chainer.backend.GpuDevice`
- :class:`chainer.backend.Intel64Device`
- :class:`chainer.backend.ChainerxDevice`
"""
@property
def xp(self):
"""Array module corresponding to the device."""
raise NotImplementedError(
'Device implementation must override this property.')
@property
def supported_array_types(self):
"""Array types supported by the device.
Returns:
tuple of array types which the device's module functions can
handle.
"""
raise NotImplementedError(
'Device implementation must override this property.')
def __enter__(self):
"""A dummy definition that simply raises RuntimeError.
:meth:`chainer.using_device` should be used instead.
"""
raise RuntimeError(
'Device class does not support runtime context using `with` '
'statement. Use chainer.using_device instead.')
def __exit__(self, exc_type, exc_value, traceback):
"""A dummy definition that should never be called."""
# Definition of __exit__ is needed to raise a custom error on
# __enter__.
pass
def __eq__(self, other):
raise NotImplementedError(
'Device implementation must override this method.')
def __ne__(self, other):
return not (self == other)
def create_context(self):
"""Returns a context manager in which the device is made current.
.. seealso::
:meth:`chainer.using_device` calls this method internally.
"""
return _dummy_context
def send(self, arrays):
"""Transfers given arrays to the device.
Args:
arrays: Array or arrays of NumPy, CuPy, or ChainerX.
Returns:
Transferred arrays.
"""
return _convert_arrays(arrays, self.send_array)
def use(self):
"""Makes the device current in the current thread.
"""
pass