-
Notifications
You must be signed in to change notification settings - Fork 548
/
pytrt.pyx
134 lines (118 loc) · 5.42 KB
/
pytrt.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import cython
import numpy as np
cimport numpy as np
from libcpp.string cimport string
from pytrt cimport TrtGooglenet
from pytrt cimport TrtMtcnnDet
cdef class PyTrtGooglenet:
cdef TrtGooglenet *c_trtnet
cdef tuple data_dims, prob_dims
def __cinit__(PyTrtGooglenet self):
self.c_trtnet = NULL
def __init__(PyTrtGooglenet self,
str engine_path, tuple shape0, tuple shape1):
assert len(shape0) == 3 and len(shape1) == 3
self.c_trtnet = new TrtGooglenet()
self.data_dims = shape0
self.prob_dims = shape1
cdef int[:] v0 = np.array(shape0, dtype=np.intc)
cdef int[:] v1 = np.array(shape1, dtype=np.intc)
cdef string c_str = engine_path.encode('UTF-8')
self.c_trtnet.initEngine(c_str, &v0[0], &v1[0])
def forward(PyTrtGooglenet self,
np.ndarray[np.float32_t, ndim=4] np_imgs not None):
"""Do a forward() computation on the input batch of imgs."""
assert np_imgs.shape[0] == 1 # only accept batch_size = 1
if not np_imgs.flags['C_CONTIGUOUS']:
np_imgs = np.ascontiguousarray(np_imgs)
np_prob = np.ascontiguousarray(
np.zeros((1,) + self.prob_dims, dtype=np.float32)
)
cdef float[:,:,:,::1] v_imgs = np_imgs
cdef float[:,:,:,::1] v_prob = np_prob
self.c_trtnet.forward(&v_imgs[0][0][0][0], &v_prob[0][0][0][0])
return { 'prob': np_prob }
def destroy(PyTrtGooglenet self):
self.c_trtnet.destroy()
cdef class PyTrtMtcnn:
cdef TrtMtcnnDet *c_trtnet
cdef int batch_size
cdef int num_bindings
cdef tuple data_dims, prob1_dims, boxes_dims, marks_dims
def __cinit__(PyTrtMtcnn self):
self.c_trtnet = NULL
def __init__(PyTrtMtcnn self,
str engine_path,
tuple shape0, tuple shape1, tuple shape2, tuple shape3=None):
self.num_bindings = 4 if shape3 else 3
assert len(shape0) == 3 and len(shape1) == 3 and len(shape2) == 3
if shape3: assert len(shape3) == 3
else: shape3 = (0, 0, 0) # set to a dummy shape
self.c_trtnet = new TrtMtcnnDet()
self.batch_size = 0
self.data_dims = shape0
self.prob1_dims = shape1
self.boxes_dims = shape2
self.marks_dims = shape3
cdef int[:] v0 = np.array(shape0, dtype=np.intc)
cdef int[:] v1 = np.array(shape1, dtype=np.intc)
cdef int[:] v2 = np.array(shape2, dtype=np.intc)
cdef int[:] v3 = np.array(shape3, dtype=np.intc)
cdef string c_str = engine_path.encode('UTF-8')
if 'det1' in engine_path:
self.c_trtnet.initDet1(c_str, &v0[0], &v1[0], &v2[0])
elif 'det2' in engine_path:
self.c_trtnet.initDet2(c_str, &v0[0], &v1[0], &v2[0])
elif 'det3' in engine_path:
self.c_trtnet.initDet3(c_str, &v0[0], &v1[0], &v2[0], &v3[0])
else:
raise ValueError('engine is neither of det1, det2 or det3!')
def set_batchsize(PyTrtMtcnn self, int batch_size):
self.c_trtnet.setBatchSize(batch_size)
self.batch_size = batch_size
def _forward_3(PyTrtMtcnn self,
np.ndarray[np.float32_t, ndim=4] np_imgs not None,
np.ndarray[np.float32_t, ndim=4] np_prob1 not None,
np.ndarray[np.float32_t, ndim=4] np_boxes not None):
cdef float[:,:,:,::1] v_imgs = np_imgs
cdef float[:,:,:,::1] v_probs = np_prob1
cdef float[:,:,:,::1] v_boxes = np_boxes
self.c_trtnet.forward(&v_imgs[0][0][0][0],
&v_probs[0][0][0][0],
&v_boxes[0][0][0][0])
return { 'prob1': np_prob1, 'boxes': np_boxes }
def _forward_4(PyTrtMtcnn self,
np.ndarray[np.float32_t, ndim=4] np_imgs not None,
np.ndarray[np.float32_t, ndim=4] np_prob1 not None,
np.ndarray[np.float32_t, ndim=4] np_boxes not None,
np.ndarray[np.float32_t, ndim=4] np_marks not None):
cdef float[:,:,:,::1] v_imgs = np_imgs
cdef float[:,:,:,::1] v_probs = np_prob1
cdef float[:,:,:,::1] v_boxes = np_boxes
cdef float[:,:,:,::1] v_marks = np_marks
self.c_trtnet.forward(&v_imgs[0][0][0][0],
&v_probs[0][0][0][0],
&v_boxes[0][0][0][0],
&v_marks[0][0][0][0])
return { 'prob1': np_prob1, 'boxes': np_boxes, 'landmarks': np_marks }
def forward(PyTrtMtcnn self,
np.ndarray[np.float32_t, ndim=4] np_imgs not None):
"""Do a forward() computation on the input batch of imgs."""
assert(np_imgs.shape[0] == self.batch_size)
if not np_imgs.flags['C_CONTIGUOUS']:
np_imgs = np.ascontiguousarray(np_imgs)
np_prob1 = np.ascontiguousarray(
np.zeros((self.batch_size,) + self.prob1_dims, dtype=np.float32)
)
np_boxes = np.ascontiguousarray(
np.zeros((self.batch_size,) + self.boxes_dims, dtype=np.float32)
)
np_marks = np.ascontiguousarray(
np.zeros((self.batch_size,) + self.marks_dims, dtype=np.float32)
)
if self.num_bindings == 3:
return self._forward_3(np_imgs, np_prob1, np_boxes)
else: # self.num_bindings == 4
return self._forward_4(np_imgs, np_prob1, np_boxes, np_marks)
def destroy(PyTrtMtcnn self):
self.c_trtnet.destroy()