Skip to content

Commit

Permalink
[#noissue] Refactor WebSocketHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
emeroad committed Nov 3, 2017
1 parent 6dad8a2 commit aea766b
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 194 deletions.
Expand Up @@ -17,15 +17,16 @@
package com.navercorp.pinpoint.web.config;


import com.navercorp.pinpoint.common.util.ArrayUtils;
import com.navercorp.pinpoint.web.websocket.PinpointWebSocketHandler;
import com.navercorp.pinpoint.web.websocket.PinpointWebSocketHandlerManager;
import com.navercorp.pinpoint.web.websocket.WebSocketSessionContextPrepareHandshakeInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistration;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;

Expand All @@ -49,28 +50,23 @@ public class WebSocketConfig implements WebSocketConfigurer {

@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
String[] allowedOriginArray = getAllowedOriginArray(configProperties.getWebSocketAllowedOrigins());
final String[] allowedOriginArray = getAllowedOriginArray(configProperties.getWebSocketAllowedOrigins());

for (PinpointWebSocketHandler handler : handlerRepository.getWebSocketHandlerRepository()) {
registry.addHandler(handler, handler.getRequestMapping() + WEBSOCKET_SUFFIX).addInterceptors(new HttpSessionHandshakeInterceptor()).setAllowedOrigins(allowedOriginArray);
String path = handler.getRequestMapping() + WEBSOCKET_SUFFIX;
final WebSocketHandlerRegistration webSocketHandlerRegistration = registry.addHandler(handler, path);

webSocketHandlerRegistration.addInterceptors(new HttpSessionHandshakeInterceptor());
webSocketHandlerRegistration.addInterceptors(new WebSocketSessionContextPrepareHandshakeInterceptor());
webSocketHandlerRegistration.setAllowedOrigins(allowedOriginArray);
}
}

private String[] getAllowedOriginArray(String allowedOrigins) {
if (!StringUtils.hasText(allowedOrigins)) {
return DEFAULT_ALLOWED_ORIGIN;
}

String[] splitString = StringUtils.split(allowedOrigins, ",");
if (ArrayUtils.isEmpty(splitString)) {
return new String[]{StringUtils.trimAllWhitespace(allowedOrigins)};
} else {
String[] result = new String[splitString.length];
for (int i = 0; i < splitString.length; i++) {
result[i] = StringUtils.trimAllWhitespace(splitString[i]);
}
return result;
}
return StringUtils.tokenizeToStringArray(allowedOrigins, ",");
}

}
Expand Up @@ -41,6 +41,7 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Timer;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -51,7 +52,6 @@
public class ActiveThreadCountHandler extends TextWebSocketHandler implements PinpointWebSocketHandler {

public static final String APPLICATION_NAME_KEY = "applicationName";
private static final String HEALTH_CHECK_WAIT_KEY = "pinpoint.healthCheck.wait";

static final String API_ACTIVE_THREAD_COUNT = "activeThreadCount";

Expand Down Expand Up @@ -107,9 +107,13 @@ public void start() {
PinpointThreadFactory flushThreadFactory = new PinpointThreadFactory(ClassUtils.simpleClassName(this) + "-Flush-Thread", true);
webSocketFlushExecutor = new SimpleOrderedThreadPool(CpuUtils.cpuCount(), 65535, flushThreadFactory);

flushTimer = new java.util.Timer(ClassUtils.simpleClassName(this) + "-Flush-Timer", true);
healthCheckTimer = new java.util.Timer(ClassUtils.simpleClassName(this) + "-HealthCheck-Timer", true);
reactiveTimer = new java.util.Timer(ClassUtils.simpleClassName(this) + "-Reactive-Timer", true);
flushTimer = newJavaTimer(ClassUtils.simpleClassName(this) + "-Flush-Timer");
healthCheckTimer = newJavaTimer(ClassUtils.simpleClassName(this) + "-HealthCheck-Timer");
reactiveTimer = newJavaTimer(ClassUtils.simpleClassName(this) + "-Reactive-Timer");
}

public Timer newJavaTimer(String timerName) {
return new java.util.Timer(timerName, true);
}

@Override
Expand Down Expand Up @@ -143,12 +147,19 @@ public String getRequestMapping() {
return requestMapping;
}

private WebSocketSessionContext getSessionContext(WebSocketSession webSocketSession) {
final WebSocketSessionContext sessionContext = WebSocketSessionContext.getSessionContext(webSocketSession);
if (sessionContext == null) {
throw new IllegalStateException("WebSocketSessionContext not initialized");
}
return sessionContext;
}

@Override
public void afterConnectionEstablished(WebSocketSession newSession) throws Exception {
logger.info("ConnectionEstablished. session:{}", newSession);

synchronized (lock) {
newSession.getAttributes().put(HEALTH_CHECK_WAIT_KEY, new AtomicBoolean(false));
sessionRepository.add(newSession);
boolean turnOn = onTimerTask.compareAndSet(false, true);
if (turnOn) {
Expand All @@ -164,8 +175,9 @@ public void afterConnectionEstablished(WebSocketSession newSession) throws Excep
public void afterConnectionClosed(WebSocketSession closeSession, CloseStatus status) throws Exception {
logger.info("ConnectionClose. session:{}, caused:{}", closeSession, status);

final WebSocketSessionContext sessionContext = getSessionContext(closeSession);
synchronized (lock) {
unbindingResponseAggregator(closeSession);
unbindingResponseAggregator(closeSession, sessionContext);

sessionRepository.remove(closeSession);
if (sessionRepository.isEmpty()) {
Expand Down Expand Up @@ -203,27 +215,33 @@ private void handleRequestMessage0(WebSocketSession webSocketSession, RequestMes
return;
}

String command = requestMessage.getCommand();

final String command = requestMessage.getCommand();
if (API_ACTIVE_THREAD_COUNT.equals(command)) {
String applicationName = MapUtils.getString(requestMessage.getParams(), APPLICATION_NAME_KEY);
if (applicationName != null) {
synchronized (lock) {
if (StringUtils.equals(applicationName, (String) webSocketSession.getAttributes().get(APPLICATION_NAME_KEY))) {
return;
}
handleActiveThreadCount(webSocketSession, requestMessage);
} else {
logger.debug("unknown command:{}", command);
}
}

unbindingResponseAggregator(webSocketSession);
if (webSocketSession.isOpen()) {
bindingResponseAggregator(webSocketSession, applicationName);
} else {
logger.warn("WebSocketSession is not opened. skip binding.");
}
private void handleActiveThreadCount(WebSocketSession webSocketSession, RequestMessage requestMessage) {
final String applicationName = MapUtils.getString(requestMessage.getParameters(), APPLICATION_NAME_KEY);
if (applicationName != null) {
final WebSocketSessionContext sessionContext = getSessionContext(webSocketSession);
synchronized (lock) {
if (StringUtils.equals(applicationName, sessionContext.getApplicationName())) {
return;
}

unbindingResponseAggregator(webSocketSession, sessionContext);
if (webSocketSession.isOpen()) {
bindingResponseAggregator(webSocketSession, sessionContext, applicationName);
} else {
logger.warn("WebSocketSession is not opened. skip binding.");
}
}
}
}

private void closeSession(WebSocketSession session, CloseStatus status) {
try {
session.close(status);
Expand All @@ -233,10 +251,8 @@ private void closeSession(WebSocketSession session, CloseStatus status) {
}

private void handlePongMessage0(WebSocketSession webSocketSession, PongMessage pongMessage) {
Object healthCheckWait = webSocketSession.getAttributes().get(HEALTH_CHECK_WAIT_KEY);
if (healthCheckWait instanceof AtomicBoolean) {
((AtomicBoolean) healthCheckWait).compareAndSet(true, false);
}
final WebSocketSessionContext sessionContext = getSessionContext(webSocketSession);
sessionContext.changeHealthCheckSuccess();
}

@Override
Expand All @@ -246,10 +262,10 @@ protected void handlePongMessage(WebSocketSession webSocketSession, org.springfr
super.handlePongMessage(webSocketSession, message);
}

private void bindingResponseAggregator(WebSocketSession webSocketSession, String applicationName) {
private void bindingResponseAggregator(WebSocketSession webSocketSession, WebSocketSessionContext webSocketSessionContext, String applicationName) {
logger.info("bindingResponseAggregator. session:{}, applicationName:{}.", webSocketSession, applicationName);

webSocketSession.getAttributes().put(APPLICATION_NAME_KEY, applicationName);
webSocketSessionContext.setApplicationName(applicationName);
if (StringUtils.isEmpty(applicationName)) {
return;
}
Expand All @@ -264,8 +280,9 @@ private void bindingResponseAggregator(WebSocketSession webSocketSession, String
responseAggregator.addWebSocketSession(webSocketSession);
}

private void unbindingResponseAggregator(WebSocketSession webSocketSession) {
String applicationName = (String) webSocketSession.getAttributes().get(APPLICATION_NAME_KEY);
private void unbindingResponseAggregator(WebSocketSession webSocketSession, WebSocketSessionContext sessionContext) {

final String applicationName = sessionContext.getApplicationName();
logger.info("unbindingResponseAggregator. session:{}, applicationName:{}.", webSocketSession, applicationName);
if (StringUtils.isEmpty(applicationName)) {
return;
Expand Down Expand Up @@ -343,38 +360,19 @@ public void run() {
logger.info("HealthCheckTimerTask started.");

// check session state.
List<WebSocketSession> webSocketSessionList = new ArrayList<>(sessionRepository);
for (WebSocketSession session : webSocketSessionList) {
if (!session.isOpen()) {
continue;
}

Object untilWait = session.getAttributes().get(HEALTH_CHECK_WAIT_KEY);
if (untilWait instanceof AtomicBoolean) {
if (((AtomicBoolean) untilWait).get()) {
closeSession(session, CloseStatus.SESSION_NOT_RELIABLE);
}
} else {
session.getAttributes().put(HEALTH_CHECK_WAIT_KEY, new AtomicBoolean(false));
}
}
List<WebSocketSession> snapshot = filterHealthCheckSuccess(sessionRepository);

// send healthCheck packet
String pingTextMessage = messageConverter.getPingTextMessage();
TextMessage pingMessage = new TextMessage(pingTextMessage);

webSocketSessionList = new ArrayList<>(sessionRepository);
for (WebSocketSession session : webSocketSessionList) {
for (WebSocketSession session : snapshot) {
if (!session.isOpen()) {
continue;
}

Object untilWait = session.getAttributes().get(HEALTH_CHECK_WAIT_KEY);
if (untilWait instanceof AtomicBoolean) {
((AtomicBoolean) untilWait).compareAndSet(false, true);
} else {
session.getAttributes().put(HEALTH_CHECK_WAIT_KEY, new AtomicBoolean(true));
}
// reset healthCheck state
final WebSocketSessionContext sessionContext = getSessionContext(session);
sessionContext.changeHealthCheckFail();

sendPingMessage(session, pingMessage);
}
Expand All @@ -385,6 +383,25 @@ public void run() {
}
}

private List<WebSocketSession> filterHealthCheckSuccess(List<WebSocketSession> sessionRepository) {
List<WebSocketSession> snapshot = new ArrayList<>(sessionRepository.size());

for (WebSocketSession session : sessionRepository) {
if (!session.isOpen()) {
continue;
}

final WebSocketSessionContext sessionContext = getSessionContext(session);
if (!sessionContext.getHealthCheckState()) {
// health check fail
closeSession(session, CloseStatus.SESSION_NOT_RELIABLE);
} else {
snapshot.add(session);
}
}
return snapshot;
}


private void sendPingMessage(WebSocketSession session, TextMessage pingMessage) {
try {
Expand Down
Expand Up @@ -248,7 +248,8 @@ public void flush(Executor executor) throws Exception {
private TextMessage createWebSocketTextMessage(AgentActiveThreadCountList activeThreadCountList) {
Map resultMap = createResultMap(activeThreadCountList, System.currentTimeMillis());
try {
TextMessage responseTextMessage = new TextMessage(messageConverter.getResponseTextMessage(ActiveThreadCountHandler.API_ACTIVE_THREAD_COUNT, resultMap));
String response = messageConverter.getResponseTextMessage(ActiveThreadCountHandler.API_ACTIVE_THREAD_COUNT, resultMap);
TextMessage responseTextMessage = new TextMessage(response);
return responseTextMessage;
} catch (JsonProcessingException e) {
logger.warn("failed while to convert message. applicationName:{}, original:{}, message:{}.", applicationName, resultMap, e.getMessage(), e);
Expand Down
Expand Up @@ -16,10 +16,10 @@

package com.navercorp.pinpoint.web.websocket;

import com.google.common.collect.ImmutableList;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
Expand All @@ -30,7 +30,7 @@ public class PinpointWebSocketHandlerManager {
private final List<PinpointWebSocketHandler> webSocketHandlerRepository;

public PinpointWebSocketHandlerManager(List<PinpointWebSocketHandler> pinpointWebSocketHandlers) {
webSocketHandlerRepository = Collections.unmodifiableList(new ArrayList<>(pinpointWebSocketHandlers));
webSocketHandlerRepository = ImmutableList.copyOf(pinpointWebSocketHandlers);
}

@PostConstruct
Expand All @@ -48,7 +48,7 @@ public void tearDown() {
}

public List<PinpointWebSocketHandler> getWebSocketHandlerRepository() {
return new ArrayList<>(webSocketHandlerRepository);
return webSocketHandlerRepository;
}

}
@@ -0,0 +1,72 @@
/*
* Copyright 2017 NAVER Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.navercorp.pinpoint.web.websocket;

import org.springframework.web.socket.WebSocketSession;

import java.util.concurrent.atomic.AtomicBoolean;

/**
* @author Woonduk Kang(emeroad)
*/
public class WebSocketSessionContext {

public static final String WEBSOCKET_SESSION_CONTEXT_KEY = "pinpoint.websocket.session.context.key";

private final AtomicBoolean healthCheckSuccess;
private String applicationName;

static WebSocketSessionContext getSessionContext(WebSocketSession webSocketSession) {
final Object context = webSocketSession.getAttributes().get(WEBSOCKET_SESSION_CONTEXT_KEY);
if (context instanceof WebSocketSessionContext) {
return (WebSocketSessionContext) context;
}
return null;
}

public WebSocketSessionContext() {
this.healthCheckSuccess = new AtomicBoolean(true);
}

public boolean changeHealthCheckSuccess() {
return healthCheckSuccess.compareAndSet(false, true);
}

public boolean changeHealthCheckFail() {
return healthCheckSuccess.compareAndSet(true, false);
}

public boolean getHealthCheckState() {
return healthCheckSuccess.get();
}

public String getApplicationName() {
return applicationName;
}

public void setApplicationName(String applicationName) {
this.applicationName = applicationName;
}

@Override
public String toString() {
return "WebSocketSessionContext{" +
"changeHealthCheckSuccess=" + healthCheckSuccess +
", applicationName='" + applicationName + '\'' +
'}';
}
}

0 comments on commit aea766b

Please sign in to comment.