/
base.py
111 lines (89 loc) · 2.56 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# coding: utf-8
# pylint: disable=invalid-name
"""ctypes library and helper functions """
from __future__ import absolute_import
import sys
import os
import ctypes
import numpy as np
from . import libinfo
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = (str,)
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32)
py_str = lambda x: x
class DGLError(Exception):
"""Error thrown by DGL function"""
pass # pylint: disable=unnecessary-pass
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
# DMatrix functions
lib.DGLGetLastError.restype = ctypes.c_char_p
return lib, os.path.basename(lib_path[0])
# version number
__version__ = libinfo.__version__
# library instance of nnvm
_LIB, _LIB_NAME = _load_lib()
# The FFI mode of DGL
_FFI_MODE = os.environ.get("DGL_FFI", "auto")
#----------------------------
# helper function in ctypes.
#----------------------------
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise DGLError(py_str(_LIB.DGLGetLastError()))
def c_str(string):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array
Parameters
----------
ctype : ctypes data type
data type of the array we want to convert to
values : tuple or list
data content
Returns
-------
out : ctypes array
Created ctypes array
"""
return (ctype * len(values))(*values)
def decorate(func, fwrapped):
"""A wrapper call of decorator package, differs to call time
Parameters
----------
func : function
The original function
fwrapped : function
The wrapped function
"""
import decorator
return decorator.decorate(func, fwrapped)