diff --git a/bindings/python/pymongoarrow/builders.pyi b/bindings/python/pymongoarrow/builders.pyi new file mode 100644 index 00000000..ef1eb83a --- /dev/null +++ b/bindings/python/pymongoarrow/builders.pyi @@ -0,0 +1,146 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Cython compiler directives +# distutils: language=c++ +# cython: language_level=3 + +cdef class _BuilderBase: + def append_values(self, values): + for value in values: + self.append(value) + + @property + def null_count(self): + return self.builder.get().null_count() + + def __len__(self): + return self.builder.get().length() + + +cdef class Int32Builder(_BuilderBase): + cdef: + shared_ptr[CInt32Builder] builder + + def __cinit__(self, MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + self.builder.reset(new CInt32Builder(pool)) + + def append(self, value): + if value is None or value is np.nan: + self.builder.get().AppendNull() + elif isinstance(value, int): + self.builder.get().Append(value) + else: + raise TypeError('Int32Builder only accepts integer objects') + + def finish(self): + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + cdef shared_ptr[CInt32Builder] unwrap(self): + return self.builder + + +cdef class Int64Builder(_BuilderBase): + cdef: + shared_ptr[CInt64Builder] builder + + def __cinit__(self, MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + self.builder.reset(new CInt64Builder(pool)) + + def append(self, value): + if value is None or value is np.nan: + self.builder.get().AppendNull() + elif isinstance(value, int): + self.builder.get().Append(value) + else: + raise TypeError('Int64Builder only accepts integer objects') + + def finish(self): + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + cdef shared_ptr[CInt64Builder] unwrap(self): + return self.builder + + +cdef class DoubleBuilder(_BuilderBase): + cdef: + shared_ptr[CDoubleBuilder] builder + + def __cinit__(self, MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + self.builder.reset(new CDoubleBuilder(pool)) + + def append(self, value): + if value is None or value is np.nan: + self.builder.get().AppendNull() + elif isinstance(value, (int, float)): + self.builder.get().Append(value) + else: + raise TypeError('DoubleBuilder only accepts floats and ints') + + def finish(self): + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + cdef shared_ptr[CDoubleBuilder] unwrap(self): + return self.builder + + +cdef class DatetimeBuilder(_BuilderBase): + cdef: + shared_ptr[CTimestampBuilder] builder + TimestampType dtype + + def __cinit__(self, TimestampType dtype=timestamp('ms'), + MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + if dtype in (timestamp('us'), timestamp('ns')): + raise ValueError("Microsecond resolution temporal type is not " + "suitable for use with MongoDB's UTC datetime " + "type which has resolution of milliseconds.") + self.dtype = dtype + self.builder.reset(new CTimestampBuilder( + pyarrow_unwrap_data_type(self.dtype), pool)) + + def append(self, value): + if value is None or value is np.nan: + self.builder.get().AppendNull() + elif isinstance(value, datetime.datetime): + self.builder.get().Append( + datetime_to_int64(value, self.dtype)) + else: + raise TypeError('TimestampBuilder only accepts datetime objects') + + def finish(self): + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + @property + def unit(self): + return self.dtype + + cdef shared_ptr[CTimestampBuilder] unwrap(self): + return self.builder diff --git a/bindings/python/pymongoarrow/lib.pyx b/bindings/python/pymongoarrow/lib.pyx new file mode 100644 index 00000000..56c7a690 --- /dev/null +++ b/bindings/python/pymongoarrow/lib.pyx @@ -0,0 +1,34 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Cython compiler directives +# distutils: language=c++ +# cython: language_level=3 + +# Stdlib imports +import datetime + +# Python imports +import numpy as np +from pyarrow import timestamp + +# Cython imports +from pyarrow.lib cimport * + + +# Utilities +include "utils.pyi" + +# Builders +include "builders.pyi" diff --git a/bindings/python/pymongoarrow/utils.pyi b/bindings/python/pymongoarrow/utils.pyi new file mode 100644 index 00000000..1a07a7bc --- /dev/null +++ b/bindings/python/pymongoarrow/utils.pyi @@ -0,0 +1,34 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def datetime_to_int64(dtm, data_type): + # TODO: rewrite as a cdef which directly accesses data_type as a CTimestampType instance + # TODO: make this function aware of datatype.timezone() + total_seconds = int((dtm - datetime.datetime(1970, 1, 1)).total_seconds()) + total_microseconds = int(total_seconds) * 10**6 + dtm.microsecond + + if data_type.unit == 's': + factor = 1. + elif data_type.unit == 'ms': + factor = 10. ** 3 + elif data_type.unit == 'us': + factor = 10. ** 6 + elif data_type.unit == 'ns': + factor = 10. ** 9 + else: + raise ValueError('Unsupported timestamp unit {}'.format( + data_type.unit)) + + int64_t = int(total_microseconds * factor / (10. ** 6)) + return int64_t diff --git a/bindings/python/setup.py b/bindings/python/setup.py index 06fbc478..473a1902 100644 --- a/bindings/python/setup.py +++ b/bindings/python/setup.py @@ -3,6 +3,9 @@ import os +import numpy as np +import pyarrow as pa + def get_pymongoarrow_version(): """Single source the version.""" @@ -15,11 +18,23 @@ def get_pymongoarrow_version(): def get_extension_modules(): - modules = cythonize(['pymongoarrow/*.pyx', - 'pymongoarrow/libbson/*.pyx']) - for module in modules: + arrow_modules = cythonize(['pymongoarrow/*.pyx']) + libbson_modules = cythonize(['pymongoarrow/libbson/*.pyx']) + + for module in libbson_modules: module.libraries.append('bson-1.0') - return modules + + for module in arrow_modules: + module.include_dirs.append(np.get_include()) + module.include_dirs.append(pa.get_include()) + module.libraries.extend(pa.get_libraries()) + module.library_dirs.extend(pa.get_library_dirs()) + + # https://arrow.apache.org/docs/python/extending.html#example + if os.name == 'posix': + module.extra_compile_args.append('-std=c++11') + + return arrow_modules + libbson_modules setup( @@ -27,4 +42,5 @@ def get_extension_modules(): version=get_pymongoarrow_version(), packages=find_packages(), ext_modules=get_extension_modules(), - setup_requires=['cython >= 0.29']) + install_requires=['pyarrow >= 3', 'pymongo >= 3.11,<4'], + setup_requires=['cython >= 0.29', 'pyarrow >= 3', 'numpy >= 1.16.6']) diff --git a/bindings/python/test/__init__.py b/bindings/python/test/__init__.py new file mode 100644 index 00000000..880e7f4a --- /dev/null +++ b/bindings/python/test/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/bindings/python/test/test_builders.py b/bindings/python/test/test_builders.py new file mode 100644 index 00000000..43c16af5 --- /dev/null +++ b/bindings/python/test/test_builders.py @@ -0,0 +1,98 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime, timedelta +from unittest import TestCase + +from pyarrow import Array, timestamp, int32, int64 + +from pymongoarrow.lib import ( + DatetimeBuilder, DoubleBuilder, Int32Builder, Int64Builder) + + +class TestIntBuildersMixin: + def test_simple(self): + builder = self.builder_cls() + builder.append(0) + builder.append_values([1, 2, 3, 4]) + builder.append(None) + arr = builder.finish() + + self.assertIsInstance(arr, Array) + self.assertEqual(arr.null_count, 1) + self.assertEqual(len(arr), 6) + self.assertEqual( + arr.to_pylist(), [0, 1, 2, 3, 4, None]) + self.assertEqual(arr.type, self.data_type) + + +class TestInt32Builder(TestCase, TestIntBuildersMixin): + def setUp(self): + self.builder_cls = Int32Builder + self.data_type = int32() + + +class TestInt64Builder(TestCase, TestIntBuildersMixin): + def setUp(self): + self.builder_cls = Int64Builder + self.data_type = int64() + + +class TestDate64Builder(TestCase): + def test_default_unit(self): + # Check default unit + builder = DatetimeBuilder() + self.assertEqual(builder.unit, timestamp('ms')) + + def _test_simple(self, tstamp_units, kwarg_name): + builder = DatetimeBuilder(dtype=timestamp(tstamp_units)) + datetimes = [datetime(1970, 1, 1) + timedelta(**{kwarg_name: k*100}) + for k in range(5)] + builder.append(datetimes[0]) + builder.append_values(datetimes[1:]) + builder.append(None) + arr = builder.finish() + + self.assertIsInstance(arr, Array) + self.assertEqual(arr.null_count, 1) + self.assertEqual(len(arr), len(datetimes) + 1) + self.assertEqual(arr.to_pylist(), datetimes + [None]) + self.assertEqual(arr.type, timestamp(tstamp_units)) + + def test_simple(self): + # milliseconds + self._test_simple('ms', 'milliseconds') + # seconds + self._test_simple('s', 'seconds') + + def test_unsupported_units(self): + with self.assertRaises(ValueError): + DatetimeBuilder(dtype=timestamp('us')) + + with self.assertRaises(ValueError): + DatetimeBuilder(dtype=timestamp('ns')) + + +class TestDoubleBuilder(TestCase): + def test_simple(self): + builder = DoubleBuilder() + builder.append(0.123) + builder.append_values([1.234, 2.345, 3.456, 4.567]) + builder.append(None) + arr = builder.finish() + + self.assertIsInstance(arr, Array) + self.assertEqual(arr.null_count, 1) + self.assertEqual(len(arr), 6) + self.assertEqual( + arr.to_pylist(), [0.123, 1.234, 2.345, 3.456, 4.567, None])