diff --git a/build.gradle.kts b/build.gradle.kts index d8f454c7..258b9d86 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -98,8 +98,6 @@ tasks.withType().configureEach { options.errorprone { disableWarningsInGeneratedCode.set(true) option("NullAway:AnnotatedPackages", "com.mongodb.hibernate") - option("NullAway:ExcludedFieldAnnotations", "org.mockito.Mock") - option("NullAway:ExcludedFieldAnnotations", "org.mockito.InjectMocks") } } tasks.compileJava { @@ -107,6 +105,10 @@ tasks.compileJava { options.errorprone.error("NullAway") } +tasks.compileTestJava { + options.errorprone.isEnabled.set(false) +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Build Config diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 2a995711..d73135d4 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -10,7 +10,7 @@ google-errorprone-core = "2.36.0" nullaway = "0.12.2" jspecify = "1.0.0" hibernate-core = "6.6.4.Final" -mongo-java-driver-sync = "5.2.1" +mongo-java-driver-sync = "5.3.0" slf4j-api = "2.0.16" logback-classic = "1.5.15" mockito = "5.14.2" diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java new file mode 100644 index 00000000..29408061 --- /dev/null +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementIntegrationTests.java @@ -0,0 +1,215 @@ +/* + * Copyright 2024-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import static com.mongodb.hibernate.internal.MongoAssertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Function; +import org.bson.BsonDocument; +import org.hibernate.Session; +import org.hibernate.SessionFactory; +import org.hibernate.cfg.Configuration; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +class MongoPreparedStatementIntegrationTests { + + private static @Nullable SessionFactory sessionFactory; + + private @Nullable Session session; + + @BeforeAll + static void beforeAll() { + sessionFactory = new Configuration().buildSessionFactory(); + } + + @AfterAll + static void afterAll() { + if (sessionFactory != null) { + sessionFactory.close(); + } + } + + @BeforeEach + void setUp() { + session = assertNotNull(sessionFactory).openSession(); + } + + @AfterEach + void tearDown() { + if (session != null) { + session.close(); + } + } + + @Nested + class ExecuteUpdateTests { + + @BeforeEach + void setUp() { + assertNotNull(session).doWork(conn -> { + conn.createStatement() + .executeUpdate( + """ + { + delete: "books", + deletes: [ + { q: {}, limit: 0 } + ] + }"""); + }); + } + + private static final String INSERT_MQL = + """ + { + insert: "books", + documents: [ + { + _id: 1, + title: "War and Peace", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy" ] + }, + { + _id: 2, + title: "Anna Karenina", + author: "Leo Tolstoy", + outOfStock: false, + tags: [ "classic", "tolstoy" ] + }, + { + _id: 3, + title: "Crime and Punishment", + author: "Fyodor Dostoevsky", + outOfStock: false, + tags: [ "classic", "Dostoevsky", "literature" ] + } + ] + }"""; + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testUpdate(boolean autoCommit) { + // given + prepareData(); + + // when && then + var expectedDocs = Set.of( + BsonDocument.parse( + """ + { + _id: 1, + title: "War and Peace", + author: "Leo Tolstoy", + outOfStock: true, + tags: [ "classic", "tolstoy", "literature" ] + }"""), + BsonDocument.parse( + """ + { + _id: 2, + title: "Anna Karenina", + author: "Leo Tolstoy", + outOfStock: true, + tags: [ "classic", "tolstoy", "literature" ] + }"""), + BsonDocument.parse( + """ + { + _id: 3, + title: "Crime and Punishment", + author: "Fyodor Dostoevsky", + outOfStock: false, + tags: [ "classic", "Dostoevsky", "literature" ] + }""")); + Function pstmtProvider = connection -> { + try { + var pstmt = (MongoPreparedStatement) + connection.prepareStatement( + """ + { + update: "books", + updates: [ + { + q: { author: { $undefined: true } }, + u: { + $set: { + outOfStock: { $undefined: true } + }, + $push: { tags: { $undefined: true } } + }, + multi: true + } + ] + }"""); + pstmt.setString(1, "Leo Tolstoy"); + pstmt.setBoolean(2, true); + pstmt.setString(3, "literature"); + return pstmt; + } catch (SQLException e) { + throw new RuntimeException(e); + } + }; + assertExecuteUpdate(pstmtProvider, autoCommit, 2, expectedDocs); + } + + private void prepareData() { + assertNotNull(session).doWork(connection -> { + connection.setAutoCommit(true); + var statement = connection.createStatement(); + statement.executeUpdate(INSERT_MQL); + }); + } + + private void assertExecuteUpdate( + Function pstmtProvider, + boolean autoCommit, + int expectedUpdatedRowCount, + Set expectedDocuments) { + assertNotNull(session).doWork(connection -> { + connection.setAutoCommit(autoCommit); + try (var pstmt = pstmtProvider.apply(connection)) { + try { + assertEquals(expectedUpdatedRowCount, pstmt.executeUpdate()); + } finally { + if (!autoCommit) { + connection.commit(); + } + } + var realDocuments = pstmt.getMongoDatabase() + .getCollection("books", BsonDocument.class) + .find() + .into(new HashSet<>()); + assertEquals(expectedDocuments, realDocuments); + } + }); + } + } +} diff --git a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementIntegrationTests.java b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementIntegrationTests.java index 0da7fbf7..c3e870c0 100644 --- a/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementIntegrationTests.java +++ b/src/integrationTest/java/com/mongodb/hibernate/jdbc/MongoStatementIntegrationTests.java @@ -238,20 +238,20 @@ private void assertExecuteUpdate( String mql, boolean autoCommit, int expectedRowCount, Set expectedDocuments) { assertNotNull(session).doWork(connection -> { connection.setAutoCommit(autoCommit); - var statement = (MongoStatement) connection.createStatement(); - try { - assertEquals(expectedRowCount, statement.executeUpdate(mql)); - } finally { - if (!autoCommit) { - connection.commit(); + try (var stmt = (MongoStatement) connection.createStatement()) { + try { + assertEquals(expectedRowCount, stmt.executeUpdate(mql)); + } finally { + if (!autoCommit) { + connection.commit(); + } } + var realDocuments = stmt.getMongoDatabase() + .getCollection("books", BsonDocument.class) + .find() + .into(new HashSet<>()); + assertEquals(expectedDocuments, realDocuments); } - var realDocuments = statement - .getMongoDatabase() - .getCollection("books", BsonDocument.class) - .find() - .into(new HashSet<>()); - assertEquals(expectedDocuments, realDocuments); }); } } diff --git a/src/main/java/com/mongodb/hibernate/internal/MongoAssertions.java b/src/main/java/com/mongodb/hibernate/internal/MongoAssertions.java index d50173ab..e4235225 100644 --- a/src/main/java/com/mongodb/hibernate/internal/MongoAssertions.java +++ b/src/main/java/com/mongodb/hibernate/internal/MongoAssertions.java @@ -41,4 +41,16 @@ public static T assertNotNull(@Nullable T value) throws AssertionError { } return value; } + + /** + * Asserts that failure happens invariably. + * + * @param msg The failure message. + * @return Never completes normally. The return type is {@link AssertionError} to allow writing {@code throw + * fail("failure message")}. This may be helpful in non-{@code void} methods. + * @throws AssertionError Always + */ + public static AssertionError fail(String msg) throws AssertionError { + throw new AssertionError(assertNotNull(msg)); + } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/ConnectionAdapter.java b/src/main/java/com/mongodb/hibernate/jdbc/ConnectionAdapter.java index ada9a776..e23751bb 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/ConnectionAdapter.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/ConnectionAdapter.java @@ -39,283 +39,283 @@ import org.jspecify.annotations.Nullable; /** - * A {@link java.sql.Connection} implementation class that throws exceptions for all its API methods. + * A {@link java.sql.Connection} adapter interface that throws exceptions for all its API methods. * * @see MongoConnection */ -abstract class ConnectionAdapter implements Connection { +interface ConnectionAdapter extends Connection { @Override - public Statement createStatement() throws SQLException { + default Statement createStatement() throws SQLException { throw new SQLFeatureNotSupportedException("createStatement not implemented"); } @Override - public PreparedStatement prepareStatement(String sql) throws SQLException { + default PreparedStatement prepareStatement(String sql) throws SQLException { throw new SQLFeatureNotSupportedException("prepareStatement not implemented"); } @Override - public CallableStatement prepareCall(String sql) throws SQLException { + default CallableStatement prepareCall(String sql) throws SQLException { throw new SQLFeatureNotSupportedException("prepareCall not implemented"); } @Override - public String nativeSQL(String sql) throws SQLException { + default String nativeSQL(String sql) throws SQLException { throw new SQLFeatureNotSupportedException("nativeSQL not implemented"); } @Override - public void setAutoCommit(boolean autoCommit) throws SQLException { + default void setAutoCommit(boolean autoCommit) throws SQLException { throw new SQLFeatureNotSupportedException("setAutoCommit not implemented"); } @Override - public boolean getAutoCommit() throws SQLException { + default boolean getAutoCommit() throws SQLException { throw new SQLFeatureNotSupportedException("getAutoCommit not implemented"); } @Override - public void commit() throws SQLException { + default void commit() throws SQLException { throw new SQLFeatureNotSupportedException("commit not implemented"); } @Override - public void rollback() throws SQLException { + default void rollback() throws SQLException { throw new SQLFeatureNotSupportedException("rollback not implemented"); } @Override - public void close() throws SQLException { + default void close() throws SQLException { throw new SQLFeatureNotSupportedException("close not implemented"); } @Override - public boolean isClosed() throws SQLException { + default boolean isClosed() throws SQLException { throw new SQLFeatureNotSupportedException("isClosed not implemented"); } @Override - public DatabaseMetaData getMetaData() throws SQLException { + default DatabaseMetaData getMetaData() throws SQLException { throw new SQLFeatureNotSupportedException("geetMetaData not implemented"); } @Override - public void setReadOnly(boolean readOnly) throws SQLException { + default void setReadOnly(boolean readOnly) throws SQLException { throw new SQLFeatureNotSupportedException("setReadOnly not implemented"); } @Override - public boolean isReadOnly() throws SQLException { + default boolean isReadOnly() throws SQLException { throw new SQLFeatureNotSupportedException("isReadOnly not implemented"); } @Override - public void setCatalog(String catalog) throws SQLException { + default void setCatalog(String catalog) throws SQLException { throw new SQLFeatureNotSupportedException("setCatalog not implemented"); } @Override - public @Nullable String getCatalog() throws SQLException { + default @Nullable String getCatalog() throws SQLException { throw new SQLFeatureNotSupportedException("getCatalog not implemented"); } @Override - public void setTransactionIsolation(int level) throws SQLException { + default void setTransactionIsolation(int level) throws SQLException { throw new SQLFeatureNotSupportedException("setTransactionIsolation not implemented"); } @Override - public int getTransactionIsolation() throws SQLException { + default int getTransactionIsolation() throws SQLException { throw new SQLFeatureNotSupportedException("getTransactionIsolation not implemented"); } @Override - public @Nullable SQLWarning getWarnings() throws SQLException { + default @Nullable SQLWarning getWarnings() throws SQLException { throw new SQLFeatureNotSupportedException("getWarnings not implemented"); } @Override - public void clearWarnings() throws SQLException { + default void clearWarnings() throws SQLException { throw new SQLFeatureNotSupportedException("clearWarnings not implemented"); } @Override - public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { + default Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { throw new SQLFeatureNotSupportedException("createStatement not implemented"); } @Override - public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) + default PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { throw new SQLFeatureNotSupportedException("prepareStatement not implemented"); } @Override - public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + default CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { throw new SQLFeatureNotSupportedException("prepareCall not implemented"); } @Override - public Map> getTypeMap() throws SQLException { + default Map> getTypeMap() throws SQLException { throw new SQLFeatureNotSupportedException("getTypeMap not implemented"); } @Override - public void setTypeMap(Map> map) throws SQLException { + default void setTypeMap(Map> map) throws SQLException { throw new SQLFeatureNotSupportedException("setTypeMap not implemented"); } @Override - public void setHoldability(int holdability) throws SQLException { + default void setHoldability(int holdability) throws SQLException { throw new SQLFeatureNotSupportedException("setHoldability not implemented"); } @Override - public int getHoldability() throws SQLException { + default int getHoldability() throws SQLException { throw new SQLFeatureNotSupportedException("getHoldability not implemented"); } @Override - public Savepoint setSavepoint() throws SQLException { + default Savepoint setSavepoint() throws SQLException { throw new SQLFeatureNotSupportedException("setSavepoint not implemented"); } @Override - public Savepoint setSavepoint(String name) throws SQLException { + default Savepoint setSavepoint(String name) throws SQLException { throw new SQLFeatureNotSupportedException("setSavepoint not implemented"); } @Override - public void rollback(Savepoint savepoint) throws SQLException { + default void rollback(Savepoint savepoint) throws SQLException { throw new SQLFeatureNotSupportedException("rollback not implemented"); } @Override - public void releaseSavepoint(Savepoint savepoint) throws SQLException { + default void releaseSavepoint(Savepoint savepoint) throws SQLException { throw new SQLFeatureNotSupportedException("releaseSavepoint not implemented"); } @Override - public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) + default Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { throw new SQLFeatureNotSupportedException("createStatement not implemented"); } @Override - public PreparedStatement prepareStatement( + default PreparedStatement prepareStatement( String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { throw new SQLFeatureNotSupportedException("prepareStatement not implemented"); } @Override - public CallableStatement prepareCall( + default CallableStatement prepareCall( String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { throw new SQLFeatureNotSupportedException("prepareCall not implemented"); } @Override - public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { + default PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { throw new SQLFeatureNotSupportedException("prepareStatement not implemented"); } @Override - public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { + default PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { throw new SQLFeatureNotSupportedException("prepareStatement not implemented"); } @Override - public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { + default PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { throw new SQLFeatureNotSupportedException("prepareStatement not implemented"); } @Override - public Clob createClob() throws SQLException { + default Clob createClob() throws SQLException { throw new SQLFeatureNotSupportedException("createClob not implemented"); } @Override - public Blob createBlob() throws SQLException { + default Blob createBlob() throws SQLException { throw new SQLFeatureNotSupportedException("createBlob not implemented"); } @Override - public NClob createNClob() throws SQLException { + default NClob createNClob() throws SQLException { throw new SQLFeatureNotSupportedException("createNClob not implemented"); } @Override - public SQLXML createSQLXML() throws SQLException { + default SQLXML createSQLXML() throws SQLException { throw new SQLFeatureNotSupportedException("createSQLXML not implemented"); } @Override - public boolean isValid(int timeout) throws SQLException { + default boolean isValid(int timeout) throws SQLException { throw new SQLFeatureNotSupportedException("isValid not implemented"); } @Override - public void setClientInfo(String name, String value) throws SQLClientInfoException { + default void setClientInfo(String name, String value) throws SQLClientInfoException { throw new SQLClientInfoException("setClientInfo not implemented", Collections.emptyMap()); } @Override - public void setClientInfo(Properties properties) throws SQLClientInfoException { + default void setClientInfo(Properties properties) throws SQLClientInfoException { throw new SQLClientInfoException("setClientInfo not implemented", Collections.emptyMap()); } @Override - public String getClientInfo(String name) throws SQLException { + default String getClientInfo(String name) throws SQLException { throw new SQLFeatureNotSupportedException("getClientInfo not implemented"); } @Override - public Properties getClientInfo() throws SQLException { + default Properties getClientInfo() throws SQLException { throw new SQLFeatureNotSupportedException("getClientInfo not implemented"); } @Override - public Array createArrayOf(String typeName, Object[] elements) throws SQLException { + default Array createArrayOf(String typeName, Object[] elements) throws SQLException { throw new SQLFeatureNotSupportedException("createArrayOf not implemented"); } @Override - public Struct createStruct(String typeName, Object[] attributes) throws SQLException { + default Struct createStruct(String typeName, Object[] attributes) throws SQLException { throw new SQLFeatureNotSupportedException("createStruct not implemented"); } @Override - public void setSchema(String schema) throws SQLException { + default void setSchema(String schema) throws SQLException { throw new SQLFeatureNotSupportedException("setSchema not implemented"); } @Override - public @Nullable String getSchema() throws SQLException { + default @Nullable String getSchema() throws SQLException { throw new SQLFeatureNotSupportedException("getSchema not implemented"); } @Override - public void abort(Executor executor) throws SQLException { + default void abort(Executor executor) throws SQLException { throw new SQLFeatureNotSupportedException("abort not implemented"); } @Override - public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { + default void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { throw new SQLFeatureNotSupportedException("setNetworkTimeout not implemented"); } @Override - public int getNetworkTimeout() throws SQLException { + default int getNetworkTimeout() throws SQLException { throw new SQLFeatureNotSupportedException("getNetworkTimeout not implemented"); } @Override - public T unwrap(Class iface) throws SQLException { + default T unwrap(Class iface) throws SQLException { throw new SQLFeatureNotSupportedException("unwrap not implemented"); } @Override - public boolean isWrapperFor(Class iface) throws SQLException { + default boolean isWrapperFor(Class iface) throws SQLException { throw new SQLFeatureNotSupportedException("isWrapperFor not implemented"); } } diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoConnection.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoConnection.java index 7379e48e..7027e307 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoConnection.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoConnection.java @@ -20,14 +20,10 @@ import com.mongodb.client.MongoClient; import com.mongodb.hibernate.internal.NotYetImplementedException; import java.sql.Array; -import java.sql.Blob; -import java.sql.Clob; import java.sql.DatabaseMetaData; -import java.sql.NClob; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.SQLWarning; -import java.sql.SQLXML; import java.sql.Statement; import java.sql.Struct; import org.jspecify.annotations.Nullable; @@ -36,9 +32,9 @@ * MongoDB Dialect's JDBC {@linkplain java.sql.Connection connection} implementation class. * *

It only focuses on API methods Mongo Dialect will support. All the other methods are implemented by throwing - * exceptions in its parent class. + * exceptions in its parent {@linkplain ConnectionAdapter adapter interface}. */ -final class MongoConnection extends ConnectionAdapter { +final class MongoConnection implements ConnectionAdapter { // temporary hard-coded database prior to the db config tech design finalizing public static final String DATABASE = "mongo-hibernate-test"; @@ -143,8 +139,7 @@ public Statement createStatement() throws SQLException { @Override public PreparedStatement prepareStatement(String mql) throws SQLException { checkClosed(); - throw new NotYetImplementedException( - "To be implemented in scope of https://jira.mongodb.org/browse/HIBERNATE-13"); + return new MongoPreparedStatement(mongoClient, clientSession, this, mql); } @Override @@ -152,36 +147,12 @@ public PreparedStatement prepareStatement(String mql, int resultSetType, int res throws SQLException { checkClosed(); throw new NotYetImplementedException( - "To be implemented in scope of https://jira.mongodb.org/browse/HIBERNATE-13"); + "To be implemented in scope of https://jira.mongodb.org/browse/HIBERNATE-21"); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // SQL99 data types - @Override - public Clob createClob() throws SQLException { - checkClosed(); - throw new NotYetImplementedException(); - } - - @Override - public Blob createBlob() throws SQLException { - checkClosed(); - throw new NotYetImplementedException(); - } - - @Override - public NClob createNClob() throws SQLException { - checkClosed(); - throw new NotYetImplementedException(); - } - - @Override - public SQLXML createSQLXML() throws SQLException { - checkClosed(); - throw new NotYetImplementedException(); - } - @Override public Array createArrayOf(String typeName, Object[] elements) throws SQLException { checkClosed(); @@ -246,6 +217,11 @@ public void clearWarnings() throws SQLException { checkClosed(); } + @Override + public boolean isWrapperFor(Class iface) { + return false; + } + private void checkClosed() throws SQLException { if (closed) { throw new SQLException("Connection has been closed"); diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java new file mode 100644 index 00000000..a6d4cecc --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoPreparedStatement.java @@ -0,0 +1,334 @@ +/* + * Copyright 2024-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import static com.mongodb.hibernate.internal.MongoAssertions.fail; +import static java.lang.String.format; + +import com.mongodb.client.ClientSession; +import com.mongodb.client.MongoClient; +import com.mongodb.hibernate.internal.NotYetImplementedException; +import java.io.InputStream; +import java.math.BigDecimal; +import java.sql.Array; +import java.sql.Date; +import java.sql.JDBCType; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.function.Consumer; +import org.bson.BsonArray; +import org.bson.BsonBinary; +import org.bson.BsonBoolean; +import org.bson.BsonDecimal128; +import org.bson.BsonDocument; +import org.bson.BsonDouble; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonNull; +import org.bson.BsonString; +import org.bson.BsonType; +import org.bson.BsonValue; +import org.bson.types.Decimal128; +import org.jspecify.annotations.Nullable; + +/** + * MongoDB Dialect's JDBC {@link java.sql.PreparedStatement} implementation class. + * + *

It only focuses on API methods MongoDB Dialect will support. All the other methods are implemented by throwing + * exceptions in its parent {@link PreparedStatementAdapter adapter interface}. + */ +final class MongoPreparedStatement extends MongoStatement implements PreparedStatementAdapter { + + private final BsonDocument command; + + private final List> parameterValueSetters; + + public MongoPreparedStatement( + MongoClient mongoClient, ClientSession clientSession, MongoConnection mongoConnection, String mql) { + super(mongoClient, clientSession, mongoConnection); + this.command = BsonDocument.parse(mql); + this.parameterValueSetters = new ArrayList<>(); + parseParameters(command, parameterValueSetters); + } + + @Override + public ResultSet executeQuery() throws SQLException { + checkClosed(); + throw new NotYetImplementedException(); + } + + @Override + public int executeUpdate() throws SQLException { + checkClosed(); + return executeUpdateCommand(command); + } + + @Override + public void setNull(int parameterIndex, int sqlType) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + switch (sqlType) { + case Types.ARRAY: + case Types.BLOB: + case Types.CLOB: + case Types.DATALINK: + case Types.JAVA_OBJECT: + case Types.NCHAR: + case Types.NCLOB: + case Types.NVARCHAR: + case Types.LONGNVARCHAR: + case Types.REF: + case Types.ROWID: + case Types.SQLXML: + case Types.STRUCT: + throw new SQLFeatureNotSupportedException( + "Unsupported sql type: " + JDBCType.valueOf(sqlType).getName()); + } + setParameter(parameterIndex, BsonNull.VALUE); + } + + @Override + public void setBoolean(int parameterIndex, boolean x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setParameter(parameterIndex, BsonBoolean.valueOf(x)); + } + + @Override + public void setByte(int parameterIndex, byte x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setInt(parameterIndex, x); + } + + @Override + public void setShort(int parameterIndex, short x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setInt(parameterIndex, x); + } + + @Override + public void setInt(int parameterIndex, int x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setParameter(parameterIndex, new BsonInt32(x)); + } + + @Override + public void setLong(int parameterIndex, long x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setParameter(parameterIndex, new BsonInt64(x)); + } + + @Override + public void setFloat(int parameterIndex, float x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setDouble(parameterIndex, x); + } + + @Override + public void setDouble(int parameterIndex, double x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setParameter(parameterIndex, new BsonDouble(x)); + } + + @Override + public void setBigDecimal(int parameterIndex, @Nullable BigDecimal x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + if (x == null) { + setNull(parameterIndex, Types.NUMERIC); + } else { + setParameter(parameterIndex, new BsonDecimal128(new Decimal128(x))); + } + } + + @Override + public void setString(int parameterIndex, @Nullable String x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + if (x == null) { + setNull(parameterIndex, Types.VARCHAR); + } else { + setParameter(parameterIndex, new BsonString(x)); + } + } + + @Override + public void setBytes(int parameterIndex, byte @Nullable [] x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + if (x == null) { + setNull(parameterIndex, Types.VARBINARY); + } else { + setParameter(parameterIndex, new BsonBinary(x)); + } + } + + @Override + public void setDate(int parameterIndex, @Nullable Date x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setDate(parameterIndex, x, null); + } + + @Override + public void setTime(int parameterIndex, @Nullable Time x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setTime(parameterIndex, x, null); + } + + @Override + public void setTimestamp(int parameterIndex, @Nullable Timestamp x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + setTimestamp(parameterIndex, x, null); + } + + @Override + public void setBinaryStream(int parameterIndex, @Nullable InputStream x, int length) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + throw new NotYetImplementedException(); + } + + // ---------------------------------------------------------------------- + // Advanced features: + + @Override + public void setObject(int parameterIndex, @Nullable Object x, int targetSqlType) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + throw new NotYetImplementedException("To be implemented during Array / Struct tickets"); + } + + @Override + public void setObject(int parameterIndex, @Nullable Object x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + throw new NotYetImplementedException("To be implemented during Array / Struct tickets"); + } + + // --------------------------JDBC 2.0----------------------------- + + @Override + public void addBatch() throws SQLException { + checkClosed(); + throw new NotYetImplementedException( + "To be implemented in scope of https://jira.mongodb.org/browse/HIBERNATE-35"); + } + + @Override + public void setArray(int parameterIndex, @Nullable Array x) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + throw new NotYetImplementedException(); + } + + @Override + public void setDate(int parameterIndex, @Nullable Date x, @Nullable Calendar cal) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + throw new NotYetImplementedException("To implement in scope of https://jira.mongodb.org/browse/HIBERNATE-42"); + } + + @Override + public void setTime(int parameterIndex, @Nullable Time x, @Nullable Calendar cal) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + throw new NotYetImplementedException("To implement in scope of https://jira.mongodb.org/browse/HIBERNATE-42"); + } + + @Override + public void setTimestamp(int parameterIndex, @Nullable Timestamp x, @Nullable Calendar cal) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + throw new NotYetImplementedException("To implement in scope of https://jira.mongodb.org/browse/HIBERNATE-42"); + } + + @Override + public void setNull(int parameterIndex, int sqlType, @Nullable String typeName) throws SQLException { + checkClosed(); + checkParameterIndex(parameterIndex); + throw new NotYetImplementedException("To be implemented during Array / Struct tickets"); + } + + private void setParameter(int parameterIndex, BsonValue parameterValue) { + var parameterValueSetter = parameterValueSetters.get(parameterIndex - 1); + parameterValueSetter.accept(parameterValue); + } + + private static void parseParameters(BsonDocument command, List> parameterValueSetters) { + for (var entry : command.entrySet()) { + if (isParameterMarker(entry.getValue())) { + parameterValueSetters.add(entry::setValue); + } else if (entry.getValue().getBsonType().isContainer()) { + parseParameters(entry.getValue(), parameterValueSetters); + } + } + } + + private static void parseParameters(BsonArray array, List> parameterValueSetters) { + for (var i = 0; i < array.size(); i++) { + var value = array.get(i); + if (isParameterMarker(value)) { + var idx = i; + parameterValueSetters.add(v -> array.set(idx, v)); + } else if (value.getBsonType().isContainer()) { + parseParameters(value, parameterValueSetters); + } + } + } + + private static void parseParameters(BsonValue value, List> parameterValueSetters) { + if (value.isDocument()) { + parseParameters(value.asDocument(), parameterValueSetters); + } else if (value.isArray()) { + parseParameters(value.asArray(), parameterValueSetters); + } else { + fail("Only BSON container type (BsonDocument or BsonArray) is accepted; provided type: " + + value.getBsonType()); + } + } + + private static boolean isParameterMarker(BsonValue value) { + return value.getBsonType() == BsonType.UNDEFINED; + } + + private void checkParameterIndex(int parameterIndex) throws SQLException { + if (parameterValueSetters.isEmpty()) { + throw new SQLException("No parameter exists"); + } + if (parameterIndex < 1 || parameterIndex > parameterValueSetters.size()) { + throw new SQLException(format( + "Parameter index invalid: %d; should be within [1, %d]", + parameterIndex, parameterValueSetters.size())); + } + } +} diff --git a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java index 2f0f4f1a..f9be5d43 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/MongoStatement.java @@ -18,6 +18,7 @@ import static com.mongodb.hibernate.internal.VisibleForTesting.AccessModifier.PRIVATE; import static com.mongodb.hibernate.jdbc.MongoConnection.DATABASE; +import static java.lang.String.format; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoClient; @@ -36,9 +37,9 @@ * MongoDB Dialect's JDBC {@link java.sql.Statement} implementation class. * *

It only focuses on API methods Mongo Dialect will support. All the other methods are implemented by throwing - * exceptions in its parent class. + * exceptions in its parent {@link StatementAdapter adapter interface}. */ -final class MongoStatement extends StatementAdapter { +class MongoStatement implements StatementAdapter { private final MongoClient mongoClient; private final MongoConnection mongoConnection; @@ -63,6 +64,10 @@ public ResultSet executeQuery(String mql) throws SQLException { public int executeUpdate(String mql) throws SQLException { checkClosed(); var command = parse(mql); + return executeUpdateCommand(command); + } + + int executeUpdateCommand(BsonDocument command) throws SQLException { startTransactionIfNeeded(); try { return mongoClient @@ -70,7 +75,7 @@ public int executeUpdate(String mql) throws SQLException { .runCommand(clientSession, command) .getInteger("n"); } catch (Exception e) { - throw new SQLException("Failed to run #executeUpdate(String)", e); + throw new SQLException("Failed to execute update command", e); } } @@ -199,9 +204,14 @@ public boolean isClosed() { return closed; } - private void checkClosed() throws SQLException { + @Override + public boolean isWrapperFor(Class iface) { + return false; + } + + void checkClosed() throws SQLException { if (closed) { - throw new SQLException("Statement has been closed"); + throw new SQLException(format("%s has been closed", getClass().getSimpleName())); } } @@ -221,11 +231,6 @@ private static BsonDocument parse(String mql) throws SQLSyntaxErrorException { /** * Starts transaction for the first {@link java.sql.Statement} executing if * {@linkplain MongoConnection#getAutoCommit() auto-commit} is disabled. - * - * @see #executeQuery(String) - * @see #executeUpdate(String) - * @see #execute(String) - * @see #executeBatch() */ private void startTransactionIfNeeded() throws SQLException { if (!mongoConnection.getAutoCommit() && !clientSession.hasActiveTransaction()) { diff --git a/src/main/java/com/mongodb/hibernate/jdbc/PreparedStatementAdapter.java b/src/main/java/com/mongodb/hibernate/jdbc/PreparedStatementAdapter.java new file mode 100644 index 00000000..2716c7e9 --- /dev/null +++ b/src/main/java/com/mongodb/hibernate/jdbc/PreparedStatementAdapter.java @@ -0,0 +1,333 @@ +/* + * Copyright 2024-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.SQLXML; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; + +/** + * A {@link java.sql.PreparedStatement} adapter interface that throws exceptions for all its API methods. + * + * @see MongoPreparedStatement + */ +interface PreparedStatementAdapter extends StatementAdapter, PreparedStatement { + @Override + default ResultSet executeQuery() throws SQLException { + throw new SQLFeatureNotSupportedException("executeQuery not implemented"); + } + + @Override + default int executeUpdate() throws SQLException { + throw new SQLFeatureNotSupportedException("executeUpdate not implemented"); + } + + @Override + default void setNull(int parameterIndex, int sqlType) throws SQLException { + throw new SQLFeatureNotSupportedException("setNull not implemented"); + } + + @Override + default void setBoolean(int parameterIndex, boolean x) throws SQLException { + throw new SQLFeatureNotSupportedException("setBoolean not implemented"); + } + + @Override + default void setByte(int parameterIndex, byte x) throws SQLException { + throw new SQLFeatureNotSupportedException("setByte not implemented"); + } + + @Override + default void setShort(int parameterIndex, short x) throws SQLException { + throw new SQLFeatureNotSupportedException("setShort not implemented"); + } + + @Override + default void setInt(int parameterIndex, int x) throws SQLException { + throw new SQLFeatureNotSupportedException("setInt not implemented"); + } + + @Override + default void setLong(int parameterIndex, long x) throws SQLException { + throw new SQLFeatureNotSupportedException("setLong not implemented"); + } + + @Override + default void setFloat(int parameterIndex, float x) throws SQLException { + throw new SQLFeatureNotSupportedException("setFloat not implemented"); + } + + @Override + default void setDouble(int parameterIndex, double x) throws SQLException { + throw new SQLFeatureNotSupportedException("setDouble not implemented"); + } + + @Override + default void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { + throw new SQLFeatureNotSupportedException("setBigDecimal not implemented"); + } + + @Override + default void setString(int parameterIndex, String x) throws SQLException { + throw new SQLFeatureNotSupportedException("setString not implemented"); + } + + @Override + default void setBytes(int parameterIndex, byte[] x) throws SQLException { + throw new SQLFeatureNotSupportedException("setBytes not implemented"); + } + + @Override + default void setDate(int parameterIndex, Date x) throws SQLException { + throw new SQLFeatureNotSupportedException("setDate not implemented"); + } + + @Override + default void setTime(int parameterIndex, Time x) throws SQLException { + throw new SQLFeatureNotSupportedException("setTime not implemented"); + } + + @Override + default void setTimestamp(int parameterIndex, Timestamp x) throws SQLException { + throw new SQLFeatureNotSupportedException("setTimestamp not implemented"); + } + + @Override + default void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("setAsciiStream not implemented"); + } + + @Override + default void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("setUnicodeStream not implemented"); + } + + @Override + default void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("setBinaryStream not implemented"); + } + + @Override + default void clearParameters() throws SQLException { + throw new SQLFeatureNotSupportedException("clearParameters not implemented"); + } + + // ---------------------------------------------------------------------- + // Advanced features: + + @Override + default void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { + throw new SQLFeatureNotSupportedException("setObject not implemented"); + } + + @Override + default void setObject(int parameterIndex, Object x) throws SQLException { + throw new SQLFeatureNotSupportedException("setObject not implemented"); + } + + @Override + default boolean execute() throws SQLException { + throw new SQLFeatureNotSupportedException("execute not implemented"); + } + + // --------------------------JDBC 2.0----------------------------- + + @Override + default void addBatch() throws SQLException { + throw new SQLFeatureNotSupportedException("addBatch not implemented"); + } + + @Override + default void setCharacterStream(int parameterIndex, Reader reader, int length) throws SQLException { + throw new SQLFeatureNotSupportedException("setCharacterStream not implemented"); + } + + @Override + default void setRef(int parameterIndex, Ref x) throws SQLException { + throw new SQLFeatureNotSupportedException("setRef not implemented"); + } + + @Override + default void setBlob(int parameterIndex, Blob x) throws SQLException { + throw new SQLFeatureNotSupportedException("setBlob not implemented"); + } + + @Override + default void setClob(int parameterIndex, Clob x) throws SQLException { + throw new SQLFeatureNotSupportedException("setClob not implemented"); + } + + @Override + default void setArray(int parameterIndex, Array x) throws SQLException { + throw new SQLFeatureNotSupportedException("setArray not implemented"); + } + + @Override + default ResultSetMetaData getMetaData() throws SQLException { + throw new SQLFeatureNotSupportedException("getMetaData not implemented"); + } + + @Override + default void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("setDate not implemented"); + } + + @Override + default void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("setTime not implemented"); + } + + @Override + default void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException { + throw new SQLFeatureNotSupportedException("setTimestamp not implemented"); + } + + @Override + default void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { + throw new SQLFeatureNotSupportedException("setNull not implemented"); + } + + // ------------------------- JDBC 3.0 ----------------------------------- + + @Override + default void setURL(int parameterIndex, URL x) throws SQLException { + throw new SQLFeatureNotSupportedException("setURL not implemented"); + } + + @Override + default ParameterMetaData getParameterMetaData() throws SQLException { + throw new SQLFeatureNotSupportedException("getParameterMetaData not implemented"); + } + + // ------------------------- JDBC 4.0 ----------------------------------- + + @Override + default void setRowId(int parameterIndex, RowId x) throws SQLException { + throw new SQLFeatureNotSupportedException("setRowId not implemented"); + } + + @Override + default void setNString(int parameterIndex, String value) throws SQLException { + throw new SQLFeatureNotSupportedException("setNString not implemented"); + } + + @Override + default void setNCharacterStream(int parameterIndex, Reader value, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("setNCharacterStream not implemented"); + } + + @Override + default void setNClob(int parameterIndex, NClob value) throws SQLException { + throw new SQLFeatureNotSupportedException("setNClob not implemented"); + } + + @Override + default void setClob(int parameterIndex, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("setClob not implemented"); + } + + @Override + default void setBlob(int parameterIndex, InputStream inputStream, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("setBlob not implemented"); + } + + @Override + default void setNClob(int parameterIndex, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("setNClob not implemented"); + } + + @Override + default void setSQLXML(int parameterIndex, SQLXML xmlObject) throws SQLException { + throw new SQLFeatureNotSupportedException("setSQLXML not implemented"); + } + + @Override + default void setObject(int parameterIndex, Object x, int targetSqlType, int scaleOrLength) throws SQLException { + throw new SQLFeatureNotSupportedException("setObject not implemented"); + } + + @Override + default void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("setAsciiStream not implemented"); + } + + @Override + default void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("setBinaryStream not implemented"); + } + + @Override + default void setCharacterStream(int parameterIndex, Reader reader, long length) throws SQLException { + throw new SQLFeatureNotSupportedException("setCharacterStream not implemented"); + } + + @Override + default void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { + throw new SQLFeatureNotSupportedException("setAsciiStream not implemented"); + } + + @Override + default void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { + throw new SQLFeatureNotSupportedException("setBinaryStream not implemented"); + } + + @Override + default void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("setCharacterStream not implemented"); + } + + @Override + default void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { + throw new SQLFeatureNotSupportedException("setNCharacterStream not implemented"); + } + + @Override + default void setClob(int parameterIndex, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("setClob not implemented"); + } + + @Override + default void setBlob(int parameterIndex, InputStream inputStream) throws SQLException { + throw new SQLFeatureNotSupportedException("setBlob not implemented"); + } + + @Override + default void setNClob(int parameterIndex, Reader reader) throws SQLException { + throw new SQLFeatureNotSupportedException("setNClob not implemented"); + } + + // ------------------------- JDBC 4.2 ----------------------------------- + +} diff --git a/src/main/java/com/mongodb/hibernate/jdbc/StatementAdapter.java b/src/main/java/com/mongodb/hibernate/jdbc/StatementAdapter.java index b634c6cb..f6759d06 100644 --- a/src/main/java/com/mongodb/hibernate/jdbc/StatementAdapter.java +++ b/src/main/java/com/mongodb/hibernate/jdbc/StatementAdapter.java @@ -25,228 +25,228 @@ import org.jspecify.annotations.Nullable; /** - * A {@link java.sql.Statement} implementation class that throws exceptions for all its API methods. + * A {@link java.sql.Statement} adapter interface that throws exceptions for all its API methods. * * @see MongoStatement */ -abstract class StatementAdapter implements Statement { +interface StatementAdapter extends Statement { @Override - public ResultSet executeQuery(String sql) throws SQLException { + default ResultSet executeQuery(String sql) throws SQLException { throw new SQLFeatureNotSupportedException("executeQuery not implemented"); } @Override - public int executeUpdate(String sql) throws SQLException { + default int executeUpdate(String sql) throws SQLException { throw new SQLFeatureNotSupportedException("executeUpdate not implemented"); } @Override - public void close() throws SQLException { + default void close() throws SQLException { throw new SQLFeatureNotSupportedException("close not implemented"); } @Override - public int getMaxFieldSize() throws SQLException { + default int getMaxFieldSize() throws SQLException { throw new SQLFeatureNotSupportedException("getMaxFieldSize not implemented"); } @Override - public void setMaxFieldSize(int max) throws SQLException { + default void setMaxFieldSize(int max) throws SQLException { throw new SQLFeatureNotSupportedException("setMaxFieldSize not implemented"); } @Override - public int getMaxRows() throws SQLException { + default int getMaxRows() throws SQLException { throw new SQLFeatureNotSupportedException("getMaxRows not implemented"); } @Override - public void setMaxRows(int max) throws SQLException { + default void setMaxRows(int max) throws SQLException { throw new SQLFeatureNotSupportedException("setMaxRows not implemented"); } @Override - public void setEscapeProcessing(boolean enable) throws SQLException { + default void setEscapeProcessing(boolean enable) throws SQLException { throw new SQLFeatureNotSupportedException("setEscapeProcessing not implemented"); } @Override - public int getQueryTimeout() throws SQLException { + default int getQueryTimeout() throws SQLException { throw new SQLFeatureNotSupportedException("getQueryTimeout not implemented"); } @Override - public void setQueryTimeout(int seconds) throws SQLException { + default void setQueryTimeout(int seconds) throws SQLException { throw new SQLFeatureNotSupportedException("setQueryTimeout not implemented"); } @Override - public void cancel() throws SQLException { + default void cancel() throws SQLException { throw new SQLFeatureNotSupportedException("cancel not implemented"); } @Override - public @Nullable SQLWarning getWarnings() throws SQLException { + default @Nullable SQLWarning getWarnings() throws SQLException { throw new SQLFeatureNotSupportedException("getWarnings not implemented"); } @Override - public void clearWarnings() throws SQLException { + default void clearWarnings() throws SQLException { throw new SQLFeatureNotSupportedException("clearWarnings not implemented"); } @Override - public void setCursorName(String name) throws SQLException { + default void setCursorName(String name) throws SQLException { throw new SQLFeatureNotSupportedException("setCursorName not implemented"); } @Override - public boolean execute(String sql) throws SQLException { + default boolean execute(String sql) throws SQLException { throw new SQLFeatureNotSupportedException("execute not implemented"); } @Override - public @Nullable ResultSet getResultSet() throws SQLException { + default @Nullable ResultSet getResultSet() throws SQLException { throw new SQLFeatureNotSupportedException("getResultSet not implemented"); } @Override - public int getUpdateCount() throws SQLException { + default int getUpdateCount() throws SQLException { throw new SQLFeatureNotSupportedException("getUpdateCount not implemented"); } @Override - public boolean getMoreResults() throws SQLException { + default boolean getMoreResults() throws SQLException { throw new SQLFeatureNotSupportedException("getMoreResults not implemented"); } @Override - public void setFetchDirection(int direction) throws SQLException { + default void setFetchDirection(int direction) throws SQLException { throw new SQLFeatureNotSupportedException("setFetchDirection not implemented"); } @Override - public int getFetchDirection() throws SQLException { + default int getFetchDirection() throws SQLException { throw new SQLFeatureNotSupportedException("getFetchDirection not implemented"); } @Override - public void setFetchSize(int rows) throws SQLException { + default void setFetchSize(int rows) throws SQLException { throw new SQLFeatureNotSupportedException("setFetchSize not implemented"); } @Override - public int getFetchSize() throws SQLException { + default int getFetchSize() throws SQLException { throw new SQLFeatureNotSupportedException("getFetchSize not implemented"); } @Override - public int getResultSetConcurrency() throws SQLException { + default int getResultSetConcurrency() throws SQLException { throw new SQLFeatureNotSupportedException("getResultSetConcurrency not implemented"); } @Override - public int getResultSetType() throws SQLException { + default int getResultSetType() throws SQLException { throw new SQLFeatureNotSupportedException("getResultSetType not implemented"); } @Override - public void addBatch(String sql) throws SQLException { + default void addBatch(String sql) throws SQLException { throw new SQLFeatureNotSupportedException("addBatch not implemented"); } @Override - public void clearBatch() throws SQLException { + default void clearBatch() throws SQLException { throw new SQLFeatureNotSupportedException("clearBatch not implemented"); } @Override - public int[] executeBatch() throws SQLException { + default int[] executeBatch() throws SQLException { throw new SQLFeatureNotSupportedException("executeBatch not implemented"); } @Override - public Connection getConnection() throws SQLException { + default Connection getConnection() throws SQLException { throw new SQLFeatureNotSupportedException("getConnection not implemented"); } @Override - public boolean getMoreResults(int current) throws SQLException { + default boolean getMoreResults(int current) throws SQLException { throw new SQLFeatureNotSupportedException("getMoreResults not implemented"); } @Override - public ResultSet getGeneratedKeys() throws SQLException { + default ResultSet getGeneratedKeys() throws SQLException { throw new SQLFeatureNotSupportedException("getGeneratedKeys not implemented"); } @Override - public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { + default int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { throw new SQLFeatureNotSupportedException("executeUpdate not implemented"); } @Override - public int executeUpdate(String sql, int[] columnIndexes) throws SQLException { + default int executeUpdate(String sql, int[] columnIndexes) throws SQLException { throw new SQLFeatureNotSupportedException("executeUpdate not implemented"); } @Override - public int executeUpdate(String sql, String[] columnNames) throws SQLException { + default int executeUpdate(String sql, String[] columnNames) throws SQLException { throw new SQLFeatureNotSupportedException("executeUpdate not implemented"); } @Override - public boolean execute(String sql, int autoGeneratedKeys) throws SQLException { + default boolean execute(String sql, int autoGeneratedKeys) throws SQLException { throw new SQLFeatureNotSupportedException("execute not implemented"); } @Override - public boolean execute(String sql, int[] columnIndexes) throws SQLException { + default boolean execute(String sql, int[] columnIndexes) throws SQLException { throw new SQLFeatureNotSupportedException("execute not implemented"); } @Override - public boolean execute(String sql, String[] columnNames) throws SQLException { + default boolean execute(String sql, String[] columnNames) throws SQLException { throw new SQLFeatureNotSupportedException("execute not implemented"); } @Override - public int getResultSetHoldability() throws SQLException { + default int getResultSetHoldability() throws SQLException { throw new SQLFeatureNotSupportedException("getResultSetHoldability not implemented"); } @Override - public boolean isClosed() throws SQLException { + default boolean isClosed() throws SQLException { throw new SQLFeatureNotSupportedException("isClosed not implemented"); } @Override - public void setPoolable(boolean poolable) throws SQLException { + default void setPoolable(boolean poolable) throws SQLException { throw new SQLFeatureNotSupportedException("setPoolable not implemented"); } @Override - public boolean isPoolable() throws SQLException { + default boolean isPoolable() throws SQLException { throw new SQLFeatureNotSupportedException("isPoolable not implemented"); } @Override - public void closeOnCompletion() throws SQLException { + default void closeOnCompletion() throws SQLException { throw new SQLFeatureNotSupportedException("closeOnCompletion not implemented"); } @Override - public boolean isCloseOnCompletion() throws SQLException { + default boolean isCloseOnCompletion() throws SQLException { throw new SQLFeatureNotSupportedException("isCloseOnCompletion not implemented"); } @Override - public T unwrap(Class iface) throws SQLException { + default T unwrap(Class iface) throws SQLException { throw new SQLFeatureNotSupportedException("unwrap not implemented"); } @Override - public boolean isWrapperFor(Class iface) throws SQLException { + default boolean isWrapperFor(Class iface) throws SQLException { throw new SQLFeatureNotSupportedException("isWrapperFor not implemented"); } } diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoConnectionTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoConnectionTests.java index 9f930f14..2c5dda28 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoConnectionTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoConnectionTests.java @@ -19,7 +19,6 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; @@ -45,7 +44,7 @@ @ExtendWith(MockitoExtension.class) class MongoConnectionTests { - @Mock(answer = RETURNS_SMART_NULLS) + @Mock private ClientSession clientSession; @InjectMocks diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java new file mode 100644 index 00000000..3c794291 --- /dev/null +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoPreparedStatementTests.java @@ -0,0 +1,220 @@ +/* + * Copyright 2024-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.hibernate.jdbc; + +import static java.lang.String.format; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import com.mongodb.client.ClientSession; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoDatabase; +import java.sql.SQLException; +import java.sql.Types; +import java.util.Map; +import java.util.stream.Stream; +import org.bson.BsonDocument; +import org.bson.Document; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class MongoPreparedStatementTests { + + @Mock + private MongoClient mongoClient; + + @Mock + private ClientSession clientSession; + + @Mock + private MongoConnection mongoConnection; + + private MongoPreparedStatement createMongoPreparedStatement(String mql) { + return new MongoPreparedStatement(mongoClient, clientSession, mongoConnection, mql); + } + + @Nested + class ParameterValueSettingTests { + + private static final String EXAMPLE_MQL = + """ + { + insert: "books", + documents: [ + { + title: { $undefined: true }, + author: { $undefined: true }, + publishYear: { $undefined: true }, + outOfStock: { $undefined: true }, + tags: [ + { $undefined: true } + ] + } + ] + } + """; + + @Mock + private MongoDatabase mongoDatabase; + + @Captor + private ArgumentCaptor commandCaptor; + + @Test + @DisplayName("Happy path when all parameters are provided values") + void testSuccess() throws SQLException { + // given + doReturn(mongoDatabase).when(mongoClient).getDatabase(anyString()); + doReturn(Document.parse("{ok: 1.0, n: 1}")) + .when(mongoDatabase) + .runCommand(eq(clientSession), any(BsonDocument.class)); + + // when && then + try (var preparedStatement = createMongoPreparedStatement(EXAMPLE_MQL)) { + + preparedStatement.setString(1, "War and Peace"); + preparedStatement.setString(2, "Leo Tolstoy"); + preparedStatement.setInt(3, 1869); + preparedStatement.setBoolean(4, false); + preparedStatement.setString(5, "classic"); + + preparedStatement.executeUpdate(); + + verify(mongoDatabase).runCommand(eq(clientSession), commandCaptor.capture()); + var command = commandCaptor.getValue(); + var expectedDoc = BsonDocument.parse( + """ + { + insert: "books", + documents: [ + { + title: "War and Peace", + author: "Leo Tolstoy", + publishYear: 1869, + outOfStock: false, + tags: [ + "classic" + ] + } + ] + } + """); + assertEquals(expectedDoc, command); + } + } + + @Test + @DisplayName("SQLException is thrown when parameter index is invalid") + void testParameterIndexInvalid() { + try (var preparedStatement = createMongoPreparedStatement(EXAMPLE_MQL)) { + var sqlException = + assertThrows(SQLException.class, () -> preparedStatement.setString(0, "War and Peace")); + assertEquals( + format("Parameter index invalid: %d; should be within [1, %d]", 0, 5), + sqlException.getMessage()); + verify(mongoClient, never()).getDatabase(anyString()); + } + } + } + + @Nested + class CloseTests { + + @FunctionalInterface + interface PreparedStatementMethodInvocation { + void run(MongoPreparedStatement pstmt) throws SQLException; + } + + @ParameterizedTest(name = "SQLException is thrown when \"{0}\" is called on a closed MongoPreparedStatement") + @MethodSource("getMongoPreparedStatementMethodInvocationsImpactedByClosing") + void testCheckClosed(String label, PreparedStatementMethodInvocation methodInvocation) { + // given + var mql = + """ + { + insert: "books", + documents: [ + { + title: "War and Peace", + author: "Leo Tolstoy", + outOfStock: false, + values: [ + { $undefined: true } + ] + } + ] + } + """; + + var preparedStatement = createMongoPreparedStatement(mql); + preparedStatement.close(); + + // when && then + var sqlException = assertThrows(SQLException.class, () -> methodInvocation.run(preparedStatement)); + assertEquals("MongoPreparedStatement has been closed", sqlException.getMessage()); + } + + private static Stream getMongoPreparedStatementMethodInvocationsImpactedByClosing() { + return Map.ofEntries( + Map.entry("executeQuery()", MongoPreparedStatement::executeQuery), + Map.entry("executeUpdate()", MongoPreparedStatement::executeUpdate), + Map.entry("setNull(int,int)", pstmt -> pstmt.setNull(1, Types.INTEGER)), + Map.entry("setBoolean(int,boolean)", pstmt -> pstmt.setBoolean(1, true)), + Map.entry("setByte(int,byte)", pstmt -> pstmt.setByte(1, (byte) 10)), + Map.entry("setShort(int,short)", pstmt -> pstmt.setShort(1, (short) 10)), + Map.entry("setInt(int,int)", pstmt -> pstmt.setInt(1, 1)), + Map.entry("setLong(int,long)", pstmt -> pstmt.setLong(1, 1L)), + Map.entry("setFloat(int,float)", pstmt -> pstmt.setFloat(1, 1.0F)), + Map.entry("setDouble(int,double)", pstmt -> pstmt.setDouble(1, 1.0)), + Map.entry("setBigDecimal(int,BigDecimal)", pstmt -> pstmt.setBigDecimal(1, null)), + Map.entry("setString(int,String)", pstmt -> pstmt.setString(1, null)), + Map.entry("setBytes(int,byte[])", pstmt -> pstmt.setBytes(1, null)), + Map.entry("setDate(int,Date)", pstmt -> pstmt.setDate(1, null)), + Map.entry("setTime(int,Time)", pstmt -> pstmt.setTime(1, null)), + Map.entry("setTimestamp(int,Timestamp)", pstmt -> pstmt.setTimestamp(1, null)), + Map.entry( + "setBinaryStream(int,InputStream,int)", pstmt -> pstmt.setBinaryStream(1, null, 0)), + Map.entry("setObject(int,Object,int)", pstmt -> pstmt.setObject(1, null, Types.OTHER)), + Map.entry("addBatch()", MongoPreparedStatement::addBatch), + Map.entry("setArray(int,Array)", pstmt -> pstmt.setArray(1, null)), + Map.entry("setDate(int,Date,Calendar)", pstmt -> pstmt.setDate(1, null, null)), + Map.entry("setTime(int,Time,Calendar)", pstmt -> pstmt.setTime(1, null, null)), + Map.entry( + "setTimestamp(int,Timestamp,Calendar)", pstmt -> pstmt.setTimestamp(1, null, null)), + Map.entry("setNull(int,Object,String)", pstmt -> pstmt.setNull(1, Types.STRUCT, "BOOK"))) + .entrySet() + .stream() + .map(entry -> Arguments.of(entry.getKey(), entry.getValue())); + } + } +} diff --git a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java index b88fa3af..90ba779b 100644 --- a/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java +++ b/src/test/java/com/mongodb/hibernate/jdbc/MongoStatementTests.java @@ -20,7 +20,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.same; @@ -49,13 +48,13 @@ @ExtendWith(MockitoExtension.class) class MongoStatementTests { - @Mock(answer = RETURNS_SMART_NULLS) + @Mock private MongoClient mongoClient; - @Mock(answer = RETURNS_SMART_NULLS) + @Mock private ClientSession clientSession; - @Mock(answer = RETURNS_SMART_NULLS) + @Mock private MongoConnection mongoConnection; @InjectMocks @@ -90,7 +89,7 @@ void testSQLExceptionThrownWhenCalledWithInvalidMql() { @Test @DisplayName("SQLException is thrown when database access error occurs") - void testSQLExceptionThrownWhenDBAccessFailed(@Mock(answer = RETURNS_SMART_NULLS) MongoDatabase mongoDatabase) { + void testSQLExceptionThrownWhenDBAccessFailed(@Mock MongoDatabase mongoDatabase) { // given doReturn(mongoDatabase).when(mongoClient).getDatabase(anyString()); var dbAccessException = new RuntimeException(); @@ -125,7 +124,7 @@ void testCheckClosed(String label, StatementMethodInvocation methodInvocation) { // when && then var exception = assertThrows(SQLException.class, () -> methodInvocation.runOn(mongoStatement)); - assertEquals("Statement has been closed", exception.getMessage()); + assertEquals("MongoStatement has been closed", exception.getMessage()); } private static Stream getMongoStatementMethodInvocationsImpactedByClosing() { diff --git a/src/test/java/org/mockito/configuration/MockitoConfiguration.java b/src/test/java/org/mockito/configuration/MockitoConfiguration.java new file mode 100644 index 00000000..cce680a2 --- /dev/null +++ b/src/test/java/org/mockito/configuration/MockitoConfiguration.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.mockito.configuration; + +import org.mockito.Answers; +import org.mockito.stubbing.Answer; + +/** + * Mockito's global configuration overriding mechanism. Before the issue is resolved, this seems the best way to configure + * {@link Answers#RETURNS_SMART_NULLS RETURNS_SMART_NULLS} as the default Mock {@link Answer}. + */ +public final class MockitoConfiguration extends DefaultMockitoConfiguration { + + public Answer getDefaultAnswer() { + return Answers.RETURNS_SMART_NULLS; + } +}