-
Notifications
You must be signed in to change notification settings - Fork 10
/
context.py
122 lines (110 loc) · 4.57 KB
/
context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from pyarrow import Table, timestamp
from pymongoarrow.types import _BsonArrowTypes, _get_internal_typemap
try:
from pymongoarrow.lib import (
BinaryBuilder,
BoolBuilder,
CodeBuilder,
Date32Builder,
Date64Builder,
DatetimeBuilder,
Decimal128Builder,
DocumentBuilder,
DoubleBuilder,
Int32Builder,
Int64Builder,
ListBuilder,
ObjectIdBuilder,
StringBuilder,
)
_TYPE_TO_BUILDER_CLS = {
_BsonArrowTypes.int32: Int32Builder,
_BsonArrowTypes.int64: Int64Builder,
_BsonArrowTypes.double: DoubleBuilder,
_BsonArrowTypes.datetime: DatetimeBuilder,
_BsonArrowTypes.objectid: ObjectIdBuilder,
_BsonArrowTypes.decimal128: Decimal128Builder,
_BsonArrowTypes.string: StringBuilder,
_BsonArrowTypes.bool: BoolBuilder,
_BsonArrowTypes.document: DocumentBuilder,
_BsonArrowTypes.array: ListBuilder,
_BsonArrowTypes.binary: BinaryBuilder,
_BsonArrowTypes.code: CodeBuilder,
_BsonArrowTypes.date32: Date32Builder,
_BsonArrowTypes.date64: Date64Builder,
}
except ImportError:
pass
class PyMongoArrowContext:
"""A context for converting BSON-formatted data to an Arrow Table."""
def __init__(self, schema, builder_map, codec_options=None):
"""Initialize the context.
:Parameters:
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
- `builder_map`: Mapping of utf-8-encoded field names to
:class:`~pymongoarrow.builders._BuilderBase` instances.
"""
self.schema = schema
self.builder_map = builder_map
if self.schema is None and codec_options is not None:
self.tzinfo = codec_options.tzinfo
else:
self.tzinfo = None
@classmethod
def from_schema(cls, schema, codec_options=DEFAULT_CODEC_OPTIONS):
"""Initialize the context from a :class:`~pymongoarrow.schema.Schema`
instance.
:Parameters:
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
- `codec_options` (optional): An instance of
:class:`~bson.codec_options.CodecOptions`.
"""
if schema is None:
return cls(schema, {}, codec_options)
builder_map = {}
tzinfo = codec_options.tzinfo
str_type_map = _get_internal_typemap(schema.typemap)
for fname, ftype in str_type_map.items():
builder_cls = _TYPE_TO_BUILDER_CLS[ftype]
encoded_fname = fname.encode("utf-8")
# special-case initializing builders for parameterized types
if builder_cls == DatetimeBuilder:
arrow_type = schema.typemap[fname]
if tzinfo is not None and arrow_type.tz is None:
arrow_type = timestamp(arrow_type.unit, tz=tzinfo)
builder_map[encoded_fname] = DatetimeBuilder(dtype=arrow_type)
elif builder_cls == DocumentBuilder:
arrow_type = schema.typemap[fname]
builder_map[encoded_fname] = DocumentBuilder(arrow_type, tzinfo)
elif builder_cls == ListBuilder:
arrow_type = schema.typemap[fname]
builder_map[encoded_fname] = ListBuilder(arrow_type, tzinfo)
elif builder_cls == BinaryBuilder:
subtype = schema.typemap[fname].subtype
builder_map[encoded_fname] = BinaryBuilder(subtype)
else:
builder_map[encoded_fname] = builder_cls()
return cls(schema, builder_map)
def finish(self):
arrays = []
names = []
for fname, builder in self.builder_map.items():
arrays.append(builder.finish())
names.append(fname.decode("utf-8"))
if self.schema is not None:
return Table.from_arrays(arrays=arrays, schema=self.schema.to_arrow())
return Table.from_arrays(arrays=arrays, names=names)