Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ISPN-14579 Various RESP commands are requesting wrong size for buffer #10704

Merged
merged 2 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import io.netty.channel.socket.nio.NioServerSocketChannel;

// This is a separate class for easier replacement within Quarkus
final class NativeTransport {
public final class NativeTransport {
private static final boolean IS_LINUX = System.getProperty("os.name").toLowerCase().startsWith("linux");
private static final String USE_EPOLL_PROPERTY = "infinispan.server.channel.epoll";
private static final String USE_IOURING_PROPERTY = "infinispan.server.channel.iouring";
Expand All @@ -22,8 +22,8 @@ final class NativeTransport {
private static final boolean IOURING_DISABLED = System.getProperty(USE_IOURING_PROPERTY, "true").equalsIgnoreCase("false");

// Has to be after other static variables to ensure they are initialized
static final boolean USE_NATIVE_EPOLL = useNativeEpoll();
static final boolean USE_NATIVE_IOURING = useNativeIOUring();
public static final boolean USE_NATIVE_EPOLL = useNativeEpoll();
public static final boolean USE_NATIVE_IOURING = useNativeIOUring();

private static boolean useNativeEpoll() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public CompletionStage<RespRequestHandler> handleRequest(ChannelHandlerContext c
byte[] respProtocolBytes = arguments.get(0);
String version = new String(respProtocolBytes, CharsetUtil.UTF_8);
if (!version.equals("3")) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-NOPROTO sorry this protocol version is not supported\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-NOPROTO sorry this protocol version is not supported\r\n", ctx.alloc()), ctx.voidPromise());
break;
}

Expand Down Expand Up @@ -78,18 +78,18 @@ private CompletionStage<Boolean> performAuth(ChannelHandlerContext ctx, String u
private boolean handleAuthResponse(ChannelHandlerContext ctx, Subject subject) {
assert ctx.channel().eventLoop().inEventLoop();
if (subject == null) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR Client sent AUTH, but no password is set\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR Client sent AUTH, but no password is set\r\n", ctx.alloc()), ctx.voidPromise());
return false;
}

cache = cache.withSubject(subject);
ctx.writeAndFlush(statusOK());
ctx.writeAndFlush(statusOK(), ctx.voidPromise());
return true;
}

private void handleUnauthorized(ChannelHandlerContext ctx) {
assert ctx.channel().eventLoop().inEventLoop();
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-WRONGPASS invalid username-password pair or user is disabled.\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-WRONGPASS invalid username-password pair or user is disabled.\r\n", ctx.alloc()), ctx.voidPromise());
}

private boolean isAuthorized() {
Expand All @@ -105,6 +105,6 @@ private static void helloResponse(ChannelHandlerContext ctx) {
"$2\r\nid\r\n:184\r\n" +
"$4\r\nmode\r\n$7\r\ncluster\r\n" +
"$4\r\nrole\r\n$6\r\nmaster\r\n" +
"$7\r\nmodules\r\n*0\r\n", ctx.alloc()));
"$7\r\nmodules\r\n*0\r\n", ctx.alloc()), ctx.voidPromise());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,44 @@
import org.infinispan.commons.CacheException;
import org.infinispan.commons.logging.LogFactory;
import org.infinispan.server.core.logging.Log;
import org.infinispan.server.core.transport.NativeTransport;
import org.infinispan.util.concurrent.AggregateCompletionStage;
import org.infinispan.util.concurrent.CompletionStages;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.CharsetUtil;

public class Resp3Handler extends Resp3AuthHandler {
private static final Log log = LogFactory.getLog(MethodHandles.lookup().lookupClass(), Log.class);
private static final ByteBuf OK = RespRequestHandler.stringToByteBuf("+OK\r\n", ByteBufAllocator.DEFAULT);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We couldn't use DEFAULT byte buffer as we want to use a byte buf that is the same type as our underlying protocol. Since netty supports both native and JVM based sockets, any ByteBuf passed down that doesn't match the same type will be first copied into an unpooled copy in the appropriate version. Thus we need to make sure this is the same type as we would configure in our server.

private static final ByteBuf OK;

static {
if (NativeTransport.USE_NATIVE_EPOLL || NativeTransport.USE_NATIVE_IOURING) {
OK = Unpooled.unreleasableBuffer(Unpooled.directBuffer(5, 5));
} else {
OK = Unpooled.unreleasableBuffer(Unpooled.buffer(5, 5));
}
OK.writeCharSequence("+OK\r\n", CharsetUtil.US_ASCII);
}

Resp3Handler(RespServer respServer) {
super(respServer);
}

// Returns a cached OK status that is retained for multiple uses
static ByteBuf statusOK() {
return OK.retain();
return OK.duplicate();
}

@Override
public CompletionStage<RespRequestHandler> handleRequest(ChannelHandlerContext ctx, String type, List<byte[]> arguments) {
switch (type) {
case "PING":
if (arguments.size() == 0) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("$4\r\nPONG\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("$4\r\nPONG\r\n", ctx.alloc()), ctx.voidPromise());
break;
}
// falls-through
Expand All @@ -50,7 +60,7 @@ public CompletionStage<RespRequestHandler> handleRequest(ChannelHandlerContext c
ByteBuf bufferToWrite = RespRequestHandler.stringToByteBufWithExtra("$" + argument.length + "\r\n", ctx.alloc(), argument.length + 2);
bufferToWrite.writeBytes(argument);
bufferToWrite.writeByte('\r').writeByte('\n');
ctx.writeAndFlush(bufferToWrite);
ctx.writeAndFlush(bufferToWrite, ctx.voidPromise());
break;
case "SET":
return performSet(ctx, cache, arguments.get(0), arguments.get(1), -1, type, statusOK());
Expand All @@ -66,9 +76,9 @@ public CompletionStage<RespRequestHandler> handleRequest(ChannelHandlerContext c
ByteBuf buf = RespRequestHandler.stringToByteBufWithExtra("$" + length + "\r\n", ctx.alloc(), length + 2);
buf.writeBytes(innerValueBytes);
buf.writeByte('\r').writeByte('\n');
ctx.writeAndFlush(buf);
ctx.writeAndFlush(buf, ctx.voidPromise());
} else {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("$-1\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("$-1\r\n", ctx.alloc()), ctx.voidPromise());
}
});
case "DEL":
Expand Down Expand Up @@ -99,20 +109,20 @@ public CompletionStage<RespRequestHandler> handleRequest(ChannelHandlerContext c

if ("GET".equalsIgnoreCase(getOrSet)) {
if ("appendonly".equalsIgnoreCase(name)) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("*2\r\n+" + name + "\r\n+no\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("*2\r\n+" + name + "\r\n+no\r\n", ctx.alloc()), ctx.voidPromise());
} else if (name.indexOf('*') != -1 || name.indexOf('?') != -1) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR CONFIG blob pattern matching not implemented\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR CONFIG blob pattern matching not implemented\r\n", ctx.alloc()), ctx.voidPromise());
} else {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("*2\r\n+" + name + "\r\n+\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("*2\r\n+" + name + "\r\n+\r\n", ctx.alloc()), ctx.voidPromise());
}
} else if ("SET".equalsIgnoreCase(getOrSet)) {
ctx.writeAndFlush(statusOK());
ctx.writeAndFlush(statusOK(), ctx.voidPromise());
} else {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR CONFIG " + getOrSet + " not implemented\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR CONFIG " + getOrSet + " not implemented\r\n", ctx.alloc()), ctx.voidPromise());
}
break;
case "INFO":
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR not implemented yet\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR not implemented yet\r\n", ctx.alloc()), ctx.voidPromise());
break;
case "PUBLISH":
// TODO: should we return the # of subscribers on this node?
Expand All @@ -124,22 +134,22 @@ public CompletionStage<RespRequestHandler> handleRequest(ChannelHandlerContext c
SubscriberHandler subscriberHandler = new SubscriberHandler(respServer, this);
return subscriberHandler.handleRequest(ctx, type, arguments);
case "SELECT":
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR Select not supported in cluster mode\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR Select not supported in cluster mode\r\n", ctx.alloc()), ctx.voidPromise());
break;
case "READWRITE":
case "READONLY":
// We are always in read write allowing read from backups
ctx.writeAndFlush(statusOK());
ctx.writeAndFlush(statusOK(), ctx.voidPromise());
break;
case "RESET":
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("+RESET\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("+RESET\r\n", ctx.alloc()), ctx.voidPromise());
if (respServer.getConfiguration().authentication().enabled()) {
return CompletableFuture.completedFuture(new Resp3AuthHandler(respServer));
}
break;
case "COMMAND":
if (!arguments.isEmpty()) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR COMMAND does not currently support arguments\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR COMMAND does not currently support arguments\r\n", ctx.alloc()), ctx.voidPromise());
break;
}
StringBuilder commandBuilder = new StringBuilder();
Expand All @@ -165,7 +175,7 @@ public CompletionStage<RespRequestHandler> handleRequest(ChannelHandlerContext c
addCommand(commandBuilder, "RESET", 1, 0, 0, 0);
addCommand(commandBuilder, "QUIT", 1, 0, 0, 0);
addCommand(commandBuilder, "COMMAND", -1, 0, 0, 0);
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(commandBuilder.toString(), ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(commandBuilder.toString(), ctx.alloc()), ctx.voidPromise());
break;
default:
return super.handleRequest(ctx, type, arguments);
Expand All @@ -191,26 +201,27 @@ private static void addCommand(StringBuilder builder, String name, int arity, in
}

private static void handleLongResult(ChannelHandlerContext ctx, Long result) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(":" + result + "\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(":" + result + "\r\n", ctx.alloc()), ctx.voidPromise());
}

static void handleThrowable(ChannelHandlerContext ctx, Throwable t) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR" + t.getMessage() + "\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR" + t.getMessage() + "\r\n", ctx.alloc()), ctx.voidPromise());
}

private static CompletionStage<Long> counterIncOrDec(Cache<byte[], byte[]> cache, byte[] key, boolean increment) {
return cache.getAsync(key)
.thenCompose(currentValueBytes -> {
if (currentValueBytes != null) {
String prevValue = new String(currentValueBytes, CharsetUtil.UTF_8);
// Numbers are always ASCII
String prevValue = new String(currentValueBytes, CharsetUtil.US_ASCII);
long prevIntValue;
try {
prevIntValue = Long.parseLong(prevValue) + (increment ? 1 : -1);
} catch (NumberFormatException e) {
throw new CacheException("value is not an integer or out of range");
}
String newValueString = String.valueOf(prevIntValue);
byte[] newValueBytes = newValueString.getBytes(CharsetUtil.UTF_8);
byte[] newValueBytes = newValueString.getBytes(CharsetUtil.US_ASCII);
return cache.replaceAsync(key, currentValueBytes, newValueBytes)
.thenCompose(replaced -> {
if (replaced) {
Expand All @@ -220,7 +231,7 @@ private static CompletionStage<Long> counterIncOrDec(Cache<byte[], byte[]> cache
});
}
long longValue = increment ? 1 : -1;
byte[] valueToPut = String.valueOf(longValue).getBytes(CharsetUtil.UTF_8);
byte[] valueToPut = String.valueOf(longValue).getBytes(CharsetUtil.US_ASCII);
return cache.putIfAbsentAsync(key, valueToPut)
.thenCompose(prev -> {
if (prev != null) {
Expand All @@ -238,7 +249,7 @@ private CompletionStage<RespRequestHandler> performSet(ChannelHandlerContext ctx
log.trace("Exception encountered while performing " + type, t);
handleThrowable(ctx, t);
} else {
ctx.writeAndFlush(messageOnSuccess);
ctx.writeAndFlush(messageOnSuccess, ctx.voidPromise());
}
});
}
Expand All @@ -254,11 +265,11 @@ private CompletionStage<RespRequestHandler> performDelete(ChannelHandlerContext
return;
}
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(":" + (prev == null ? "0" : "1") +
"\r\n", ctx.alloc()));
"\r\n", ctx.alloc()), ctx.voidPromise());
});
} else if (keysToRemove == 0) {
// TODO: is this an error?
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(":0\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(":0\r\n", ctx.alloc()), ctx.voidPromise());
return myStage;
} else {
AtomicInteger removes = new AtomicInteger();
Expand All @@ -277,15 +288,15 @@ private CompletionStage<RespRequestHandler> performDelete(ChannelHandlerContext
handleThrowable(ctx, t);
return;
}
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(":" + removals.get() + "\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf(":" + removals.get() + "\r\n", ctx.alloc()), ctx.voidPromise());
});
}
}

private CompletionStage<RespRequestHandler> performMget(ChannelHandlerContext ctx, Cache<byte[], byte[]> cache, List<byte[]> arguments) {
int keysToRetrieve = arguments.size();
if (keysToRetrieve == 0) {
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("*0\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("*0\r\n", ctx.alloc()), ctx.voidPromise());
return myStage;
}
List<byte[]> results = Collections.synchronizedList(Arrays.asList(
Expand All @@ -301,8 +312,8 @@ private CompletionStage<RespRequestHandler> performMget(ChannelHandlerContext ct
results.set(innerCount, returnValue);
int length = returnValue.length;
if (length > 0) {
// byte length + digit length (log10 + 1) + $
resultBytesSize.addAndGet(returnValue.length + (int) Math.log10(length) + 1 + 1);
// $ + digit length (log10 + 1) + /r/n + byte length
resultBytesSize.addAndGet(1 + (int) Math.log10(length) + 1 + 2 + returnValue.length);
} else {
// $0
resultBytesSize.addAndGet(2);
Expand All @@ -322,33 +333,34 @@ private CompletionStage<RespRequestHandler> performMget(ChannelHandlerContext ct
return;
}
int elements = results.size();
// * + digit length (log10 + 1) + \r\n
ByteBuf byteBuf = ctx.alloc().buffer(resultBytesSize.addAndGet(1 + (int) Math.log10(elements)
+ 1 + 2));
byteBuf.writeCharSequence("*" + results.size(), CharsetUtil.UTF_8);
// * + digit length (log10 + 1) + \r\n + accumulated bytes
int byteAmount = 1 + (int) Math.log10(elements) + 1 + 2 + resultBytesSize.get();
ByteBuf byteBuf = ctx.alloc().buffer(byteAmount, byteAmount);
byteBuf.writeCharSequence("*" + results.size(), CharsetUtil.US_ASCII);
byteBuf.writeByte('\r');
byteBuf.writeByte('\n');
for (byte[] value : results) {
if (value == null) {
byteBuf.writeCharSequence("$-1", CharsetUtil.UTF_8);
byteBuf.writeCharSequence("$-1", CharsetUtil.US_ASCII);
} else {
byteBuf.writeCharSequence("$" + value.length, CharsetUtil.UTF_8);
byteBuf.writeCharSequence("$" + value.length, CharsetUtil.US_ASCII);
byteBuf.writeByte('\r');
byteBuf.writeByte('\n');
byteBuf.writeBytes(value);
}
byteBuf.writeByte('\r');
byteBuf.writeByte('\n');
}
ctx.writeAndFlush(byteBuf);
assert byteBuf.writerIndex() == byteAmount;
ctx.writeAndFlush(byteBuf, ctx.voidPromise());
});
}

private CompletionStage<RespRequestHandler> performMset(ChannelHandlerContext ctx, Cache<byte[], byte[]> cache, List<byte[]> arguments) {
int keyValuePairCount = arguments.size();
if ((keyValuePairCount & 1) == 1) {
log.tracef("Received: %s count for keys and values combined, should be even for MSET", keyValuePairCount);
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR Missing a value for a key" + "\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR Missing a value for a key" + "\r\n", ctx.alloc()), ctx.voidPromise());
return myStage;
}
AggregateCompletionStage<Void> setStage = CompletionStages.aggregateCompletionStage();
Expand All @@ -362,7 +374,7 @@ private CompletionStage<RespRequestHandler> performMset(ChannelHandlerContext ct
log.trace("Exception encountered while performing MSET", t);
handleThrowable(ctx, t);
} else {
ctx.writeAndFlush(statusOK());
ctx.writeAndFlush(statusOK(), ctx.voidPromise());
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.unexpectedException(cause);
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR Server Error Encountered: " + cause.getMessage() + "\r\n", ctx.alloc()));
ctx.writeAndFlush(RespRequestHandler.stringToByteBuf("-ERR Server Error Encountered: " + cause.getMessage() + "\r\n", ctx.alloc()), ctx.voidPromise());
ctx.close();
}
}