Skip to content

Commit

Permalink
增加压测客户端
Browse files Browse the repository at this point in the history
  • Loading branch information
夜色 committed Dec 8, 2016
1 parent 79f9658 commit a444897
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 46 deletions.
Expand Up @@ -95,4 +95,11 @@ public void setUserId(String userId) {
this.userId = userId;
}

@Override
public String toString() {
return "{" +
"deviceId='" + deviceId + '\'' +
", userId='" + userId + '\'' +
'}';
}
}
Expand Up @@ -28,6 +28,7 @@
import com.mpush.api.protocol.Packet;
import com.mpush.cache.redis.RedisKey;
import com.mpush.cache.redis.manager.RedisManager;
import com.mpush.common.ErrorCode;
import com.mpush.common.message.*;
import com.mpush.common.security.AesCipher;
import com.mpush.common.security.CipherBox;
Expand All @@ -36,10 +37,7 @@
import com.mpush.tools.thread.NamedPoolThreadFactory;
import com.mpush.tools.thread.ThreadNames;
import io.netty.channel.*;
import io.netty.util.HashedWheelTimer;
import io.netty.util.Timeout;
import io.netty.util.Timer;
import io.netty.util.TimerTask;
import io.netty.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -56,9 +54,16 @@
public final class ConnClientChannelHandler extends ChannelInboundHandlerAdapter {
private static final Logger LOGGER = LoggerFactory.getLogger(ConnClientChannelHandler.class);
private static final Timer HASHED_WHEEL_TIMER = new HashedWheelTimer(new NamedPoolThreadFactory(ThreadNames.T_CONN_TIMER));
public static final AttributeKey<ClientConfig> CONFIG_KEY = AttributeKey.newInstance("clientConfig");
public static final TestStatistics STATISTICS = new TestStatistics();

private final Connection connection = new NettyConnection();
private final ClientConfig clientConfig;
private ClientConfig clientConfig;
private boolean stressingTest;

public ConnClientChannelHandler() {
stressingTest = true;
}

public ConnClientChannelHandler(ClientConfig clientConfig) {
this.clientConfig = clientConfig;
Expand All @@ -75,15 +80,19 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
Packet packet = (Packet) msg;
Command command = Command.toCMD(packet.cmd);
if (command == Command.HANDSHAKE) {
int connectedNum = STATISTICS.connectedNum.incrementAndGet();
connection.getSessionContext().changeCipher(new AesCipher(clientConfig.getClientKey(), clientConfig.getIv()));
HandshakeOkMessage message = new HandshakeOkMessage(packet, connection);
byte[] sessionKey = CipherBox.I.mixKey(clientConfig.getClientKey(), message.serverKey);
connection.getSessionContext().changeCipher(new AesCipher(sessionKey, clientConfig.getIv()));
startHeartBeat(message.heartbeat);
LOGGER.warn(">>> handshake success, message={}, sessionKey={}", message, sessionKey);
LOGGER.info(">>> handshake success, clientConfig={}, connectedNum={}", clientConfig, connectedNum);
bindUser(clientConfig);
saveToRedisForFastConnection(clientConfig, message.sessionId, message.expireTime, sessionKey);
if (!stressingTest) {
saveToRedisForFastConnection(clientConfig, message.sessionId, message.expireTime, sessionKey);
}
} else if (command == Command.FAST_CONNECT) {
int connectedNum = STATISTICS.connectedNum.incrementAndGet();
String cipherStr = clientConfig.getCipher();
String[] cs = cipherStr.split(",");
byte[] key = AesCipher.toArray(cs[0]);
Expand All @@ -93,39 +102,39 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
FastConnectOkMessage message = new FastConnectOkMessage(packet, connection);
startHeartBeat(message.heartbeat);
bindUser(clientConfig);
LOGGER.warn(">>> fast connect success, message=" + message);
LOGGER.info(">>> fast connect success, clientConfig={}, connectedNum={}", clientConfig, connectedNum);
} else if (command == Command.KICK) {
KickUserMessage message = new KickUserMessage(packet, connection);
LOGGER.error(">>> receive kick user userId={}, deviceId={}, message={},", clientConfig.getUserId(), clientConfig.getDeviceId(), message);
ctx.close();
} else if (command == Command.ERROR) {
ErrorMessage errorMessage = new ErrorMessage(packet, connection);
LOGGER.error(">>> receive an error packet=" + errorMessage);
} else if (command == Command.BIND) {

} else if (command == Command.PUSH) {
int receivePushNum = STATISTICS.receivePushNum.incrementAndGet();

PushMessage message = new PushMessage(packet, connection);
LOGGER.warn(">>> receive an push message, content=" + new String(message.content, Constants.UTF_8));
LOGGER.info(">>> receive an push message, content={}, receivePushNum={}", new String(message.content, Constants.UTF_8), receivePushNum);

if (message.needAck()) {
AckMessage.from(message).sendRaw();
LOGGER.warn(">>> send ack success for sessionId={}", message.getSessionId());
LOGGER.info(">>> send ack success for sessionId={}", message.getSessionId());
}

} else if (command == Command.HEARTBEAT) {
LOGGER.warn(">>> receive a heartbeat pong...");
LOGGER.info(">>> receive a heartbeat pong...");
} else if (command == Command.OK) {
OkMessage okMessage = new OkMessage(packet, connection);
LOGGER.warn(">>> receive an success packet=" + okMessage);
Map<String, String> headers = new HashMap<>();
headers.put(Constants.HTTP_HEAD_READ_TIMEOUT, "10000");
HttpRequestMessage message = new HttpRequestMessage(connection);
message.headers = headers;
message.uri = "http://baidu.com";
message.send();
int bindUserNum = STATISTICS.bindUserNum.get();
if (okMessage.cmd == Command.BIND.cmd) {
bindUserNum = STATISTICS.bindUserNum.incrementAndGet();
}

LOGGER.info(">>> receive an success message={}, bindUserNum={}", okMessage, bindUserNum);

} else if (command == Command.HTTP_PROXY) {
HttpResponseMessage message = new HttpResponseMessage(packet, connection);
LOGGER.warn(">>> receive a http response, message={}, body={}",
LOGGER.info(">>> receive a http response, message={}, body={}",
message, message.body == null ? null : new String(message.body, Constants.UTF_8));
}
}
Expand All @@ -142,36 +151,45 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
LOGGER.info("client connect channel={}", ctx.channel());
int clientNum = STATISTICS.clientNum.incrementAndGet();
LOGGER.info("client connect channel={}, clientNum={}", ctx.channel(), clientNum);
if (clientConfig == null) {
clientConfig = ctx.channel().attr(CONFIG_KEY).getAndRemove();
}
connection.init(ctx.channel(), true);
tryFastConnect();
if (stressingTest) {
handshake();
} else {
tryFastConnect();
}
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
int clientNum = STATISTICS.clientNum.decrementAndGet();
connection.close();
EventBus.I.post(new ConnectionCloseEvent(connection));
LOGGER.info("client disconnect connection={}", connection);
LOGGER.info("client disconnect channel={}, clientNum={}", connection, clientNum);
}

private void tryFastConnect() {

Map<String, String> sessionTickets = getFastConnectionInfo(clientConfig.getDeviceId());

if (sessionTickets == null) {
handshake(clientConfig);
handshake();
return;
}
String sessionId = sessionTickets.get("sessionId");
if (sessionId == null) {
handshake(clientConfig);
handshake();
return;
}
String expireTime = sessionTickets.get("expireTime");
if (expireTime != null) {
long exp = Long.parseLong(expireTime);
if (exp < System.currentTimeMillis()) {
handshake(clientConfig);
handshake();
return;
}
}
Expand All @@ -186,7 +204,7 @@ private void tryFastConnect() {
if (channelFuture.isSuccess()) {
clientConfig.setCipher(cipher);
} else {
handshake(clientConfig);
handshake();
}
});
LOGGER.debug("<<< send fast connect message={}", message);
Expand Down Expand Up @@ -215,14 +233,14 @@ private Map<String, String> getFastConnectionInfo(String deviceId) {
return RedisManager.I.get(key, Map.class);
}

private void handshake(ClientConfig client) {
private void handshake() {
HandshakeMessage message = new HandshakeMessage(connection);
message.clientKey = client.getClientKey();
message.iv = client.getIv();
message.clientVersion = client.getClientVersion();
message.deviceId = client.getDeviceId();
message.osName = client.getOsName();
message.osVersion = client.getOsVersion();
message.clientKey = clientConfig.getClientKey();
message.iv = clientConfig.getIv();
message.clientVersion = clientConfig.getClientVersion();
message.deviceId = clientConfig.getDeviceId();
message.osName = clientConfig.getOsName();
message.osVersion = clientConfig.getOsVersion();
message.timestamp = System.currentTimeMillis();
message.send();
LOGGER.debug("<<< send handshake message={}", message);
Expand Down
@@ -0,0 +1,44 @@
/*
* (C) Copyright 2015-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Contributors:
* ohun@live.cn (夜色)
*/

package com.mpush.client.connect;

import java.util.concurrent.atomic.AtomicInteger;

/**
* Created by ohun on 2016/12/8.
*
* @author ohun@live.cn (夜色)
*/
public final class TestStatistics {
public AtomicInteger clientNum = new AtomicInteger();
public AtomicInteger connectedNum = new AtomicInteger();
public AtomicInteger bindUserNum = new AtomicInteger();
public AtomicInteger receivePushNum = new AtomicInteger();

@Override
public String toString() {
return "TestStatistics{" +
"clientNum=" + clientNum +
", connectedNum=" + connectedNum +
", bindUserNum=" + bindUserNum +
", receivePushNum=" + receivePushNum +
'}';
}
}
70 changes: 66 additions & 4 deletions mpush-test/src/test/java/com/mpush/test/client/ConnClientBoot.java
Expand Up @@ -22,20 +22,36 @@
import com.google.common.collect.Lists;
import com.mpush.api.service.BaseService;
import com.mpush.api.service.Listener;
import com.mpush.api.service.ServiceException;
import com.mpush.cache.redis.manager.RedisManager;
import com.mpush.client.connect.ClientConfig;
import com.mpush.client.connect.ConnClientChannelHandler;
import com.mpush.netty.codec.PacketDecoder;
import com.mpush.netty.codec.PacketEncoder;
import com.mpush.zk.ZKClient;
import com.mpush.zk.listener.ZKServerNodeWatcher;
import com.mpush.zk.node.ZKServerNode;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.AttributeKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.util.List;

public final class ConnClientBoot extends BaseService {
private static final Logger LOGGER = LoggerFactory.getLogger(ConnClientBoot.class);

private final ZKServerNodeWatcher watcher = ZKServerNodeWatcher.buildConnect();
private Bootstrap bootstrap;
private NioEventLoopGroup workerGroup;

public List<ZKServerNode> getServers() {
return Lists.newArrayList(watcher.getCache().values());
}

@Override
protected void doStart(Listener listener) throws Throwable {
Expand All @@ -52,12 +68,58 @@ public void onFailure(Throwable cause) {
listener.onFailure(cause);
}
});

this.workerGroup = new NioEventLoopGroup();
this.bootstrap = new Bootstrap();
bootstrap.group(workerGroup)//
.option(ChannelOption.TCP_NODELAY, true)//
.option(ChannelOption.SO_REUSEADDR, true)//
.option(ChannelOption.SO_KEEPALIVE, true)//
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)//
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 4000)
.channel(NioSocketChannel.class);

bootstrap.handler(new ChannelInitializer<SocketChannel>() { // (4)
@Override
public void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast("decoder", new PacketDecoder());
ch.pipeline().addLast("encoder", PacketEncoder.INSTANCE);
ch.pipeline().addLast("handler", new ConnClientChannelHandler());
}
});
}

@Override
protected void doStop(Listener listener) throws Throwable {
if (workerGroup != null) workerGroup.shutdownGracefully();
ZKClient.I.syncStop();
RedisManager.I.destroy();
listener.onSuccess();
}

public List<ZKServerNode> getServers() {
return Lists.newArrayList(watcher.getCache().values());
}


public void connect(String host, int port, ClientConfig clientConfig) {
ChannelFuture future = bootstrap.connect(new InetSocketAddress(host, port));
future.channel().attr(ConnClientChannelHandler.CONFIG_KEY).set(clientConfig);
future.addListener(f -> {
if (f.isSuccess()) {
LOGGER.info("start netty client success, host={}, port={}", host, port);
} else {
LOGGER.error("start netty client failure, host={}, port={}", host, port, f.cause());
}
});
future.syncUninterruptibly();
}

public Bootstrap getBootstrap() {
return bootstrap;
}

public NioEventLoopGroup getWorkerGroup() {
return workerGroup;
}
}

0 comments on commit a444897

Please sign in to comment.