Skip to content

Commit

Permalink
feat: Add better roundtrip support (#142)
Browse files Browse the repository at this point in the history
Follow-up for #141

Also notice apache/arrow#38891
  • Loading branch information
candiduslynx committed Nov 27, 2023
1 parent 8c09cbb commit d59f91a
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 106 deletions.
167 changes: 107 additions & 60 deletions lib/src/main/java/io/cloudquery/helper/ArrowHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,30 @@
import io.cloudquery.schema.Resource;
import io.cloudquery.schema.Table;
import io.cloudquery.schema.Table.TableBuilder;
import io.cloudquery.types.JSONType;
import io.cloudquery.types.JSONType.JSONVector;
import io.cloudquery.types.UUIDType;
import io.cloudquery.types.UUIDType.UUIDVector;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.time.Duration;
import java.util.*;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.LargeVarBinaryVector;
import org.apache.arrow.vector.LargeVarCharVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeStampVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.UInt1Vector;
import org.apache.arrow.vector.UInt2Vector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.joou.UByte;
import org.joou.UInteger;
import org.joou.ULong;
import org.joou.UShort;

public class ArrowHelper {
public static final String CQ_EXTENSION_INCREMENTAL = "cq:extension:incremental";
Expand All @@ -72,6 +57,32 @@ private static void setVectorData(FieldVector vector, Object data) {
bitVector.set(0, (boolean) data ? 1 : 0);
return;
}
if (vector instanceof DateDayVector dayDateVector) {
dayDateVector.set(0, (int) data);
return;
}
if (vector instanceof DateMilliVector dateMilliVector) {
dateMilliVector.set(0, (long) data);
return;
}
if (vector instanceof DurationVector durationVector) {
Duration duration = (Duration) data;
switch (durationVector.getUnit()) {
case SECOND -> {
durationVector.set(0, duration.toSeconds());
}
case MILLISECOND -> {
durationVector.set(0, duration.toMillis());
}
case MICROSECOND -> {
durationVector.set(0, duration.toNanos() / 1000);
}
case NANOSECOND -> {
durationVector.set(0, duration.toNanos());
}
}
return;
}
if (vector instanceof FixedSizeBinaryVector fixedSizeBinaryVector) {
fixedSizeBinaryVector.set(0, (byte[]) data);
return;
Expand Down Expand Up @@ -100,6 +111,22 @@ private static void setVectorData(FieldVector vector, Object data) {
smallIntVector.set(0, (short) data);
return;
}
if (vector instanceof TimeMicroVector timeMicroVector) {
timeMicroVector.set(0, (long) data);
return;
}
if (vector instanceof TimeMilliVector timeMilliVector) {
timeMilliVector.set(0, (int) data);
return;
}
if (vector instanceof TimeNanoVector timeNanoVector) {
timeNanoVector.set(0, (long) data);
return;
}
if (vector instanceof TimeSecVector timeSecVector) {
timeSecVector.set(0, (int) data);
return;
}
if (vector instanceof TimeStampVector timeStampVector) {
timeStampVector.set(0, (long) data);
return;
Expand All @@ -109,19 +136,19 @@ private static void setVectorData(FieldVector vector, Object data) {
return;
}
if (vector instanceof UInt1Vector uInt1Vector) {
uInt1Vector.set(0, (byte) data);
uInt1Vector.set(0, ((UByte) data).shortValue());
return;
}
if (vector instanceof UInt2Vector uInt2Vector) {
uInt2Vector.set(0, (short) data);
uInt2Vector.set(0, ((UShort) data).intValue());
return;
}
if (vector instanceof UInt4Vector uInt4Vector) {
uInt4Vector.set(0, (int) data);
uInt4Vector.set(0, ((UInteger) data).intValue());
return;
}
if (vector instanceof UInt8Vector uInt8Vector) {
uInt8Vector.set(0, (long) data);
uInt8Vector.set(0, ((ULong) data).longValue());
return;
}
if (vector instanceof VarBinaryVector varBinaryVector) {
Expand All @@ -132,16 +159,14 @@ private static void setVectorData(FieldVector vector, Object data) {
vectorCharVector.set(0, (Text) data);
return;
}
if (vector instanceof UUIDVector uuidVector) {
uuidVector.set(0, (java.util.UUID) data);
return;
}
// CloudQuery-specific
if (vector instanceof JSONVector jsonVector) {
jsonVector.setSafe(0, (byte[]) data);
return;
}
if (vector instanceof DateDayVector dayDateVector) {
dayDateVector.set(0, (int) data);
// CloudQuery-specific
if (vector instanceof UUIDVector uuidVector) {
uuidVector.set(0, (java.util.UUID) data);
return;
}

Expand Down Expand Up @@ -177,17 +202,7 @@ public static Schema toArrowSchema(Table table) {
List<Column> columns = table.getColumns();
Field[] fields = new Field[columns.size()];
for (int i = 0; i < columns.size(); i++) {
Column column = columns.get(i);
Map<String, String> metadata = new HashMap<>();
metadata.put(CQ_EXTENSION_UNIQUE, Boolean.toString(column.isUnique()));
metadata.put(CQ_EXTENSION_PRIMARY_KEY, Boolean.toString(column.isPrimaryKey()));
metadata.put(CQ_EXTENSION_INCREMENTAL, Boolean.toString(column.isIncrementalKey()));
Field field =
new Field(
column.getName(),
new FieldType(!column.isNotNull(), column.getType(), null, metadata),
null);
fields[i] = field;
fields[i] = getField(columns.get(i));
}
Map<String, String> metadata = new HashMap<>();
metadata.put(CQ_TABLE_NAME, table.getName());
Expand All @@ -204,23 +219,21 @@ public static Schema toArrowSchema(Table table) {
return new Schema(asList(fields), metadata);
}

private static Field getField(Column column) {
Map<String, String> metadata = new HashMap<>();
metadata.put(CQ_EXTENSION_UNIQUE, Boolean.toString(column.isUnique()));
metadata.put(CQ_EXTENSION_PRIMARY_KEY, Boolean.toString(column.isPrimaryKey()));
metadata.put(CQ_EXTENSION_INCREMENTAL, Boolean.toString(column.isIncrementalKey()));
return new Field(
column.getName(),
new FieldType(!column.isNotNull(), column.getType(), null, metadata),
null);
}

public static Table fromArrowSchema(Schema schema) {
List<Column> columns = new ArrayList<>();
for (Field field : schema.getFields()) {
boolean isUnique = Objects.equals(field.getMetadata().get(CQ_EXTENSION_UNIQUE), "true");
boolean isPrimaryKey =
Objects.equals(field.getMetadata().get(CQ_EXTENSION_PRIMARY_KEY), "true");
boolean isIncrementalKey =
Objects.equals(field.getMetadata().get(CQ_EXTENSION_INCREMENTAL), "true");

columns.add(
Column.builder()
.name(field.getName())
.unique(isUnique)
.primaryKey(isPrimaryKey)
.incrementalKey(isIncrementalKey)
.type(field.getType())
.build());
columns.add(getColumn(field));
}

Map<String, String> metaData = schema.getCustomMetadata();
Expand All @@ -244,6 +257,40 @@ public static Table fromArrowSchema(Schema schema) {
return tableBuilder.build();
}

private static Column getColumn(Field field) {
boolean isUnique = Objects.equals(field.getMetadata().get(CQ_EXTENSION_UNIQUE), "true");
boolean isPrimaryKey =
Objects.equals(field.getMetadata().get(CQ_EXTENSION_PRIMARY_KEY), "true");
boolean isIncrementalKey =
Objects.equals(field.getMetadata().get(CQ_EXTENSION_INCREMENTAL), "true");

ArrowType fieldType = field.getType();
String extensionName =
field.getMetadata().get(ArrowType.ExtensionType.EXTENSION_METADATA_KEY_NAME);
String extensionMetadata =
field.getMetadata().get(ArrowType.ExtensionType.EXTENSION_METADATA_KEY_METADATA);

// We need to scan our extension types manually because of
// https://github.com/apache/arrow/issues/38891
if (JSONType.EXTENSION_NAME.equals(extensionName)
&& JSONType.INSTANCE.serialize().equals(extensionMetadata)
&& JSONType.INSTANCE.storageType().equals(fieldType)) {
fieldType = JSONType.INSTANCE;
} else if (UUIDType.EXTENSION_NAME.equals(extensionName)
&& UUIDType.INSTANCE.serialize().equals(extensionMetadata)
&& UUIDType.INSTANCE.storageType().equals(fieldType)) {
fieldType = UUIDType.INSTANCE;
}

return Column.builder()
.name(field.getName())
.unique(isUnique)
.primaryKey(isPrimaryKey)
.incrementalKey(isIncrementalKey)
.type(fieldType)
.build();
}

public static ByteString encode(Resource resource) throws IOException {
try (BufferAllocator bufferAllocator = new RootAllocator()) {
Table table = resource.getTable();
Expand Down
7 changes: 7 additions & 0 deletions lib/src/main/java/io/cloudquery/scalar/DateMilli.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.cloudquery.scalar;

import java.time.LocalDateTime;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;

Expand Down Expand Up @@ -34,6 +35,12 @@ public void setValue(Object value) throws ValidationException {
return;
}

if (value instanceof LocalDateTime localDateTime) {
// we actually store only date
this.value = localDateTime.toLocalDate().toEpochDay();
return;
}

throw new ValidationException(
ValidationException.NO_CONVERSION_AVAILABLE, this.dataType(), value);
}
Expand Down
5 changes: 5 additions & 0 deletions lib/src/main/java/io/cloudquery/scalar/Number.java
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ protected void setValue(Object value) throws ValidationException {
return;
}

if (value instanceof Character character) {
this.value = UShort.valueOf(character);
return;
}

throw new ValidationException(
ValidationException.NO_CONVERSION_AVAILABLE, this.dataType(), value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private static ArrowType transformArrowType(String name, Class<?> type)
return Timestamp.dt;
}
case "java.util.UUID" -> {
return new UUIDType();
return UUIDType.INSTANCE;
}
default -> {
if (type.isArray()) {
Expand Down
1 change: 1 addition & 0 deletions lib/src/main/java/io/cloudquery/types/UUIDType.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.arrow.vector.types.pojo.FieldType;

public class UUIDType extends ExtensionType {
public static final UUIDType INSTANCE = new UUIDType();
public static final int BYTE_WIDTH = 16;
public static final String EXTENSION_NAME = "uuid";

Expand Down

0 comments on commit d59f91a

Please sign in to comment.