Skip to content

Commit

Permalink
fixed nested structs, arrays and maps in NiFiRecordSerde, added unit …
Browse files Browse the repository at this point in the history
…tests and fixed broken tests
  • Loading branch information
gkkorir authored and gkkorir committed May 22, 2019
1 parent 535d653 commit d04fd4c
Show file tree
Hide file tree
Showing 3 changed files with 461 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.hive.serde2.typeinfo.*;
import org.apache.hadoop.io.ObjectWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hive.common.util.HiveStringUtils;
Expand All @@ -42,10 +38,14 @@
import org.apache.nifi.serialization.RecordReader;
import org.apache.nifi.serialization.record.Record;
import org.apache.nifi.serialization.record.RecordField;
import org.apache.nifi.serialization.record.RecordFieldType;
import org.apache.nifi.serialization.record.RecordSchema;
import org.apache.nifi.serialization.record.type.MapDataType;
import org.apache.nifi.serialization.record.util.DataTypeUtils;

import java.io.IOException;
import java.math.BigDecimal;
import java.sql.Array;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -57,6 +57,7 @@
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class NiFiRecordSerDe extends AbstractSerDe {

Expand All @@ -71,8 +72,6 @@ public class NiFiRecordSerDe extends AbstractSerDe {

private final static Pattern INTERNAL_PATTERN = Pattern.compile("_col([0-9]+)");

private Map<String, Integer> fieldPositionMap;

public NiFiRecordSerDe(RecordReader recordReader, ComponentLog log) {
this.recordReader = recordReader;
this.log = log;
Expand Down Expand Up @@ -114,12 +113,6 @@ public void initialize(Configuration conf, Properties tbl) throws SerDeException
log.debug("schema : {}", new Object[]{schema});
cachedObjectInspector = (StandardStructObjectInspector) TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(rowTypeInfo);
tsParser = new TimestampParser(HiveStringUtils.splitAndUnEscape(tbl.getProperty(serdeConstants.TIMESTAMP_FORMATS)));
// Populate mapping of field names to column positions
try {
populateFieldPositionMap();
} catch (MalformedRecordException | IOException e) {
throw new SerDeException(e);
}
stats = new SerDeStats();
}

Expand All @@ -142,7 +135,28 @@ public SerDeStats getSerDeStats() {
public Object deserialize(Writable writable) throws SerDeException {
ObjectWritable t = (ObjectWritable) writable;
Record record = (Record) t.get();
List<Object> r = new ArrayList<>(Collections.nCopies(columnNames.size(), null));
return deserialize(record, schema, true);
}

/**
* Deserialize a record object into a Hive struct.
* <param>record</param> The record to deserialize, can be null
* <param>structTypeInfo</param> The hive table column info that corresponds to the Hive struct
* <param>isParentStruct</param> Whether or not this struct is the one contained by the Writable
*/
Object deserialize(Record record, StructTypeInfo structTypeInfo, boolean isParentStruct) throws SerDeException
{
if(record == null){
return null;
}
Map<String, Integer> fieldPositionMap = null;
try {
fieldPositionMap = populateFieldPositionMap(record.getSchema(), structTypeInfo, log);
} catch (IOException ex) {
throw new SerDeException(ex);
}

List<Object> r = new ArrayList<>(Collections.nCopies(structTypeInfo.getAllStructFieldNames().size(), null));
try {
RecordSchema recordSchema = record.getSchema();
for (RecordField field : recordSchema.getFields()) {
Expand All @@ -155,30 +169,29 @@ public Object deserialize(Writable writable) throws SerDeException {
// This is either a partition column or not a column in the target table, ignore either way
continue;
}
Object currField = extractCurrentField(record, field, schema.getStructFieldTypeInfo(normalizedFieldName));
Object fieldValue = record.getValue(fieldName);
Object currField = convertFieldValue(fieldValue, field, structTypeInfo.getStructFieldTypeInfo(normalizedFieldName));
r.set(fpos, currField);
}
stats.setRowCount(stats.getRowCount() + 1);
if(isParentStruct) {
stats.setRowCount(stats.getRowCount() + 1);
}

} catch (Exception e) {
log.warn("Error [{}] parsing Record [{}].", new Object[]{e.toString(), t}, e);
log.warn("Error [{}] parsing Record [{}].", new Object[]{e.toString(), record}, e);
throw new SerDeException(e);
}

return r;
}

/**
* Utility method to extract current expected field from given record.
*/
@SuppressWarnings("unchecked")
private Object extractCurrentField(Record record, RecordField field, TypeInfo fieldTypeInfo) throws SerDeException {
Object val;
if (field == null) {
private Object convertFieldValue(final Object fieldValue, final RecordField field, final TypeInfo fieldTypeInfo) throws SerDeException {
if(fieldValue == null){
return null;
}
String fieldName = field.getFieldName();

//from here on fieldValue is never null no need for null checks
Object val;
switch (fieldTypeInfo.getCategory()) {
case PRIMITIVE:
PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = PrimitiveObjectInspector.PrimitiveCategory.UNKNOWN;
Expand All @@ -187,132 +200,127 @@ private Object extractCurrentField(Record record, RecordField field, TypeInfo fi
}
switch (primitiveCategory) {
case BYTE:
Integer bIntValue = record.getAsInt(fieldName);
val = bIntValue == null ? null : bIntValue.byteValue();
Integer bIntValue = DataTypeUtils.toInteger(fieldValue, field.getDataType().getFormat());
val = bIntValue.byteValue();
break;
case SHORT:
Integer sIntValue = record.getAsInt(fieldName);
val = sIntValue == null ? null : sIntValue.shortValue();
Integer sIntValue = DataTypeUtils.toInteger(fieldValue, field.getDataType().getFormat());
val = sIntValue.shortValue();
break;
case INT:
val = record.getAsInt(fieldName);
val = DataTypeUtils.toInteger(fieldValue, field.getDataType().getFormat());
break;
case LONG:
val = record.getAsLong(fieldName);
val = DataTypeUtils.toLong(fieldValue, field.getDataType().getFormat());
break;
case BOOLEAN:
val = record.getAsBoolean(fieldName);
val = DataTypeUtils.toBoolean(fieldValue, field.getDataType().getFormat());
break;
case FLOAT:
val = record.getAsFloat(fieldName);
val = DataTypeUtils.toFloat(fieldValue, field.getDataType().getFormat());
break;
case DOUBLE:
val = record.getAsDouble(fieldName);
val = DataTypeUtils.toDouble(fieldValue, field.getDataType().getFormat());
break;
case STRING:
case VARCHAR:
case CHAR:
val = record.getAsString(fieldName);
val = DataTypeUtils.toString(fieldValue, field.getDataType().getFormat());
break;
case BINARY:
Object[] array = record.getAsArray(fieldName);
if (array == null) {
return null;
}
Object[] array = DataTypeUtils.toArray(fieldValue, field.getFieldName(), field.getDataType());
val = AvroTypeUtil.convertByteArray(array).array();
break;
case DATE:
Date d = record.getAsDate(fieldName, field.getDataType().getFormat());
if(d != null) {
org.apache.hadoop.hive.common.type.Date hiveDate = new org.apache.hadoop.hive.common.type.Date();
hiveDate.setTimeInMillis(d.getTime());
val = hiveDate;
} else {
val = null;
}
Date d = DataTypeUtils.toDate(fieldValue, () -> DataTypeUtils.getDateFormat(field.getDataType().getFormat()), field.getFieldName());
org.apache.hadoop.hive.common.type.Date hiveDate = new org.apache.hadoop.hive.common.type.Date();
hiveDate.setTimeInMillis(d.getTime());
val = hiveDate;
break;
// ORC doesn't currently handle TIMESTAMPLOCALTZ
case TIMESTAMP:
Timestamp ts = DataTypeUtils.toTimestamp(record.getValue(fieldName), () -> DataTypeUtils.getDateFormat(field.getDataType().getFormat()), fieldName);
if(ts != null) {
// Convert to Hive's Timestamp type
org.apache.hadoop.hive.common.type.Timestamp hivetimestamp = new org.apache.hadoop.hive.common.type.Timestamp();
hivetimestamp.setTimeInMillis(ts.getTime(), ts.getNanos());
val = hivetimestamp;
} else {
val = null;
}
Timestamp ts = DataTypeUtils.toTimestamp(fieldValue, () -> DataTypeUtils.getDateFormat(field.getDataType().getFormat()), field.getFieldName());;
// Convert to Hive's Timestamp type
org.apache.hadoop.hive.common.type.Timestamp hivetimestamp = new org.apache.hadoop.hive.common.type.Timestamp();
hivetimestamp.setTimeInMillis(ts.getTime(), ts.getNanos());
val = hivetimestamp;
break;
case DECIMAL:
Double value = record.getAsDouble(fieldName);
val = value == null ? null : HiveDecimal.create(value);
if(fieldValue instanceof BigDecimal){
val = HiveDecimal.create((BigDecimal) fieldValue);
} else if (fieldValue instanceof Double){
val = HiveDecimal.create((Double)fieldValue);
} else if (fieldValue instanceof Number) {
val = HiveDecimal.create(((Number)fieldValue).doubleValue());
} else {
val = HiveDecimal.create(DataTypeUtils.toDouble(fieldValue, field.getDataType().getFormat()));
}
break;
default:
throw new IllegalArgumentException("Field " + fieldName + " cannot be converted to type: " + primitiveCategory.name());
throw new IllegalArgumentException("Field " + field.getFieldName() + " cannot be converted to type: " + primitiveCategory.name());
}
break;
case LIST:
Object[] value = record.getAsArray(fieldName);
val = value == null ? null : Arrays.asList(value);
Object[] value = (Object[])fieldValue;
ListTypeInfo listTypeInfo = (ListTypeInfo)fieldTypeInfo;
TypeInfo nestedType = listTypeInfo.getListElementTypeInfo();
List<Object> converted = new ArrayList<>(value.length);
for(int i=0; i<value.length; i++){
converted.add(convertFieldValue(value[i], field, nestedType));
}
val = converted;
break;
case MAP:
val = record.getValue(fieldName);
//in nifi all maps are <String,?> so use that
Map<String, Object> valueMap = (Map<String,Object>)fieldValue;
MapTypeInfo mapTypeInfo = (MapTypeInfo)fieldTypeInfo;
Map<Object, Object> convertedMap = new HashMap<>(valueMap.size());
//get a key record field, nifi map keys are always string. synthesize new
//record fields for the map field key and value.
RecordField keyField = new RecordField(field.getFieldName() + ".key", RecordFieldType.STRING.getDataType());
RecordField valueField = new RecordField(field.getFieldName() + ".value", ((MapDataType)field.getDataType()).getValueType());
for (Map.Entry<String, Object> entry: valueMap.entrySet()) {
convertedMap.put(
convertFieldValue(entry.getKey(), keyField, mapTypeInfo.getMapKeyTypeInfo()),
convertFieldValue(entry.getValue(), valueField, mapTypeInfo.getMapValueTypeInfo())
);
}
val = convertedMap;
break;
case STRUCT:
// The Hive StandardStructObjectInspector expects the object corresponding to a "struct" to be an array or List rather than a Map.
// Do the conversion here, calling extractCurrentField recursively to traverse any nested structs.
Record nestedRecord = (Record) record.getValue(fieldName);
if (nestedRecord == null) {
return null;
}
try {
RecordSchema recordSchema = nestedRecord.getSchema();
List<RecordField> recordFields = recordSchema.getFields();
if (recordFields == null || recordFields.isEmpty()) {
return Collections.emptyList();
}
// This List will hold the values of the entries in the Map
List<Object> structList = new ArrayList<>(recordFields.size());
StructTypeInfo typeInfo = (StructTypeInfo) schema.getStructFieldTypeInfo(fieldName);
for (RecordField nestedRecordField : recordFields) {
String fName = nestedRecordField.getFieldName();
String normalizedFieldName = fName.toLowerCase();
structList.add(extractCurrentField(nestedRecord, nestedRecordField, typeInfo.getStructFieldTypeInfo(normalizedFieldName)));
}
return structList;
} catch (Exception e) {
log.warn("Error [{}] parsing Record [{}].", new Object[]{e.toString(), nestedRecord}, e);
throw new SerDeException(e);
}
// break unreachable
Record nestedRecord = (Record) fieldValue;
val = deserialize(nestedRecord, (StructTypeInfo)fieldTypeInfo, false);
break;
default:
log.error("Unknown type found: " + fieldTypeInfo + "for field of type: " + field.getDataType().toString());
return null;
}
return val;
}



@Override
public ObjectInspector getObjectInspector() {
return cachedObjectInspector;
}

private void populateFieldPositionMap() throws MalformedRecordException, IOException {
// Populate the mapping of field names to column positions only once
fieldPositionMap = new HashMap<>(columnNames.size());

RecordSchema recordSchema = recordReader.getSchema();

private static HashMap<String, Integer> populateFieldPositionMap(RecordSchema recordSchema, StructTypeInfo typeInfo, ComponentLog log) throws IOException {
// Populate the mapping of field names to column positions only once
HashMap<String, Integer> fieldPosition = new HashMap<>(typeInfo.getAllStructFieldNames().size());
for (RecordField field : recordSchema.getFields()) {
String fieldName = field.getFieldName();
String normalizedFieldName = fieldName.toLowerCase();

int fpos = schema.getAllStructFieldNames().indexOf(fieldName.toLowerCase());
int fpos = typeInfo.getAllStructFieldNames().indexOf(fieldName.toLowerCase());
if (fpos == -1) {
Matcher m = INTERNAL_PATTERN.matcher(fieldName);
fpos = m.matches() ? Integer.parseInt(m.group(1)) : -1;

log.debug("NPE finding position for field [{}] in schema [{}],"
+ " attempting to check if it is an internal column name like _col0", new Object[]{fieldName, schema});
+ " attempting to check if it is an internal column name like _col0", new Object[]{fieldName, typeInfo});
if (fpos == -1) {
// unknown field, we return. We'll continue from the next field onwards. Log at debug level because partition columns will be "unknown fields"
log.debug("Field {} is not found in the target table, ignoring...", new Object[]{field.getFieldName()});
Expand All @@ -332,7 +340,8 @@ private void populateFieldPositionMap() throws MalformedRecordException, IOExcep
// If we reached here, then we were successful at finding an alternate internal
// column mapping, and we're about to proceed.
}
fieldPositionMap.put(normalizedFieldName, fpos);
fieldPosition.put(normalizedFieldName, fpos);
}
return fieldPosition;
}
}
Loading

0 comments on commit d04fd4c

Please sign in to comment.