diff --git a/lakesoul-spark/src/main/java/org/apache/spark/sql/arrow/ArrowColumnVector.java b/lakesoul-spark/src/main/java/org/apache/spark/sql/arrow/ArrowColumnVector.java index 068fd85de..816ffe581 100644 --- a/lakesoul-spark/src/main/java/org/apache/spark/sql/arrow/ArrowColumnVector.java +++ b/lakesoul-spark/src/main/java/org/apache/spark/sql/arrow/ArrowColumnVector.java @@ -21,541 +21,561 @@ @DeveloperApi public class ArrowColumnVector extends ColumnVector { - ArrowVectorAccessor accessor; - ArrowColumnVector[] childColumns; - - public ValueVector getValueVector() { return accessor.vector; } - - @Override - public boolean hasNull() { - return accessor.getNullCount() > 0; - } - - @Override - public int numNulls() { - return accessor.getNullCount(); - } - - @Override - public void close() { - if (childColumns != null) { - for (int i = 0; i < childColumns.length; i++) { - childColumns[i].close(); - childColumns[i] = null; - } - childColumns = null; - } - accessor.close(); - } - - @Override - public boolean isNullAt(int rowId) { - return accessor.isNullAt(rowId); - } - - @Override - public boolean getBoolean(int rowId) { - return accessor.getBoolean(rowId); - } - - @Override - public byte getByte(int rowId) { - return accessor.getByte(rowId); - } - - @Override - public short getShort(int rowId) { - return accessor.getShort(rowId); - } - - @Override - public int getInt(int rowId) { - return accessor.getInt(rowId); - } - - @Override - public long getLong(int rowId) { - return accessor.getLong(rowId); - } - - @Override - public float getFloat(int rowId) { - return accessor.getFloat(rowId); - } - - @Override - public double getDouble(int rowId) { - return accessor.getDouble(rowId); - } - - @Override - public Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; - return accessor.getDecimal(rowId, precision, scale); - } - - @Override - public UTF8String getUTF8String(int rowId) { - if (isNullAt(rowId)) return null; - return accessor.getUTF8String(rowId); - } - - @Override - public byte[] getBinary(int rowId) { - if (isNullAt(rowId)) return null; - return accessor.getBinary(rowId); - } - - @Override - public ColumnarArray getArray(int rowId) { - if (isNullAt(rowId)) return null; - return accessor.getArray(rowId); - } - - @Override - public ColumnarMap getMap(int rowId) { - if (isNullAt(rowId)) return null; - return accessor.getMap(rowId); - } - - @Override - public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } - - ArrowColumnVector(DataType type) { - super(type); - } - - public ArrowColumnVector(ValueVector vector) { - this(ArrowUtils.fromArrowField(vector.getField())); - initAccessor(vector); - } - - void initAccessor(ValueVector vector) { - if (vector instanceof BitVector) { - accessor = new BooleanAccessor((BitVector) vector); - } else if (vector instanceof TinyIntVector) { - accessor = new ByteAccessor((TinyIntVector) vector); - } else if (vector instanceof SmallIntVector) { - accessor = new ShortAccessor((SmallIntVector) vector); - } else if (vector instanceof IntVector) { - accessor = new IntAccessor((IntVector) vector); - } else if (vector instanceof BigIntVector) { - accessor = new LongAccessor((BigIntVector) vector); - } else if (vector instanceof Float4Vector) { - accessor = new FloatAccessor((Float4Vector) vector); - } else if (vector instanceof Float8Vector) { - accessor = new DoubleAccessor((Float8Vector) vector); - } else if (vector instanceof DecimalVector) { - accessor = new DecimalAccessor((DecimalVector) vector); - } else if (vector instanceof VarCharVector) { - accessor = new StringAccessor((VarCharVector) vector); - } else if (vector instanceof VarBinaryVector) { - accessor = new BinaryAccessor((VarBinaryVector) vector); - } else if (vector instanceof DateDayVector) { - accessor = new DateAccessor((DateDayVector) vector); - } else if (vector instanceof TimeStampMicroTZVector) { - accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); - } else if (vector instanceof TimeStampMicroVector) { - accessor = new TimestampNTZAccessor((TimeStampMicroVector) vector); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; - accessor = new MapAccessor(mapVector); - } else if (vector instanceof ListVector) { - ListVector listVector = (ListVector) vector; - accessor = new ArrayAccessor(listVector); - } else if (vector instanceof StructVector) { - StructVector structVector = (StructVector) vector; - accessor = new StructAccessor(structVector); + ArrowVectorAccessor accessor; + ArrowColumnVector[] childColumns; - childColumns = new ArrowColumnVector[structVector.size()]; - for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ArrowColumnVector(structVector.getVectorById(i)); - } - } else if (vector instanceof NullVector) { - accessor = new NullAccessor((NullVector) vector); - } else if (vector instanceof IntervalYearVector) { - accessor = new IntervalYearAccessor((IntervalYearVector) vector); - } else if (vector instanceof DurationVector) { - accessor = new DurationAccessor((DurationVector) vector); - } else { - throw new UnsupportedOperationException(); + public ValueVector getValueVector() { + return accessor.vector; } - } - abstract static class ArrowVectorAccessor { - - final ValueVector vector; - - ArrowVectorAccessor(ValueVector vector) { - this.vector = vector; + @Override + public boolean hasNull() { + return accessor.getNullCount() > 0; } - // TODO: should be final after removing ArrayAccessor workaround - boolean isNullAt(int rowId) { - return vector.isNull(rowId); + @Override + public int numNulls() { + return accessor.getNullCount(); } - final int getNullCount() { - return vector.getNullCount(); + @Override + public void close() { + if (childColumns != null) { + for (int i = 0; i < childColumns.length; i++) { + childColumns[i].close(); + childColumns[i] = null; + } + childColumns = null; + } + accessor.close(); } - final void close() { - vector.close(); + @Override + public boolean isNullAt(int rowId) { + return accessor.isNullAt(rowId); } - boolean getBoolean(int rowId) { - throw new UnsupportedOperationException(); + @Override + public boolean getBoolean(int rowId) { + return accessor.getBoolean(rowId); } - byte getByte(int rowId) { - throw new UnsupportedOperationException(); + @Override + public byte getByte(int rowId) { + return accessor.getByte(rowId); } - short getShort(int rowId) { - throw new UnsupportedOperationException(); + @Override + public short getShort(int rowId) { + return accessor.getShort(rowId); } - int getInt(int rowId) { - throw new UnsupportedOperationException(); + @Override + public int getInt(int rowId) { + return accessor.getInt(rowId); } - long getLong(int rowId) { - throw new UnsupportedOperationException(); + @Override + public long getLong(int rowId) { + return accessor.getLong(rowId); } - float getFloat(int rowId) { - throw new UnsupportedOperationException(); + @Override + public float getFloat(int rowId) { + return accessor.getFloat(rowId); } - double getDouble(int rowId) { - throw new UnsupportedOperationException(); + @Override + public double getDouble(int rowId) { + return accessor.getDouble(rowId); } - Decimal getDecimal(int rowId, int precision, int scale) { - throw new UnsupportedOperationException(); + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + return accessor.getDecimal(rowId, precision, scale); } - UTF8String getUTF8String(int rowId) { - throw new UnsupportedOperationException(); + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; + return accessor.getUTF8String(rowId); } - byte[] getBinary(int rowId) { - throw new UnsupportedOperationException(); + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; + return accessor.getBinary(rowId); } - ColumnarArray getArray(int rowId) { - throw new UnsupportedOperationException(); + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; + return accessor.getArray(rowId); } - ColumnarMap getMap(int rowId) { - throw new UnsupportedOperationException(); + @Override + public ColumnarMap getMap(int rowId) { + if (isNullAt(rowId)) return null; + return accessor.getMap(rowId); } - } - static class BooleanAccessor extends ArrowVectorAccessor { + @Override + public ArrowColumnVector getChild(int ordinal) { + return childColumns[ordinal]; + } + + ArrowColumnVector(DataType type) { + super(type); + } + + public ArrowColumnVector(ValueVector vector) { + this(ArrowUtils.sparkTypeFromArrowField(vector.getField())); + initAccessor(vector); + } + + void initAccessor(ValueVector vector) { + if (vector instanceof BitVector) { + accessor = new BooleanAccessor((BitVector) vector); + } else if (vector instanceof TinyIntVector) { + accessor = new ByteAccessor((TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + accessor = new ShortAccessor((SmallIntVector) vector); + } else if (vector instanceof IntVector) { + accessor = new IntAccessor((IntVector) vector); + } else if (vector instanceof BigIntVector) { + accessor = new LongAccessor((BigIntVector) vector); + } else if (vector instanceof Float4Vector) { + accessor = new FloatAccessor((Float4Vector) vector); + } else if (vector instanceof Float8Vector) { + accessor = new DoubleAccessor((Float8Vector) vector); + } else if (vector instanceof DecimalVector) { + accessor = new DecimalAccessor((DecimalVector) vector); + } else if (vector instanceof VarCharVector) { + accessor = new StringAccessor((VarCharVector) vector); + } else if (vector instanceof VarBinaryVector) { + accessor = new BinaryAccessor((VarBinaryVector) vector); + } else if (vector instanceof DateDayVector) { + accessor = new DateAccessor((DateDayVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); + } else if (vector instanceof TimeStampMicroVector) { + accessor = new TimestampNTZAccessor((TimeStampMicroVector) vector); + } else if (vector instanceof MapVector) { + MapVector mapVector = (MapVector) vector; + accessor = new MapAccessor(mapVector); + } else if (vector instanceof ListVector) { + ListVector listVector = (ListVector) vector; + accessor = new ArrayAccessor(listVector); + } else if (vector instanceof StructVector) { + StructVector structVector = (StructVector) vector; + accessor = new StructAccessor(structVector); + + childColumns = new ArrowColumnVector[structVector.size()]; + for (int i = 0; i < childColumns.length; ++i) { + childColumns[i] = new ArrowColumnVector(structVector.getVectorById(i)); + } + } else if (vector instanceof NullVector) { + accessor = new NullAccessor((NullVector) vector); + } else if (vector instanceof IntervalYearVector) { + accessor = new IntervalYearAccessor((IntervalYearVector) vector); + } else if (vector instanceof DurationVector) { + accessor = new DurationAccessor((DurationVector) vector); + } else if (vector instanceof TimeSecVector) { + accessor = new TimeSecAccessor((TimeSecVector) vector); + } else { + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass()); + } + } + + abstract static class ArrowVectorAccessor { + + final ValueVector vector; + + ArrowVectorAccessor(ValueVector vector) { + this.vector = vector; + } + + // TODO: should be final after removing ArrayAccessor workaround + boolean isNullAt(int rowId) { + return vector.isNull(rowId); + } + + final int getNullCount() { + return vector.getNullCount(); + } + + final void close() { + vector.close(); + } + + boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } - private final BitVector accessor; + byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } - BooleanAccessor(BitVector vector) { - super(vector); - this.accessor = vector; - } + ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } - @Override - final boolean getBoolean(int rowId) { - return accessor.get(rowId) == 1; + ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } } - } - static class ByteAccessor extends ArrowVectorAccessor { + static class BooleanAccessor extends ArrowVectorAccessor { - private final TinyIntVector accessor; + private final BitVector accessor; - ByteAccessor(TinyIntVector vector) { - super(vector); - this.accessor = vector; - } + BooleanAccessor(BitVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final byte getByte(int rowId) { - return accessor.get(rowId); + @Override + final boolean getBoolean(int rowId) { + return accessor.get(rowId) == 1; + } } - } - static class ShortAccessor extends ArrowVectorAccessor { + static class ByteAccessor extends ArrowVectorAccessor { - private final SmallIntVector accessor; + private final TinyIntVector accessor; - ShortAccessor(SmallIntVector vector) { - super(vector); - this.accessor = vector; - } + ByteAccessor(TinyIntVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final short getShort(int rowId) { - return accessor.get(rowId); + @Override + final byte getByte(int rowId) { + return accessor.get(rowId); + } } - } - static class IntAccessor extends ArrowVectorAccessor { + static class ShortAccessor extends ArrowVectorAccessor { - private final IntVector accessor; + private final SmallIntVector accessor; - IntAccessor(IntVector vector) { - super(vector); - this.accessor = vector; - } + ShortAccessor(SmallIntVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final int getInt(int rowId) { - return accessor.get(rowId); + @Override + final short getShort(int rowId) { + return accessor.get(rowId); + } } - } - static class LongAccessor extends ArrowVectorAccessor { + static class IntAccessor extends ArrowVectorAccessor { - private final BigIntVector accessor; + private final IntVector accessor; - LongAccessor(BigIntVector vector) { - super(vector); - this.accessor = vector; - } + IntAccessor(IntVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final long getLong(int rowId) { - return accessor.get(rowId); + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } } - } - static class FloatAccessor extends ArrowVectorAccessor { + static class LongAccessor extends ArrowVectorAccessor { - private final Float4Vector accessor; + private final BigIntVector accessor; - FloatAccessor(Float4Vector vector) { - super(vector); - this.accessor = vector; - } + LongAccessor(BigIntVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final float getFloat(int rowId) { - return accessor.get(rowId); + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } } - } - static class DoubleAccessor extends ArrowVectorAccessor { + static class FloatAccessor extends ArrowVectorAccessor { - private final Float8Vector accessor; + private final Float4Vector accessor; - DoubleAccessor(Float8Vector vector) { - super(vector); - this.accessor = vector; - } + FloatAccessor(Float4Vector vector) { + super(vector); + this.accessor = vector; + } - @Override - final double getDouble(int rowId) { - return accessor.get(rowId); + @Override + final float getFloat(int rowId) { + return accessor.get(rowId); + } } - } - static class DecimalAccessor extends ArrowVectorAccessor { + static class DoubleAccessor extends ArrowVectorAccessor { - private final DecimalVector accessor; + private final Float8Vector accessor; - DecimalAccessor(DecimalVector vector) { - super(vector); - this.accessor = vector; - } + DoubleAccessor(Float8Vector vector) { + super(vector); + this.accessor = vector; + } - @Override - final Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; - return Decimal.apply(accessor.getObject(rowId), precision, scale); + @Override + final double getDouble(int rowId) { + return accessor.get(rowId); + } } - } - static class StringAccessor extends ArrowVectorAccessor { + static class DecimalAccessor extends ArrowVectorAccessor { - private final VarCharVector accessor; - private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); + private final DecimalVector accessor; - StringAccessor(VarCharVector vector) { - super(vector); - this.accessor = vector; - } + DecimalAccessor(DecimalVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final UTF8String getUTF8String(int rowId) { - accessor.get(rowId, stringResult); - if (stringResult.isSet == 0) { - return null; - } else { - return UTF8String.fromAddress(null, - stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); - } + @Override + final Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + return Decimal.apply(accessor.getObject(rowId), precision, scale); + } } - } - static class BinaryAccessor extends ArrowVectorAccessor { + static class StringAccessor extends ArrowVectorAccessor { - private final VarBinaryVector accessor; + private final VarCharVector accessor; + private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); - BinaryAccessor(VarBinaryVector vector) { - super(vector); - this.accessor = vector; - } + StringAccessor(VarCharVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final byte[] getBinary(int rowId) { - return accessor.getObject(rowId); + @Override + final UTF8String getUTF8String(int rowId) { + accessor.get(rowId, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress(null, + stringResult.buffer.memoryAddress() + stringResult.start, + stringResult.end - stringResult.start); + } + } } - } - static class DateAccessor extends ArrowVectorAccessor { + static class BinaryAccessor extends ArrowVectorAccessor { - private final DateDayVector accessor; + private final VarBinaryVector accessor; - DateAccessor(DateDayVector vector) { - super(vector); - this.accessor = vector; - } + BinaryAccessor(VarBinaryVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final int getInt(int rowId) { - return accessor.get(rowId); + @Override + final byte[] getBinary(int rowId) { + return accessor.getObject(rowId); + } } - } - static class TimestampAccessor extends ArrowVectorAccessor { + static class DateAccessor extends ArrowVectorAccessor { - private final TimeStampMicroTZVector accessor; + private final DateDayVector accessor; - TimestampAccessor(TimeStampMicroTZVector vector) { - super(vector); - this.accessor = vector; - } + DateAccessor(DateDayVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final long getLong(int rowId) { - return accessor.get(rowId); + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } } - } - static class TimestampNTZAccessor extends ArrowVectorAccessor { + static class TimestampAccessor extends ArrowVectorAccessor { - private final TimeStampMicroVector accessor; + private final TimeStampMicroTZVector accessor; - TimestampNTZAccessor(TimeStampMicroVector vector) { - super(vector); - this.accessor = vector; - } + TimestampAccessor(TimeStampMicroTZVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final long getLong(int rowId) { - return accessor.get(rowId); + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } } - } - static class ArrayAccessor extends ArrowVectorAccessor { + static class TimestampNTZAccessor extends ArrowVectorAccessor { - private final ListVector accessor; - private final ArrowColumnVector arrayData; + private final TimeStampMicroVector accessor; - ArrayAccessor(ListVector vector) { - super(vector); - this.accessor = vector; - this.arrayData = new ArrowColumnVector(vector.getDataVector()); - } + TimestampNTZAccessor(TimeStampMicroVector vector) { + super(vector); + this.accessor = vector; + } - @Override - final boolean isNullAt(int rowId) { - // TODO: Workaround if vector has all non-null values, see ARROW-1948 - if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { - return false; - } else { - return super.isNullAt(rowId); - } + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } } - @Override - final ColumnarArray getArray(int rowId) { - int start = accessor.getElementStartIndex(rowId); - int end = accessor.getElementEndIndex(rowId); - return new ColumnarArray(arrayData, start, end - start); - } - } + static class ArrayAccessor extends ArrowVectorAccessor { + + private final ListVector accessor; + private final ArrowColumnVector arrayData; - /** - * Any call to "get" method will throw UnsupportedOperationException. - * - * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses - * getStruct() method defined in the parent class. Any call to "get" method in this class is a - * bug in the code. - * - */ - static class StructAccessor extends ArrowVectorAccessor { + ArrayAccessor(ListVector vector) { + super(vector); + this.accessor = vector; + this.arrayData = new ArrowColumnVector(vector.getDataVector()); + } - StructAccessor(StructVector vector) { - super(vector); + @Override + final boolean isNullAt(int rowId) { + // TODO: Workaround if vector has all non-null values, see ARROW-1948 + if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { + return false; + } else { + return super.isNullAt(rowId); + } + } + + @Override + final ColumnarArray getArray(int rowId) { + int start = accessor.getElementStartIndex(rowId); + int end = accessor.getElementEndIndex(rowId); + return new ColumnarArray(arrayData, start, end - start); + } } - } - static class MapAccessor extends ArrowVectorAccessor { - private final MapVector accessor; - private final ArrowColumnVector keys; - private final ArrowColumnVector values; + /** + * Any call to "get" method will throw UnsupportedOperationException. + *

+ * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses + * getStruct() method defined in the parent class. Any call to "get" method in this class is a + * bug in the code. + */ + static class StructAccessor extends ArrowVectorAccessor { - MapAccessor(MapVector vector) { - super(vector); - this.accessor = vector; - StructVector entries = (StructVector) vector.getDataVector(); - this.keys = new ArrowColumnVector(entries.getChild(MapVector.KEY_NAME)); - this.values = new ArrowColumnVector(entries.getChild(MapVector.VALUE_NAME)); + StructAccessor(StructVector vector) { + super(vector); + } } - @Override - final ColumnarMap getMap(int rowId) { - int index = rowId * MapVector.OFFSET_WIDTH; - int offset = accessor.getOffsetBuffer().getInt(index); - int length = accessor.getInnerValueCountAt(rowId); - return new ColumnarMap(keys, values, offset, length); + static class MapAccessor extends ArrowVectorAccessor { + private final MapVector accessor; + private final ArrowColumnVector keys; + private final ArrowColumnVector values; + + MapAccessor(MapVector vector) { + super(vector); + this.accessor = vector; + StructVector entries = (StructVector) vector.getDataVector(); + this.keys = new ArrowColumnVector(entries.getChild(MapVector.KEY_NAME)); + this.values = new ArrowColumnVector(entries.getChild(MapVector.VALUE_NAME)); + } + + @Override + final ColumnarMap getMap(int rowId) { + int index = rowId * MapVector.OFFSET_WIDTH; + int offset = accessor.getOffsetBuffer().getInt(index); + int length = accessor.getInnerValueCountAt(rowId); + return new ColumnarMap(keys, values, offset, length); + } } - } - static class NullAccessor extends ArrowVectorAccessor { + static class NullAccessor extends ArrowVectorAccessor { - NullAccessor(NullVector vector) { - super(vector); + NullAccessor(NullVector vector) { + super(vector); + } } - } - static class IntervalYearAccessor extends ArrowVectorAccessor { + static class IntervalYearAccessor extends ArrowVectorAccessor { - private final IntervalYearVector accessor; + private final IntervalYearVector accessor; - IntervalYearAccessor(IntervalYearVector vector) { - super(vector); - this.accessor = vector; - } + IntervalYearAccessor(IntervalYearVector vector) { + super(vector); + this.accessor = vector; + } - @Override - int getInt(int rowId) { - return accessor.get(rowId); + @Override + int getInt(int rowId) { + return accessor.get(rowId); + } } - } - static class DurationAccessor extends ArrowVectorAccessor { + static class DurationAccessor extends ArrowVectorAccessor { - private final DurationVector accessor; + private final DurationVector accessor; - DurationAccessor(DurationVector vector) { - super(vector); - this.accessor = vector; + DurationAccessor(DurationVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final long getLong(int rowId) { + return DurationVector.get(accessor.getDataBuffer(), rowId); + } } - @Override - final long getLong(int rowId) { - return DurationVector.get(accessor.getDataBuffer(), rowId); + static class TimeSecAccessor extends ArrowVectorAccessor { + + private final TimeSecVector accessor; + + TimeSecAccessor(TimeSecVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final int getInt(int rowId) { + return TimeSecVector.get(accessor.getDataBuffer(), rowId); + } } - } } diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/arrow/ArrowWriter.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/arrow/ArrowWriter.scala index 4a001a5f5..047f5ca8b 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/arrow/ArrowWriter.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/arrow/ArrowWriter.scala @@ -31,7 +31,7 @@ object ArrowWriter { private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() - (ArrowUtils.fromArrowField(field), vector) match { + (ArrowUtils.sparkTypeFromArrowField(field), vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) @@ -44,7 +44,8 @@ object ArrowWriter { case (StringType, vector: VarCharVector) => new StringWriter(vector) case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) case (DateType, vector: DateDayVector) => new DateWriter(vector) - case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) + case (TimestampType, vector: TimeStampVector) => new TimestampWriter(vector) + case (_, vector: TimeSecVector) => new TimeSecWriter(vector) case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) @@ -63,7 +64,7 @@ object ArrowWriter { case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) case (dt, _) => - throw QueryExecutionErrors.unsupportedDataTypeError(dt.catalogString) + throw QueryExecutionErrors.unsupportedDataTypeError(s"${dt.catalogString} for vector ${vector.getClass}") } } } @@ -100,10 +101,13 @@ private[arrow] abstract class ArrowFieldWriter { def valueVector: ValueVector def name: String = valueVector.getField().getName() - def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField()) + + def dataType: DataType = ArrowUtils.sparkTypeFromArrowField(valueVector.getField()) + def nullable: Boolean = valueVector.getField().isNullable() def setNull(): Unit + def setValue(input: SpecializedGetters, ordinal: Int): Unit private[arrow] var count: Int = 0 @@ -205,9 +209,9 @@ private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFi } private[arrow] class DecimalWriter( - val valueVector: DecimalVector, - precision: Int, - scale: Int) extends ArrowFieldWriter { + val valueVector: DecimalVector, + precision: Int, + scale: Int) extends ArrowFieldWriter { override def setNull(): Unit = { valueVector.setNull(count) @@ -238,7 +242,7 @@ private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowF } private[arrow] class BinaryWriter( - val valueVector: VarBinaryVector) extends ArrowFieldWriter { + val valueVector: VarBinaryVector) extends ArrowFieldWriter { override def setNull(): Unit = { valueVector.setNull(count) @@ -262,7 +266,7 @@ private[arrow] class DateWriter(val valueVector: DateDayVector) extends ArrowFie } private[arrow] class TimestampWriter( - val valueVector: TimeStampMicroTZVector) extends ArrowFieldWriter { + val valueVector: TimeStampVector) extends ArrowFieldWriter { override def setNull(): Unit = { valueVector.setNull(count) @@ -274,7 +278,7 @@ private[arrow] class TimestampWriter( } private[arrow] class TimestampNTZWriter( - val valueVector: TimeStampMicroVector) extends ArrowFieldWriter { + val valueVector: TimeStampMicroVector) extends ArrowFieldWriter { override def setNull(): Unit = { valueVector.setNull(count) @@ -285,9 +289,20 @@ private[arrow] class TimestampNTZWriter( } } +private[arrow] class TimeSecWriter(val valueVector: TimeSecVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getInt(ordinal)) + } +} + private[arrow] class ArrayWriter( - val valueVector: ListVector, - val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { + val valueVector: ListVector, + val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { override def setNull(): Unit = { } @@ -315,8 +330,8 @@ private[arrow] class ArrayWriter( } private[arrow] class StructWriter( - val valueVector: StructVector, - children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { + val valueVector: StructVector, + children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { override def setNull(): Unit = { var i = 0 @@ -350,10 +365,10 @@ private[arrow] class StructWriter( } private[arrow] class MapWriter( - val valueVector: MapVector, - val structVector: StructVector, - val keyWriter: ArrowFieldWriter, - val valueWriter: ArrowFieldWriter) extends ArrowFieldWriter { + val valueVector: MapVector, + val structVector: StructVector, + val keyWriter: ArrowFieldWriter, + val valueWriter: ArrowFieldWriter) extends ArrowFieldWriter { override def setNull(): Unit = {} @@ -363,7 +378,7 @@ private[arrow] class MapWriter( val keys = map.keyArray() val values = map.valueArray() var i = 0 - while (i < map.numElements()) { + while (i < map.numElements()) { structVector.setIndexDefined(keyWriter.count) keyWriter.write(keys, i) valueWriter.write(values, i) @@ -415,4 +430,4 @@ private[arrow] class DurationWriter(val valueVector: DurationVector) override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { valueVector.set(count, input.getLong(ordinal)) } -} +} \ No newline at end of file diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/LakeSoulFileWriter.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/LakeSoulFileWriter.scala index ae166e25e..c6d305a50 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/LakeSoulFileWriter.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/LakeSoulFileWriter.scala @@ -33,8 +33,7 @@ import org.apache.spark.sql.lakesoul.rules.withPartitionAndOrdering import org.apache.spark.sql.lakesoul.sources.LakeSoulSQLConf import org.apache.spark.sql.vectorized.ArrowFakeRowAdaptor import org.apache.spark.util.{SerializableConfiguration, Utils} - -import com.dmetasoul.lakesoul.meta.DBConfig.LAKESOUL_NON_PARTITION_TABLE_PART_DESC +import com.dmetasoul.lakesoul.meta.DBConfig.{LAKESOUL_NON_PARTITION_TABLE_PART_DESC, LAKESOUL_RANGE_PARTITION_SPLITTER} import java.util.{Date, UUID} @@ -395,6 +394,7 @@ object LakeSoulFileWriter extends Logging { private var fileCounter: Int = _ private var recordsInFile: Long = _ private val partValue: Option[String] = options.get("partValue").filter(_ != LAKESOUL_NON_PARTITION_TABLE_PART_DESC) + .map(_.replace(LAKESOUL_RANGE_PARTITION_SPLITTER, "/")) private def newOutputWriter(): Unit = { recordsInFile = 0 diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/CompactionCommand.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/CompactionCommand.scala index bafcc3f49..276293d1a 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/CompactionCommand.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/lakesoul/commands/CompactionCommand.scala @@ -60,8 +60,8 @@ case class CompactionCommand(snapshotManagement: SnapshotManagement, val option = new CaseInsensitiveStringMap( Map("basePath" -> tc.tableInfo.table_path_s.get, "isCompaction" -> "true")) - val partitionNames = readPartitionInfo.map(p => { - p.range_value.split("=").head + val partitionNames = readPartitionInfo.head.range_value.split(',').map(p => { + p.split('=').head }) val scan = table.newScanBuilder(option).build() diff --git a/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/CompactionSuite.scala b/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/CompactionSuite.scala index 48a46183d..627836bb5 100644 --- a/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/CompactionSuite.scala +++ b/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/CompactionSuite.scala @@ -203,6 +203,43 @@ class CompactionSuite extends QueryTest }) } + test("simple compaction - multiple partition") { + withTempDir(file => { + val tableName = file.getCanonicalPath + + val df1 = Seq(("", "", 1, 1), (null, "", 1, 1), ("3", null, 1, 1), ("1", "2", 2, 2), ("1", "3", 3, 3)) + .toDF("range1", "range2", "hash", "value") + df1.write + .option("rangePartitions", "range1,range2") + .option("hashPartitions", "hash") + .option("hashBucketNum", "2") + .format("lakesoul") + .save(tableName) + + val sm = SnapshotManagement(SparkUtil.makeQualifiedTablePath(new Path(tableName)).toString) + var rangeGroup = SparkUtil.allDataInfo(sm.updateSnapshot()).groupBy(_.range_partitions) + assert(rangeGroup.forall(_._2.groupBy(_.file_bucket_id).forall(_._2.length == 1))) + + + val df2 = Seq(("", "", 1, 2), (null, "", 1, 2), ("3", null, 1, 2), ("1", "2", 2, 3), ("1", "3", 3, 4)) + .toDF("range1", "range2", "hash", "name") + + withSQLConf("spark.dmetasoul.lakesoul.schema.autoMerge.enabled" -> "true") { + LakeSoulTable.forPath(tableName).upsert(df2) + } + + rangeGroup = SparkUtil.allDataInfo(sm.updateSnapshot()).groupBy(_.range_partitions) + assert(!rangeGroup.forall(_._2.groupBy(_.file_bucket_id).forall(_._2.length == 1))) + + + LakeSoulTable.forPath(tableName).compaction(true) + rangeGroup = SparkUtil.allDataInfo(sm.updateSnapshot()).groupBy(_.range_partitions) + rangeGroup.forall(_._2.groupBy(_.file_bucket_id).forall(_._2.length == 1)) + assert(rangeGroup.forall(_._2.groupBy(_.file_bucket_id).forall(_._2.length == 1))) + + }) + } + test("compaction with condition - simple") { withTempDir(file => { val tableName = file.getCanonicalPath diff --git a/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala b/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala index 8fd6f0ab0..b75e375ec 100644 --- a/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala +++ b/native-io/lakesoul-io-java/src/main/scala/org/apache/spark/sql/arrow/ArrowUtils.scala @@ -4,20 +4,23 @@ package org.apache.spark.sql.arrow -import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.complex.MapVector -import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{TimestampType, _} +import org.json4s.jackson.JsonMethods.mapper import java.util +import scala.collection.JavaConverters._ object ArrowUtils { val rootAllocator = new RootAllocator(Long.MaxValue) + private val writer = mapper.writerWithDefaultPrettyPrinter + private val reader = mapper.readerFor(classOf[Field]) // todo: support more types. @@ -60,8 +63,6 @@ object ArrowUtils { case ArrowType.Binary.INSTANCE => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType - case ts: ArrowType.Timestamp - if ts.getTimezone == null => TimestampNTZType case ts: ArrowType.Timestamp => TimestampType case ArrowType.Null.INSTANCE => NullType case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType() @@ -72,13 +73,17 @@ object ArrowUtils { } /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ - def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId: String, metadata: util.Map[String, String] = null): Field = { + def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId: String, sparkFieldMetadata: Metadata, metadata: util.Map[String, String] = null): Field = { + + if (sparkFieldMetadata.contains("__lakesoul_arrow_field__")) { + return reader.readValue(sparkFieldMetadata.getString("__lakesoul_arrow_field__")) + } dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null, metadata) new Field(name, fieldType, - Seq(toArrowField("element", elementType, containsNull, timeZoneId, metadata)).asJava) + Seq(toArrowField("element", elementType, containsNull, timeZoneId, sparkFieldMetadata, metadata)).asJava) case StructType(fields) => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null, metadata) new Field(name, fieldType, @@ -89,7 +94,7 @@ object ArrowUtils { map.put("spark_comment", comment.get) map } else null - toArrowField(field.name, field.dataType, field.nullable, timeZoneId, child_metadata) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, sparkFieldMetadata, child_metadata) }.toSeq.asJava) case MapType(keyType, valueType, valueContainsNull) => val mapType = new FieldType(nullable, new ArrowType.Map(false), null, metadata) @@ -100,27 +105,29 @@ object ArrowUtils { .add(MapVector.KEY_NAME, keyType, nullable = false) .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), nullable = false, - timeZoneId)).asJava) + timeZoneId, + sparkFieldMetadata + )).asJava) case dataType => val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null, metadata) new Field(name, fieldType, Seq.empty[Field].asJava) } } - def fromArrowField(field: Field): DataType = { + def sparkTypeFromArrowField(field: Field): DataType = { field.getType match { case _: ArrowType.Map => val elementField = field.getChildren.get(0) - val keyType = fromArrowField(elementField.getChildren.get(0)) - val valueType = fromArrowField(elementField.getChildren.get(1)) + val keyType = sparkTypeFromArrowField(elementField.getChildren.get(0)) + val valueType = sparkTypeFromArrowField(elementField.getChildren.get(1)) MapType(keyType, valueType, elementField.getChildren.get(1).isNullable) case ArrowType.List.INSTANCE => val elementField = field.getChildren().get(0) - val elementType = fromArrowField(elementField) + val elementType = sparkTypeFromArrowField(elementField) ArrayType(elementType, containsNull = elementField.isNullable) case ArrowType.Struct.INSTANCE => val fields = field.getChildren().asScala.map { child => - val dt = fromArrowField(child) + val dt = sparkTypeFromArrowField(child) val comment = child.getMetadata.get("spark_comment") if (comment == null) StructField(child.getName, dt, child.isNullable) @@ -132,6 +139,30 @@ object ArrowUtils { } } + def fromArrowField(field: Field): StructField = { + val dt = sparkTypeFromArrowField(field) + val metadata = field.getMetadata + val comment = metadata.get("spark_comment") + val sparkField = + if (comment == null) + StructField(field.getName, dt, field.isNullable) + else + StructField(field.getName, dt, field.isNullable).withComment(comment) + val newMetadata = new MetadataBuilder() + newMetadata.withMetadata(sparkField.metadata) + metadata.forEach((key, value) => if (key != "spark_comment") { + newMetadata.putString(key, value) + }) + field.getType match { + case ti: ArrowType.Time if ti.getBitWidth == 32 => + newMetadata.putString("__lakesoul_arrow_field__", writer.writeValueAsString(field)) + case ts: ArrowType.Timestamp if ts.getTimezone == null => + newMetadata.putString("__lakesoul_arrow_field__", writer.writeValueAsString(field)) + case _ => + } + sparkField.copy(metadata = newMetadata.build()) + } + /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ def toArrowSchema(schema: StructType, timeZoneId: String = "UTC"): Schema = { new Schema(schema.map { field => @@ -141,19 +172,12 @@ object ArrowUtils { map.put("spark_comment", comment.get) map } else null - toArrowField(field.name, field.dataType, field.nullable, timeZoneId, metadata) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, field.metadata, metadata) }.asJava) } def fromArrowSchema(schema: Schema): StructType = { - StructType(schema.getFields.asScala.map { field => - val dt = fromArrowField(field) - val comment = field.getMetadata.get("spark_comment") - if (comment == null) - StructField(field.getName, dt, field.isNullable) - else - StructField(field.getName, dt, field.isNullable).withComment(comment) - }.toSeq) + StructType(schema.getFields.asScala.map(fromArrowField)) } /** Return Map with conf settings to be used in ArrowPythonRunner */