Skip to content

Commit

Permalink
Adjust resolver code to make it easier to test in the future (#13445)
Browse files Browse the repository at this point in the history
Motivation:

DnsQueryContext had a dependency on DnsNameResolver which made it hard
to write tests for things like DnsQueryContextManager and
DnsQueryContext in isolation.

Modifications:

- Remove dependecy on DnsNameResolver
- Cleanup code

Result:

Easier to write self-contained tests and cleanup
  • Loading branch information
normanmaurer committed Jun 14, 2023
1 parent c59a11c commit 22bf43b
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,25 @@
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;

import java.net.InetSocketAddress;

final class DatagramDnsQueryContext extends DnsQueryContext {

DatagramDnsQueryContext(DnsNameResolver parent, InetSocketAddress nameServerAddr, DnsQuestion question,
DnsRecord[] additionals,
DatagramDnsQueryContext(Channel channel, Future<? extends Channel> channelReadyFuture,
DnsQueryContextManager queryContextManager,
int maxPayLoadSize, boolean recursionDesired,
DnsQuestion question, DnsRecord[] additionals,
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
super(parent, nameServerAddr, question, additionals, promise);
super(channel, channelReadyFuture, queryContextManager, maxPayLoadSize, recursionDesired,
question, additionals, promise);
}

@Override
protected DnsQuery newQuery(int id) {
return new DatagramDnsQuery(null, nameServerAddr(), id);
}

@Override
protected Channel channel() {
return parent().ch;
protected DnsQuery newQuery(int id, InetSocketAddress nameServerAddr) {
return new DatagramDnsQuery(null, nameServerAddr, id);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.socket.DatagramChannel;
Expand Down Expand Up @@ -1262,7 +1261,7 @@ private InetSocketAddress nextNameServerAddress() {
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
InetSocketAddress nameServerAddr, DnsQuestion question) {

return query0(nameServerAddr, question, EMPTY_ADDITIONALS, true, ch.newPromise(),
return query0(nameServerAddr, question, NoopDnsQueryLifecycleObserver.INSTANCE, EMPTY_ADDITIONALS, true,
ch.eventLoop().<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
}

Expand All @@ -1272,8 +1271,9 @@ public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
InetSocketAddress nameServerAddr, DnsQuestion question, Iterable<DnsRecord> additionals) {

return query0(nameServerAddr, question, toArray(additionals, false), true, ch.newPromise(),
ch.eventLoop().<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
return query0(nameServerAddr, question, NoopDnsQueryLifecycleObserver.INSTANCE,
toArray(additionals, false), true,
ch.eventLoop().<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
}

/**
Expand All @@ -1283,7 +1283,8 @@ public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
InetSocketAddress nameServerAddr, DnsQuestion question,
Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {

return query0(nameServerAddr, question, EMPTY_ADDITIONALS, true, ch.newPromise(), promise);
return query0(nameServerAddr, question, NoopDnsQueryLifecycleObserver.INSTANCE,
EMPTY_ADDITIONALS, true, promise);
}

/**
Expand All @@ -1294,7 +1295,8 @@ public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
Iterable<DnsRecord> additionals,
Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {

return query0(nameServerAddr, question, toArray(additionals, false), true, ch.newPromise(), promise);
return query0(nameServerAddr, question, NoopDnsQueryLifecycleObserver.INSTANCE,
toArray(additionals, false), true, promise);
}

/**
Expand All @@ -1321,17 +1323,19 @@ final void flushQueries() {

final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
InetSocketAddress nameServerAddr, DnsQuestion question,
final DnsQueryLifecycleObserver queryLifecycleObserver,
DnsRecord[] additionals,
boolean flush,
ChannelPromise writePromise,
Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {
assert !writePromise.isVoid();

final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> castPromise = cast(
checkNotNull(promise, "promise"));
final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0;
try {
new DatagramDnsQueryContext(this, nameServerAddr, question, additionals, castPromise)
.query(flush, writePromise);
DnsQueryContext queryContext = new DatagramDnsQueryContext(ch, channelReadyPromise, queryContextManager,
payloadSize, isRecursionDesired(), question, additionals, castPromise);
ChannelFuture future = queryContext.writeQuery(nameServerAddr, queryTimeoutMillis(), flush);
queryLifecycleObserver.queryWritten(nameServerAddr, future);
return castPromise;
} catch (Exception e) {
return castPromise.setFailure(e);
Expand All @@ -1357,21 +1361,22 @@ private final class DnsResponseHandler extends ChannelInboundHandlerAdapter {

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
final Channel qCh = ctx.channel();
final DatagramDnsResponse res = (DatagramDnsResponse) msg;
final int queryId = res.id();
logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res);
logger.debug("{} RECEIVED: UDP [{}: {}], {}", qCh, queryId, res.sender(), res);

final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId);
if (qCtx == null) {
logger.debug("{} Received a DNS response with an unknown ID: UDP [{}: {}]",
ch, queryId, res.sender());
qCh, queryId, res.sender());
res.release();
return;
}

// Check if the response was truncated and if we can fallback to TCP to retry.
if (!res.isTruncated() || socketChannelFactory == null) {
qCtx.finish(res);
qCtx.finishSuccess(qCh, res);
return;
}

Expand All @@ -1388,49 +1393,52 @@ public void operationComplete(ChannelFuture future) {
ch, queryId, res.sender(), future.cause());

// TCP fallback failed, just use the truncated response.
qCtx.finish(res);
qCtx.finishSuccess(qCh, res);
return;
}
final Channel channel = future.channel();
final Channel tcpCh = future.channel();

Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
channel.eventLoop().newPromise();
final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(DnsNameResolver.this, channel,
(InetSocketAddress) channel.remoteAddress(), qCtx.question(),
tcpCh.eventLoop().newPromise();
final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0;
final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh, channelReadyPromise,
queryContextManager, payloadSize, isRecursionDesired(), qCtx.question(),
EMPTY_ADDITIONALS, promise);

channel.pipeline().addLast(new TcpDnsResponseDecoder());
channel.pipeline().addLast(new ChannelInboundHandlerAdapter() {
tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
Channel channel = ctx.channel();
Channel tcpCh = ctx.channel();
DnsResponse response = (DnsResponse) msg;
int queryId = response.id();

if (logger.isDebugEnabled()) {
logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId,
channel.remoteAddress(), response);
logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
tcpCh.remoteAddress(), response);
}

DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId);
if (foundCtx == tcpCtx) {
tcpCtx.finish(new AddressedEnvelopeAdapter(
tcpCtx.finishSuccess(tcpCh, new AddressedEnvelopeAdapter(
(InetSocketAddress) ctx.channel().remoteAddress(),
(InetSocketAddress) ctx.channel().localAddress(),
response));
} else {
response.release();
tcpCtx.tryFailure("Received TCP DNS response with unexpected ID", null, false);
tcpCtx.finishFailure((InetSocketAddress) tcpCh.remoteAddress(),
"Received TCP DNS response with unexpected ID", null, false);
if (logger.isDebugEnabled()) {
logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
channel, queryId, channel.remoteAddress());
tcpCh, queryId, tcpCh.remoteAddress());
}
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (tcpCtx.tryFailure("TCP fallback error", cause, false) && logger.isDebugEnabled()) {
if (tcpCtx.finishFailure((InetSocketAddress) ctx.channel().remoteAddress(),
"TCP fallback error", cause, false) && logger.isDebugEnabled()) {
logger.debug("{} Error during processing response: TCP [{}: {}]",
ctx.channel(), queryId,
ctx.channel().remoteAddress(), cause);
Expand All @@ -1443,18 +1451,18 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
@Override
public void operationComplete(
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
channel.close();

if (future.isSuccess()) {
qCtx.finish(future.getNow());
qCtx.finishSuccess(qCh, future.getNow());
res.release();
} else {
// TCP fallback failed, just use the truncated response.
qCtx.finish(res);
qCtx.finishSuccess(qCh, res);
}
tcpCh.close();
}
});
tcpCtx.query(true, future.channel().newPromise());
tcpCtx.writeQuery((InetSocketAddress) tcpCh.remoteAddress(), queryTimeoutMillis(),
true);
}
});
}
Expand Down

0 comments on commit 22bf43b

Please sign in to comment.