Skip to content

Commit

Permalink
Reuse Bootstrap to create SocketChannel for DNS TCP (#13453)
Browse files Browse the repository at this point in the history
Motivation:

There is no need to create a new Bootstrap each time we open a new
connection.

Modifications:

- Create the Bootstrap only once
- Tighten up some visibility

Result:

Produce less GC
  • Loading branch information
normanmaurer committed Jun 16, 2023
1 parent dd2c2bb commit 1bb825b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Collections;
import java.util.List;

import io.netty.channel.Channel;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
Expand All @@ -33,25 +34,26 @@ final class DnsAddressResolveContext extends DnsResolveContext<InetAddress> {
private final AuthoritativeDnsServerCache authoritativeDnsServerCache;
private final boolean completeEarlyIfPossible;

DnsAddressResolveContext(DnsNameResolver parent, Promise<?> originalPromise,
DnsAddressResolveContext(DnsNameResolver parent, Channel channel, Promise<?> originalPromise,
String hostname, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs, int allowedQueries, DnsCache resolveCache,
AuthoritativeDnsServerCache authoritativeDnsServerCache,
boolean completeEarlyIfPossible) {
super(parent, originalPromise, hostname, DnsRecord.CLASS_IN,
super(parent, channel, originalPromise, hostname, DnsRecord.CLASS_IN,
parent.resolveRecordTypes(), additionals, nameServerAddrs, allowedQueries);
this.resolveCache = resolveCache;
this.authoritativeDnsServerCache = authoritativeDnsServerCache;
this.completeEarlyIfPossible = completeEarlyIfPossible;
}

@Override
DnsResolveContext<InetAddress> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
DnsResolveContext<InetAddress> newResolverContext(DnsNameResolver parent, Channel channel,
Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs, int allowedQueries) {
return new DnsAddressResolveContext(parent, originalPromise, hostname, additionals, nameServerAddrs,
return new DnsAddressResolveContext(parent, channel, originalPromise, hostname, additionals, nameServerAddrs,
allowedQueries, resolveCache, authoritativeDnsServerCache, completeEarlyIfPossible);
}

Expand Down Expand Up @@ -80,12 +82,12 @@ boolean isDuplicateAllowed() {
@Override
void cache(String hostname, DnsRecord[] additionals,
DnsRecord result, InetAddress convertedResult) {
resolveCache.cache(hostname, additionals, convertedResult, result.timeToLive(), parent.ch.eventLoop());
resolveCache.cache(hostname, additionals, convertedResult, result.timeToLive(), channel().eventLoop());
}

@Override
void cache(String hostname, DnsRecord[] additionals, UnknownHostException cause) {
resolveCache.cache(hostname, additionals, cause, parent.ch.eventLoop());
resolveCache.cache(hostname, additionals, cause, channel().eventLoop());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ public class DnsNameResolver extends InetNameResolver {
}
WINDOWS_HOST_NAME = hostName;
logger.debug("Windows hostname: {}", WINDOWS_HOST_NAME);
}

static {
String[] searchDomains;
try {
List<String> list = PlatformDependent.isWindows()
Expand Down Expand Up @@ -225,15 +223,15 @@ protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket p
private static final DatagramDnsQueryEncoder DATAGRAM_ENCODER = new DatagramDnsQueryEncoder();
private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();

final Promise<Channel> channelReadyPromise;
final Channel ch;
private final Promise<Channel> channelReadyPromise;
private final Channel ch;

// Comparator that ensures we will try first to use the nameservers that use our preferred address type.
private final Comparator<InetSocketAddress> nameServerComparator;
/**
* Manages the {@link DnsQueryContext}s in progress and their query IDs.
*/
final DnsQueryContextManager queryContextManager = new DnsQueryContextManager();
private final DnsQueryContextManager queryContextManager = new DnsQueryContextManager();

/**
* Cache for {@link #doResolve(String, Promise)} and {@link #doResolveAll(String, Promise)}.
Expand Down Expand Up @@ -268,7 +266,7 @@ protected DnsServerAddressStream initialValue() {
private final boolean decodeIdn;
private final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory;
private final boolean completeOncePreferredResolved;
private final ChannelFactory<? extends SocketChannel> socketChannelFactory;
private final Bootstrap socketBootstrap;

private final int maxNumConsolidation;
private final Map<String, Future<List<InetAddress>>> inflightLookups;
Expand Down Expand Up @@ -451,7 +449,15 @@ public DnsNameResolver(
this.ndots = ndots >= 0 ? ndots : DEFAULT_OPTIONS.ndots();
this.decodeIdn = decodeIdn;
this.completeOncePreferredResolved = completeOncePreferredResolved;
this.socketChannelFactory = socketChannelFactory;
if (socketChannelFactory == null) {
socketBootstrap = null;
} else {
socketBootstrap = new Bootstrap();
socketBootstrap.option(ChannelOption.SO_REUSEADDR, true)
.group(executor())
.channelFactory(socketChannelFactory)
.handler(TCP_ENCODER);
}
switch (this.resolvedAddressTypes) {
case IPV4_ONLY:
supportsAAAARecords = false;
Expand Down Expand Up @@ -919,8 +925,8 @@ private Future<List<DnsRecord>> resolveAll(DnsQuestion question, DnsRecord[] add
// It was not A/AAAA question or there was no entry in /etc/hosts.
final DnsServerAddressStream nameServerAddrs =
dnsServerAddressStreamProvider.nameServerAddressStream(hostname);
new DnsRecordResolveContext(this, promise, question, additionals, nameServerAddrs, maxQueriesPerResolve)
.resolve(promise);
new DnsRecordResolveContext(this, ch, promise, question, additionals,
nameServerAddrs, maxQueriesPerResolve).resolve(promise);
return promise;
}

Expand Down Expand Up @@ -1213,8 +1219,8 @@ private void resolveNow(final String hostname,
final boolean completeEarlyIfPossible) {
final DnsServerAddressStream nameServerAddrs =
dnsServerAddressStreamProvider.nameServerAddressStream(hostname);
DnsAddressResolveContext ctx = new DnsAddressResolveContext(this, originalPromise, hostname, additionals,
nameServerAddrs, maxQueriesPerResolve, resolveCache,
DnsAddressResolveContext ctx = new DnsAddressResolveContext(this, ch, originalPromise, hostname,
additionals, nameServerAddrs, maxQueriesPerResolve, resolveCache,
authoritativeDnsServerCache, completeEarlyIfPossible);
ctx.resolve(promise);
}
Expand Down Expand Up @@ -1380,17 +1386,12 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
}

// Check if the response was truncated and if we can fallback to TCP to retry.
if (!res.isTruncated() || socketChannelFactory == null) {
if (!res.isTruncated() || socketBootstrap == null) {
qCtx.finishSuccess(res);
return;
}

Bootstrap bs = new Bootstrap();
bs.option(ChannelOption.SO_REUSEADDR, true)
.group(executor())
.channelFactory(socketChannelFactory)
.handler(TCP_ENCODER);
bs.connect(res.sender()).addListener(new ChannelFutureListener() {
socketBootstrap.connect(res.sender()).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
if (!future.isSuccess()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.net.UnknownHostException;
import java.util.List;

import io.netty.channel.Channel;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
Expand All @@ -27,29 +28,30 @@

final class DnsRecordResolveContext extends DnsResolveContext<DnsRecord> {

DnsRecordResolveContext(DnsNameResolver parent, Promise<?> originalPromise, DnsQuestion question,
DnsRecordResolveContext(DnsNameResolver parent, Channel channel, Promise<?> originalPromise, DnsQuestion question,
DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs, int allowedQueries) {
this(parent, originalPromise, question.name(), question.dnsClass(),
this(parent, channel, originalPromise, question.name(), question.dnsClass(),
new DnsRecordType[] { question.type() },
additionals, nameServerAddrs, allowedQueries);
}

private DnsRecordResolveContext(DnsNameResolver parent, Promise<?> originalPromise, String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
private DnsRecordResolveContext(DnsNameResolver parent, Channel channel, Promise<?> originalPromise,
String hostname, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs,
int allowedQueries) {
super(parent, originalPromise, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs, allowedQueries);
super(parent, channel, originalPromise, hostname, dnsClass, expectedTypes,
additionals, nameServerAddrs, allowedQueries);
}

@Override
DnsResolveContext<DnsRecord> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
DnsResolveContext<DnsRecord> newResolverContext(DnsNameResolver parent, Channel channel, Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs,
int allowedQueries) {
return new DnsRecordResolveContext(parent, originalPromise, hostname, dnsClass,
return new DnsRecordResolveContext(parent, channel, originalPromise, hostname, dnsClass,
expectedTypes, additionals, nameServerAddrs, allowedQueries);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
Expand Down Expand Up @@ -82,6 +83,7 @@ abstract class DnsResolveContext<T> {
DnsResolveContext.class, "tryToFinishResolve(..)");

final DnsNameResolver parent;
private final Channel channel;
private final Promise<?> originalPromise;
private final DnsServerAddressStream nameServerAddrs;
private final String hostname;
Expand All @@ -98,12 +100,13 @@ abstract class DnsResolveContext<T> {
private boolean triedCNAME;
private boolean completeEarly;

DnsResolveContext(DnsNameResolver parent, Promise<?> originalPromise,
DnsResolveContext(DnsNameResolver parent, Channel channel, Promise<?> originalPromise,
String hostname, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs, int allowedQueries) {
assert expectedTypes.length > 0;

this.parent = parent;
this.channel = channel;
this.originalPromise = originalPromise;
this.hostname = hostname;
this.dnsClass = dnsClass;
Expand Down Expand Up @@ -147,6 +150,13 @@ static DnsResolveContextException newStatic(String message, Class<?> clazz, Stri
}
}

/**
* The {@link Channel} used.
*/
Channel channel() {
return channel;
}

/**
* The {@link DnsCache} to use while resolving.
*/
Expand All @@ -171,7 +181,8 @@ AuthoritativeDnsServerCache authoritativeDnsServerCache() {
/**
* Creates a new context with the given parameters.
*/
abstract DnsResolveContext<T> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
abstract DnsResolveContext<T> newResolverContext(DnsNameResolver parent, Channel channel,
Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
Expand Down Expand Up @@ -281,7 +292,7 @@ public Throwable fillInStackTrace() {
}

void doSearchDomainQuery(String hostname, Promise<List<T>> nextPromise) {
DnsResolveContext<T> nextContext = newResolverContext(parent, originalPromise, hostname, dnsClass,
DnsResolveContext<T> nextContext = newResolverContext(parent, channel, originalPromise, hostname, dnsClass,
expectedTypes, additionals, nameServerAddrs,
parent.maxQueriesPerResolve());
nextContext.internalResolve(hostname, nextPromise);
Expand Down Expand Up @@ -432,7 +443,7 @@ private void query(final DnsServerAddressStream nameServerAddrStream,
return;
}
final Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> queryPromise =
parent.ch.eventLoop().newPromise();
channel.eventLoop().newPromise();

final long queryStartTimeNanos;
final boolean isFeedbackAddressStream;
Expand Down Expand Up @@ -540,7 +551,7 @@ public void operationComplete(final Future<List<InetAddress>> future) {
if (!DnsNameResolver.doResolveAllCached(nameServerName, additionals, resolverPromise, resolveCache,
parent.resolvedInternetProtocolFamiliesUnsafe())) {

new DnsAddressResolveContext(parent, originalPromise, nameServerName, additionals,
new DnsAddressResolveContext(parent, channel, originalPromise, nameServerName, additionals,
parent.newNameServerAddressStream(nameServerName),
// Resolving the unresolved nameserver must be limited by allowedQueries
// so we eventually fail
Expand Down Expand Up @@ -843,7 +854,7 @@ private void onExpectedResponse(
if (logger.isDebugEnabled()) {
logger.debug("{} Ignoring record {} for [{}: {}] as it contains a different name than " +
"the question name [{}]. Cnames: {}, Search domains: {}",
parent.ch, r.toString(), response.id(), envelope.sender(), questionName, cnames,
channel, r.toString(), response.id(), envelope.sender(), questionName, cnames,
parent.searchDomains());
}
continue;
Expand All @@ -856,7 +867,7 @@ private void onExpectedResponse(
if (logger.isDebugEnabled()) {
logger.debug("{} Ignoring record {} for [{}: {}] as the converted record is null. "
+ "Hostname [{}], Additionals: {}",
parent.ch, r.toString(), response.id(), envelope.sender(), hostname, additionals);
channel, r.toString(), response.id(), envelope.sender(), hostname, additionals);
}
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3252,7 +3252,14 @@ protected DnsMessage filterMessage(DnsMessage message) {
DnsNameResolver resolver = null;
try {
DnsNameResolverBuilder builder = newResolver();

final DatagramChannel datagramChannel = new NioDatagramChannel();
ChannelFactory<DatagramChannel> channelFactory = new ChannelFactory<DatagramChannel>() {
@Override
public DatagramChannel newChannel() {
return datagramChannel;
}
};
builder.channelFactory(channelFactory);
if (tcpFallback) {
dnsServer2.start(null, (InetSocketAddress) serverSocket.getLocalSocketAddress());

Expand All @@ -3267,7 +3274,7 @@ protected DnsMessage filterMessage(DnsMessage message) {
.nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress()));
resolver = builder.build();
if (truncatedBecauseOfMtu) {
resolver.ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() {
datagramChannel.pipeline().addFirst(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof DatagramPacket) {
Expand Down

0 comments on commit 1bb825b

Please sign in to comment.