From 5da277075230e8a131edfefadfb119c0556fa904 Mon Sep 17 00:00:00 2001 From: Chris Ruffalo Date: Mon, 30 Oct 2023 17:21:14 -0400 Subject: [PATCH] Pluggable I/O for SimpleResolver Closes #253 --- .../java/org/xbill/DNS/DefaultIoClient.java | 46 ++++++++++ src/main/java/org/xbill/DNS/Lookup.java | 3 +- src/main/java/org/xbill/DNS/NioTcpClient.java | 30 +++---- src/main/java/org/xbill/DNS/NioUdpClient.java | 42 +++++----- .../java/org/xbill/DNS/SimpleResolver.java | 21 ++++- .../xbill/DNS/io/DefaultIoClientFactory.java | 30 +++++++ .../org/xbill/DNS/io/IoClientFactory.java | 30 +++++++ .../java/org/xbill/DNS/io/TcpIoClient.java | 35 ++++++++ .../java/org/xbill/DNS/io/UdpIoClient.java | 37 ++++++++ .../java/org/xbill/DNS/NioTcpClientTest.java | 8 +- .../xbill/DNS/SimpleResolverDeniedTest.java | 82 ++++++++++-------- src/test/java/org/xbill/DNS/TSIGTest.java | 84 +++++++++++-------- 12 files changed, 335 insertions(+), 113 deletions(-) create mode 100644 src/main/java/org/xbill/DNS/DefaultIoClient.java create mode 100644 src/main/java/org/xbill/DNS/io/DefaultIoClientFactory.java create mode 100644 src/main/java/org/xbill/DNS/io/IoClientFactory.java create mode 100644 src/main/java/org/xbill/DNS/io/TcpIoClient.java create mode 100644 src/main/java/org/xbill/DNS/io/UdpIoClient.java diff --git a/src/main/java/org/xbill/DNS/DefaultIoClient.java b/src/main/java/org/xbill/DNS/DefaultIoClient.java new file mode 100644 index 00000000..0f6cd407 --- /dev/null +++ b/src/main/java/org/xbill/DNS/DefaultIoClient.java @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import org.xbill.DNS.io.TcpIoClient; +import org.xbill.DNS.io.UdpIoClient; + +/** + * An implementation of the IO clients that use the internal NIO-based clients. + * + * @see NioUdpClient + * @see NioTcpClient + * @since 3.6 + */ +public class DefaultIoClient implements TcpIoClient, UdpIoClient { + private final TcpIoClient tcpIoClient; + private final UdpIoClient udpIoClient; + + public DefaultIoClient() { + tcpIoClient = new NioTcpClient(); + udpIoClient = new NioUdpClient(); + } + + @Override + public CompletableFuture sendAndReceiveTcp( + InetSocketAddress local, + InetSocketAddress remote, + Message query, + byte[] data, + Duration timeout) { + return tcpIoClient.sendAndReceiveTcp(local, remote, query, data, timeout); + } + + @Override + public CompletableFuture sendAndReceiveUdp( + InetSocketAddress local, + InetSocketAddress remote, + Message query, + byte[] data, + int max, + Duration timeout) { + return udpIoClient.sendAndReceiveUdp(local, remote, query, data, max, timeout); + } +} diff --git a/src/main/java/org/xbill/DNS/Lookup.java b/src/main/java/org/xbill/DNS/Lookup.java index b8b27cdf..d895cae1 100644 --- a/src/main/java/org/xbill/DNS/Lookup.java +++ b/src/main/java/org/xbill/DNS/Lookup.java @@ -248,7 +248,8 @@ private static List convertSearchPathDomainList(List domains) { } /** - * Sets a custom logger that will be used to log the sent and received packets. + * Sets a custom logger that will be used to log the sent and received packets. This is only + * applicable to the default I/O implementations. * * @param logger The logger */ diff --git a/src/main/java/org/xbill/DNS/NioTcpClient.java b/src/main/java/org/xbill/DNS/NioTcpClient.java index aacc1899..759a2a29 100644 --- a/src/main/java/org/xbill/DNS/NioTcpClient.java +++ b/src/main/java/org/xbill/DNS/NioTcpClient.java @@ -18,22 +18,21 @@ import java.util.concurrent.ConcurrentLinkedQueue; import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; -import lombok.experimental.UtilityClass; import lombok.extern.slf4j.Slf4j; +import org.xbill.DNS.io.TcpIoClient; @Slf4j -@UtilityClass -final class NioTcpClient extends NioClient { - private static final Queue registrationQueue = new ConcurrentLinkedQueue<>(); - private static final Map channelMap = new ConcurrentHashMap<>(); +final class NioTcpClient extends NioClient implements TcpIoClient { + private final Queue registrationQueue = new ConcurrentLinkedQueue<>(); + private final Map channelMap = new ConcurrentHashMap<>(); - static { - setRegistrationsTask(NioTcpClient::processPendingRegistrations, true); - setTimeoutTask(NioTcpClient::checkTransactionTimeouts, true); - setCloseTask(NioTcpClient::closeTcp, true); + NioTcpClient() { + setRegistrationsTask(this::processPendingRegistrations, true); + setTimeoutTask(this::checkTransactionTimeouts, true); + setCloseTask(this::closeTcp, true); } - private static void processPendingRegistrations() { + private void processPendingRegistrations() { while (!registrationQueue.isEmpty()) { ChannelState state = registrationQueue.remove(); try { @@ -49,7 +48,7 @@ private static void processPendingRegistrations() { } } - private static void checkTransactionTimeouts() { + private void checkTransactionTimeouts() { for (ChannelState state : channelMap.values()) { for (Iterator it = state.pendingTransactions.iterator(); it.hasNext(); ) { Transaction t = it.next(); @@ -61,7 +60,7 @@ private static void checkTransactionTimeouts() { } } - private static void closeTcp() { + private void closeTcp() { registrationQueue.clear(); EOFException closing = new EOFException("Client is closing"); channelMap.forEach((key, state) -> state.handleTransactionException(closing)); @@ -112,8 +111,8 @@ void send() throws IOException { } @RequiredArgsConstructor - private static class ChannelState implements KeyProcessor { - final SocketChannel channel; + private class ChannelState implements KeyProcessor { + private final SocketChannel channel; final Queue pendingTransactions = new ConcurrentLinkedQueue<>(); ByteBuffer responseLengthData = ByteBuffer.allocate(2); ByteBuffer responseData = ByteBuffer.allocate(Message.MAXLENGTH); @@ -259,7 +258,8 @@ private static class ChannelKey { final InetSocketAddress remote; } - static CompletableFuture sendrecv( + @Override + public CompletableFuture sendAndReceiveTcp( InetSocketAddress local, InetSocketAddress remote, Message query, diff --git a/src/main/java/org/xbill/DNS/NioUdpClient.java b/src/main/java/org/xbill/DNS/NioUdpClient.java index 340c5df3..2605dc93 100644 --- a/src/main/java/org/xbill/DNS/NioUdpClient.java +++ b/src/main/java/org/xbill/DNS/NioUdpClient.java @@ -17,20 +17,19 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import lombok.RequiredArgsConstructor; -import lombok.experimental.UtilityClass; import lombok.extern.slf4j.Slf4j; +import org.xbill.DNS.io.UdpIoClient; @Slf4j -@UtilityClass -final class NioUdpClient extends NioClient { - private static final int EPHEMERAL_START; - private static final int EPHEMERAL_RANGE; +final class NioUdpClient extends NioClient implements UdpIoClient { + private final int ephemeralStart; + private final int ephemeralRange; - private static final SecureRandom prng; - private static final Queue registrationQueue = new ConcurrentLinkedQueue<>(); - private static final Queue pendingTransactions = new ConcurrentLinkedQueue<>(); + private final SecureRandom prng; + private final Queue registrationQueue = new ConcurrentLinkedQueue<>(); + private final Queue pendingTransactions = new ConcurrentLinkedQueue<>(); - static { + NioUdpClient() { // https://tools.ietf.org/html/rfc6335#section-6 int ephemeralStartDefault = 49152; int ephemeralEndDefault = 65535; @@ -41,21 +40,21 @@ final class NioUdpClient extends NioClient { ephemeralEndDefault = 60999; } - EPHEMERAL_START = Integer.getInteger("dnsjava.udp.ephemeral.start", ephemeralStartDefault); + ephemeralStart = Integer.getInteger("dnsjava.udp.ephemeral.start", ephemeralStartDefault); int ephemeralEnd = Integer.getInteger("dnsjava.udp.ephemeral.end", ephemeralEndDefault); - EPHEMERAL_RANGE = ephemeralEnd - EPHEMERAL_START; + ephemeralRange = ephemeralEnd - ephemeralStart; if (Boolean.getBoolean("dnsjava.udp.ephemeral.use_ephemeral_port")) { prng = null; } else { prng = new SecureRandom(); } - setRegistrationsTask(NioUdpClient::processPendingRegistrations, false); - setTimeoutTask(NioUdpClient::checkTransactionTimeouts, false); - setCloseTask(NioUdpClient::closeUdp, false); + setRegistrationsTask(this::processPendingRegistrations, false); + setTimeoutTask(this::checkTransactionTimeouts, false); + setCloseTask(this::closeUdp, false); } - private static void processPendingRegistrations() { + private void processPendingRegistrations() { while (!registrationQueue.isEmpty()) { Transaction t = registrationQueue.remove(); try { @@ -68,7 +67,7 @@ private static void processPendingRegistrations() { } } - private static void checkTransactionTimeouts() { + private void checkTransactionTimeouts() { for (Iterator it = pendingTransactions.iterator(); it.hasNext(); ) { Transaction t = it.next(); if (t.endTime - System.nanoTime() < 0) { @@ -79,7 +78,7 @@ private static void checkTransactionTimeouts() { } @RequiredArgsConstructor - private static class Transaction implements KeyProcessor { + private class Transaction implements KeyProcessor { private final int id; private final byte[] data; private final int max; @@ -159,7 +158,8 @@ private void silentCloseChannel() { } } - static CompletableFuture sendrecv( + @Override + public CompletableFuture sendAndReceiveUdp( InetSocketAddress local, InetSocketAddress remote, Message query, @@ -182,12 +182,12 @@ static CompletableFuture sendrecv( InetSocketAddress addr = null; if (local == null) { if (prng != null) { - addr = new InetSocketAddress(prng.nextInt(EPHEMERAL_RANGE) + EPHEMERAL_START); + addr = new InetSocketAddress(prng.nextInt(ephemeralRange) + ephemeralStart); } } else { int port = local.getPort(); if (port == 0 && prng != null) { - port = prng.nextInt(EPHEMERAL_RANGE) + EPHEMERAL_START; + port = prng.nextInt(ephemeralRange) + ephemeralStart; } addr = new InetSocketAddress(local.getAddress(), port); @@ -225,7 +225,7 @@ static CompletableFuture sendrecv( return f; } - private static void closeUdp() { + private void closeUdp() { registrationQueue.clear(); EOFException closing = new EOFException("Client is closing"); pendingTransactions.forEach(t -> t.completeExceptionally(closing)); diff --git a/src/main/java/org/xbill/DNS/SimpleResolver.java b/src/main/java/org/xbill/DNS/SimpleResolver.java index 0dbac400..e33c660a 100644 --- a/src/main/java/org/xbill/DNS/SimpleResolver.java +++ b/src/main/java/org/xbill/DNS/SimpleResolver.java @@ -14,7 +14,11 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; import java.util.concurrent.ForkJoinPool; +import lombok.Getter; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import org.xbill.DNS.io.DefaultIoClientFactory; +import org.xbill.DNS.io.IoClientFactory; /** * An implementation of Resolver that sends one query to one server. SimpleResolver handles TCP @@ -44,6 +48,13 @@ public class SimpleResolver implements Resolver { private static final short DEFAULT_UDPSIZE = 512; + /** + * Gets or sets the factory that creates clients for sending messages to the wire. + * + * @since 3.6 + */ + @Getter @Setter private IoClientFactory ioClientFactory = new DefaultIoClientFactory(); + private static InetSocketAddress defaultResolver = new InetSocketAddress(InetAddress.getLoopbackAddress(), DEFAULT_PORT); @@ -368,9 +379,15 @@ CompletableFuture sendAsync(Message query, boolean forceTcp, Executor e CompletableFuture result; if (tcp) { - result = NioTcpClient.sendrecv(localAddress, address, query, out, timeoutValue); + result = + ioClientFactory + .createOrGetTcpClient() + .sendAndReceiveTcp(localAddress, address, query, out, timeoutValue); } else { - result = NioUdpClient.sendrecv(localAddress, address, query, out, udpSize, timeoutValue); + result = + ioClientFactory + .createOrGetUdpClient() + .sendAndReceiveUdp(localAddress, address, query, out, udpSize, timeoutValue); } return result.thenComposeAsync( diff --git a/src/main/java/org/xbill/DNS/io/DefaultIoClientFactory.java b/src/main/java/org/xbill/DNS/io/DefaultIoClientFactory.java new file mode 100644 index 00000000..04d173f2 --- /dev/null +++ b/src/main/java/org/xbill/DNS/io/DefaultIoClientFactory.java @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS.io; + +import org.xbill.DNS.DefaultIoClient; +import org.xbill.DNS.SimpleResolver; + +/** + * Serves as a default implementation that is used by the {@link SimpleResolver}, unless otherwise + * configured. This preserves the default behavior (to use the built-in NIO clients) while allowing + * flexibility at the point of use. + * + * @since 3.6 + */ +public class DefaultIoClientFactory implements IoClientFactory { + /** + * Shared instance because it only serves as a bridge to the static NIO classes and does not need + * to be different per class. + */ + private static final DefaultIoClient RESOLVER_CLIENT = new DefaultIoClient(); + + @Override + public TcpIoClient createOrGetTcpClient() { + return RESOLVER_CLIENT; + } + + @Override + public UdpIoClient createOrGetUdpClient() { + return RESOLVER_CLIENT; + } +} diff --git a/src/main/java/org/xbill/DNS/io/IoClientFactory.java b/src/main/java/org/xbill/DNS/io/IoClientFactory.java new file mode 100644 index 00000000..70154ff2 --- /dev/null +++ b/src/main/java/org/xbill/DNS/io/IoClientFactory.java @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS.io; + +import org.xbill.DNS.SimpleResolver; + +/** + * Interface for creating the TCP/UDP factories necessary for the {@link SimpleResolver}. + * + * @since 3.6 + */ +public interface IoClientFactory { + /** + * Create or return a cached/reused instance of the TCP resolver that should be used to send DNS + * data over the wire to the remote target.
+ * It is the responsibility of this method to manage pooling or connection reuse. This method is + * called right before the connection is made every time the {@link SimpleResolver} is called. The + * implementer of this method should be aware and choose how to pool or reuse connections. + * + * @return an instance of the tcp resolver client + */ + TcpIoClient createOrGetTcpClient(); + + /** + * Create or return a cached/reused instance of the UDP resolver that should be used to send DNS + * data over the wire to the remote target. + * + * @return an instance of the udp resolver client + */ + UdpIoClient createOrGetUdpClient(); +} diff --git a/src/main/java/org/xbill/DNS/io/TcpIoClient.java b/src/main/java/org/xbill/DNS/io/TcpIoClient.java new file mode 100644 index 00000000..e8febbef --- /dev/null +++ b/src/main/java/org/xbill/DNS/io/TcpIoClient.java @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS.io; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import org.xbill.DNS.Message; +import org.xbill.DNS.Resolver; + +/** + * Serves as an interface from a {@link Resolver} to the underlying mechanism for sending bytes over + * the wire as a TCP message. + * + * @since 3.6 + */ +public interface TcpIoClient { + /** + * Sends a query to a remote server and returns the answer. + * + * @param local Address from which the connection is coming. may be {@code null} and the + * implementation must decide on the local address. + * @param remote Address that the connection should send the data to. + * @param query DNS message representation of the outbound query. + * @param data Raw byte representation of the outbound query. + * @param timeout Duration before the connection will time out and be closed. + * @return A {@link CompletableFuture} that will be completed with the byte value of the response. + * @since 3.6 + */ + CompletableFuture sendAndReceiveTcp( + InetSocketAddress local, + InetSocketAddress remote, + Message query, + byte[] data, + Duration timeout); +} diff --git a/src/main/java/org/xbill/DNS/io/UdpIoClient.java b/src/main/java/org/xbill/DNS/io/UdpIoClient.java new file mode 100644 index 00000000..394ccc4b --- /dev/null +++ b/src/main/java/org/xbill/DNS/io/UdpIoClient.java @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS.io; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import org.xbill.DNS.Message; +import org.xbill.DNS.Resolver; + +/** + * Serves as an interface from a {@link Resolver} to the underlying mechanism for sending bytes over + * the wire as a UDP message. + * + * @since 3.6 + */ +public interface UdpIoClient { + /** + * Sends a query to a remote server and returns the answer. + * + * @param local Address from which the connection is coming. may be {@code null} and the + * implementation must decide on the local address. + * @param remote Address that the connection should send the data to. + * @param query DNS message representation of the outbound query. + * @param data Raw byte representation of the outbound query. + * @param max Size of the response buffer. + * @param timeout Duration before the connection will time out and be closed. + * @return A {@link CompletableFuture} that will be completed with the byte value of the response. + * @since 3.6 + */ + CompletableFuture sendAndReceiveUdp( + InetSocketAddress local, + InetSocketAddress remote, + Message query, + byte[] data, + int max, + Duration timeout); +} diff --git a/src/test/java/org/xbill/DNS/NioTcpClientTest.java b/src/test/java/org/xbill/DNS/NioTcpClientTest.java index 7cc9b9fd..05558178 100644 --- a/src/test/java/org/xbill/DNS/NioTcpClientTest.java +++ b/src/test/java/org/xbill/DNS/NioTcpClientTest.java @@ -46,6 +46,7 @@ void testResponseStream() throws InterruptedException, IOException { try { // start the selector thread early NioClient.selector(); + NioTcpClient nioTcpClient = new NioTcpClient(); Record qr = Record.newRecord(Name.fromConstantString("example.com."), Type.A, DClass.IN); Message[] q = new Message[] {Message.newQuery(qr), Message.newQuery(qr)}; @@ -112,7 +113,8 @@ void testResponseStream() throws InterruptedException, IOException { for (int j = 0; j < q.length; j++) { int jj = j; - NioTcpClient.sendrecv( + nioTcpClient + .sendAndReceiveTcp( null, (InetSocketAddress) ss.getLocalSocketAddress(), q[j], @@ -160,6 +162,7 @@ void testTooShortResponseStream(String base16ResponseBytes) try { // start the selector thread early NioClient.selector(); + NioTcpClient nioTcpClient = new NioTcpClient(); Record qr = Record.newRecord(Name.fromConstantString("example.com."), Type.A, DClass.IN); Message q = Message.newQuery(qr); @@ -208,7 +211,8 @@ void testTooShortResponseStream(String base16ResponseBytes) fail("timed out waiting for server thread to start"); } - NioTcpClient.sendrecv( + nioTcpClient + .sendAndReceiveTcp( null, (InetSocketAddress) ss.getLocalSocketAddress(), q, diff --git a/src/test/java/org/xbill/DNS/SimpleResolverDeniedTest.java b/src/test/java/org/xbill/DNS/SimpleResolverDeniedTest.java index 443c967a..1fdd4760 100644 --- a/src/test/java/org/xbill/DNS/SimpleResolverDeniedTest.java +++ b/src/test/java/org/xbill/DNS/SimpleResolverDeniedTest.java @@ -5,14 +5,17 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.IOException; import java.net.InetSocketAddress; import java.time.Duration; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.Test; -import org.mockito.MockedStatic; -import org.mockito.Mockito; +import org.xbill.DNS.io.IoClientFactory; +import org.xbill.DNS.io.TcpIoClient; +import org.xbill.DNS.io.UdpIoClient; class SimpleResolverDeniedTest { @@ -25,39 +28,46 @@ void emptyResponseShouldThrowWireParseException() throws IOException { new CNAMERecord(Name.fromString("www", zone), DClass.IN, 300, Name.fromString("example.")); query.addRecord(record, Section.UPDATE); - try (MockedStatic udpClient = Mockito.mockStatic(NioUdpClient.class)) { - udpClient - .when( - () -> - NioUdpClient.sendrecv( - any(), - any(InetSocketAddress.class), - any(Message.class), - any(byte[].class), - anyInt(), - any(Duration.class))) - .thenAnswer( - a -> { - Message qparsed = new Message(a.getArgument(3)); - - int id = qparsed.getHeader().getID(); - Message response = new Message(id); - response.getHeader().setRcode(Rcode.REFUSED); - byte[] rbytes = response.toWire(Message.MAXLENGTH); - - // This was the exact format returned by denying server - assertArrayEquals( - rbytes, - new byte[] {(byte) (id >>> 8), (byte) id, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}); - - CompletableFuture f = new CompletableFuture<>(); - f.complete(rbytes); - return f; - }); - - SimpleResolver simpleResolver = new SimpleResolver("127.0.0.1"); - - assertThrows(WireParseException.class, () -> simpleResolver.send(query)); - } + SimpleResolver simpleResolver = new SimpleResolver("127.0.0.1"); + simpleResolver.setIoClientFactory( + new IoClientFactory() { + @Override + public TcpIoClient createOrGetTcpClient() { + return null; + } + + @Override + public UdpIoClient createOrGetUdpClient() { + UdpIoClient udpMock = mock(NioUdpClient.class); + when(udpMock.sendAndReceiveUdp( + any(), + any(InetSocketAddress.class), + any(Message.class), + any(byte[].class), + anyInt(), + any(Duration.class))) + .thenAnswer( + a -> { + Message qparsed = new Message(a.getArgument(3)); + + int id = qparsed.getHeader().getID(); + Message response = new Message(id); + response.getHeader().setRcode(Rcode.REFUSED); + byte[] rbytes = response.toWire(Message.MAXLENGTH); + + // This was the exact format returned by denying server + assertArrayEquals( + rbytes, + new byte[] {(byte) (id >>> 8), (byte) id, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}); + + CompletableFuture f = new CompletableFuture<>(); + f.complete(rbytes); + return f; + }); + return udpMock; + } + }); + + assertThrows(WireParseException.class, () -> simpleResolver.send(query)); } } diff --git a/src/test/java/org/xbill/DNS/TSIGTest.java b/src/test/java/org/xbill/DNS/TSIGTest.java index 66da6dd9..21956221 100644 --- a/src/test/java/org/xbill/DNS/TSIGTest.java +++ b/src/test/java/org/xbill/DNS/TSIGTest.java @@ -9,6 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.IOException; import java.lang.reflect.Field; @@ -36,9 +38,10 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.MockedStatic; -import org.mockito.Mockito; import org.xbill.DNS.TSIG.StreamGenerator; +import org.xbill.DNS.io.IoClientFactory; +import org.xbill.DNS.io.TcpIoClient; +import org.xbill.DNS.io.UdpIoClient; import org.xbill.DNS.utils.base64; class TSIGTest { @@ -221,40 +224,49 @@ void signedQuerySignedResponseViaResolver() throws IOException { Record question = Record.newRecord(qname, Type.A, DClass.IN); Message query = Message.newQuery(question); - try (MockedStatic udpClient = Mockito.mockStatic(NioUdpClient.class)) { - udpClient - .when( - () -> - NioUdpClient.sendrecv( - any(), - any(InetSocketAddress.class), - any(), - any(byte[].class), - anyInt(), - any(Duration.class))) - .thenAnswer( - a -> { - Message qparsed = new Message(a.getArgument(3, byte[].class)); - - Message response = new Message(qparsed.getHeader().getID()); - response.setTSIG(defaultKey, Rcode.NOERROR, qparsed.getTSIG()); - response.getHeader().setFlag(Flags.QR); - response.addRecord(question, Section.QUESTION); - Record answer = Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null); - response.addRecord(answer, Section.ANSWER); - byte[] rbytes = response.toWire(Message.MAXLENGTH); - - CompletableFuture f = new CompletableFuture<>(); - f.complete(rbytes); - return f; - }); - SimpleResolver res = new SimpleResolver("127.0.0.1"); - res.setTSIGKey(defaultKey); - - Message responseFromResolver = res.send(query); - assertTrue(responseFromResolver.isSigned()); - assertTrue(responseFromResolver.isVerified()); - } + SimpleResolver res = new SimpleResolver("127.0.0.1"); + res.setIoClientFactory( + new IoClientFactory() { + @Override + public TcpIoClient createOrGetTcpClient() { + return null; + } + + @Override + public UdpIoClient createOrGetUdpClient() { + UdpIoClient udpClient = mock(UdpIoClient.class); + when(udpClient.sendAndReceiveUdp( + any(), + any(InetSocketAddress.class), + any(), + any(byte[].class), + anyInt(), + any(Duration.class))) + .thenAnswer( + a -> { + Message qparsed = new Message(a.getArgument(3, byte[].class)); + + Message response = new Message(qparsed.getHeader().getID()); + response.setTSIG(defaultKey, Rcode.NOERROR, qparsed.getTSIG()); + response.getHeader().setFlag(Flags.QR); + response.addRecord(question, Section.QUESTION); + Record answer = + Record.fromString(qname, Type.A, DClass.IN, 300, "1.2.3.4", null); + response.addRecord(answer, Section.ANSWER); + byte[] rbytes = response.toWire(Message.MAXLENGTH); + + CompletableFuture f = new CompletableFuture<>(); + f.complete(rbytes); + return f; + }); + return udpClient; + } + }); + res.setTSIGKey(defaultKey); + + Message responseFromResolver = res.send(query); + assertTrue(responseFromResolver.isSigned()); + assertTrue(responseFromResolver.isVerified()); } @Test