Skip to content

Commit

Permalink
feat: support statement tags in hints (#1579)
Browse files Browse the repository at this point in the history
  • Loading branch information
olavloite committed Apr 21, 2024
1 parent 10cc93b commit 0c3aec1
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spanner.jdbc;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
import com.google.cloud.spanner.connection.AbstractMockServerTest;
import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.RequestOptions.Priority;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;

@RunWith(Parameterized.class)
public class ClientSideStatementHintsTest extends AbstractMockServerTest {

@Parameter public Dialect dialect;

private Dialect currentDialect;

@Parameters(name = "dialect = {0}")
public static Object[] data() {
return Dialect.values();
}

@Before
public void setupDialect() {
if (this.dialect != currentDialect) {
mockSpanner.putStatementResult(StatementResult.detectDialectResult(this.dialect));
this.currentDialect = dialect;
}
}

@After
public void clearRequests() {
mockSpanner.clearRequests();
}

private String createUrl() {
return String.format(
"jdbc:cloudspanner://localhost:%d/projects/%s/instances/%s/databases/%s?usePlainText=true",
getPort(), "proj", "inst", "db" + (dialect == Dialect.POSTGRESQL ? "pg" : ""));
}

private Connection createConnection() throws SQLException {
return DriverManager.getConnection(createUrl());
}

@Test
public void testStatementTagInHint() throws SQLException {
try (Connection connection = createConnection()) {
try (ResultSet resultSet =
connection
.createStatement()
.executeQuery(
dialect == Dialect.POSTGRESQL
? "/*@statement_tag='test-tag'*/SELECT 1"
: "@{statement_tag='test-tag'}SELECT 1")) {
assertTrue(resultSet.next());
assertEquals(1L, resultSet.getLong(1));
assertFalse(resultSet.next());
}
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0);
assertEquals("test-tag", request.getRequestOptions().getRequestTag());
}

@Test
public void testRpcPriorityInHint() throws SQLException {
try (Connection connection = createConnection()) {
try (ResultSet resultSet =
connection
.createStatement()
.executeQuery(
dialect == Dialect.POSTGRESQL
? "/*@rpc_priority=PRIORITY_LOW*/SELECT 1"
: "@{rpc_priority=PRIORITY_LOW}SELECT 1")) {
assertTrue(resultSet.next());
assertEquals(1L, resultSet.getLong(1));
assertFalse(resultSet.next());
}
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0);
assertEquals(Priority.PRIORITY_LOW, request.getRequestOptions().getPriority());
}
}
67 changes: 46 additions & 21 deletions src/test/java/com/google/cloud/spanner/jdbc/TagMockServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import com.google.cloud.spanner.Dialect;
Expand All @@ -40,16 +41,25 @@

@RunWith(Parameterized.class)
public class TagMockServerTest extends AbstractMockServerTest {
private static final String SELECT_RANDOM_SQL = SELECT_RANDOM_STATEMENT.getSql();

private static final String INSERT_SQL = INSERT_STATEMENT.getSql();

@Parameter public Dialect dialect;

private Dialect currentDialect;

@Parameters(name = "dialect = {0}")
public static Object[] data() {
return Dialect.values();
}

@Before
public void setupDialect() {
mockSpanner.putStatementResult(StatementResult.detectDialectResult(this.dialect));
if (this.dialect != currentDialect) {
mockSpanner.putStatementResult(StatementResult.detectDialectResult(this.dialect));
this.currentDialect = dialect;
}
}

@After
Expand Down Expand Up @@ -77,8 +87,7 @@ public void testStatementTagForQuery() throws SQLException {
connection
.createStatement()
.execute(String.format("set %sstatement_tag='my-tag'", getVariablePrefix()));
try (ResultSet resultSet =
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
assertTrue(resultSet.next());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
Expand All @@ -87,8 +96,7 @@ public void testStatementTagForQuery() throws SQLException {

// Verify that the tag is cleared after having been used.
mockSpanner.clearRequests();
try (ResultSet resultSet =
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
assertTrue(resultSet.next());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
Expand All @@ -103,8 +111,7 @@ public void testTransactionTagForQuery() throws SQLException {
connection
.createStatement()
.execute(String.format("set %stransaction_tag='my-tag'", getVariablePrefix()));
try (ResultSet resultSet =
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
assertTrue(resultSet.next());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
Expand All @@ -113,8 +120,7 @@ public void testTransactionTagForQuery() throws SQLException {

// Verify that the tag is used for the entire transaction.
mockSpanner.clearRequests();
try (ResultSet resultSet =
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
assertTrue(resultSet.next());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
Expand All @@ -125,8 +131,7 @@ public void testTransactionTagForQuery() throws SQLException {
connection.commit();

mockSpanner.clearRequests();
try (ResultSet resultSet =
connection.createStatement().executeQuery(SELECT_RANDOM_STATEMENT.getSql())) {
try (ResultSet resultSet = connection.createStatement().executeQuery(SELECT_RANDOM_SQL)) {
assertTrue(resultSet.next());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
Expand All @@ -143,8 +148,8 @@ public void testStatementTagForBatchDml() throws SQLException {
.execute(String.format("set %sstatement_tag='my-tag'", getVariablePrefix()));

try (Statement statement = connection.createStatement()) {
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_SQL);
statement.addBatch(INSERT_SQL);
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
Expand All @@ -155,8 +160,8 @@ public void testStatementTagForBatchDml() throws SQLException {
// Verify that the tag is cleared after having been used.
mockSpanner.clearRequests();
try (Statement statement = connection.createStatement()) {
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_SQL);
statement.addBatch(INSERT_SQL);
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
Expand All @@ -173,8 +178,8 @@ public void testTransactionTagForBatchDml() throws SQLException {
.execute(String.format("set %stransaction_tag='my-tag'", getVariablePrefix()));

try (Statement statement = connection.createStatement()) {
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_SQL);
statement.addBatch(INSERT_SQL);
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
Expand All @@ -185,8 +190,8 @@ public void testTransactionTagForBatchDml() throws SQLException {
// Verify that the tag is used for the entire transaction.
mockSpanner.clearRequests();
try (Statement statement = connection.createStatement()) {
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_SQL);
statement.addBatch(INSERT_SQL);
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
Expand All @@ -197,13 +202,33 @@ public void testTransactionTagForBatchDml() throws SQLException {
connection.commit();
mockSpanner.clearRequests();
try (Statement statement = connection.createStatement()) {
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_STATEMENT.getSql());
statement.addBatch(INSERT_SQL);
statement.addBatch(INSERT_SQL);
assertArrayEquals(new int[] {1, 1}, statement.executeBatch());
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
request = mockSpanner.getRequestsOfType(ExecuteBatchDmlRequest.class).get(0);
assertEquals("", request.getRequestOptions().getTransactionTag());
}
}

@Test
public void testStatementTagInHint() throws SQLException {
try (Connection connection = createConnection()) {
try (ResultSet resultSet =
connection
.createStatement()
.executeQuery(
dialect == Dialect.POSTGRESQL
? "/*@statement_tag='test-tag'*/SELECT 1"
: "@{statement_tag='test-tag'}SELECT 1")) {
assertTrue(resultSet.next());
assertEquals(1L, resultSet.getLong(1));
assertFalse(resultSet.next());
}
}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
ExecuteSqlRequest request = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0);
assertEquals("test-tag", request.getRequestOptions().getRequestTag());
}
}

0 comments on commit 0c3aec1

Please sign in to comment.