Skip to content

Commit

Permalink
Simplify DnsQueryContext usage again (#13450)
Browse files Browse the repository at this point in the history
Motivation:

22bf43b made some changes to simplify
DnsQueryContext but did introduce one change which might be a bit
error-prone.

Modifications:

Don't pass the nameserveraddress and the channel to the finish* and
writeQuery methods.

Result:

Always log / use the correct nameserveraddress / channel
  • Loading branch information
normanmaurer committed Jun 15, 2023
1 parent 22bf43b commit bf8e779
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
final class DatagramDnsQueryContext extends DnsQueryContext {

DatagramDnsQueryContext(Channel channel, Future<? extends Channel> channelReadyFuture,
InetSocketAddress nameServerAddr,
DnsQueryContextManager queryContextManager,
int maxPayLoadSize, boolean recursionDesired,
DnsQuestion question, DnsRecord[] additionals,
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
super(channel, channelReadyFuture, queryContextManager, maxPayLoadSize, recursionDesired,
super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired,
question, additionals, promise);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1332,9 +1332,9 @@ final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
checkNotNull(promise, "promise"));
final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0;
try {
DnsQueryContext queryContext = new DatagramDnsQueryContext(ch, channelReadyPromise, queryContextManager,
payloadSize, isRecursionDesired(), question, additionals, castPromise);
ChannelFuture future = queryContext.writeQuery(nameServerAddr, queryTimeoutMillis(), flush);
DnsQueryContext queryContext = new DatagramDnsQueryContext(ch, channelReadyPromise, nameServerAddr,
queryContextManager, payloadSize, isRecursionDesired(), question, additionals, castPromise);
ChannelFuture future = queryContext.writeQuery(queryTimeoutMillis(), flush);
queryLifecycleObserver.queryWritten(nameServerAddr, future);
return castPromise;
} catch (Exception e) {
Expand Down Expand Up @@ -1376,7 +1376,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {

// Check if the response was truncated and if we can fallback to TCP to retry.
if (!res.isTruncated() || socketChannelFactory == null) {
qCtx.finishSuccess(qCh, res);
qCtx.finishSuccess(res);
return;
}

Expand All @@ -1393,7 +1393,7 @@ public void operationComplete(ChannelFuture future) {
ch, queryId, res.sender(), future.cause());

// TCP fallback failed, just use the truncated response.
qCtx.finishSuccess(qCh, res);
qCtx.finishSuccess(res);
return;
}
final Channel tcpCh = future.channel();
Expand All @@ -1402,8 +1402,8 @@ public void operationComplete(ChannelFuture future) {
tcpCh.eventLoop().newPromise();
final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0;
final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh, channelReadyPromise,
queryContextManager, payloadSize, isRecursionDesired(), qCtx.question(),
EMPTY_ADDITIONALS, promise);
(InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, payloadSize,
isRecursionDesired(), qCtx.question(), EMPTY_ADDITIONALS, promise);

tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
Expand All @@ -1420,14 +1420,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {

DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId);
if (foundCtx == tcpCtx) {
tcpCtx.finishSuccess(tcpCh, new AddressedEnvelopeAdapter(
tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
(InetSocketAddress) ctx.channel().remoteAddress(),
(InetSocketAddress) ctx.channel().localAddress(),
response));
} else {
response.release();
tcpCtx.finishFailure((InetSocketAddress) tcpCh.remoteAddress(),
"Received TCP DNS response with unexpected ID", null, false);
tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
if (logger.isDebugEnabled()) {
logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
tcpCh, queryId, tcpCh.remoteAddress());
Expand All @@ -1437,7 +1436,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (tcpCtx.finishFailure((InetSocketAddress) ctx.channel().remoteAddress(),
if (tcpCtx.finishFailure(
"TCP fallback error", cause, false) && logger.isDebugEnabled()) {
logger.debug("{} Error during processing response: TCP [{}: {}]",
ctx.channel(), queryId,
Expand All @@ -1452,17 +1451,16 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
public void operationComplete(
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
if (future.isSuccess()) {
qCtx.finishSuccess(qCh, future.getNow());
qCtx.finishSuccess(future.getNow());
res.release();
} else {
// TCP fallback failed, just use the truncated response.
qCtx.finishSuccess(qCh, res);
qCtx.finishSuccess(res);
}
tcpCh.close();
}
});
tcpCtx.writeQuery((InetSocketAddress) tcpCh.remoteAddress(), queryTimeoutMillis(),
true);
tcpCtx.writeQuery(queryTimeoutMillis(), true);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ abstract class DnsQueryContext {

private final Future<? extends Channel> channelReadyFuture;
private final Channel channel;
private final InetSocketAddress nameServerAddr;
private final DnsQueryContextManager queryContextManager;
private final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise;

Expand All @@ -60,6 +61,7 @@ abstract class DnsQueryContext {

DnsQueryContext(Channel channel,
Future<? extends Channel> channelReadyFuture,
InetSocketAddress nameServerAddr,
DnsQueryContextManager queryContextManager,
int maxPayLoadSize,
boolean recursionDesired,
Expand All @@ -69,6 +71,7 @@ abstract class DnsQueryContext {
this.channel = checkNotNull(channel, "channel");
this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager");
this.channelReadyFuture = checkNotNull(channelReadyFuture, "channelReadyFuture");
this.nameServerAddr = checkNotNull(nameServerAddr, "nameServerAddr");
this.question = checkNotNull(question, "question");
this.additionals = checkNotNull(additionals, "additionals");
this.promise = checkNotNull(promise, "promise");
Expand Down Expand Up @@ -126,15 +129,13 @@ final DnsQuestion question() {
/**
* Write the query and return the {@link ChannelFuture} that is completed once the write completes.
*
* @param nameServerAddr the nameserver to write the query to.
* @param queryTimeoutMillis the timeout after which the query is considered timeout and the original
* {@link Promise} will be failed.
* @param flush {@code true} if {@link Channel#flush()} should be called as well.
* @return
* @return the {@link ChannelFuture} that is notified once once the write completes.
*/
final ChannelFuture writeQuery(final InetSocketAddress nameServerAddr, long queryTimeoutMillis,
boolean flush) {
assert id == -1 : this.getClass().getSimpleName() + ".writeQuery(...) + can only be executed once.";
final ChannelFuture writeQuery(long queryTimeoutMillis, boolean flush) {
assert id == -1 : this.getClass().getSimpleName() + ".writeQuery(...) can only be executed once.";
id = queryContextManager.add(nameServerAddr, this);

// Ensure we remove the id from the QueryContextManager once the query completes.
Expand Down Expand Up @@ -223,36 +224,35 @@ private void writeQuery(final InetSocketAddress nameServerAddr, final DnsQuery q
final ChannelFuture writeFuture = flush ? channel.writeAndFlush(query, promise) :
channel.write(query, promise);
if (writeFuture.isDone()) {
onQueryWriteCompletion(nameServerAddr, queryTimeoutMillis, writeFuture);
onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
} else {
writeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
onQueryWriteCompletion(nameServerAddr, queryTimeoutMillis, writeFuture);
onQueryWriteCompletion(queryTimeoutMillis, writeFuture);
}
});
}
}

private void onQueryWriteCompletion(final InetSocketAddress nameServerAddr, final long queryTimeoutMillis,
private void onQueryWriteCompletion(final long queryTimeoutMillis,
ChannelFuture writeFuture) {
if (!writeFuture.isSuccess()) {
finishFailure(nameServerAddr,
"failed to send a query '" + id + "' via " + protocol(), writeFuture.cause(), false);
finishFailure("failed to send a query '" + id + "' via " + protocol(), writeFuture.cause(), false);
return;
}

// Schedule a query timeout task if necessary.
if (queryTimeoutMillis > 0) {
timeoutFuture = writeFuture.channel().eventLoop().schedule(new Runnable() {
timeoutFuture = channel.eventLoop().schedule(new Runnable() {
@Override
public void run() {
if (promise.isDone()) {
// Received a response before the query times out.
return;
}

finishFailure(nameServerAddr, "query '" + id + "' via " + protocol() + " timed out after " +
finishFailure("query '" + id + "' via " + protocol() + " timed out after " +
queryTimeoutMillis + " milliseconds", null, true);
}
}, queryTimeoutMillis, TimeUnit.MILLISECONDS);
Expand All @@ -263,7 +263,7 @@ public void run() {
* Notifies the original {@link Promise} that the response for the query was received.
* This method takes ownership of passed {@link AddressedEnvelope}.
*/
void finishSuccess(Channel channel, AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
final DnsResponse res = envelope.content();
if (res.count(DnsSection.QUESTION) != 1) {
logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
Expand All @@ -285,7 +285,7 @@ private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAd
/**
* Notifies the original {@link Promise} that the query completes because of an failure.
*/
final boolean finishFailure(InetSocketAddress nameServerAddr, String message, Throwable cause, boolean timeout) {
final boolean finishFailure(String message, Throwable cause, boolean timeout) {
if (promise.isDone()) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
final class TcpDnsQueryContext extends DnsQueryContext {

TcpDnsQueryContext(Channel channel, Future<? extends Channel> channelReadyFuture,
InetSocketAddress nameServerAddr,
DnsQueryContextManager queryContextManager,
int maxPayLoadSize, boolean recursionDesired,
DnsQuestion question, DnsRecord[] additionals,
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
super(channel, channelReadyFuture, queryContextManager, maxPayLoadSize, recursionDesired,
super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired,
question, additionals, promise);
}

Expand Down

0 comments on commit bf8e779

Please sign in to comment.