Skip to content

Commit

Permalink
Pulsar SQL support for Decimal data type (apache#15153)
Browse files Browse the repository at this point in the history
(cherry picked from commit 6b004ed)
(cherry picked from commit 642159c)
  • Loading branch information
shibd authored and nicoloboschi committed May 9, 2022
1 parent ebb0421 commit 2482228
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 4 deletions.
Expand Up @@ -40,6 +40,8 @@
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.DateType;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.Decimals;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.IntegerType;
import io.prestosql.spi.type.MapType;
Expand All @@ -53,6 +55,7 @@
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarbinaryType;
import io.prestosql.spi.type.VarcharType;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -139,7 +142,7 @@ private boolean isSupportedType(Type type) {
}

private boolean isSupportedPrimitive(Type type) {
return type instanceof VarcharType || SUPPORTED_PRIMITIVE_TYPES.contains(type);
return type instanceof VarcharType || type instanceof DecimalType || SUPPORTED_PRIMITIVE_TYPES.contains(type);
}

public FieldValueProvider decodeField(GenericRecord avroRecord) {
Expand Down Expand Up @@ -205,6 +208,13 @@ public long getLong() {
return floatToIntBits((Float) value);
}

if (columnType instanceof DecimalType) {
ByteBuffer buffer = (ByteBuffer) value;
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
return new BigInteger(bytes).longValue();
}

throw new PrestoException(DECODER_CONVERSION_NOT_SUPPORTED,
format("cannot decode object of '%s' as '%s' for column '%s'",
value.getClass(), columnType, columnName));
Expand Down Expand Up @@ -234,6 +244,13 @@ private static Slice getSlice(Object value, Type type, String columnName) {
}
}

// The returned Slice size must be equals to 18 Byte
if (type instanceof DecimalType) {
ByteBuffer buffer = (ByteBuffer) value;
BigInteger bigInteger = new BigInteger(buffer.array());
return Decimals.encodeUnscaledValue(bigInteger);
}

throw new PrestoException(DECODER_CONVERSION_NOT_SUPPORTED,
format("cannot decode object of '%s' as '%s' for column '%s'",
value.getClass(), type, columnName));
Expand Down
Expand Up @@ -33,6 +33,7 @@
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.IntegerType;
import io.prestosql.spi.type.RealType;
Expand Down Expand Up @@ -128,7 +129,14 @@ private Type parseAvroPrestoType(String fieldname, Schema schema) {
+ "please check the schema or report the bug.", fieldname));
case FIXED:
case BYTES:
//TODO: support decimal logicalType
// When the precision <= 0, throw Exception.
// When the precision > 0 and <= 18, use ShortDecimalType. and mapping Long
// When the precision > 18 and <= 36, use LongDecimalType. and mapping Slice
// When the precision > 36, throw Exception.
if (logicalType instanceof LogicalTypes.Decimal) {
LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType;
return DecimalType.createDecimalType(decimal.getPrecision(), decimal.getScale());
}
return VarbinaryType.VARBINARY;
case INT:
if (logicalType == LogicalTypes.timeMillis()) {
Expand Down
Expand Up @@ -128,6 +128,12 @@ private Type parseJsonPrestoType(String fieldname, Schema schema) {
+ "please check the schema or report the bug.", fieldname));
case FIXED:
case BYTES:
// In the current implementation, since JsonSchema is generated by Avro,
// there may exist LogicalTypes.Decimal.
// Mapping decimalType with varcharType in JsonSchema.
if (logicalType instanceof LogicalTypes.Decimal) {
return createUnboundedVarcharType();
}
return VarbinaryType.VARBINARY;
case INT:
if (logicalType == LogicalTypes.timeMillis()) {
Expand Down
Expand Up @@ -25,6 +25,7 @@
import io.prestosql.spi.connector.ConnectorContext;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.testing.TestingConnectorContext;
import java.math.BigDecimal;
import org.apache.bookkeeper.mledger.AsyncCallbacks;
import org.apache.bookkeeper.mledger.Entry;
import org.apache.bookkeeper.mledger.ManagedLedgerConfig;
Expand Down Expand Up @@ -166,6 +167,8 @@ public enum TestEnum {
public int time;
@org.apache.avro.reflect.AvroSchema("{ \"type\": \"int\", \"logicalType\": \"date\" }")
public int date;
@org.apache.avro.reflect.AvroSchema("{ \"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 4, \"scale\": 2 }")
public BigDecimal decimal;
public TestPulsarConnector.Bar bar;
public TestEnum field7;
}
Expand Down Expand Up @@ -253,6 +256,7 @@ public static class Bar {
fooFieldNames.add("date");
fooFieldNames.add("bar");
fooFieldNames.add("field7");
fooFieldNames.add("decimal");


ConnectorContext prestoConnectorContext = new TestingConnectorContext();
Expand Down Expand Up @@ -313,6 +317,7 @@ public static class Bar {
LocalDate epoch = LocalDate.ofEpochDay(0);
return Math.toIntExact(ChronoUnit.DAYS.between(epoch, localDate));
});
fooFunctions.put("decimal", integer -> BigDecimal.valueOf(1234, 2));
fooFunctions.put("bar.field1", integer -> integer % 3 == 0 ? null : integer + 1);
fooFunctions.put("bar.field2", integer -> integer % 2 == 0 ? null : String.valueOf(integer + 2));
fooFunctions.put("bar.field3", integer -> integer + 3.0f);
Expand All @@ -331,7 +336,6 @@ public static class Bar {
* @param schemaInfo
* @param handleKeyValueType
* @param includeInternalColumn
* @param dispatchingRowDecoderFactory
* @return
*/
protected static List<PulsarColumnHandle> getColumnColumnHandles(TopicName topicName, SchemaInfo schemaInfo,
Expand Down Expand Up @@ -393,6 +397,7 @@ private static List<Entry> getTopicEntries(String topicSchemaName) {
LocalDate localDate = LocalDate.now();
LocalDate epoch = LocalDate.ofEpochDay(0);
foo.date = Math.toIntExact(ChronoUnit.DAYS.between(epoch, localDate));
foo.decimal= BigDecimal.valueOf(count, 2);

MessageMetadata messageMetadata = new MessageMetadata()
.setProducerName("test-producer").setSequenceId(i)
Expand Down Expand Up @@ -609,6 +614,7 @@ public void run() {
foo.timestamp = (long) fooFunctions.get("timestamp").apply(count);
foo.time = (int) fooFunctions.get("time").apply(count);
foo.date = (int) fooFunctions.get("date").apply(count);
foo.decimal = (BigDecimal) fooFunctions.get("decimal").apply(count);
foo.bar = bar;
foo.field7 = (Foo.TestEnum) fooFunctions.get("field7").apply(count);

Expand Down
Expand Up @@ -22,7 +22,11 @@
import io.airlift.log.Logger;
import io.netty.buffer.ByteBuf;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import java.math.BigDecimal;
import lombok.Data;
import org.apache.bookkeeper.mledger.AsyncCallbacks;
import org.apache.bookkeeper.mledger.Entry;
Expand Down Expand Up @@ -142,6 +146,17 @@ public void testTopics() throws Exception {
}else if (fooColumnHandles.get(i).getName().equals("field7")) {
assertEquals(pulsarRecordCursor.getSlice(i).getBytes(), fooFunctions.get("field7").apply(count).toString().getBytes());
columnsSeen.add(fooColumnHandles.get(i).getName());
}else if (fooColumnHandles.get(i).getName().equals("decimal")) {
Type type = fooColumnHandles.get(i).getType();
// In JsonDecoder, decimal trans to varcharType
if (type instanceof VarcharType) {
assertEquals(new String(pulsarRecordCursor.getSlice(i).getBytes()),
fooFunctions.get("decimal").apply(count).toString());
} else {
DecimalType decimalType = (DecimalType) fooColumnHandles.get(i).getType();
assertEquals(BigDecimal.valueOf(pulsarRecordCursor.getLong(i), decimalType.getScale()), fooFunctions.get("decimal").apply(count));
}
columnsSeen.add(fooColumnHandles.get(i).getName());
} else {
if (PulsarInternalColumn.getInternalFieldsMap().containsKey(fooColumnHandles.get(i).getName())) {
columnsSeen.add(fooColumnHandles.get(i).getName());
Expand Down
Expand Up @@ -26,6 +26,7 @@
import io.prestosql.spi.connector.ConnectorContext;
import io.prestosql.spi.type.Type;
import io.prestosql.testing.TestingConnectorContext;
import java.math.BigDecimal;
import org.apache.pulsar.common.naming.NamespaceName;
import org.apache.pulsar.common.naming.TopicName;
import org.apache.pulsar.common.schema.SchemaInfo;
Expand Down Expand Up @@ -102,6 +103,10 @@ protected void checkValue(Map<DecoderColumnHandle, FieldValueProvider> decodedRo
decoderTestUtil.checkValue(decodedRow, handle, value);
}

protected void checkValue(Map<DecoderColumnHandle, FieldValueProvider> decodedRow, DecoderColumnHandle handle, BigDecimal value) {
decoderTestUtil.checkValue(decodedRow, handle, value);
}

protected Block getBlock(Map<DecoderColumnHandle, FieldValueProvider> decodedRow, DecoderColumnHandle handle) {
FieldValueProvider provider = decodedRow.get(handle);
assertNotNull(provider);
Expand Down
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pulsar.sql.presto.decoder;

import java.math.BigDecimal;
import lombok.Data;

import java.util.List;
Expand Down Expand Up @@ -45,6 +46,10 @@ public static enum TestEnum {
public int dateField;
public TestRow rowField;
public TestEnum enumField;
@org.apache.avro.reflect.AvroSchema("{ \"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 4, \"scale\": 2 }")
public BigDecimal decimalField;
@org.apache.avro.reflect.AvroSchema("{ \"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 30, \"scale\": 2 }")
public BigDecimal longDecimalField;

public List<String> arrayField;
public Map<String, Long> mapField;
Expand All @@ -62,7 +67,6 @@ public static class NestedRow {
public long longField;
}


public static class CompositeRow {
public String stringField;
public List<NestedRow> arrayField;
Expand Down
Expand Up @@ -23,11 +23,16 @@
import io.prestosql.decoder.FieldValueProvider;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.Decimals;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Map;

import static io.prestosql.spi.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;
import static io.prestosql.testing.TestingConnectorSession.SESSION;
import static org.testng.Assert.*;

Expand Down Expand Up @@ -113,6 +118,21 @@ public void checkValue(Map<DecoderColumnHandle, FieldValueProvider> decodedRow,
assertEquals(provider.getBoolean(), value);
}

public void checkValue(Map<DecoderColumnHandle, FieldValueProvider> decodedRow, DecoderColumnHandle handle, BigDecimal value) {
FieldValueProvider provider = decodedRow.get(handle);
DecimalType decimalType = (DecimalType) handle.getType();
BigDecimal actualDecimal;
if (decimalType.getFixedSize() == UNSCALED_DECIMAL_128_SLICE_LENGTH) {
Slice slice = provider.getSlice();
BigInteger bigInteger = Decimals.decodeUnscaledValue(slice);
actualDecimal = new BigDecimal(bigInteger, decimalType.getScale());
} else {
actualDecimal = BigDecimal.valueOf(provider.getLong(), decimalType.getScale());
}
assertNotNull(provider);
assertEquals(actualDecimal, value);
}

public void checkIsNull(Map<DecoderColumnHandle, FieldValueProvider> decodedRow, DecoderColumnHandle handle) {
FieldValueProvider provider = decodedRow.get(handle);
assertNotNull(provider);
Expand Down
Expand Up @@ -25,11 +25,13 @@
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.StandardTypes;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignatureParameter;
import io.prestosql.spi.type.VarcharType;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -87,6 +89,8 @@ public void testPrimitiveType() {
message.longField = 222L;
message.timestampField = System.currentTimeMillis();
message.enumField = DecoderTestMessage.TestEnum.TEST_ENUM_1;
message.decimalField = BigDecimal.valueOf(2233, 2);
message.longDecimalField = new BigDecimal("1234567891234567891234567891.23");

LocalTime now = LocalTime.now(ZoneId.systemDefault());
message.timeField = now.toSecondOfDay() * 1000;
Expand Down Expand Up @@ -127,6 +131,13 @@ public void testPrimitiveType() {
"enumField", VARCHAR, false, false, "enumField", null, null, PulsarColumnHandle.HandleKeyValueType.NONE);
checkValue(decodedRow, enumFieldColumnHandle, message.enumField.toString());

PulsarColumnHandle decimalFieldColumnHandle = new PulsarColumnHandle(getPulsarConnectorId().toString(),
"decimalField", DecimalType.createDecimalType(4, 2), false, false, "decimalField", null, null, PulsarColumnHandle.HandleKeyValueType.NONE);
checkValue(decodedRow, decimalFieldColumnHandle, message.decimalField);

PulsarColumnHandle longDecimalFieldColumnHandle = new PulsarColumnHandle(getPulsarConnectorId().toString(),
"longDecimalField", DecimalType.createDecimalType(30, 2), false, false, "longDecimalField", null, null, PulsarColumnHandle.HandleKeyValueType.NONE);
checkValue(decodedRow, longDecimalFieldColumnHandle, message.longDecimalField);
}

@Test
Expand Down

0 comments on commit 2482228

Please sign in to comment.