Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: separate flags for request/response E2E checksum and enable request checksum by default #1251

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions datastore-v1-proto-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.google.testparameterinjector</groupId>
<artifactId>test-parameter-injector</artifactId>
<version>1.14</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.google.truth</groupId>
<artifactId>truth</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ class RemoteRpc {
private final HttpRequestInitializer initializer;
private final String url;
private final AtomicInteger rpcCount = new AtomicInteger(0);
// Not final - so it can be set/reset in Unittests
private static boolean enableE2EChecksum =
Boolean.parseBoolean(System.getenv("GOOGLE_CLOUD_DATASTORE_HTTP_ENABLE_E2E_CHECKSUM"));
kolea2 marked this conversation as resolved.
Show resolved Hide resolved
private static final String E2E_REQUEST_CHECKSUM_FLAG =
"GOOGLE_CLOUD_DATASTORE_HTTP_ENABLE_E2E_REQUEST_CHECKSUM";
private static final String E2E_RESPONSE_CHECKSUM_FLAG =
"GOOGLE_CLOUD_DATASTORE_HTTP_ENABLE_E2E_RESPONSE_CHECKSUM";
// By default request checksum is enabled.
// Not final - so it can be set/reset in Unittests.
private static boolean enableE2ERequestChecksum =
System.getenv(E2E_REQUEST_CHECKSUM_FLAG) == null
|| Boolean.parseBoolean(System.getenv(E2E_REQUEST_CHECKSUM_FLAG));

private static boolean enableE2EResponseChecksum =
Boolean.parseBoolean(System.getenv(E2E_RESPONSE_CHECKSUM_FLAG));

RemoteRpc(HttpRequestFactory client, HttpRequestInitializer initializer, String url) {
this.client = client;
Expand Down Expand Up @@ -113,7 +122,7 @@ public InputStream call(
}
}
InputStream inputStream = httpResponse.getContent();
return enableE2EChecksum && EndToEndChecksumHandler.hasChecksumHeader(httpResponse)
return enableE2EResponseChecksum && EndToEndChecksumHandler.hasChecksumHeader(httpResponse)
? new ChecksumEnforcingInputStream(inputStream, httpResponse)
: inputStream;
} catch (SocketTimeoutException e) {
Expand All @@ -138,7 +147,7 @@ void setHeaders(
builder.append(databaseId);
}
httpRequest.getHeaders().put(X_GOOG_REQUEST_PARAMS_HEADER, builder.toString());
if (enableE2EChecksum && request != null) {
if (enableE2ERequestChecksum && request != null) {
String checksum = EndToEndChecksumHandler.computeChecksum(request.toByteArray());
if (checksum != null) {
httpRequest
Expand All @@ -154,8 +163,10 @@ HttpRequestFactory getClient() {
}

@VisibleForTesting
static void setSystemEnvE2EChecksum(boolean enableE2EChecksum) {
RemoteRpc.enableE2EChecksum = enableE2EChecksum;
static void setSystemEnvE2EChecksum(
boolean enableE2ERequestChecksum, boolean enableE2EResponseChecksum) {
RemoteRpc.enableE2ERequestChecksum = enableE2ERequestChecksum;
RemoteRpc.enableE2EResponseChecksum = enableE2EResponseChecksum;
}

void resetRpcCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
package com.google.datastore.v1.client;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;

import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpTransport;
Expand All @@ -31,17 +30,23 @@
import com.google.protobuf.MessageLite;
import com.google.rpc.Code;
import com.google.rpc.Status;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.zip.GZIPOutputStream;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Test for {@link RemoteRpc}. */
@RunWith(JUnit4.class)
@RunWith(TestParameterInjector.class)
public class RemoteRpcTest {

private static final String METHOD_NAME = "methodName";
Expand Down Expand Up @@ -157,48 +162,88 @@ public void testGzip() throws IOException, DatastoreException {
}

@Test
public void testHttpHeaders_expectE2eChecksumHeader() throws IOException {
// Enable E2E-Checksum system env variable
RemoteRpc.setSystemEnvE2EChecksum(true);
public void testE2EChecksum(@TestParameter boolean reqEnabled, @TestParameter boolean respEnabled)
throws IOException, DatastoreException {
RemoteRpc.setSystemEnvE2EChecksum(reqEnabled, respEnabled);
String projectId = "project-id";
MessageLite request =
RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build();
RemoteRpc rpc =
newRemoteRpc(
new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true));
HttpRequest httpRequest =
rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request));
rpc.setHeaders(request, httpRequest, projectId, "");
assertNotNull(
httpRequest.getHeaders().getFirstHeaderStringValue(RemoteRpc.API_FORMAT_VERSION_HEADER));
// Expect to find e2e-checksum header
String header =
httpRequest
.getHeaders()
.getFirstHeaderStringValue(EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER);
assertEquals(9, header.length());

// Always return invalid response checksum to check that it will raise an exception only when
// response checksum verification is enabled.
List<MyHeader> respHeaders =
Collections.singletonList(
new MyHeader(
EndToEndChecksumHandler.HTTP_RESPONSE_CHECKSUM_HEADER, "invalid_checksum"));

Set<MyHeader> expectedRequestHeaders = new HashSet<>();
expectedRequestHeaders.add(MyHeader.AnyValue(RemoteRpc.API_FORMAT_VERSION_HEADER));

if (reqEnabled) {
expectedRequestHeaders.add(
new MyHeader(
EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER,
EndToEndChecksumHandler.computeChecksum(request.toByteArray())));
} else {
expectedRequestHeaders.add(
MyHeader.AnyValue(EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER).mustNotExist());
}

InjectedTestValues testVals =
new InjectedTestValues(
gzip(newBeginTransactionResponse()),
new byte[1],
true,
respHeaders,
expectedRequestHeaders);
RemoteRpc rpc = newRemoteRpc(testVals);

InputStream stream = rpc.call("someMethod", request, projectId, "");
byte[] buf = new byte[1000];
if (respEnabled) {
// Must throw an IOException when verifying response checksum because we provided an invalid
// checksum in the response header.
assertThrows(
IOException.class,
() -> {
while (stream.read(buf, 0, 1000) != -1) {
// Do nothing with the bytes read.
}
});
} else {
// Must not raise an exception even with invalid response checksum because we did not enable
// response checksum verification.
while (stream.read(buf, 0, 1000) != -1) {
// Do nothing with the bytes read.
}
}
}

@Test
public void testHttpHeaders_doNotExpectE2eChecksumHeader() throws IOException {
// disable E2E-Checksum system env variable
RemoteRpc.setSystemEnvE2EChecksum(false);
public void testE2EChecksum_validResponseChecksum() throws IOException, DatastoreException {
RemoteRpc.setSystemEnvE2EChecksum(false, true);
String projectId = "project-id";
MessageLite request =
RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build();
RemoteRpc rpc =
newRemoteRpc(
new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true));
HttpRequest httpRequest =
rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request));
rpc.setHeaders(request, httpRequest, projectId, "");
assertNotNull(
httpRequest.getHeaders().getFirstHeaderStringValue(RemoteRpc.API_FORMAT_VERSION_HEADER));
// Do not expect to find e2e-checksum header
assertNull(
httpRequest
.getHeaders()
.getFirstHeaderStringValue(EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER));

BeginTransactionResponse response = newBeginTransactionResponse();

List<MyHeader> respHeaders =
Collections.singletonList(
new MyHeader(
EndToEndChecksumHandler.HTTP_RESPONSE_CHECKSUM_HEADER,
EndToEndChecksumHandler.computeChecksum(response.toByteArray())));

InjectedTestValues testVals =
new InjectedTestValues(gzip(response), new byte[1], true, respHeaders);
RemoteRpc rpc = newRemoteRpc(testVals);

InputStream stream = rpc.call("someMethod", request, projectId, "");
byte[] buf = new byte[1000];
// Must not raise an exception.
while (stream.read(buf, 0, 1000) != -1) {
// Do nothing with the bytes read.
}
}

@Test
Expand Down Expand Up @@ -258,12 +303,38 @@ private static class InjectedTestValues {
private final InputStream inputStream;
private final int contentLength;
private final boolean isGzip;
private final List<MyHeader> responseHeaders;
private final Set<MyHeader> expectedRequestHeaders;

public InjectedTestValues(byte[] messageBytes, byte[] additionalBytes, boolean isGzip) {
this(
messageBytes,
additionalBytes,
isGzip,
new ArrayList<MyHeader>(),
new HashSet<MyHeader>());
}

public InjectedTestValues(
byte[] messageBytes,
byte[] additionalBytes,
boolean isGzip,
List<MyHeader> responseHeaders) {
this(messageBytes, additionalBytes, isGzip, responseHeaders, new HashSet<MyHeader>());
}

public InjectedTestValues(
byte[] messageBytes,
byte[] additionalBytes,
boolean isGzip,
List<MyHeader> responseHeaders,
Set<MyHeader> expectedRequestHeaders) {
byte[] allBytes = concat(messageBytes, additionalBytes);
this.inputStream = new ByteArrayInputStream(allBytes);
this.contentLength = allBytes.length;
this.isGzip = isGzip;
this.responseHeaders = responseHeaders;
this.expectedRequestHeaders = expectedRequestHeaders;
}

private static byte[] concat(byte[] a, byte[] b) {
Expand All @@ -289,24 +360,103 @@ protected LowLevelHttpRequest buildRequest(String method, String url) throws IOE
}
}

private static class MyHeader {
private final String key;
private final String value;
private final boolean ignoreValue;
private boolean mustExist;

public static MyHeader AnyValue(String key) {
nimf marked this conversation as resolved.
Show resolved Hide resolved
return new MyHeader(key, "", true);
}

public MyHeader(String key, String value) {
this(key, value, false);
}

private MyHeader(String key, String value, boolean ignoreValue) {
this.key = key.toLowerCase();
this.value = value;
this.ignoreValue = ignoreValue;
this.mustExist = true;
}

public MyHeader mustNotExist() {
mustExist = false;
return this;
}

public boolean matches(MyHeader h) {
return key.equals(h.key) && (h.ignoreValue || ignoreValue || value.equals(h.value));
}

public String toString() {
String mustExistString = mustExist ? "" : "must not exist: ";
if (ignoreValue) {
return String.format("%s\"%s\": ANY", mustExistString, key);
}
return String.format("%s\"%s\": \"%s\"", mustExistString, key, value);
}
}

/**
* {@link LowLevelHttpRequest} that allows injection of the returned {@link LowLevelHttpResponse}.
*/
private static class MyLowLevelHttpRequest extends LowLevelHttpRequest {

private final InjectedTestValues injectedTestValues;

private final List<MyHeader> requestHeaders = new ArrayList<>();

public MyLowLevelHttpRequest(InjectedTestValues injectedTestValues) {
this.injectedTestValues = injectedTestValues;
}

@Override
public void addHeader(String name, String value) throws IOException {
// Do nothing.
requestHeaders.add(new MyHeader(name, value));
}

private void assertHeaders() {
if (injectedTestValues.expectedRequestHeaders.isEmpty()) {
return;
}

Set<MyHeader> mustExist = new HashSet<>();
List<MyHeader> mustNotExist = new ArrayList<>();
for (MyHeader header : injectedTestValues.expectedRequestHeaders) {
if (header.mustExist) {
mustExist.add(header);
} else {
mustNotExist.add(header);
}
}

for (MyHeader h : requestHeaders) {
mustExist.removeIf(expected -> expected.matches(h));
}

if (!mustExist.isEmpty()) {
throw new RuntimeException(
"These request headers were expected but missing:\n"
+ mustExist
+ "\nThese headers were seen:\n"
+ requestHeaders);
}

for (MyHeader notExpected : mustNotExist) {
for (MyHeader h : requestHeaders) {
if (h.matches(notExpected)) {
throw new RuntimeException(
"Expected header " + notExpected.toString() + " but found: " + h.toString());
}
}
}
}

@Override
public LowLevelHttpResponse execute() throws IOException {
assertHeaders();
return new MyLowLevelHttpResponse(injectedTestValues);
}
}
Expand Down Expand Up @@ -357,17 +507,17 @@ public String getReasonPhrase() throws IOException {

@Override
public int getHeaderCount() throws IOException {
return 0;
return injectedTestValues.responseHeaders.size();
}

@Override
public String getHeaderName(int index) throws IOException {
return null;
return injectedTestValues.responseHeaders.get(index).key;
}

@Override
public String getHeaderValue(int index) throws IOException {
return null;
return injectedTestValues.responseHeaders.get(index).value;
}
}
}