-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
_backend.py
130 lines (102 loc) · 3.65 KB
/
_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def _convert_arrays(array, func):
# Converts array or arrays
if isinstance(array, (list, tuple)):
# The same object encountered multiple times in the container is
# converted into the same object.
d = {}
ret = []
for arr in array:
if arr is None:
ret.append(None)
else:
arr2 = d.get(id(arr))
if arr2 is None:
arr2 = func(arr)
d[id(arr)] = arr2
ret.append(arr2)
return type(array)(ret)
else:
return func(array)
class _DummyContext(object):
def __enter__(self):
pass
def __exit__(self, typ, value, traceback):
pass
_dummy_context = _DummyContext()
# TODO(niboshi): Write more detailed description about interface/usage.
class Device(object):
"""A base class of unified devices.
Chainer has the following concrete implementations:
- :class:`chainer.backend.CpuDevice`
- :class:`chainer.backend.GpuDevice`
- :class:`chainer.backend.Intel64Device`
- :class:`chainer.backend.ChainerxDevice`
"""
@property
def xp(self):
"""Array module corresponding to the device."""
raise NotImplementedError(
'Device implementation must override this property.')
@property
def name(self):
"""A unique name of the device."""
raise NotImplementedError(
'Device implementation must override this property.')
@property
def supported_array_types(self):
"""Array types supported by the device.
Returns:
tuple of array types which the device's module functions can
handle.
"""
raise NotImplementedError(
'Device implementation must override this property.')
def __str__(self):
return self.name
def __enter__(self):
"""A dummy definition that simply raises RuntimeError.
:meth:`chainer.using_device` should be used instead.
"""
raise RuntimeError(
'Device class does not support runtime context using `with` '
'statement. Use chainer.using_device instead.')
def __exit__(self, exc_type, exc_value, traceback):
"""A dummy definition that should never be called."""
# Definition of __exit__ is needed to raise a custom error on
# __enter__.
pass
def __eq__(self, other):
raise NotImplementedError(
'Device implementation must override this method.')
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash(self.name)
def create_context(self):
"""Returns a context manager in which the device is made current.
.. seealso::
:meth:`chainer.using_device` calls this method internally.
"""
return _dummy_context
def send(self, arrays):
"""Transfers given arrays to the device.
Args:
arrays: Array or arrays of NumPy, CuPy, or ChainerX.
Returns:
Transferred arrays.
"""
return _convert_arrays(arrays, self.send_array)
def use(self):
"""Makes the device current in the current thread.
"""
pass
def is_array_supported(self, array):
"""Returns if the specified array is compatible with the device.
Args:
array (:ref:`ndarray`): An array to be checked
Returns:
``True`` if the array is compatible with the device. Otherwise
``False`` is returned.
"""
raise NotImplementedError(
'Device implementation must override this method.')