Skip to content

Commit

Permalink
[#issue] Packet Immutability enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
emeroad committed Feb 26, 2018
1 parent 2c03942 commit 4094cbc
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 80 deletions.
Expand Up @@ -58,7 +58,6 @@ public class DefaultPinpointClientHandler extends SimpleChannelHandler implement
private final Timer channelTimer;

private final ConnectionFactory connectionFactory;
private SocketAddress connectSocketAddress;
private volatile PinpointClient pinpointClient;

private final MessageListener messageListener;
Expand Down Expand Up @@ -133,8 +132,6 @@ public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) thr
}

logger.info("{} channelConnected() started. channel:{}", objectUniqName, channel);
this.connectSocketAddress = channel.getRemoteAddress();
logger.debug("{} connectedSocketAddress:() channel:{}", channel, connectSocketAddress);

SocketStateChangeResult stateChangeResult = state.toConnected();
if (!stateChangeResult.isChange()) {
Expand Down Expand Up @@ -263,7 +260,11 @@ public void response(int requestId, byte[] payload) {

@Override
public SocketAddress getRemoteAddress() {
return connectSocketAddress;
final Channel channel = this.channel;
if (channel == null) {
return null;
}
return channel.getRemoteAddress();
}

private void await(ChannelFuture channelFuture) {
Expand Down Expand Up @@ -314,15 +315,15 @@ public Future<ResponseMessage> request(byte[] bytes) {
throw new NullPointerException("bytes");
}

boolean isEnable = state.isEnableCommunication();
final boolean isEnable = state.isEnableCommunication();
if (!isEnable) {
DefaultFuture<ResponseMessage> closedException = new DefaultFuture<ResponseMessage>();
closedException.setFailure(new PinpointSocketException("invalid state:" + state.getCurrentStateCode() + " channel:" + channel));
return closedException;
}

RequestPacket request = new RequestPacket(bytes);
final ChannelWriteFailListenableFuture<ResponseMessage> messageFuture = this.requestManager.register(request, clientOption.getTimeoutMillis());
final int requestId = this.requestManager.nextRequestId();
final RequestPacket request = new RequestPacket(requestId, bytes);
final ChannelWriteFailListenableFuture<ResponseMessage> messageFuture = this.requestManager.register(request.getRequestId(), clientOption.getTimeoutMillis());

write0(request, messageFuture);
return messageFuture;
Expand Down Expand Up @@ -562,15 +563,15 @@ private void reconnect() {
}

private ChannelFuture write0(Object message) {
return write0(message, null);
return channel.write(message);
}

private ChannelFuture write0(Object message, ChannelFutureListener futureListener) {
ChannelFuture future = channel.write(message);
if (futureListener != null) {
future.addListener(futureListener);
if (futureListener == null) {
throw new NullPointerException("futureListener must not be null");
}

ChannelFuture future = channel.write(message);
future.addListener(futureListener);
return future;
}

Expand Down
Expand Up @@ -79,8 +79,7 @@ public boolean fireFailure() {
return failureEventHandler;
}


private void addTimeoutTask(long timeoutMillis, DefaultFuture future) {
private void addTimeoutTask(DefaultFuture future, long timeoutMillis) {
if (future == null) {
throw new NullPointerException("future");
}
Expand All @@ -93,7 +92,7 @@ private void addTimeoutTask(long timeoutMillis, DefaultFuture future) {
}
}

private int getNextRequestId() {
public int nextRequestId() {
return this.requestId.getAndIncrement();
}

Expand Down Expand Up @@ -135,30 +134,54 @@ public void messageReceived(RequestPacket requestPacket, Channel channel) {
logger.error("unexpectedMessage received:{} address:{}", requestPacket, channel.getRemoteAddress());
}

public ChannelWriteFailListenableFuture<ResponseMessage> register(RequestPacket requestPacket) {
return register(requestPacket, defaultTimeoutMillis);
public ChannelWriteFailListenableFuture<ResponseMessage> register(int requestId) {
return register(requestId, defaultTimeoutMillis);
}

public ChannelWriteFailListenableFuture<ResponseMessage> register(RequestPacket requestPacket, long timeoutMillis) {
public ChannelWriteFailListenableFuture<ResponseMessage> register(int requestId, long timeoutMillis) {
// shutdown check
final int requestId = getNextRequestId();
requestPacket.setRequestId(requestId);

final ChannelWriteFailListenableFuture<ResponseMessage> future = new ChannelWriteFailListenableFuture<ResponseMessage>(timeoutMillis);
final ChannelWriteFailListenableFuture<ResponseMessage> responseFuture = new ChannelWriteFailListenableFuture<ResponseMessage>(timeoutMillis);

final DefaultFuture old = this.requestMap.put(requestId, future);
final DefaultFuture old = this.requestMap.put(requestId, responseFuture);
if (old != null) {
throw new PinpointSocketException("unexpected error. old future exist:" + old + " id:" + requestId);
}

// when future fails, put a handle in order to remove a failed future in the requestMap.
FailureEventHandler removeTable = createFailureEventHandler(requestId);
future.setFailureEventHandler(removeTable);
responseFuture.setFailureEventHandler(removeTable);

addTimeoutTask(timeoutMillis, future);
return future;
addTimeoutTask(responseFuture, timeoutMillis);
return responseFuture;
}

// public ChannelWriteFailListenableFuture<ResponseMessage> register(final int requestId, final long timeoutMillis) {
// // shutdown check
// final ChannelWriteFailListenableFuture<ResponseMessage> responseFuture = new ChannelWriteFailListenableFuture<ResponseMessage>(timeoutMillis) {
// @Override
// public void operationComplete(ChannelFuture future) throws Exception {
// fireWriteComplete(requestId, future, this, timeoutMillis);
// }
// };
// return responseFuture;
// }
//
// private void fireWriteComplete(int requestId, ChannelFuture ioWriteFuture, DefaultFuture<ResponseMessage> responseFuture, long timeoutMillis) {
// if (ioWriteFuture.isSuccess()) {
// final DefaultFuture old = requestMap.put(requestId, responseFuture);
// if (old != null) {
// PinpointSocketException pinpointSocketException = new PinpointSocketException("unexpected error. old responseFuture exist:" + old + " id:" + requestId);
// responseFuture.setFailure(pinpointSocketException);
// return;
// } else {
// FailureEventHandler removeTable = createFailureEventHandler(requestId);
// responseFuture.setFailureEventHandler(removeTable);
// addTimeoutTask(responseFuture, timeoutMillis);
// }
// } else {
// responseFuture.setFailure(ioWriteFuture.getCause());
// }
// }

public void close() {
logger.debug("close()");
Expand All @@ -175,7 +198,7 @@ public void close() {
// }
int requestFailCount = 0;
for (Map.Entry<Integer, DefaultFuture<ResponseMessage>> entry : requestMap.entrySet()) {
if(entry.getValue().setFailure(closed)) {
if (entry.getValue().setFailure(closed)) {
requestFailCount++;
}
}
Expand Down
Expand Up @@ -24,14 +24,7 @@
*/
public class RequestPacket extends BasicPacket {

private int requestId;

public RequestPacket() {
}

public RequestPacket(byte[] payload) {
super(payload);
}
private final int requestId;

public RequestPacket(int requestId, byte[] payload) {
super(payload);
Expand All @@ -42,9 +35,6 @@ public int getRequestId() {
return requestId;
}

public void setRequestId(int requestId) {
this.requestId = requestId;
}

@Override
public short getPacketType() {
Expand Down Expand Up @@ -77,8 +67,7 @@ public static RequestPacket readBuffer(short packetType, ChannelBuffer buffer) {
if (payload == null) {
return null;
}
final RequestPacket requestPacket = new RequestPacket(payload.array());
requestPacket.setRequestId(messageId);
final RequestPacket requestPacket = new RequestPacket(messageId, payload.array());
return requestPacket;
}

Expand Down
Expand Up @@ -23,14 +23,7 @@
* @author emeroad
*/
public class ResponsePacket extends BasicPacket {
private int requestId;

public ResponsePacket() {
}

public ResponsePacket(byte[] payload) {
super(payload);
}
private final int requestId;

public ResponsePacket(int requestId, byte[] payload) {
super(payload);
Expand All @@ -41,10 +34,6 @@ public int getRequestId() {
return requestId;
}

public void setRequestId(int requestId) {
this.requestId = requestId;
}

@Override
public short getPacketType() {
return PacketType.APPLICATION_RESPONSE;
Expand Down Expand Up @@ -74,9 +63,7 @@ public static ResponsePacket readBuffer(short packetType, ChannelBuffer buffer)
if (payload == null) {
return null;
}
ResponsePacket responsePacket = new ResponsePacket(payload.array());
responsePacket.setRequestId(messageId);

ResponsePacket responsePacket = new ResponsePacket(messageId, payload.array());
return responsePacket;

}
Expand Down
Expand Up @@ -205,10 +205,11 @@ public Future<ResponseMessage> request(byte[] payload) {
throw new IllegalStateException("Request fail. Error: Illegal State. pinpointServer:" + toString());
}

RequestPacket requestPacket = new RequestPacket(payload);
ChannelWriteFailListenableFuture<ResponseMessage> messageFuture = this.requestManager.register(requestPacket);
write0(requestPacket, messageFuture);
return messageFuture;
final int requestId = this.requestManager.nextRequestId();
RequestPacket requestPacket = new RequestPacket(requestId, payload);
ChannelWriteFailListenableFuture<ResponseMessage> responseFuture = this.requestManager.register(requestPacket.getRequestId());
write0(requestPacket, responseFuture);
return responseFuture;
}

@Override
Expand Down
Expand Up @@ -20,9 +20,11 @@
import com.navercorp.pinpoint.rpc.Future;
import com.navercorp.pinpoint.rpc.TestAwaitTaskUtils;
import com.navercorp.pinpoint.rpc.TestAwaitUtils;
import com.navercorp.pinpoint.rpc.packet.RequestPacket;
import org.jboss.netty.util.HashedWheelTimer;
import org.jboss.netty.util.Timer;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -36,13 +38,27 @@ public class RequestManagerTest {

private final Logger logger = LoggerFactory.getLogger(this.getClass());

private Timer timer = getTimer();

@Before
public void setUp() throws Exception {
this.timer = getTimer();

}

@After
public void tearDown() throws Exception {
if (this.timer != null) {
this.timer.stop();
}
}

@Test
public void testRegisterRequest() throws Exception {
HashedWheelTimer timer = getTimer();
RequestManager requestManager = new RequestManager(timer, 3000);
try {
RequestPacket packet = new RequestPacket(new byte[0]);
final Future future = requestManager.register(packet, 50);
final int requestId = requestManager.nextRequestId();
final Future future = requestManager.register(requestId, 50);

TestAwaitUtils.await(new TestAwaitTaskUtils() {
@Override
Expand All @@ -57,25 +73,23 @@ public boolean checkCompleted() {
logger.debug(future.getCause().getMessage());
} finally {
requestManager.close();
timer.stop();
}
}


@Test
public void testRemoveMessageFuture() throws Exception {
HashedWheelTimer timer = getTimer();
RequestManager requestManager = new RequestManager(timer, 3000);
try {
RequestPacket packet = new RequestPacket(1, new byte[0]);
DefaultFuture future = requestManager.register(packet, 2000);
int requestId = requestManager.nextRequestId();

DefaultFuture future = requestManager.register(requestId, 2000);
future.setFailure(new RuntimeException());

Future nullFuture = requestManager.removeMessageFuture(packet.getRequestId());
Future nullFuture = requestManager.removeMessageFuture(requestId);
Assert.assertNull(nullFuture);
} finally {
requestManager.close();
timer.stop();
}

}
Expand All @@ -84,12 +98,7 @@ private HashedWheelTimer getTimer() {
return new HashedWheelTimer(10, TimeUnit.MICROSECONDS);
}

// @Test
public void testTimerStartTiming() throws InterruptedException {
HashedWheelTimer timer = new HashedWheelTimer(1000, TimeUnit.MILLISECONDS);
timer.start();
timer.stop();
}


@Test
public void testClose() throws Exception {
Expand Down
Expand Up @@ -179,8 +179,7 @@ private void sendRegisterPacket(OutputStream outputStream, Map<String, Object> p
}

private void sendSimpleRequestPacket(OutputStream outputStream) throws ProtocolException, IOException {
RequestPacket packet = new RequestPacket(new byte[0]);
packet.setRequestId(10);
RequestPacket packet = new RequestPacket(10, new byte[0]);

ByteBuffer bb = packet.toBuffer().toByteBuffer(0, packet.toBuffer().writerIndex());
IOUtils.write(outputStream, bb.array());
Expand Down
Expand Up @@ -123,8 +123,7 @@ private void sendRegisterPacket(OutputStream outputStream, Map<String, Object> p
}

private void sendSimpleRequestPacket(OutputStream outputStream) throws ProtocolException, IOException {
RequestPacket packet = new RequestPacket(new byte[0]);
packet.setRequestId(10);
RequestPacket packet = new RequestPacket(10, new byte[0]);

ByteBuffer bb = packet.toBuffer().toByteBuffer(0, packet.toBuffer().writerIndex());
IOUtils.write(outputStream, bb.array());
Expand Down

0 comments on commit 4094cbc

Please sign in to comment.