Skip to content

Commit

Permalink
PoC: Client channel
Browse files Browse the repository at this point in the history
  • Loading branch information
bsideup committed Mar 15, 2024
1 parent 38f968f commit cfd19fd
Show file tree
Hide file tree
Showing 5 changed files with 469 additions and 0 deletions.
95 changes: 95 additions & 0 deletions netty/src/test/java/io/grpc/clientchannel/ByteBufMarshaller.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package io.grpc.clientchannel;

import io.grpc.*;
import io.netty.buffer.*;

import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;

class ByteBufMarshaller implements MethodDescriptor.Marshaller<ByteBuf> {

static ByteBufMarshaller INSTANCE = new ByteBufMarshaller();

@Override
public InputStream stream(ByteBuf value) {
return new DrainableInputStream(value);
}

@Override
public ByteBuf parse(InputStream stream) {
try {
// See https://github.com/GoogleCloudPlatform/grpc-gcp-java/pull/77
if (stream instanceof KnownLength) {
int size = ((KnownLength) stream).available();

if (size == 0) {
return Unpooled.EMPTY_BUFFER;
}

if (stream instanceof Detachable) {
Detachable detachable = (Detachable) stream;
stream = detachable.detach();
}

if (stream instanceof HasByteBuffer && ((HasByteBuffer) stream).byteBufferSupported()) {
HasByteBuffer hasByteBuffer = (HasByteBuffer) stream;
stream.mark(size);

ByteBuf firstBuffer = Unpooled.wrappedBuffer((hasByteBuffer.getByteBuffer()));
stream.skip(firstBuffer.readableBytes());

try {
// Skip composite buffer if the result fits into a single buffer
if (stream.available() <= 0) {
return firstBuffer;
}

CompositeByteBuf compositeBuffer = Unpooled.compositeBuffer(32);
compositeBuffer.addComponent(true, firstBuffer);

while (stream.available() != 0) {
ByteBuffer buffer = ((HasByteBuffer) stream).getByteBuffer();
ByteBuf byteBuf = Unpooled.wrappedBuffer(buffer);
compositeBuffer.addComponent(true, byteBuf);
stream.skip(buffer.remaining());
}

return compositeBuffer;
} finally {
stream.reset();
}
}
}

ByteBuf buf = Unpooled.buffer(stream.available());
buf.writeBytes(stream, stream.available());

return buf;
} catch (Exception e) {
throw new RuntimeException(e);
}
}

class DrainableInputStream extends ByteBufInputStream implements Drainable {

private final ByteBuf buffer;

public DrainableInputStream(ByteBuf buffer) {
super(buffer);
this.buffer = buffer;
}

@Override
public int drainTo(OutputStream target) {
int capacity = buffer.readableBytes();
try {
buffer.getBytes(buffer.readerIndex(), target, capacity);
buffer.release();
} catch (Exception e) {
throw new RuntimeException(e);
}
return capacity;
}
}
}
159 changes: 159 additions & 0 deletions netty/src/test/java/io/grpc/clientchannel/ClientChannelService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package io.grpc.clientchannel;

import io.grpc.*;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelOption;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalServerChannel;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

public abstract class ClientChannelService implements BindableService {

static String TUNNEL_SERVICE = "io.grpc.Tunnel";

static MethodDescriptor<ByteBuf, ByteBuf> NEW_TUNNEL_METHOD = MethodDescriptor
.newBuilder(ByteBufMarshaller.INSTANCE, ByteBufMarshaller.INSTANCE)
.setFullMethodName(TUNNEL_SERVICE + "/new")
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.build();

public static void registerServer(
ManagedChannel networkChannel,
Metadata headers,
Consumer<ServerBuilder<?>> serverBuilderConsumer
) {
ClientCall<ByteBuf, ByteBuf> serverCall = networkChannel.newCall(NEW_TUNNEL_METHOD, CallOptions.DEFAULT);

DefaultEventLoopGroup eventLoopGroup = new DefaultEventLoopGroup();

TunnelChannel channel = new TunnelChannel(serverCall::sendMessage);

NettyServerBuilder nettyServerBuilder = NettyServerBuilder
.forAddress(new LocalAddress("clientchannel-" + System.nanoTime()))
.workerEventLoopGroup(eventLoopGroup)
.bossEventLoopGroup(eventLoopGroup);
serverBuilderConsumer.accept(nettyServerBuilder);

Server server = nettyServerBuilder
.withOption(ChannelOption.SO_KEEPALIVE, null)
.withOption(ChannelOption.AUTO_READ, true)
.withOption(ChannelOption.AUTO_CLOSE, false)
.channelFactory(() -> {
return new LocalServerChannel() {
@Override
protected void doBeginRead() {
pipeline().fireChannelRead(channel);
}
};
})
.build();

try {
server.start();
} catch (IOException e) {
throw new RuntimeException(e);
}

serverCall.start(
new ClientCall.Listener<ByteBuf>() {

@Override
public void onReady() {
serverCall.request(Integer.MAX_VALUE);
}

@Override
public void onMessage(ByteBuf bytes) {
if (bytes.readableBytes() > 0) {
channel.pipeline().fireChannelRead(bytes);
}
}

@Override
public void onClose(Status status, Metadata trailers) {
server.shutdown();
}
},
headers
);
}

abstract protected void onChannel(ManagedChannel channel, Metadata headers);

@Override
public ServerServiceDefinition bindService() {
return ServerServiceDefinition
.builder(TUNNEL_SERVICE)
.addMethod(NEW_TUNNEL_METHOD, new TunnelHandler(this))
.build();
}

static class TunnelHandler implements ServerCallHandler<ByteBuf, ByteBuf> {

private final ClientChannelService tunnelClientChannelService;

private final AtomicLong id = new AtomicLong();

public TunnelHandler(ClientChannelService tunnelClientChannelService) {
this.tunnelClientChannelService = tunnelClientChannelService;
}

@Override
public ServerCall.Listener<ByteBuf> startCall(ServerCall<ByteBuf, ByteBuf> call, Metadata headers) {
try {
call.sendHeaders(new Metadata());
call.request(Integer.MAX_VALUE);
} catch (Exception e) {
throw new RuntimeException(e);
}

TunnelChannel nettyChannel = new TunnelChannel(call::sendMessage);
DefaultEventLoopGroup eventLoopGroup = new DefaultEventLoopGroup();

ManagedChannel grpcChannel = NettyChannelBuilder
.forAddress(new LocalAddress("tunnel-" + id.incrementAndGet()))
.eventLoopGroup(eventLoopGroup)
.directExecutor()
.channelFactory(() -> nettyChannel)
.withOption(ChannelOption.SO_KEEPALIVE, null)
.withOption(ChannelOption.AUTO_READ, false)
.withOption(ChannelOption.AUTO_CLOSE, false)
.usePlaintext()
.build();

tunnelClientChannelService.onChannel(grpcChannel, headers);

return new ServerCall.Listener<ByteBuf>() {

@Override
public void onMessage(ByteBuf byteBuf) {
if (byteBuf.readableBytes() > 0) {
nettyChannel.pipeline().fireChannelRead(byteBuf);
}
}

@Override
public void onHalfClose() {
onCancel();
}

@Override
public void onComplete() {
onCancel();
}

@Override
public void onCancel() {
grpcChannel.shutdown();
eventLoopGroup.shutdownGracefully();
}
};
}
}
}
94 changes: 94 additions & 0 deletions netty/src/test/java/io/grpc/clientchannel/TunnelChannel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package io.grpc.clientchannel;

import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.util.ReferenceCountUtil;

import java.net.SocketAddress;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

class TunnelChannel extends AbstractChannel {

private final AtomicBoolean closed = new AtomicBoolean(false);

private final Consumer<ByteBuf> call;

public TunnelChannel(Consumer<ByteBuf> call) {
super(null);
this.call = call;
}

@Override
protected void doWrite(ChannelOutboundBuffer in) {
Object msg;
while ((msg = in.current()) != null) {
ReferenceCountUtil.retain(msg);
call.accept(((ByteBuf) msg).touch());
in.remove();
}
}

@Override
protected AbstractUnsafe newUnsafe() {
return new AbstractUnsafe() {
@Override
public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
promise.setSuccess();
}
};
}

@Override
public boolean isOpen() {
return !closed.get();
}

@Override
public boolean isActive() {
return isOpen();
}

@Override
protected void doClose() {
closed.set(true);
}

//////////////////////////////////

@Override
protected void doBeginRead() {

}

@Override
protected boolean isCompatible(EventLoop loop) {
return true;
}

@Override
public ChannelConfig config() {
return new DefaultChannelConfig(this);
}

@Override
public ChannelMetadata metadata() {
return new ChannelMetadata(false);
}

@Override
protected void doBind(SocketAddress localAddress) {}

@Override
protected void doDisconnect() {}

@Override
protected SocketAddress localAddress0() {
return null;
}

@Override
protected SocketAddress remoteAddress0() {
return null;
}
}
Loading

0 comments on commit cfd19fd

Please sign in to comment.