Skip to content

Commit

Permalink
[Proxy/Client] Fix DNS server denial-of-service issue when DNS entry …
Browse files Browse the repository at this point in the history
…expires (apache#15403)

- DnsNameResolver doesn't coordinate concurrency and this leads to DNS server DoS
  under high load
- In Netty, DnsAddressResolverGroup internally uses internal InflightNameResolver
  class to address the problem
  - The solution is to use DnsAddressResolverGroup instead of instantiating DnsNameResolver
    directly
  • Loading branch information
lhotari committed May 1, 2022
1 parent c0bf8f0 commit fe78908
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import com.google.common.collect.Lists;
import io.netty.channel.EventLoopGroup;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import org.apache.pulsar.broker.auth.MockedPulsarServiceBaseTest;
import org.apache.pulsar.client.impl.conf.ClientConfigurationData;
import org.apache.pulsar.common.util.netty.EventLoopUtil;
Expand All @@ -30,21 +34,17 @@
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;

public class ConnectionPoolTest extends MockedPulsarServiceBaseTest {

String serviceUrl;
int brokerPort;

@BeforeClass
@Override
protected void setup() throws Exception {
super.internalSetup();
serviceUrl = "pulsar://non-existing-dns-name:" + pulsar.getBrokerListenPort().get();
brokerPort = pulsar.getBrokerListenPort().get();
serviceUrl = "pulsar://non-existing-dns-name:" + brokerPort;
}

@AfterClass(alwaysRun = true)
Expand All @@ -61,9 +61,11 @@ public void testSingleIpAddress() throws Exception {
conf.setServiceUrl(serviceUrl);
PulsarClientImpl client = new PulsarClientImpl(conf, eventLoop, pool);

List<InetAddress> result = Lists.newArrayList();
result.add(InetAddress.getByName("127.0.0.1"));
Mockito.when(pool.resolveName("non-existing-dns-name")).thenReturn(CompletableFuture.completedFuture(result));
List<InetSocketAddress> result = Lists.newArrayList();
result.add(new InetSocketAddress("127.0.0.1", brokerPort));
Mockito.when(pool.resolveName(InetSocketAddress.createUnresolved("non-existing-dns-name",
brokerPort)))
.thenReturn(CompletableFuture.completedFuture(result));

client.newProducer().topic("persistent://sample/standalone/ns/my-topic").create();

Expand All @@ -73,20 +75,20 @@ public void testSingleIpAddress() throws Exception {

@Test
public void testDoubleIpAddress() throws Exception {
String serviceUrl = "pulsar://non-existing-dns-name:" + pulsar.getBrokerListenPort().get();

ClientConfigurationData conf = new ClientConfigurationData();
EventLoopGroup eventLoop = EventLoopUtil.newEventLoopGroup(1, new DefaultThreadFactory("test"));
ConnectionPool pool = Mockito.spy(new ConnectionPool(conf, eventLoop));
conf.setServiceUrl(serviceUrl);
PulsarClientImpl client = new PulsarClientImpl(conf, eventLoop, pool);

List<InetAddress> result = Lists.newArrayList();
List<InetSocketAddress> result = Lists.newArrayList();

// Add a non existent IP to the response to check that we're trying the 2nd address as well
result.add(InetAddress.getByName("127.0.0.99"));
result.add(InetAddress.getByName("127.0.0.1"));
Mockito.when(pool.resolveName("non-existing-dns-name")).thenReturn(CompletableFuture.completedFuture(result));
result.add(new InetSocketAddress("127.0.0.99", brokerPort));
result.add(new InetSocketAddress("127.0.0.1", brokerPort));
Mockito.when(pool.resolveName(InetSocketAddress.createUnresolved("non-existing-dns-name",
brokerPort)))
.thenReturn(CompletableFuture.completedFuture(result));

// Create producer should succeed by trying the 2nd IP
client.newProducer().topic("persistent://sample/standalone/ns/my-topic").create();
Expand All @@ -103,7 +105,7 @@ public void testNoConnectionPool() throws Exception {
ConnectionPool pool = Mockito.spy(new ConnectionPool(conf, eventLoop));

InetSocketAddress brokerAddress =
InetSocketAddress.createUnresolved("127.0.0.1", pulsar.getBrokerListenPort().get());
InetSocketAddress.createUnresolved("127.0.0.1", brokerPort);
IntStream.range(1, 5).forEach(i -> {
pool.getConnection(brokerAddress).thenAccept(cnx -> {
Assert.assertTrue(cnx.channel().isActive());
Expand All @@ -125,7 +127,7 @@ public void testEnableConnectionPool() throws Exception {
ConnectionPool pool = Mockito.spy(new ConnectionPool(conf, eventLoop));

InetSocketAddress brokerAddress =
InetSocketAddress.createUnresolved("127.0.0.1", pulsar.getBrokerListenPort().get());
InetSocketAddress.createUnresolved("127.0.0.1", brokerPort);
IntStream.range(1, 10).forEach(i -> {
pool.getConnection(brokerAddress).thenAccept(cnx -> {
Assert.assertTrue(cnx.channel().isActive());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.AddressResolver;
import io.netty.resolver.dns.DnsAddressResolverGroup;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.Future;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -62,7 +62,7 @@ public class ConnectionPool implements AutoCloseable {
private final int maxConnectionsPerHosts;
private final boolean isSniProxy;

protected final DnsNameResolver dnsResolver;
protected final AddressResolver<InetSocketAddress> addressResolver;
private final boolean shouldCloseDnsResolver;

public ConnectionPool(ClientConfigurationData conf, EventLoopGroup eventLoopGroup) throws PulsarClientException {
Expand All @@ -75,7 +75,8 @@ public ConnectionPool(ClientConfigurationData conf, EventLoopGroup eventLoopGrou
}

public ConnectionPool(ClientConfigurationData conf, EventLoopGroup eventLoopGroup,
Supplier<ClientCnx> clientCnxSupplier, Optional<DnsNameResolver> dnsNameResolver)
Supplier<ClientCnx> clientCnxSupplier,
Optional<AddressResolver<InetSocketAddress>> addressResolver)
throws PulsarClientException {
this.eventLoopGroup = eventLoopGroup;
this.clientConfig = conf;
Expand All @@ -100,15 +101,19 @@ public ConnectionPool(ClientConfigurationData conf, EventLoopGroup eventLoopGrou
throw new PulsarClientException(e);
}

this.shouldCloseDnsResolver = !dnsNameResolver.isPresent();
this.dnsResolver = dnsNameResolver.orElseGet(() -> createDnsNameResolver(conf, eventLoopGroup));
this.shouldCloseDnsResolver = !addressResolver.isPresent();
this.addressResolver = addressResolver.orElseGet(() -> createAddressResolver(conf, eventLoopGroup));
}

private static DnsNameResolver createDnsNameResolver(ClientConfigurationData conf, EventLoopGroup eventLoopGroup) {
DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder(eventLoopGroup.next())
private static AddressResolver<InetSocketAddress> createAddressResolver(ClientConfigurationData conf,
EventLoopGroup eventLoopGroup) {
DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder()
.traceEnabled(true).channelType(EventLoopUtil.getDatagramChannelClass(eventLoopGroup));
DnsResolverUtil.applyJdkDnsCacheSettings(dnsNameResolverBuilder);
return dnsNameResolverBuilder.build();
// use DnsAddressResolverGroup to create the AddressResolver since it contains a solution
// to prevent cache stampede / thundering herds problem when a DNS entry expires while the system
// is under high load
return new DnsAddressResolverGroup(dnsNameResolverBuilder).getResolver(eventLoopGroup.next());
}

private static final Random random = new Random();
Expand Down Expand Up @@ -235,19 +240,17 @@ private CompletableFuture<ClientCnx> createConnection(InetSocketAddress logicalA
* Resolve DNS asynchronously and attempt to connect to any IP address returned by DNS server.
*/
private CompletableFuture<Channel> createConnection(InetSocketAddress unresolvedAddress) {
int port;
CompletableFuture<List<InetAddress>> resolvedAddress = null;
CompletableFuture<List<InetSocketAddress>> resolvedAddress;
try {
if (isSniProxy) {
URI proxyURI = new URI(clientConfig.getProxyServiceUrl());
port = proxyURI.getPort();
resolvedAddress = resolveName(proxyURI.getHost());
resolvedAddress =
resolveName(InetSocketAddress.createUnresolved(proxyURI.getHost(), proxyURI.getPort()));
} else {
port = unresolvedAddress.getPort();
resolvedAddress = resolveName(unresolvedAddress.getHostString());
resolvedAddress = resolveName(unresolvedAddress);
}
return resolvedAddress.thenCompose(
inetAddresses -> connectToResolvedAddresses(inetAddresses.iterator(), port,
inetAddresses -> connectToResolvedAddresses(inetAddresses.iterator(),
isSniProxy ? unresolvedAddress : null));
} catch (URISyntaxException e) {
log.error("Invalid Proxy url {}", clientConfig.getProxyServiceUrl(), e);
Expand All @@ -260,18 +263,17 @@ private CompletableFuture<Channel> createConnection(InetSocketAddress unresolved
* Try to connect to a sequence of IP addresses until a successful connection can be made, or fail if no
* address is working.
*/
private CompletableFuture<Channel> connectToResolvedAddresses(Iterator<InetAddress> unresolvedAddresses,
int port,
private CompletableFuture<Channel> connectToResolvedAddresses(Iterator<InetSocketAddress> unresolvedAddresses,
InetSocketAddress sniHost) {
CompletableFuture<Channel> future = new CompletableFuture<>();

// Successfully connected to server
connectToAddress(unresolvedAddresses.next(), port, sniHost)
connectToAddress(unresolvedAddresses.next(), sniHost)
.thenAccept(future::complete)
.exceptionally(exception -> {
if (unresolvedAddresses.hasNext()) {
// Try next IP address
connectToResolvedAddresses(unresolvedAddresses, port, sniHost).thenAccept(future::complete)
connectToResolvedAddresses(unresolvedAddresses, sniHost).thenAccept(future::complete)
.exceptionally(ex -> {
// This is already unwinding the recursive call
future.completeExceptionally(ex);
Expand All @@ -287,10 +289,9 @@ private CompletableFuture<Channel> connectToResolvedAddresses(Iterator<InetAddre
return future;
}

@VisibleForTesting
CompletableFuture<List<InetAddress>> resolveName(String hostname) {
CompletableFuture<List<InetAddress>> future = new CompletableFuture<>();
dnsResolver.resolveAll(hostname).addListener((Future<List<InetAddress>> resolveFuture) -> {
CompletableFuture<List<InetSocketAddress>> resolveName(InetSocketAddress unresolvedAddress) {
CompletableFuture<List<InetSocketAddress>> future = new CompletableFuture<>();
addressResolver.resolveAll(unresolvedAddress).addListener((Future<List<InetSocketAddress>> resolveFuture) -> {
if (resolveFuture.isSuccess()) {
future.complete(resolveFuture.get());
} else {
Expand All @@ -303,8 +304,7 @@ CompletableFuture<List<InetAddress>> resolveName(String hostname) {
/**
* Attempt to establish a TCP connection to an already resolved single IP address.
*/
private CompletableFuture<Channel> connectToAddress(InetAddress ipAddress, int port, InetSocketAddress sniHost) {
InetSocketAddress remoteAddress = new InetSocketAddress(ipAddress, port);
private CompletableFuture<Channel> connectToAddress(InetSocketAddress remoteAddress, InetSocketAddress sniHost) {
if (clientConfig.isUseTls()) {
return toCompletableFuture(bootstrap.register())
.thenCompose(channel -> channelInitializerHandler
Expand All @@ -331,7 +331,7 @@ public void releaseConnection(ClientCnx cnx) {
public void close() throws Exception {
closeAllConnections();
if (shouldCloseDnsResolver) {
dnsResolver.close();
addressResolver.close();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import static com.google.common.base.Preconditions.checkState;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsAddressResolverGroup;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collections;
Expand Down Expand Up @@ -82,7 +82,7 @@ public class ProxyConnection extends PulsarHandler {
private final AtomicLong requestIdGenerator =
new AtomicLong(ThreadLocalRandom.current().nextLong(0, Long.MAX_VALUE / 2));
private final ProxyService service;
private final DnsNameResolver dnsNameResolver;
private final DnsAddressResolverGroup dnsAddressResolverGroup;
private Authentication clientAuthentication;
AuthenticationDataSource authenticationData;
private State state;
Expand Down Expand Up @@ -134,10 +134,10 @@ ConnectionPool getConnectionPool() {
}

public ProxyConnection(ProxyService proxyService, Supplier<SslHandler> sslHandlerSupplier,
DnsNameResolver dnsNameResolver) {
DnsAddressResolverGroup dnsAddressResolverGroup) {
super(30, TimeUnit.SECONDS);
this.service = proxyService;
this.dnsNameResolver = dnsNameResolver;
this.dnsAddressResolverGroup = dnsAddressResolverGroup;
this.state = State.Init;
this.sslHandlerSupplier = sslHandlerSupplier;
this.brokerProxyValidator = service.getBrokerProxyValidator();
Expand Down Expand Up @@ -280,7 +280,8 @@ private synchronized void completeConnect(AuthData clientData) throws PulsarClie

if (this.connectionPool == null) {
this.connectionPool = new ConnectionPool(clientConf, service.getWorkerGroup(),
clientCnxSupplier, Optional.of(dnsNameResolver));
clientCnxSupplier,
Optional.of(dnsAddressResolverGroup.getResolver(service.getWorkerGroup().next())));
} else {
LOG.error("BUG! Connection Pool has already been created for proxy connection to {} state {} role {}",
remoteAddress, state, clientAuthRole);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsAddressResolverGroup;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
import io.prometheus.client.Counter;
Expand Down Expand Up @@ -71,7 +71,7 @@ public class ProxyService implements Closeable {

private final ProxyConfiguration proxyConfig;
@Getter
private final DnsNameResolver dnsNameResolver;
private final DnsAddressResolverGroup dnsAddressResolverGroup;
@Getter
private final BrokerProxyValidator brokerProxyValidator;
private String serviceUrl;
Expand Down Expand Up @@ -146,13 +146,13 @@ public ProxyService(ProxyConfiguration proxyConfig,
workersThreadFactory);
this.authenticationService = authenticationService;

DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder(workerGroup.next())
DnsNameResolverBuilder dnsNameResolverBuilder = new DnsNameResolverBuilder()
.channelType(EventLoopUtil.getDatagramChannelClass(workerGroup));
DnsResolverUtil.applyJdkDnsCacheSettings(dnsNameResolverBuilder);

dnsNameResolver = dnsNameResolverBuilder.build();
dnsAddressResolverGroup = new DnsAddressResolverGroup(dnsNameResolverBuilder);

brokerProxyValidator = new BrokerProxyValidator(dnsNameResolver.asAddressResolver(),
brokerProxyValidator = new BrokerProxyValidator(dnsAddressResolverGroup.getResolver(workerGroup.next()),
proxyConfig.getBrokerProxyAllowedHostNames(),
proxyConfig.getBrokerProxyAllowedIPAddresses(),
proxyConfig.getBrokerProxyAllowedTargetPorts());
Expand Down Expand Up @@ -238,7 +238,7 @@ public BrokerDiscoveryProvider getDiscoveryProvider() {
}

public void close() throws IOException {
dnsNameResolver.close();
dnsAddressResolverGroup.close();

if (discoveryProvider != null) {
discoveryProvider.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public SslHandler get() {
}

ch.pipeline().addLast("handler",
new ProxyConnection(proxyService, sslHandlerSupplier, proxyService.getDnsNameResolver()));
new ProxyConnection(proxyService, sslHandlerSupplier, proxyService.getDnsAddressResolverGroup()));

}
}

0 comments on commit fe78908

Please sign in to comment.