-
Notifications
You must be signed in to change notification settings - Fork 13
/
common.py
173 lines (128 loc) · 5.78 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from py4j import java_gateway
from pyboof import gateway
from six import string_types
import pyboof
import struct
import numpy as np
def exception_use_mmap():
raise Exception("Need to turn on mmap. Add pb.init_memmap() to your code before any other calls to PyBoof")
def is_java_class(java_class, string_path):
"""True if the passed in object is the Class specified by the path"""
return gateway.jvm.pyboof.PyBoofEntryPoint.isClass(java_class, string_path)
def ejml_matrix_d_to_f(D):
F = gateway.jvm.org.ejml.data.FMatrixRMaj(D.getNumRows(), D.getNumCols())
gateway.jvm.org.ejml.ops.ConvertMatrixData.convert(D, F)
return F
def boof_fixed_length(length):
return gateway.jvm.boofcv.struct.ConfigLength(float(length), float(-1))
class JavaWrapper:
def __init__(self, java_object=None):
self.java_obj = java_object
self.java_fields = [x for x in gateway.jvm.pyboof.PyBoofEntryPoint.getPublicFields(self.java_obj.getClass())]
def __getattr__(self, item):
if "java_fields" in self.__dict__ and item in self.__dict__["java_fields"]:
return java_gateway.get_field(self.java_obj, item)
else:
return object.__getattribute__(self, item)
def __setattr__(self, key, value):
if "java_fields" in self.__dict__ and key in self.__dict__["java_fields"]:
java_gateway.set_field(self.java_obj, key, value)
else:
self.__dict__[key] = value
def __dir__(self):
return sorted(set(self.__dict__.keys() + self.java_fields))
def set_java_object(self, obj):
self.java_obj = obj
def get_java_object(self):
return self.java_obj
def __str__(self):
return "Wrapped Java:\n" + self.java_obj.toString()
class JavaConfig(JavaWrapper):
"""
Provides a nice python wrapper around Java classes. Public variables are automatically turned into Python
attributes
"""
# TODO variables which are java classes are a little messed up
def __init__(self, java_class_path):
if isinstance(java_class_path, string_types):
self.java_class_path = java_class_path
words = java_class_path.replace('$', ".").split(".")
a = gateway.jvm.__getattr__(words[0])
for i in range(1, len(words)):
a = a.__getattr__(words[i])
self.java_obj = a.__call__()
else:
self.java_obj = java_class_path
JavaWrapper.__init__(self, self.java_obj)
def __getattr__(self, item):
if "java_fields" in self.__dict__ and item in self.__dict__["java_fields"]:
a = java_gateway.get_field(self.java_obj, item)
if gateway.jvm.pyboof.PyBoofEntryPoint.isConfigClass(a):
return JavaConfig(a)
else:
return a
else:
return object.__getattribute__(self, item)
def __setattr__(self, key, value):
if "java_fields" in self.__dict__ and key in self.__dict__["java_fields"]:
if isinstance(value, JavaConfig):
java_gateway.set_field(self.java_obj, key, value.java_obj)
else:
java_gateway.set_field(self.java_obj, key, value)
else:
self.__dict__[key] = value
class JavaList(JavaWrapper):
def __init__(self, java_list, java_type):
JavaWrapper.__init__(self, java_list)
self.java_type = java_type
def size(self):
return self.java_obj.size()
def save_to_disk(self, file_name):
gateway.jvm.pyboof.FileIO.saveList(self.java_obj, self.java_type, file_name)
def JavaList_to_fastarray(java_list, java_class_type):
return gateway.jvm.pyboof.PyBoofEntryPoint.listToFastArray(java_list, java_class_type)
def create_java_file_writer( path : str ):
java_file = gateway.jvm.java.io.File(path)
return gateway.jvm.java.io.FileWriter(java_file)
def mmap_array_python_to_java(pylist, mmap_type: pyboof.MmapType):
"""
Converts a python primitive list into a java primitive array
"""
# Ensure the data type is correct
pylist = pyboof.mmap_force_array_type(pylist, mmap_type)
num_elements = len(pylist)
mm = pyboof.mmap_file
num_element_bytes = pyboof.mmap_primitive_len(mmap_type)
format = pyboof.mmap_primitive_format(mmap_type)
# max number of list elements it can write at once
max_elements = (pyboof.mmap_size - 100) / num_element_bytes
# See if it can be writen in a single chunk
if max_elements < num_elements:
raise Exception("max_elements is too small")
# Write as much of the list as it can to the mmap file
mm.seek(0)
mm.write(struct.pack('>HI', mmap_type, num_elements))
for i in range(0, num_elements):
mm.write(struct.pack(format, pylist[i]))
# Now tell the java end to read what it just wrote
return gateway.jvm.pyboof.PyBoofEntryPoint.mmap.read_primitive_array(mmap_type)
def mmap_array_java_to_python(java_array, mmap_type: pyboof.MmapType):
"""
Converts a java primitive array into a python primitive list
"""
num_elements = len(java_array)
mm = pyboof.mmap_file
num_element_bytes = pyboof.mmap_primitive_len(mmap_type)
format = pyboof.mmap_primitive_format(mmap_type)
python_list = []
gateway.jvm.pyboof.PyBoofEntryPoint.mmap.write_primitive_array(java_array, mmap_type, 0)
mm.seek(0)
data_type, num_found = struct.unpack(">HI", mm.read(2 + 4))
if data_type != mmap_type:
raise Exception("Unexpected data type in mmap file. {%d}" % data_type)
if num_found != num_elements:
raise Exception("Unexpected number of elements returned. " + str(num_found))
for i in range(num_found):
element = struct.unpack(format, mm.read(num_element_bytes))
python_list.append(element)
return pyboof.mmap_force_array_type(python_list, mmap_type)