From f2b9a139fe31c52ca88104fa6c336cda02cff771 Mon Sep 17 00:00:00 2001 From: Aydar Zaynutdinov Date: Thu, 2 Sep 2021 15:43:41 +0300 Subject: [PATCH] [BEAM-11873] Add support for writes with returning values in JdbcIO --- .../sdk/io/common/DatabaseTestHelper.java | 26 ++ .../org/apache/beam/sdk/io/jdbc/JdbcIO.java | 275 +++++++++++++++++- .../beam/sdk/io/jdbc/JdbcWriteResult.java | 28 ++ .../org/apache/beam/sdk/io/jdbc/JdbcIOIT.java | 56 ++++ .../apache/beam/sdk/io/jdbc/JdbcIOTest.java | 67 +++-- .../beam/sdk/io/jdbc/JdbcTestHelper.java | 57 ++++ 6 files changed, 477 insertions(+), 32 deletions(-) create mode 100644 sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteResult.java diff --git a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java index 56b72303fc6b6..46d9aa97017b7 100644 --- a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java +++ b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java @@ -17,14 +17,18 @@ */ package org.apache.beam.sdk.io.common; +import static org.junit.Assert.assertEquals; + import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.text.SimpleDateFormat; +import java.util.ArrayList; import java.util.Date; import java.util.Optional; import javax.sql.DataSource; +import org.apache.beam.sdk.values.KV; import org.postgresql.ds.PGSimpleDataSource; /** This class contains helper methods to ease database usage in tests. */ @@ -104,4 +108,26 @@ public static void createTableWithStatement(DataSource dataSource, String stmt) } } } + + public static ArrayList> getTestDataToWrite(long rowsToAdd) { + ArrayList> data = new ArrayList<>(); + for (int i = 0; i < rowsToAdd; i++) { + KV kv = KV.of(i, "Test"); + data.add(kv); + } + return data; + } + + public static void assertRowCount(DataSource dataSource, String tableName, int expectedRowCount) + throws SQLException { + try (Connection connection = dataSource.getConnection()) { + try (Statement statement = connection.createStatement()) { + try (ResultSet resultSet = statement.executeQuery("select count(*) from " + tableName)) { + resultSet.next(); + int count = resultSet.getInt(1); + assertEquals(expectedRowCount, count); + } + } + } + } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index b3c10df7ecf61..2cab8b41f69fd 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -342,7 +342,7 @@ public static ReadWithPartitions readWithPartitions() { * @param Type of the data to be written. */ public static Write write() { - return new Write(); + return new Write<>(); } public static WriteVoid writeVoid() { @@ -1283,43 +1283,43 @@ public static class Write extends PTransform, PDone> { /** See {@link WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */ public Write withDataSourceConfiguration(DataSourceConfiguration config) { - return new Write(inner.withDataSourceConfiguration(config)); + return new Write<>(inner.withDataSourceConfiguration(config)); } /** See {@link WriteVoid#withDataSourceProviderFn(SerializableFunction)}. */ public Write withDataSourceProviderFn( SerializableFunction dataSourceProviderFn) { - return new Write(inner.withDataSourceProviderFn(dataSourceProviderFn)); + return new Write<>(inner.withDataSourceProviderFn(dataSourceProviderFn)); } /** See {@link WriteVoid#withStatement(String)}. */ public Write withStatement(String statement) { - return new Write(inner.withStatement(statement)); + return new Write<>(inner.withStatement(statement)); } /** See {@link WriteVoid#withPreparedStatementSetter(PreparedStatementSetter)}. */ public Write withPreparedStatementSetter(PreparedStatementSetter setter) { - return new Write(inner.withPreparedStatementSetter(setter)); + return new Write<>(inner.withPreparedStatementSetter(setter)); } /** See {@link WriteVoid#withBatchSize(long)}. */ public Write withBatchSize(long batchSize) { - return new Write(inner.withBatchSize(batchSize)); + return new Write<>(inner.withBatchSize(batchSize)); } /** See {@link WriteVoid#withRetryStrategy(RetryStrategy)}. */ public Write withRetryStrategy(RetryStrategy retryStrategy) { - return new Write(inner.withRetryStrategy(retryStrategy)); + return new Write<>(inner.withRetryStrategy(retryStrategy)); } /** See {@link WriteVoid#withRetryConfiguration(RetryConfiguration)}. */ public Write withRetryConfiguration(RetryConfiguration retryConfiguration) { - return new Write(inner.withRetryConfiguration(retryConfiguration)); + return new Write<>(inner.withRetryConfiguration(retryConfiguration)); } /** See {@link WriteVoid#withTable(String)}. */ public Write withTable(String table) { - return new Write(inner.withTable(table)); + return new Write<>(inner.withTable(table)); } /** @@ -1341,6 +1341,24 @@ public WriteVoid withResults() { return inner; } + /** + * Returns {@link WriteWithResults} transform that could return a specific result. + * + *

See {@link WriteWithResults} + */ + public WriteWithResults withWriteResults( + RowMapper rowMapper) { + return new AutoValue_JdbcIO_WriteWithResults.Builder() + .setRowMapper(rowMapper) + .setRetryStrategy(inner.getRetryStrategy()) + .setRetryConfiguration(inner.getRetryConfiguration()) + .setDataSourceProviderFn(inner.getDataSourceProviderFn()) + .setPreparedStatementSetter(inner.getPreparedStatementSetter()) + .setStatement(inner.getStatement()) + .setTable(inner.getTable()) + .build(); + } + @Override public void populateDisplayData(DisplayData.Builder builder) { inner.populateDisplayData(builder); @@ -1364,7 +1382,244 @@ void set( throws SQLException; } - /** A {@link PTransform} to write to a JDBC datasource. */ + /** + * A {@link PTransform} to write to a JDBC datasource. Executes statements one by one. + * + *

The INSERT, UPDATE, and DELETE commands sometimes have an optional RETURNING clause that + * supports obtaining data from modified rows while they are being manipulated. Output {@link + * PCollection} of this transform is a collection of such returning results mapped by {@link + * RowMapper}. + */ + @AutoValue + public abstract static class WriteWithResults + extends PTransform, PCollection> { + abstract @Nullable SerializableFunction getDataSourceProviderFn(); + + abstract @Nullable ValueProvider getStatement(); + + abstract @Nullable PreparedStatementSetter getPreparedStatementSetter(); + + abstract @Nullable RetryStrategy getRetryStrategy(); + + abstract @Nullable RetryConfiguration getRetryConfiguration(); + + abstract @Nullable String getTable(); + + abstract @Nullable RowMapper getRowMapper(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setDataSourceProviderFn( + SerializableFunction dataSourceProviderFn); + + abstract Builder setStatement(ValueProvider statement); + + abstract Builder setPreparedStatementSetter(PreparedStatementSetter setter); + + abstract Builder setRetryStrategy(RetryStrategy deadlockPredicate); + + abstract Builder setRetryConfiguration(RetryConfiguration retryConfiguration); + + abstract Builder setTable(String table); + + abstract Builder setRowMapper(RowMapper rowMapper); + + abstract WriteWithResults build(); + } + + public WriteWithResults withDataSourceConfiguration(DataSourceConfiguration config) { + return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config)); + } + + public WriteWithResults withDataSourceProviderFn( + SerializableFunction dataSourceProviderFn) { + return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build(); + } + + public WriteWithResults withStatement(String statement) { + return withStatement(ValueProvider.StaticValueProvider.of(statement)); + } + + public WriteWithResults withStatement(ValueProvider statement) { + return toBuilder().setStatement(statement).build(); + } + + public WriteWithResults withPreparedStatementSetter(PreparedStatementSetter setter) { + return toBuilder().setPreparedStatementSetter(setter).build(); + } + + /** + * When a SQL exception occurs, {@link Write} uses this {@link RetryStrategy} to determine if it + * will retry the statements. If {@link RetryStrategy#apply(SQLException)} returns {@code true}, + * then {@link Write} retries the statements. + */ + public WriteWithResults withRetryStrategy(RetryStrategy retryStrategy) { + checkArgument(retryStrategy != null, "retryStrategy can not be null"); + return toBuilder().setRetryStrategy(retryStrategy).build(); + } + + /** + * When a SQL exception occurs, {@link Write} uses this {@link RetryConfiguration} to + * exponentially back off and retry the statements based on the {@link RetryConfiguration} + * mentioned. + * + *

Usage of RetryConfiguration - + * + *

{@code
+     * pipeline.apply(JdbcIO.write())
+     *    .withReturningResults(...)
+     *    .withDataSourceConfiguration(...)
+     *    .withRetryStrategy(...)
+     *    .withRetryConfiguration(JdbcIO.RetryConfiguration.
+     *        create(5, Duration.standardSeconds(5), Duration.standardSeconds(1))
+     *
+     * }
+ * + * maxDuration and initialDuration are Nullable + * + *
{@code
+     * pipeline.apply(JdbcIO.write())
+     *    .withReturningResults(...)
+     *    .withDataSourceConfiguration(...)
+     *    .withRetryStrategy(...)
+     *    .withRetryConfiguration(JdbcIO.RetryConfiguration.
+     *        create(5, null, null)
+     *
+     * }
+ */ + public WriteWithResults withRetryConfiguration(RetryConfiguration retryConfiguration) { + checkArgument(retryConfiguration != null, "retryConfiguration can not be null"); + return toBuilder().setRetryConfiguration(retryConfiguration).build(); + } + + public WriteWithResults withTable(String table) { + checkArgument(table != null, "table name can not be null"); + return toBuilder().setTable(table).build(); + } + + public WriteWithResults withRowMapper(RowMapper rowMapper) { + checkArgument(rowMapper != null, "result set getter can not be null"); + return toBuilder().setRowMapper(rowMapper).build(); + } + + @Override + public PCollection expand(PCollection input) { + checkArgument(getStatement() != null, "withStatement() is required"); + checkArgument( + getPreparedStatementSetter() != null, "withPreparedStatementSetter() is required"); + checkArgument( + (getDataSourceProviderFn() != null), + "withDataSourceConfiguration() or withDataSourceProviderFn() is required"); + + return input.apply(ParDo.of(new WriteWithResultsFn<>(this))); + } + + private static class WriteWithResultsFn extends DoFn { + + private final WriteWithResults spec; + private DataSource dataSource; + private Connection connection; + private PreparedStatement preparedStatement; + private static FluentBackoff retryBackOff; + + public WriteWithResultsFn(WriteWithResults spec) { + this.spec = spec; + } + + @Setup + public void setup() { + dataSource = spec.getDataSourceProviderFn().apply(null); + RetryConfiguration retryConfiguration = spec.getRetryConfiguration(); + + retryBackOff = + FluentBackoff.DEFAULT + .withInitialBackoff(retryConfiguration.getInitialDuration()) + .withMaxCumulativeBackoff(retryConfiguration.getMaxDuration()) + .withMaxRetries(retryConfiguration.getMaxAttempts()); + } + + @ProcessElement + public void processElement(ProcessContext context) throws Exception { + T record = context.element(); + + // Only acquire the connection if there is something to write. + if (connection == null) { + connection = dataSource.getConnection(); + connection.setAutoCommit(false); + preparedStatement = connection.prepareStatement(spec.getStatement().get()); + } + Sleeper sleeper = Sleeper.DEFAULT; + BackOff backoff = retryBackOff.backoff(); + while (true) { + try (PreparedStatement preparedStatement = + connection.prepareStatement(spec.getStatement().get())) { + try { + + try { + spec.getPreparedStatementSetter().setParameters(record, preparedStatement); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // execute the statement + preparedStatement.execute(); + // commit the changes + connection.commit(); + context.output(spec.getRowMapper().mapRow(preparedStatement.getResultSet())); + return; + } catch (SQLException exception) { + if (!spec.getRetryStrategy().apply(exception)) { + throw exception; + } + LOG.warn("Deadlock detected, retrying", exception); + connection.rollback(); + if (!BackOffUtils.next(sleeper, backoff)) { + // we tried the max number of times + throw exception; + } + } + } + } + } + + @FinishBundle + public void finishBundle() throws Exception { + cleanUpStatementAndConnection(); + } + + @Override + protected void finalize() throws Throwable { + cleanUpStatementAndConnection(); + } + + private void cleanUpStatementAndConnection() throws Exception { + try { + if (preparedStatement != null) { + try { + preparedStatement.close(); + } finally { + preparedStatement = null; + } + } + } finally { + if (connection != null) { + try { + connection.close(); + } finally { + connection = null; + } + } + } + } + } + } + + /** + * A {@link PTransform} to write to a JDBC datasource. Executes statements in a batch, and returns + * a trivial result. + */ @AutoValue public abstract static class WriteVoid extends PTransform, PCollection> { diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteResult.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteResult.java new file mode 100644 index 0000000000000..3117c2459ba27 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteResult.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import com.google.auto.value.AutoValue; + +/** The result of writing a row to JDBC datasource. */ +@AutoValue +public abstract class JdbcWriteResult { + public static JdbcWriteResult create() { + return new AutoValue_JdbcWriteResult(); + } +} diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java index 6c800029f37ff..45eb09418a025 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java @@ -17,11 +17,14 @@ */ package org.apache.beam.sdk.io.jdbc; +import static org.apache.beam.sdk.io.common.DatabaseTestHelper.assertRowCount; +import static org.apache.beam.sdk.io.common.DatabaseTestHelper.getTestDataToWrite; import static org.apache.beam.sdk.io.common.IOITHelper.executeWithRetry; import static org.apache.beam.sdk.io.common.IOITHelper.readIOTestPipelineOptions; import com.google.cloud.Timestamp; import java.sql.SQLException; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -44,8 +47,10 @@ import org.apache.beam.sdk.testutils.publishing.InfluxDBSettings; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Top; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -79,6 +84,7 @@ @RunWith(JUnit4.class) public class JdbcIOIT { + private static final int EXPECTED_ROW_COUNT = 1000; private static final String NAMESPACE = JdbcIOIT.class.getName(); private static int numberOfRows; private static PGSimpleDataSource dataSource; @@ -255,4 +261,54 @@ private PipelineResult runRead() { return pipelineRead.run(); } + + @Test + public void testWriteWithWriteResults() throws Exception { + String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + DatabaseTestHelper.createTable(dataSource, firstTableName); + try { + ArrayList> data = getTestDataToWrite(EXPECTED_ROW_COUNT); + + PCollection> dataCollection = pipelineWrite.apply(Create.of(data)); + PCollection resultSetCollection = + dataCollection.apply( + getJdbcWriteWithReturning(firstTableName) + .withWriteResults( + (resultSet -> { + if (resultSet != null && resultSet.next()) { + return new JdbcTestHelper.TestDto(resultSet.getInt(1)); + } + return new JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT); + }))); + resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER); + + List expectedResult = new ArrayList<>(); + for (int id = 0; id < EXPECTED_ROW_COUNT; id++) { + expectedResult.add(new JdbcTestHelper.TestDto(id)); + } + + PAssert.that(resultSetCollection).containsInAnyOrder(expectedResult); + + pipelineWrite.run(); + + assertRowCount(dataSource, firstTableName, EXPECTED_ROW_COUNT); + } finally { + DatabaseTestHelper.deleteTable(dataSource, firstTableName); + } + } + + /** + * @return {@link JdbcIO.Write} transform that writes to {@param tableName} Postgres table and + * returns all fields of modified rows. + */ + private static JdbcIO.Write> getJdbcWriteWithReturning(String tableName) { + return JdbcIO.>write() + .withDataSourceProviderFn(voidInput -> dataSource) + .withStatement(String.format("insert into %s values(?, ?) returning *", tableName)) + .withPreparedStatementSetter( + (element, statement) -> { + statement.setInt(1, element.getKey()); + statement.setString(2, element.getValue()); + }); + } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index 30b6c484a375c..31ec663573c13 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.jdbc; import static java.sql.JDBCType.NUMERIC; +import static org.apache.beam.sdk.io.common.DatabaseTestHelper.assertRowCount; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; @@ -42,7 +43,6 @@ import java.sql.Date; import java.sql.JDBCType; import java.sql.PreparedStatement; -import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.sql.Time; @@ -468,12 +468,47 @@ public void testWrite() throws Exception { pipeline.run(); - assertRowCount(tableName, EXPECTED_ROW_COUNT); + assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT); } finally { DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName); } } + @Test + public void testWriteWithWriteResults() throws Exception { + String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + DatabaseTestHelper.createTable(DATA_SOURCE, firstTableName); + try { + ArrayList> data = getDataToWrite(EXPECTED_ROW_COUNT); + + PCollection> dataCollection = pipeline.apply(Create.of(data)); + PCollection resultSetCollection = + dataCollection.apply( + getJdbcWrite(firstTableName) + .withWriteResults( + (resultSet -> { + if (resultSet != null && resultSet.next()) { + return new JdbcTestHelper.TestDto(resultSet.getInt(1)); + } + return new JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT); + }))); + resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER); + + List expectedResult = new ArrayList<>(); + for (int i = 0; i < EXPECTED_ROW_COUNT; i++) { + expectedResult.add(new JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT)); + } + + PAssert.that(resultSetCollection).containsInAnyOrder(expectedResult); + + pipeline.run(); + + assertRowCount(DATA_SOURCE, firstTableName, EXPECTED_ROW_COUNT); + } finally { + DatabaseTestHelper.deleteTable(DATA_SOURCE, firstTableName); + } + } + @Test public void testWriteWithResultsAndWaitOn() throws Exception { String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); @@ -490,8 +525,8 @@ public void testWriteWithResultsAndWaitOn() throws Exception { pipeline.run(); - assertRowCount(firstTableName, EXPECTED_ROW_COUNT); - assertRowCount(secondTableName, EXPECTED_ROW_COUNT); + assertRowCount(DATA_SOURCE, firstTableName, EXPECTED_ROW_COUNT); + assertRowCount(DATA_SOURCE, secondTableName, EXPECTED_ROW_COUNT); } finally { DatabaseTestHelper.deleteTable(DATA_SOURCE, firstTableName); } @@ -525,18 +560,6 @@ private static ArrayList> getDataToWrite(long rowsToAdd) { return data; } - private static void assertRowCount(String tableName, int expectedRowCount) throws SQLException { - try (Connection connection = DATA_SOURCE.getConnection()) { - try (Statement statement = connection.createStatement()) { - try (ResultSet resultSet = statement.executeQuery("select count(*) from " + tableName)) { - resultSet.next(); - int count = resultSet.getInt(1); - assertEquals(expectedRowCount, count); - } - } - } - } - @Test public void testWriteWithBackoff() throws Exception { String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE_BACKOFF"); @@ -593,7 +616,7 @@ public void testWriteWithBackoff() throws Exception { // we verify that the backoff has been called thanks to the log message expectedLogs.verifyWarn("Deadlock detected, retrying"); - assertRowCount(tableName, 2); + assertRowCount(DATA_SOURCE, tableName, 2); } @Test @@ -645,7 +668,7 @@ public void testWriteWithoutPreparedStatement() throws Exception { .withBatchSize(10L) .withTable(tableName)); pipeline.run(); - assertRowCount(tableName, rowsToAdd); + assertRowCount(DATA_SOURCE, tableName, rowsToAdd); } finally { DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName); } @@ -728,7 +751,7 @@ public void testWriteWithoutPreparedStatementAndNonRowType() throws Exception { .withBatchSize(10L) .withTable(tableName)); pipeline.run(); - assertRowCount(tableName, rowsToAdd); + assertRowCount(DATA_SOURCE, tableName, rowsToAdd); } finally { DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName); } @@ -1027,7 +1050,7 @@ protected boolean matchesSafely(Iterable logRecords) { }); // Since the pipeline was unable to write, only the row from insertStatement was written. - assertRowCount(tableName, 1); + assertRowCount(DATA_SOURCE, tableName, 1); } @Test @@ -1064,8 +1087,8 @@ public void testWriteRowsResultsAndWaitOn() throws Exception { pipeline.run(); - assertRowCount(firstTableName, EXPECTED_ROW_COUNT); - assertRowCount(secondTableName, EXPECTED_ROW_COUNT); + assertRowCount(DATA_SOURCE, firstTableName, EXPECTED_ROW_COUNT); + assertRowCount(DATA_SOURCE, secondTableName, EXPECTED_ROW_COUNT); } finally { DatabaseTestHelper.deleteTable(DATA_SOURCE, firstTableName); DatabaseTestHelper.deleteTable(DATA_SOURCE, secondTableName); diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestHelper.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestHelper.java index e929a5b9dd17d..081f8af68b489 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestHelper.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcTestHelper.java @@ -17,9 +17,18 @@ */ package org.apache.beam.sdk.io.jdbc; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.Objects; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CustomCoder; import org.apache.beam.sdk.io.common.TestRow; /** @@ -28,6 +37,54 @@ */ class JdbcTestHelper { + public static class TestDto extends JdbcWriteResult implements Serializable { + + public static final int EMPTY_RESULT = 0; + + private int simpleField; + + public TestDto(int simpleField) { + this.simpleField = simpleField; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestDto testDto = (TestDto) o; + return simpleField == testDto.simpleField; + } + + @Override + public int hashCode() { + return Objects.hash(simpleField); + } + } + + public static final Coder TEST_DTO_CODER = + new CustomCoder() { + @Override + public void encode(TestDto value, OutputStream outStream) + throws CoderException, IOException { + BigEndianIntegerCoder.of().encode(value.simpleField, outStream); + } + + @Override + public TestDto decode(InputStream inStream) throws CoderException, IOException { + int simpleField = BigEndianIntegerCoder.of().decode(inStream); + return new TestDto(simpleField); + } + + @Override + public Object structuralValue(TestDto v) { + return v; + } + }; + static class CreateTestRowOfNameAndId implements JdbcIO.RowMapper { @Override public TestRow mapRow(ResultSet resultSet) throws Exception {