-
Notifications
You must be signed in to change notification settings - Fork 307
/
InferSchema.java
378 lines (342 loc) · 15.1 KB
/
InferSchema.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.mongodb.spark.sql.connector.schema;
import static java.lang.String.format;
import static java.util.stream.Collectors.groupingBy;
import com.mongodb.client.model.Aggregates;
import com.mongodb.spark.sql.connector.assertions.Assertions;
import com.mongodb.spark.sql.connector.config.MongoConfig;
import com.mongodb.spark.sql.connector.config.ReadConfig;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.spark.sql.catalyst.analysis.TypeCoercion;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.bson.BsonDocument;
import org.bson.BsonValue;
import org.bson.conversions.Bson;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;
/**
* A helper that determines the {@code StructType} for a {@code BsonDocument} and finds the common
* {@code StructType} for a list of BsonDocuments.
*
* <p>All Bson types are considered convertible to Spark types. For any Bson types that doesn't have
* a direct conversion to a Spark type then a String type will be used.
*/
@NotNull
public final class InferSchema {
/** Inferred schema metadata */
public static final Metadata INFERRED_METADATA = Metadata.fromJson("{\"inferred\": true}");
/**
* Infer the schema for the collection
*
* @param options the configuration options to determine the namespace to determine the schema for
* @return the schema
*/
public static StructType inferSchema(final CaseInsensitiveStringMap options) {
ReadConfig readConfig = MongoConfig.readConfig(options.asCaseSensitiveMap())
.withOptions(options.asCaseSensitiveMap());
ArrayList<Bson> samplePipeline = new ArrayList<>(readConfig.getAggregationPipeline());
samplePipeline.add(Aggregates.sample(readConfig.getInferSchemaSampleSize()));
return inferSchema(
readConfig.withCollection(coll -> coll.aggregate(samplePipeline)
.allowDiskUse(readConfig.getAggregationAllowDiskUse())
.comment(readConfig.getComment())
.into(new ArrayList<>())),
readConfig);
}
/**
* @param schema the schema
* @return true if the schema has been inferred.
*/
public static boolean isInferred(final StructType schema) {
return Arrays.stream(schema.fields()).allMatch(f -> f.metadata().equals(INFERRED_METADATA));
}
@VisibleForTesting
static StructType inferSchema(
final List<BsonDocument> bsonDocuments, final ReadConfig readConfig) {
StructType structType = bsonDocuments.stream()
.map(d -> getStructType(d, readConfig))
.reduce(PLACE_HOLDER_STRUCT_TYPE, (dt1, dt2) -> compatibleStructType(dt1, dt2, readConfig));
return DataTypes.createStructType(Arrays.stream(structType.fields())
.map(f -> {
if (f.dataType().sameType(PLACE_HOLDER_ARRAY_TYPE)) {
return DataTypes.createStructField(
f.name(),
DataTypes.createArrayType(DataTypes.StringType, true),
f.nullable(),
INFERRED_METADATA);
}
return f;
})
.collect(Collectors.toList()));
}
@NotNull
private static StructType getStructType(
final BsonDocument bsonDocument, final ReadConfig readConfig) {
return (StructType) getDataType(bsonDocument, readConfig);
}
@VisibleForTesting
static DataType getDataType(final BsonValue bsonValue, final ReadConfig readConfig) {
switch (bsonValue.getBsonType()) {
case DOCUMENT:
List<StructField> fields = new ArrayList<>();
bsonValue
.asDocument()
.forEach((k, v) -> fields.add(
new StructField(k, getDataType(v, readConfig), true, INFERRED_METADATA)));
return dataTypeCheckStructTypeToMapType(DataTypes.createStructType(fields), readConfig);
case ARRAY:
DataType elementType = bsonValue.asArray().stream()
.map(v -> getDataType(v, readConfig))
.distinct()
.reduce((d1, d2) -> compatibleType(d1, d2, readConfig))
.orElse(PLACE_HOLDER_DATA_TYPE);
if (elementType.sameType(PLACE_HOLDER_DATA_TYPE)) {
return PLACE_HOLDER_ARRAY_TYPE;
}
return DataTypes.createArrayType(elementType, true);
case SYMBOL:
case STRING:
case OBJECT_ID:
return DataTypes.StringType;
case BINARY:
return DataTypes.BinaryType;
case BOOLEAN:
return DataTypes.BooleanType;
case TIMESTAMP:
case DATE_TIME:
return DataTypes.TimestampType;
case NULL:
return DataTypes.NullType;
case DOUBLE:
return DataTypes.DoubleType;
case INT32:
return DataTypes.IntegerType;
case INT64:
return DataTypes.LongType;
case DECIMAL128:
BigDecimal bigDecimal = bsonValue.asDecimal128().decimal128Value().bigDecimalValue();
return DataTypes.createDecimalType(
Math.max(bigDecimal.precision(), bigDecimal.scale()), bigDecimal.scale());
default:
return DataTypes.StringType;
}
}
/**
* Determines the compatible {@link StructType} for the two StructTypes with the supplied {@link
* ReadConfig}.
*
* <p>Uses {@link #compatibleType(DataType, DataType, ReadConfig)} to determine the compatible
* data types for fields.
*
* @param structType1 the first StructType
* @param structType2 the second StructType
* @param readConfig the read configuration
* @return the compatible {@link StructType} with the fields sorted by name or errors.
*/
private static StructType compatibleStructType(
final StructType structType1, final StructType structType2, final ReadConfig readConfig) {
if (structType1 == PLACE_HOLDER_STRUCT_TYPE) {
return structType2;
}
Map<String, List<StructField>> fieldNameToStructFieldMap = Stream.of(
structType1.fields(), structType2.fields())
.flatMap(Stream::of)
.collect(groupingBy(StructField::name));
List<StructField> structFields = new ArrayList<>();
fieldNameToStructFieldMap.forEach((fieldName, groupedStructFields) -> {
DataType fieldCommonDataType = groupedStructFields.stream()
.map(StructField::dataType)
.reduce(PLACE_HOLDER_DATA_TYPE, (dt1, dt2) -> compatibleType(dt1, dt2, readConfig));
structFields.add(
DataTypes.createStructField(fieldName, fieldCommonDataType, true, INFERRED_METADATA));
});
structFields.sort(Comparator.comparing(StructField::name));
return DataTypes.createStructType(structFields);
}
/**
* Returns the compatible type between two data types. All types are can be converted to be
* compatible because {@link DataTypes#StringType} is the lowest common type for all Bson types.
*
* <p>Uses the {@link TypeCoercion#findTightestCommonType()} function to find the most compatible
* type. If the types are deemed incompatible because they are:
*
* <ul>
* <li>both Structs: then the Struct fields are merged and the types expanded until they are
* compatible. See {@link #compatibleStructType}.
* <li>both Arrays: the Array value types are expanded until they are compatible. See {@link
* #compatibleArrayType}.
* <li>both Decimals: creates a compatible type for the decimals. See {@link
* #compatibleDecimalType}.
* <li>A Map and Struct type: creates a new Map type with a compatible value type. See {@link
* #appendStructToMap}.
* </ul>
*
* @param dataType1 the first data type
* @param dataType2 the second data type
* @param readConfig the read configuration
* @return the compatible type
*/
private static DataType compatibleType(
@Nullable final DataType dataType1, final DataType dataType2, final ReadConfig readConfig) {
if (dataType1 == PLACE_HOLDER_DATA_TYPE) {
return dataType2;
}
DataType dataType = TypeCoercion.findTightestCommonType()
.apply(dataType1, dataType2)
.getOrElse(
// dataType1 or dataType2 is a StructType, MapType, ArrayType or a decimal type.
() -> {
if (dataType1 instanceof StructType && dataType2 instanceof StructType) {
return compatibleStructType(
(StructType) dataType1, (StructType) dataType2, readConfig);
}
if (dataType1 instanceof ArrayType && dataType2 instanceof ArrayType) {
return compatibleArrayType(
(ArrayType) dataType1, (ArrayType) dataType2, readConfig);
}
// The case that given `DecimalType` is capable of given `IntegralType` is handled
// in `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
// the given `DecimalType` is not capable of the given `IntegralType`.
if (dataType1 instanceof DecimalType || dataType2 instanceof DecimalType) {
return compatibleDecimalType(dataType1, dataType2);
}
// If working with MapTypes and StructTypes append the struct data
if ((dataType1 instanceof MapType && dataType2 instanceof StructType)
|| (dataType1 instanceof StructType && dataType2 instanceof MapType)) {
return appendStructToMap(dataType1, dataType2, readConfig);
}
return DataTypes.StringType; // Lowest common type
});
return dataTypeCheckStructTypeToMapType(dataType, readConfig);
}
private static DataType compatibleArrayType(
final ArrayType arrayType1, final ArrayType arrayType2, final ReadConfig readConfig) {
DataType arrayElementType1 = arrayType1.elementType();
DataType arrayElementType2 = arrayType2.elementType();
if (arrayElementType1 != PLACE_HOLDER_DATA_TYPE
&& arrayElementType2 != PLACE_HOLDER_DATA_TYPE) {
return DataTypes.createArrayType(
compatibleType(arrayElementType1, arrayElementType2, readConfig),
arrayType1.containsNull() || arrayType2.containsNull());
} else if (arrayElementType1 == PLACE_HOLDER_DATA_TYPE
&& arrayElementType2 == PLACE_HOLDER_DATA_TYPE) {
return DataTypes.createArrayType(
DataTypes.StringType, arrayType1.containsNull() || arrayType2.containsNull());
} else if (arrayElementType1 != PLACE_HOLDER_DATA_TYPE) {
return arrayType1;
} else {
return arrayType2;
}
}
private static DataType appendStructToMap(
final DataType dataType1, final DataType dataType2, final ReadConfig readConfig) {
Assertions.ensureArgument(
() -> (dataType1 instanceof StructType && dataType2 instanceof MapType)
|| (dataType1 instanceof MapType && dataType2 instanceof StructType),
() -> format(
"Requires a StructType and a MapType. Got: %s, %s",
dataType1.typeName(), dataType2.typeName()));
StructType structType =
dataType1 instanceof StructType ? (StructType) dataType1 : (StructType) dataType2;
MapType mapType = dataType1 instanceof StructType ? (MapType) dataType2 : (MapType) dataType1;
DataType valueType = Stream.concat(
Stream.of(mapType.valueType()),
Arrays.stream(structType.fields()).map(StructField::dataType))
.reduce(PLACE_HOLDER_DATA_TYPE, (dt1, dt2) -> compatibleType(dt1, dt2, readConfig));
return DataTypes.createMapType(mapType.keyType(), valueType, mapType.valueContainsNull());
}
private static DataType compatibleDecimalType(
final DataType dataType1, final DataType dataType2) {
Assertions.ensureArgument(
() -> dataType1 instanceof DecimalType || dataType2 instanceof DecimalType,
() -> format(
"Neither datatype is an instance of DecimalType. Got: %s, %s",
dataType1.typeName(), dataType2.typeName()));
DecimalType decimalType =
dataType1 instanceof DecimalType ? (DecimalType) dataType1 : (DecimalType) dataType2;
DataType dataType = dataType1 instanceof DecimalType ? dataType2 : dataType1;
if (dataType instanceof DecimalType) {
DecimalType decimalType2 = (DecimalType) dataType;
int scale = Math.max(decimalType.scale(), decimalType2.scale());
int range = Math.max(
decimalType.precision() - decimalType.scale(),
decimalType2.precision() - decimalType2.scale());
if (range + scale > 38) {
// DecimalType can't support precision > 38
return DataTypes.DoubleType;
} else {
return DataTypes.createDecimalType(range + scale, scale);
}
} else if (dataType instanceof IntegerType) {
return DataTypes.createDecimalType(10, 0);
} else if (dataType instanceof LongType) {
return DataTypes.createDecimalType(20, 0);
} else if (dataType instanceof DoubleType) {
return DataTypes.createDecimalType(30, 15);
}
return DataTypes.StringType;
}
private static DataType dataTypeCheckStructTypeToMapType(
final DataType dataType, final ReadConfig readConfig) {
if (dataType instanceof StructType) {
StructType structType = (StructType) dataType;
if (readConfig.inferSchemaMapType()
&& structType.fields().length >= readConfig.getInferSchemaMapTypeMinimumKeySize()) {
DataType valueType = Arrays.stream(structType.fields())
.map(StructField::dataType)
.reduce(PLACE_HOLDER_DATA_TYPE, (dt1, dt2) -> compatibleType(dt1, dt2, readConfig));
return DataTypes.createMapType(DataTypes.StringType, valueType, true);
}
}
return dataType;
}
private static final StructType PLACE_HOLDER_STRUCT_TYPE =
DataTypes.createStructType(new StructField[0]);
private static final DataType PLACE_HOLDER_DATA_TYPE = new DataType() {
@Override
public int defaultSize() {
return 0;
}
@Override
public DataType asNullable() {
return PLACE_HOLDER_DATA_TYPE;
}
};
static final ArrayType PLACE_HOLDER_ARRAY_TYPE =
DataTypes.createArrayType(PLACE_HOLDER_DATA_TYPE);
private InferSchema() {}
}