From ab9b8e65894d74527e7ae6c9f60ea791483e8df1 Mon Sep 17 00:00:00 2001 From: Rahul Sharma Date: Thu, 23 Feb 2023 17:48:20 +0530 Subject: [PATCH] Fixing the ResultSet look up logic to look against the column name instead of field index from Schema. Removed the buggy logic of index in Oracle plugin --- .../main/java/io/cdap/plugin/db/DBRecord.java | 3 +- .../plugin/db/batch/sink/AbstractDBSink.java | 39 ++++-- .../db/batch/sink/AbstractDBSinkTest.java | 127 +++++++++++++++++- .../plugin/oracle/OracleSourceDBRecord.java | 22 +-- 4 files changed, 163 insertions(+), 28 deletions(-) diff --git a/database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java b/database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java index 7592a5fdf..edaeb9d1c 100644 --- a/database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java +++ b/database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java @@ -104,7 +104,8 @@ public void readFields(ResultSet resultSet) throws SQLException { StructuredRecord.Builder recordBuilder = StructuredRecord.builder(schema); for (int i = 0; i < schema.getFields().size(); i++) { Schema.Field field = schema.getFields().get(i); - int columnIndex = i + 1; + // Find the field index in the resultSet having the same name + int columnIndex = resultSet.findColumn(field.getName()); int sqlType = metadata.getColumnType(columnIndex); int sqlPrecision = metadata.getPrecision(columnIndex); int sqlScale = metadata.getScale(columnIndex); diff --git a/database-commons/src/main/java/io/cdap/plugin/db/batch/sink/AbstractDBSink.java b/database-commons/src/main/java/io/cdap/plugin/db/batch/sink/AbstractDBSink.java index fadd78b78..5322cc093 100644 --- a/database-commons/src/main/java/io/cdap/plugin/db/batch/sink/AbstractDBSink.java +++ b/database-commons/src/main/java/io/cdap/plugin/db/batch/sink/AbstractDBSink.java @@ -64,7 +64,9 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Properties; @@ -276,8 +278,7 @@ private void setResultSetMetadata() throws Exception { ResultSet rs = statement.executeQuery(String.format("SELECT %s FROM %s WHERE 1 = 0", dbColumns, fullyQualifiedTableName)) ) { - ResultSetMetaData resultSetMetadata = rs.getMetaData(); - columnTypes.addAll(getMatchedColumnTypeList(resultSetMetadata, columns)); + columnTypes.addAll(getMatchedColumnTypeList(rs, columns)); } } @@ -287,22 +288,36 @@ private void setResultSetMetadata() throws Exception { /** * Compare columns from schema with columns in table and returns list of matched columns in {@link ColumnType} format. * - * @param resultSetMetadata result set metadata from table. + * @param resultSet result set from table. * @param columns list of columns from schema. * @return list of matched columns. */ - static List getMatchedColumnTypeList(ResultSetMetaData resultSetMetadata, List columns) + static List getMatchedColumnTypeList(ResultSet resultSet, List columns) throws SQLException { List columnTypes = new ArrayList<>(columns.size()); - // JDBC driver column indices start with 1 - for (int i = 0; i < resultSetMetadata.getColumnCount(); i++) { - String name = resultSetMetadata.getColumnName(i + 1); - String columnTypeName = resultSetMetadata.getColumnTypeName(i + 1); - int type = resultSetMetadata.getColumnType(i + 1); + ResultSetMetaData resultSetMetadata = resultSet.getMetaData(); + Map resultSetColumnNames = new HashMap<>(resultSetMetadata.getColumnCount()); + + // Populate the ResultSet field names in lower case vs original names + // JDBC driver column indices start with index 1 + for (int i = 1; i <= resultSetMetadata.getColumnCount(); i++) { + resultSetColumnNames.put(resultSetMetadata.getColumnName(i).toLowerCase(), resultSetMetadata.getColumnName(i)); + } + + // Iterate of all the columns present in the output schema and + // check if the resultSet contains a column with the same name. + for (int i = 0; i < columns.size(); i++) { String schemaColumnName = columns.get(i); - Preconditions.checkArgument(schemaColumnName.toLowerCase().equals(name.toLowerCase()), - "Missing column '%s' in SQL table", schemaColumnName); - columnTypes.add(new ColumnType(schemaColumnName, columnTypeName, type)); + String schemaColName = schemaColumnName.toLowerCase(); + Preconditions.checkArgument(resultSetColumnNames.keySet().contains(schemaColName), + "Missing column '%s' in SQL table", schemaColumnName); + + // Find the column in the resultSet, as the index in the schema might not match with the resultSet. + int columnIndex = resultSet.findColumn(resultSetColumnNames.get(schemaColName)); + String name = resultSetMetadata.getColumnName(columnIndex); + String columnTypeName = resultSetMetadata.getColumnTypeName(columnIndex); + int type = resultSetMetadata.getColumnType(columnIndex); + columnTypes.add(new ColumnType(name, columnTypeName, type)); } return columnTypes; } diff --git a/database-commons/src/test/java/io/cdap/plugin/db/batch/sink/AbstractDBSinkTest.java b/database-commons/src/test/java/io/cdap/plugin/db/batch/sink/AbstractDBSinkTest.java index 1b78fe4e3..a3be3f777 100644 --- a/database-commons/src/test/java/io/cdap/plugin/db/batch/sink/AbstractDBSinkTest.java +++ b/database-commons/src/test/java/io/cdap/plugin/db/batch/sink/AbstractDBSinkTest.java @@ -17,6 +17,7 @@ package io.cdap.plugin.db.batch.sink; import com.google.common.collect.ImmutableList; +import com.mockrunner.mock.jdbc.MockResultSet; import com.mockrunner.mock.jdbc.MockResultSetMetaData; import io.cdap.plugin.db.ColumnType; import org.junit.Assert; @@ -24,7 +25,9 @@ import java.sql.SQLException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; /** * Test class for abstract sink. @@ -32,7 +35,7 @@ public class AbstractDBSinkTest { @Test - public void testGetMatchedColumnTypeList() throws SQLException { + public void testGetMatchedColumnTypeList() throws Exception { List columns = ImmutableList.of( "ID", "NAME", @@ -45,20 +48,25 @@ public void testGetMatchedColumnTypeList() throws SQLException { resultSetMetaData.setColumnCount(columns.size()); for (int i = 0; i < columns.size(); i++) { - String name = columns.get(i); resultSetMetaData.setColumnName(i + 1, columns.get(i)); resultSetMetaData.setColumnTypeName(i + 1, "STRING"); resultSetMetaData.setColumnType(i + 1, i); - expectedColumns.add(new ColumnType(name, "STRING", i)); + expectedColumns.add(new ColumnType(columns.get(i), "STRING", i)); } - List result = AbstractDBSink.getMatchedColumnTypeList(resultSetMetaData, columns); + MockResultSet resultSet = new MockResultSet("data"); + Set columnNamesSet = new HashSet(); + columnNamesSet.addAll(columns); + resultSet.addColumns(columnNamesSet); + resultSet.setResultSetMetaData(resultSetMetaData); + + List result = AbstractDBSink.getMatchedColumnTypeList(resultSet, columns); Assert.assertEquals(expectedColumns, result); } @Test - public void testGetMismatchColumnTypeList() throws SQLException { + public void testGetMismatchColumnTypeList() throws Exception { List wrongColumns = ImmutableList.of( "MY_ID", "NAME", @@ -80,12 +88,119 @@ public void testGetMismatchColumnTypeList() throws SQLException { resultSetMetaData.setColumnType(i + 1, i); } + MockResultSet resultSet = new MockResultSet("data"); + Set columnNamesSet = new HashSet(); + columnNamesSet.addAll(columns); + resultSet.addColumns(columnNamesSet); + resultSet.setResultSetMetaData(resultSetMetaData); + try { - AbstractDBSink.getMatchedColumnTypeList(resultSetMetaData, wrongColumns); + AbstractDBSink.getMatchedColumnTypeList(resultSet, wrongColumns); Assert.fail(String.format("Expected to throw %s", IllegalArgumentException.class.getName())); } catch (IllegalArgumentException e) { String errorMessage = "Missing column 'MY_ID' in SQL table"; Assert.assertEquals(errorMessage, e.getMessage()); } } + + @Test + public void testDifferentOrderOfFieldsInResultSet() throws Exception { + List diffOrdCol = ImmutableList.of( + "Name", + "SCORE", + "ID" + ); + + List columns = ImmutableList.of( + "ID", + "NAME", + "SCORE" + ); + + List typeName = ImmutableList.of( + "INT", + "STRING", + "DOUBLE" + ); + + List typeValue = ImmutableList.of( + 1, + 2, + 3 + ); + + List expectedColumns = new ArrayList<>(); + MockResultSetMetaData resultSetMetaData = new MockResultSetMetaData(); + resultSetMetaData.setColumnCount(columns.size()); + + for (int i = 0; i < columns.size(); i++) { + resultSetMetaData.setColumnName(i + 1, columns.get(i)); + resultSetMetaData.setColumnTypeName(i + 1, typeName.get(i)); + resultSetMetaData.setColumnType(i + 1, typeValue.get(i)); + expectedColumns.add(new ColumnType(columns.get(i), typeName.get(i), typeValue.get(i))); + } + + MockResultSet resultSet = new MockResultSet("data"); + Set columnNamesSet = new HashSet(); + columnNamesSet.addAll(columns); + resultSet.addColumns(columnNamesSet); + resultSet.setResultSetMetaData(resultSetMetaData); + + List actualColumns = AbstractDBSink.getMatchedColumnTypeList(resultSet, diffOrdCol); + + // Assert that all expected fields are present in the actual fields + for (ColumnType exColType : expectedColumns) { + Assert.assertTrue(actualColumns.contains(exColType)); + } + } + + @Test + public void testSubsetColumnsInResultSet() throws Exception { + List subsetCol = ImmutableList.of( + "SCORE", + "ID" + ); + + List columns = ImmutableList.of( + "ID", + "NAME", + "SCORE" + ); + + List typeName = ImmutableList.of( + "INT", + "STRING", + "DOUBLE" + ); + + List typeValue = ImmutableList.of( + 1, + 2, + 3 + ); + + List expectedColumns = new ArrayList<>(); + MockResultSetMetaData resultSetMetaData = new MockResultSetMetaData(); + resultSetMetaData.setColumnCount(columns.size()); + + for (int i = 0; i < columns.size(); i++) { + resultSetMetaData.setColumnName(i + 1, columns.get(i)); + resultSetMetaData.setColumnTypeName(i + 1, typeName.get(i)); + resultSetMetaData.setColumnType(i + 1, typeValue.get(i)); + expectedColumns.add(new ColumnType(columns.get(i), typeName.get(i), typeValue.get(i))); + } + + MockResultSet resultSet = new MockResultSet("data"); + Set columnNamesSet = new HashSet(); + columnNamesSet.addAll(columns); + resultSet.addColumns(columnNamesSet); + resultSet.setResultSetMetaData(resultSetMetaData); + + List actualColumns = AbstractDBSink.getMatchedColumnTypeList(resultSet, subsetCol); + + // Assert that all actual fields are present in the expected fields + for (ColumnType acColType : actualColumns) { + Assert.assertTrue(expectedColumns.contains(acColType)); + } + } } diff --git a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java index 942f879ac..89807e21e 100644 --- a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java +++ b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java @@ -80,16 +80,22 @@ public void readFields(ResultSet resultSet) throws SQLException { // All LONG or LONG RAW columns have to be retrieved from the ResultSet prior to all the other columns. // Otherwise, we will face java.sql.SQLException: Stream has already been closed - for (int i = 0; i < schema.getFields().size(); i++) { - if (isLongOrLongRaw(metadata.getColumnType(i + 1))) { - readField(i, metadata, resultSet, schema, recordBuilder); + for (Schema.Field field : schema.getFields()) { + // Index of a field in the schema may not be same in the ResultSet, + // hence find the field by name in the given resultSet + int columnIndex = resultSet.findColumn(field.getName()); + if (isLongOrLongRaw(metadata.getColumnType(columnIndex))) { + readField(columnIndex, metadata, resultSet, field, recordBuilder); } } // Read fields of other types - for (int i = 0; i < schema.getFields().size(); i++) { - if (!isLongOrLongRaw(metadata.getColumnType(i + 1))) { - readField(i, metadata, resultSet, schema, recordBuilder); + for (Schema.Field field : schema.getFields()) { + // Index of a field in the schema may not be same in the ResultSet, + // hence find the field by name in the given resultSet + int columnIndex = resultSet.findColumn(field.getName()); + if (!isLongOrLongRaw(metadata.getColumnType(columnIndex))) { + readField(columnIndex, metadata, resultSet, field, recordBuilder); } } @@ -242,10 +248,8 @@ private boolean isLongOrLongRaw(int columnType) { return columnType == OracleSourceSchemaReader.LONG || columnType == OracleSourceSchemaReader.LONG_RAW; } - private void readField(int index, ResultSetMetaData metadata, ResultSet resultSet, Schema schema, + private void readField(int columnIndex, ResultSetMetaData metadata, ResultSet resultSet, Schema.Field field, StructuredRecord.Builder recordBuilder) throws SQLException { - Schema.Field field = schema.getFields().get(index); - int columnIndex = index + 1; int sqlType = metadata.getColumnType(columnIndex); int sqlPrecision = metadata.getPrecision(columnIndex); int sqlScale = metadata.getScale(columnIndex);