Skip to content

Commit

Permalink
Revert "Send cluster name and discovery node in handshake (#48916)" (#…
Browse files Browse the repository at this point in the history
…50944)

This reverts commit 0645ee8.
  • Loading branch information
Tim-Brooks committed Jan 14, 2020
1 parent 4974f56 commit d8510be
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Booleans;
import org.elasticsearch.common.Strings;
Expand Down Expand Up @@ -152,7 +151,7 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P
BigArrays bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS);

this.outboundHandler = new OutboundHandler(nodeName, version, features, threadPool, bigArrays);
this.handshaker = new TransportHandshaker(ClusterName.CLUSTER_NAME_SETTING.get(settings), version, threadPool,
this.handshaker = new TransportHandshaker(version, threadPool,
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, false, true),
Expand All @@ -168,11 +167,6 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P
protected void doStart() {
}

@Override
public void setLocalNode(DiscoveryNode localNode) {
handshaker.setLocalNode(localNode);
}

@Override
public synchronized void setMessageListener(TransportMessageListener listener) {
outboundHandler.setMessageListener(listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ public interface Transport extends LifecycleComponent {

void setMessageListener(TransportMessageListener listener);

void setLocalNode(DiscoveryNode localNode);

/**
* The address the transport is bound on.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
Expand All @@ -47,26 +46,19 @@ final class TransportHandshaker {
private final ConcurrentMap<Long, HandshakeResponseHandler> pendingHandshakes = new ConcurrentHashMap<>();
private final CounterMetric numHandshakes = new CounterMetric();

private final ClusterName clusterName;
private final Version version;
private final ThreadPool threadPool;
private final HandshakeRequestSender handshakeRequestSender;
private final HandshakeResponseSender handshakeResponseSender;
private volatile DiscoveryNode localNode;

TransportHandshaker(ClusterName clusterName, Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
HandshakeResponseSender handshakeResponseSender) {
this.clusterName = clusterName;
this.version = version;
this.threadPool = threadPool;
this.handshakeRequestSender = handshakeRequestSender;
this.handshakeResponseSender = handshakeResponseSender;
}

void setLocalNode(DiscoveryNode localNode) {
this.localNode = localNode;
}

void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) {
numHandshakes.inc();
final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, version, listener);
Expand Down Expand Up @@ -97,17 +89,14 @@ void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeV
}

void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId, StreamInput stream) throws IOException {
// The TransportService blocks incoming requests until this has been set.
assert localNode != null : "Local node must be set before handshake is handled";

// Must read the handshake request to exhaust the stream
HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
final int nextByte = stream.read();
if (nextByte != -1) {
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
}
HandshakeResponse response = new HandshakeResponse(handshakeRequest.version, this.version, this.clusterName, this.localNode);
HandshakeResponse response = new HandshakeResponse(this.version);
handshakeResponseSender.sendResponse(version, features, channel, response, requestId);
}

Expand Down Expand Up @@ -138,13 +127,13 @@ private HandshakeResponseHandler(long requestId, Version currentVersion, ActionL

@Override
public HandshakeResponse read(StreamInput in) throws IOException {
return new HandshakeResponse(this.currentVersion, in);
return new HandshakeResponse(in);
}

@Override
public void handleResponse(HandshakeResponse response) {
if (isDone.compareAndSet(false, true)) {
Version version = response.version;
Version version = response.responseVersion;
if (currentVersion.isCompatible(version) == false) {
listener.onFailure(new IllegalStateException("Received message from unsupported version: [" + version
+ "] minimal compatible version is: [" + currentVersion.minimumCompatibilityVersion() + "]"));
Expand Down Expand Up @@ -212,58 +201,25 @@ public void writeTo(StreamOutput streamOutput) throws IOException {

static final class HandshakeResponse extends TransportResponse {

private final Version requestVersion;
private final Version version;
private final ClusterName clusterName;
private final DiscoveryNode discoveryNode;

HandshakeResponse(Version requestVersion, Version responseVersion, ClusterName clusterName, DiscoveryNode discoveryNode) {
this.requestVersion = requestVersion;
this.version = responseVersion;
this.clusterName = clusterName;
this.discoveryNode = discoveryNode;
private final Version responseVersion;

HandshakeResponse(Version responseVersion) {
this.responseVersion = responseVersion;
}

private HandshakeResponse(Version requestVersion, StreamInput in) throws IOException {
private HandshakeResponse(StreamInput in) throws IOException {
super(in);
this.requestVersion = requestVersion;
version = Version.readVersion(in);
// During the handshake process, nodes set their stream version to the minimum compatibility
// version they support. When deserializing the response, we use the version the other node
// told us that it actually is in the handshake response (`version`).
if (requestVersion.onOrAfter(Version.V_7_6_0) && version.onOrAfter(Version.V_7_6_0)) {
clusterName = new ClusterName(in);
discoveryNode = new DiscoveryNode(in);
} else {
clusterName = null;
discoveryNode = null;
}
responseVersion = Version.readVersion(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
assert version != null;
Version.writeVersion(version, out);
// During the handshake process, nodes set their stream version to the minimum compatibility
// version they support. When deciding what response to send, we use the version the other node
// told us that it actually is in the handshake request (`requestVersion`). If it did not tell
// us a `requestVersion`, it is at least a pre-7.6 node.
if (requestVersion != null && requestVersion.onOrAfter(Version.V_7_6_0) && version.onOrAfter(Version.V_7_6_0)) {
clusterName.writeTo(out);
discoveryNode.writeTo(out);
}
}

Version getVersion() {
return version;
}

ClusterName getClusterName() {
return clusterName;
assert responseVersion != null;
Version.writeVersion(responseVersion, out);
}

DiscoveryNode getDiscoveryNode() {
return discoveryNode;
Version getResponseVersion() {
return responseVersion;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ protected void doStart() {
}
}
localNode = localNodeFactory.apply(transport.boundAddress());
transport.setLocalNode(localNode);

if (connectToRemoteCluster) {
// here we start to connect to the remote clusters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,4 @@ public RequestHandlerRegistry getRequestHandler(String action) {
public void setMessageListener(TransportMessageListener listener) {
this.listener = listener;
}

@Override
public void setLocalNode(DiscoveryNode localNode) {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,6 @@ public RequestHandlerRegistry getRequestHandler(String action) {
public void setMessageListener(TransportMessageListener listener) {
}

@Override
public void setLocalNode(DiscoveryNode localNode) {
}

@Override
public BoundTransportAddress boundAddress() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
Expand Down Expand Up @@ -59,8 +58,9 @@ public void setUp() throws Exception {
channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
InboundMessage.Reader reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext());
TransportHandshaker handshaker = new TransportHandshaker(new ClusterName("cluster-name"), version, threadPool, (n, c, r, v) -> {
}, (v, f, c, r, r_id) -> {});
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {
}, (v, f, c, r, r_id) -> {
});
TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
OutboundHandler outboundHandler =
new OutboundHandler("node", version, new String[0], threadPool, BigArrays.NON_RECYCLING_INSTANCE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@

import org.elasticsearch.Version;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.mockito.ArgumentCaptor;

import java.io.IOException;
import java.net.InetAddress;
import java.util.Collections;
import java.util.concurrent.TimeUnit;

Expand All @@ -46,29 +42,23 @@
public class TransportHandshakerTests extends ESTestCase {

private TransportHandshaker handshaker;
private DiscoveryNode remoteNode;
private DiscoveryNode node;
private TcpChannel channel;
private TestThreadPool threadPool;
private TransportHandshaker.HandshakeRequestSender requestSender;
private TransportHandshaker.HandshakeResponseSender responseSender;
private ClusterName clusterName;
private DiscoveryNode localNode;

@Override
@SuppressForbidden(reason = "Allow accessing localhost")
public void setUp() throws Exception {
super.setUp();
String nodeId = "remote-node-id";
String nodeId = "node-id";
channel = mock(TcpChannel.class);
requestSender = mock(TransportHandshaker.HandshakeRequestSender.class);
responseSender = mock(TransportHandshaker.HandshakeResponseSender.class);
remoteNode = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(),
Collections.emptyMap(), Collections.emptySet(), Version.CURRENT);
node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(),
Collections.emptySet(), Version.CURRENT);
threadPool = new TestThreadPool("thread-poll");
clusterName = new ClusterName("cluster");
localNode = new DiscoveryNode("local-node-id", new TransportAddress(InetAddress.getLocalHost(), 0), Version.CURRENT);
handshaker = new TransportHandshaker(clusterName, Version.CURRENT, threadPool, requestSender, responseSender);
handshaker.setLocalNode(localNode);
handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender);
}

@Override
Expand All @@ -80,9 +70,9 @@ public void tearDown() throws Exception {
public void testHandshakeRequestAndResponse() throws IOException {
PlainActionFuture<Version> versionFuture = PlainActionFuture.newFuture();
long reqId = randomLongBetween(1, 10);
handshaker.sendHandshake(reqId, remoteNode, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);
handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);

verify(requestSender).sendRequest(remoteNode, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());

assertFalse(versionFuture.isDone());

Expand All @@ -98,39 +88,18 @@ public void testHandshakeRequestAndResponse() throws IOException {
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
eq(reqId));


TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
handler.handleResponse((TransportHandshaker.HandshakeResponse) responseCaptor.getValue());

assertTrue(versionFuture.isDone());
assertEquals(Version.CURRENT, versionFuture.actionGet());
TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue();
assertEquals(Version.CURRENT, response.getVersion());
assertEquals(clusterName, response.getClusterName());
assertEquals(localNode, response.getDiscoveryNode());
}

public void testHandshakeRequestAndResponsePreV7_6() throws IOException {
PlainActionFuture<Version> versionFuture = PlainActionFuture.newFuture();
long reqId = randomLongBetween(1, 10);
handshaker.sendHandshake(reqId, remoteNode, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);

TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
try (BytesStreamOutput out = new BytesStreamOutput()) {
new TransportHandshaker.HandshakeResponse(Version.V_7_5_0, Version.V_7_5_0, clusterName, localNode).writeTo(out);
TransportHandshaker.HandshakeResponse response = handler.read(out.bytes().streamInput());
assertEquals(Version.V_7_5_0, response.getVersion());
// When writing or reading a 6.6 stream, these are not serialized
assertNull(response.getDiscoveryNode());
assertNull(response.getClusterName());
}
}

public void testHandshakeRequestFutureVersionsCompatibility() throws IOException {
long reqId = randomLongBetween(1, 10);
handshaker.sendHandshake(reqId, remoteNode, channel, new TimeValue(30, TimeUnit.SECONDS), PlainActionFuture.newFuture());
handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), PlainActionFuture.newFuture());

verify(requestSender).sendRequest(remoteNode, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());

TcpChannel mockChannel = mock(TcpChannel.class);
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
Expand Down Expand Up @@ -162,15 +131,15 @@ public void testHandshakeRequestFutureVersionsCompatibility() throws IOException

TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue();

assertEquals(Version.CURRENT, response.getVersion());
assertEquals(Version.CURRENT, response.getResponseVersion());
}

public void testHandshakeError() throws IOException {
PlainActionFuture<Version> versionFuture = PlainActionFuture.newFuture();
long reqId = randomLongBetween(1, 10);
handshaker.sendHandshake(reqId, remoteNode, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);
handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);

verify(requestSender).sendRequest(remoteNode, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());

assertFalse(versionFuture.isDone());

Expand All @@ -186,9 +155,9 @@ public void testSendRequestThrowsException() throws IOException {
PlainActionFuture<Version> versionFuture = PlainActionFuture.newFuture();
long reqId = randomLongBetween(1, 10);
Version compatibilityVersion = Version.CURRENT.minimumCompatibilityVersion();
doThrow(new IOException("boom")).when(requestSender).sendRequest(remoteNode, channel, reqId, compatibilityVersion);
doThrow(new IOException("boom")).when(requestSender).sendRequest(node, channel, reqId, compatibilityVersion);

handshaker.sendHandshake(reqId, remoteNode, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);
handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);

assertTrue(versionFuture.isDone());
ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet);
Expand All @@ -199,9 +168,9 @@ public void testSendRequestThrowsException() throws IOException {
public void testHandshakeTimeout() throws IOException {
PlainActionFuture<Version> versionFuture = PlainActionFuture.newFuture();
long reqId = randomLongBetween(1, 10);
handshaker.sendHandshake(reqId, remoteNode, channel, new TimeValue(100, TimeUnit.MILLISECONDS), versionFuture);
handshaker.sendHandshake(reqId, node, channel, new TimeValue(100, TimeUnit.MILLISECONDS), versionFuture);

verify(requestSender).sendRequest(remoteNode, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());

ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet);
assertThat(cte.getMessage(), containsString("handshake_timeout"));
Expand Down

0 comments on commit d8510be

Please sign in to comment.