diff --git a/cloudsql-postgresql-plugin/docs/CloudSQLPostgreSQL-batchsink.md b/cloudsql-postgresql-plugin/docs/CloudSQLPostgreSQL-batchsink.md index 079d5df32..87f85e0b0 100644 --- a/cloudsql-postgresql-plugin/docs/CloudSQLPostgreSQL-batchsink.md +++ b/cloudsql-postgresql-plugin/docs/CloudSQLPostgreSQL-batchsink.md @@ -148,6 +148,7 @@ Please, refer to PostgreSQL data types documentation to figure out proper format | double precision | double | | | integer | int | | | numeric(precision, scale)/decimal(precision, scale) | decimal | | +| numeric(with 0 precision) | string | | | real | float | | | smallint | int | | | text | string | | diff --git a/cloudsql-postgresql-plugin/docs/CloudSQLPostgreSQL-batchsource.md b/cloudsql-postgresql-plugin/docs/CloudSQLPostgreSQL-batchsource.md index 3c3bd989e..e3175f79e 100644 --- a/cloudsql-postgresql-plugin/docs/CloudSQLPostgreSQL-batchsource.md +++ b/cloudsql-postgresql-plugin/docs/CloudSQLPostgreSQL-batchsource.md @@ -172,6 +172,7 @@ Please, refer to PostgreSQL data types documentation to figure out proper format | double precision | double | | | integer | int | | | numeric(precision, scale)/decimal(precision, scale) | decimal | | +| numeric(with 0 precision) | string | | | real | float | | | smallint | int | | | smallserial | int | | diff --git a/postgresql-plugin/docs/Postgres-batchsink.md b/postgresql-plugin/docs/Postgres-batchsink.md index 9e1e8404f..b8a996463 100644 --- a/postgresql-plugin/docs/Postgres-batchsink.md +++ b/postgresql-plugin/docs/Postgres-batchsink.md @@ -79,6 +79,7 @@ Please, refer to PostgreSQL data types documentation to figure out proper format | double precision | double | | | integer | int | | | numeric(precision, scale)/decimal(precision, scale) | decimal | | +| numeric(with 0 precision) | string | | | real | float | | | smallint | int | | | text | string | | diff --git a/postgresql-plugin/docs/Postgres-batchsource.md b/postgresql-plugin/docs/Postgres-batchsource.md index 8bd018baf..af359022d 100644 --- a/postgresql-plugin/docs/Postgres-batchsource.md +++ b/postgresql-plugin/docs/Postgres-batchsource.md @@ -110,6 +110,7 @@ Please, refer to PostgreSQL data types documentation to figure out proper format | double precision | double | | | integer | int | | | numeric(precision, scale)/decimal(precision, scale) | decimal | | +| numeric(with 0 precision) | string | | | real | float | | | smallint | int | | | smallserial | int | | diff --git a/postgresql-plugin/pom.xml b/postgresql-plugin/pom.xml index 49f6893f8..f1abfe858 100644 --- a/postgresql-plugin/pom.xml +++ b/postgresql-plugin/pom.xml @@ -53,10 +53,6 @@ 42.2.20 test - - org.mockito - mockito-core - io.cdap.plugin database-commons @@ -72,9 +68,15 @@ io.cdap.cdap cdap-data-pipeline3_2.12 + + org.mockito + mockito-core + test + junit junit + test io.cdap.cdap diff --git a/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresDBRecord.java b/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresDBRecord.java index dde6bdcdd..629eefad5 100644 --- a/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresDBRecord.java +++ b/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresDBRecord.java @@ -24,10 +24,13 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.Types; import java.util.List; /** @@ -49,24 +52,41 @@ public PostgresDBRecord() { @Override protected void handleField(ResultSet resultSet, StructuredRecord.Builder recordBuilder, Schema.Field field, int columnIndex, int sqlType, int sqlPrecision, int sqlScale) throws SQLException { + ResultSetMetaData metadata = resultSet.getMetaData(); if (isUseSchema(resultSet.getMetaData(), columnIndex)) { setFieldAccordingToSchema(resultSet, recordBuilder, field, columnIndex); - } else { - setField(resultSet, recordBuilder, field, columnIndex, sqlType, sqlPrecision, sqlScale); + return; } + int columnType = metadata.getColumnType(columnIndex); + if (columnType == Types.NUMERIC) { + Schema nonNullableSchema = field.getSchema().isNullable() ? + field.getSchema().getNonNullable() : field.getSchema(); + int precision = metadata.getPrecision(columnIndex); + if (precision == 0 && Schema.Type.STRING.equals(nonNullableSchema.getType())) { + // When output schema is set to String for precision less numbers + recordBuilder.set(field.getName(), resultSet.getString(columnIndex)); + return; + } + BigDecimal orgValue = resultSet.getBigDecimal(columnIndex); + if (Schema.LogicalType.DECIMAL.equals(nonNullableSchema.getLogicalType()) && orgValue != null) { + BigDecimal decimalValue = new BigDecimal(orgValue.toPlainString()) + .setScale(nonNullableSchema.getScale(), RoundingMode.HALF_EVEN); + recordBuilder.setDecimal(field.getName(), decimalValue); + return; + } + } + setField(resultSet, recordBuilder, field, columnIndex, sqlType, sqlPrecision, sqlScale); } private static boolean isUseSchema(ResultSetMetaData metadata, int columnIndex) throws SQLException { - switch (metadata.getColumnTypeName(columnIndex)) { - case "bit": - case "timetz": - case "money": - return true; - default: - return PostgresSchemaReader.STRING_MAPPED_POSTGRES_TYPES.contains(metadata.getColumnType(columnIndex)); - } + String columnTypeName = metadata.getColumnTypeName(columnIndex); + // If the column Type Name is present in the String mapped PostgreSQL types then return true. + return (PostgresSchemaReader.STRING_MAPPED_POSTGRES_TYPES_NAMES.contains(columnTypeName) + || PostgresSchemaReader.STRING_MAPPED_POSTGRES_TYPES.contains(metadata.getColumnType(columnIndex))); + } + private Object createPGobject(String type, String value, ClassLoader classLoader) throws SQLException { try { Class pGObjectClass = classLoader.loadClass("org.postgresql.util.PGobject"); @@ -89,11 +109,17 @@ protected void writeToDB(PreparedStatement stmt, Schema.Field field, int fieldIn if (PostgresSchemaReader.STRING_MAPPED_POSTGRES_TYPES_NAMES.contains(columnType.getTypeName()) || PostgresSchemaReader.STRING_MAPPED_POSTGRES_TYPES.contains(columnType.getType())) { stmt.setObject(sqlIndex, createPGobject(columnType.getTypeName(), - record.get(field.getName()), - stmt.getClass().getClassLoader())); - } else { - super.writeToDB(stmt, field, fieldIndex); + record.get(field.getName()), + stmt.getClass().getClassLoader())); + return; + } + if (columnType.getType() == Types.NUMERIC && record.get(field.getName()) != null && + field.getSchema().getType() == Schema.Type.STRING) { + stmt.setBigDecimal(sqlIndex, new BigDecimal((String) record.get(field.getName()))); + return; } + + super.writeToDB(stmt, field, fieldIndex); } @Override diff --git a/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresFieldsValidator.java b/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresFieldsValidator.java index b3b8cac62..8b81a7937 100644 --- a/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresFieldsValidator.java +++ b/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresFieldsValidator.java @@ -21,6 +21,7 @@ import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.Types; import java.util.Objects; /** @@ -45,6 +46,13 @@ public boolean isFieldCompatible(Schema.Field field, ResultSetMetaData metadata, return false; } } + // Since Numeric types without precision and scale are getting converted into CDAP String type at the Source + // plugin, hence making the String type compatible with the Numeric type at the Sink as well. + if (fieldType.equals(Schema.Type.STRING)) { + if (Types.NUMERIC == columnType) { + return true; + } + } return super.isFieldCompatible(field, metadata, index); } diff --git a/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresSchemaReader.java b/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresSchemaReader.java index 685f4ffc6..ca69057ae 100644 --- a/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresSchemaReader.java +++ b/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresSchemaReader.java @@ -20,6 +20,9 @@ import io.cdap.cdap.api.data.schema.Schema; import io.cdap.plugin.db.CommonSchemaReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Types; @@ -30,6 +33,8 @@ */ public class PostgresSchemaReader extends CommonSchemaReader { + private static final Logger LOG = LoggerFactory.getLogger(PostgresSchemaReader.class); + public static final Set STRING_MAPPED_POSTGRES_TYPES = ImmutableSet.of( Types.OTHER, Types.ARRAY, Types.SQLXML ); @@ -57,6 +62,17 @@ public Schema getSchema(ResultSetMetaData metadata, int index) throws SQLExcepti if (STRING_MAPPED_POSTGRES_TYPES_NAMES.contains(typeName) || STRING_MAPPED_POSTGRES_TYPES.contains(columnType)) { return Schema.of(Schema.Type.STRING); } + // If it is a numeric type without precision then use the Schema of String to avoid any precision loss + if (Types.NUMERIC == columnType) { + int precision = metadata.getPrecision(index); + if (precision == 0) { + LOG.warn(String.format("Field '%s' is a %s type without precision and scale, " + + "converting into STRING type to avoid any precision loss.", + metadata.getColumnName(index), + metadata.getColumnTypeName(index))); + return Schema.of(Schema.Type.STRING); + } + } return super.getSchema(metadata, index); } diff --git a/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresSource.java b/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresSource.java index 3be7cdb30..73e738330 100644 --- a/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresSource.java +++ b/postgresql-plugin/src/main/java/io/cdap/plugin/postgres/PostgresSource.java @@ -23,6 +23,7 @@ import io.cdap.cdap.api.annotation.MetadataProperty; import io.cdap.cdap.api.annotation.Name; import io.cdap.cdap.api.annotation.Plugin; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.cdap.etl.api.FailureCollector; import io.cdap.cdap.etl.api.batch.BatchSource; import io.cdap.cdap.etl.api.batch.BatchSourceContext; @@ -137,5 +138,22 @@ public void validate(FailureCollector collector) { ConfigUtil.validateConnection(this, useConnection, connection, collector); super.validate(collector); } + + @Override + protected void validateField(FailureCollector collector, Schema.Field field, Schema actualFieldSchema, + Schema expectedFieldSchema) { + + // This change is needed to make sure that the pipeline upgrade continues to work post upgrade. + // Since the older handling of the precision less used to convert to the decimal type, + // and the new version would try to convert to the String type. In that case the output schema would + // contain Decimal(38, 0) (or something similar), and the code internally would try to identify + // the schema of the field(without precision and scale) as String. + if (Schema.LogicalType.DECIMAL.equals(expectedFieldSchema.getLogicalType()) && + actualFieldSchema.getType().equals(Schema.Type.STRING)) { + return; + } + super.validateField(collector, field, actualFieldSchema, expectedFieldSchema); + } } } + diff --git a/postgresql-plugin/src/test/java/io/cdap/plugin/postgres/PostgresDBRecordUnitTest.java b/postgresql-plugin/src/test/java/io/cdap/plugin/postgres/PostgresDBRecordUnitTest.java new file mode 100644 index 000000000..53a8795b3 --- /dev/null +++ b/postgresql-plugin/src/test/java/io/cdap/plugin/postgres/PostgresDBRecordUnitTest.java @@ -0,0 +1,81 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.cdap.plugin.postgres; + +import io.cdap.cdap.api.data.format.StructuredRecord; +import io.cdap.cdap.api.data.schema.Schema; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.math.BigDecimal; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Types; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + + +@RunWith(MockitoJUnitRunner.class) + public class PostgresDBRecordUnitTest { + + private static final int DEFAULT_PRECISION = 38; + + /** + * Validate the precision less Numbers handling against following use cases. + * 1. Ensure that the numeric type with [p,s] set as [38,4] detect as BigDecimal(38,4) in cdap. + * 2. Ensure that the numeric type without [p,s] detect as String type in cdap. + * @throws Exception + */ + @Test + public void validatePrecisionLessDecimalParsing() throws Exception { + Schema.Field field1 = Schema.Field.of("ID1", Schema.decimalOf(DEFAULT_PRECISION, 4)); + Schema.Field field2 = Schema.Field.of("ID2", Schema.of(Schema.Type.STRING)); + + Schema schema = Schema.recordOf( + "dbRecord", + field1, + field2 + ); + + ResultSetMetaData resultSetMetaData = Mockito.mock(ResultSetMetaData.class); + when(resultSetMetaData.getColumnType(eq(1))).thenReturn(Types.NUMERIC); + when(resultSetMetaData.getPrecision(eq(1))).thenReturn(DEFAULT_PRECISION); + when(resultSetMetaData.getColumnType(eq(2))).thenReturn(Types.NUMERIC); + when(resultSetMetaData.getPrecision(eq(2))).thenReturn(0); + + ResultSet resultSet = Mockito.mock(ResultSet.class); + + when(resultSet.getMetaData()).thenReturn(resultSetMetaData); + when(resultSet.getBigDecimal(eq(1))).thenReturn(BigDecimal.valueOf(123.4568)); + when(resultSet.getString(eq(2))).thenReturn("123.4568"); + + StructuredRecord.Builder builder = StructuredRecord.builder(schema); + PostgresDBRecord dbRecord = new PostgresDBRecord(null, null); + dbRecord.handleField(resultSet, builder, field1, 1, Types.NUMERIC, DEFAULT_PRECISION, 4); + dbRecord.handleField(resultSet, builder, field2, 2, Types.NUMERIC, 0, -127); + + StructuredRecord record = builder.build(); + Assert.assertTrue(record.getDecimal("ID1") instanceof BigDecimal); + Assert.assertEquals(record.getDecimal("ID1"), BigDecimal.valueOf(123.4568)); + Assert.assertTrue(record.get("ID2") instanceof String); + Assert.assertEquals(record.get("ID2"), "123.4568"); + } +}