Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARTEMIS-2226 last consumer should close the previous consumer #1

Open
wants to merge 2 commits into
base: ARTEMIS-2226
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ public final class ManagementHelper {

public static final SimpleString HDR_MESSAGE_ID = new SimpleString("_AMQ_Message_ID");

public static final SimpleString HDR_PROTOCOL_NAME = new SimpleString("_AMQ_Protocol_Name");

public static final SimpleString HDR_CLIENT_ID = new SimpleString("_AMQ_Client_ID");

// Attributes ----------------------------------------------------

// Static --------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,7 @@ synchronized void disconnect(boolean failure) {

private MQTTSessionState getSessionState(String clientId) {
/* [MQTT-3.1.2-4] Attach an existing session if one exists otherwise create a new one. */
MQTTSessionState state = MQTTSession.SESSIONS.get(clientId);
if (state == null) {
state = new MQTTSessionState(clientId);
MQTTSession.SESSIONS.put(clientId, state);
}

return state;
return session.getProtocolManager().getSessionState(clientId);
}

private String validateClientId(String clientId, boolean cleanSession) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -30,8 +32,14 @@
import io.netty.handler.codec.mqtt.MqttMessage;
import org.apache.activemq.artemis.api.core.ActiveMQBuffer;
import org.apache.activemq.artemis.api.core.BaseInterceptor;
import org.apache.activemq.artemis.api.core.SimpleString;
import org.apache.activemq.artemis.api.core.management.CoreNotificationType;
import org.apache.activemq.artemis.api.core.management.ManagementHelper;
import org.apache.activemq.artemis.core.postoffice.Binding;
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyServerConnection;
import org.apache.activemq.artemis.core.server.ActiveMQServer;
import org.apache.activemq.artemis.core.server.*;
import org.apache.activemq.artemis.core.server.Queue;
import org.apache.activemq.artemis.core.server.impl.ServerConsumerImpl;
import org.apache.activemq.artemis.core.server.management.Notification;
import org.apache.activemq.artemis.core.server.management.NotificationListener;
import org.apache.activemq.artemis.spi.core.protocol.AbstractProtocolManager;
Expand All @@ -40,11 +48,14 @@
import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection;
import org.apache.activemq.artemis.spi.core.remoting.Acceptor;
import org.apache.activemq.artemis.spi.core.remoting.Connection;
import org.apache.activemq.artemis.utils.collections.TypedProperties;

import static org.apache.activemq.artemis.api.core.management.CoreNotificationType.CONSUMER_CREATED;

/**
* MQTTProtocolManager
*/
class MQTTProtocolManager extends AbstractProtocolManager<MqttMessage, MQTTInterceptor, MQTTConnection> implements NotificationListener {
public class MQTTProtocolManager extends AbstractProtocolManager<MqttMessage, MQTTInterceptor, MQTTConnection> implements NotificationListener {

private static final List<String> websocketRegistryNames = Arrays.asList("mqtt", "mqttv3.1");

Expand All @@ -55,18 +66,54 @@ class MQTTProtocolManager extends AbstractProtocolManager<MqttMessage, MQTTInter
private final List<MQTTInterceptor> outgoingInterceptors = new ArrayList<>();

//TODO Read in a list of existing client IDs from stored Sessions.
private Map<String, MQTTConnection> connectedClients = new ConcurrentHashMap<>();
private final Map<String, MQTTConnection> connectedClients = new ConcurrentHashMap<>();
private final Map<String, MQTTSessionState> sessionStates = new ConcurrentHashMap<>();

MQTTProtocolManager(ActiveMQServer server,
List<BaseInterceptor> incomingInterceptors,
List<BaseInterceptor> outgoingInterceptors) {
this.server = server;
this.updateInterceptors(incomingInterceptors, outgoingInterceptors);
server.getManagementService().addNotificationListener(this);
}

@Override
public void onNotification(Notification notification) {
// TODO handle notifications
if (!(notification.getType() instanceof CoreNotificationType))
return;

CoreNotificationType type = (CoreNotificationType) notification.getType();
if (type != CONSUMER_CREATED)
return;

TypedProperties props = notification.getProperties();

SimpleString protocolName = props.getSimpleStringProperty(ManagementHelper.HDR_PROTOCOL_NAME);

if (protocolName == null || !protocolName.toString().equals(MQTTProtocolManagerFactory.MQTT_PROTOCOL_NAME))
return;

int distance = props.getIntProperty(ManagementHelper.HDR_DISTANCE);

if (distance > 0) {
SimpleString queueName = props.getSimpleStringProperty(ManagementHelper.HDR_ROUTING_NAME);

Binding binding = server.getPostOffice().getBinding(queueName);
if (binding != null) {
Queue queue = (Queue) binding.getBindable();
String clientId = props.getSimpleStringProperty(ManagementHelper.HDR_CLIENT_ID).toString();
//If the client ID represents a client already connected to the server then the server MUST disconnect the existing client.
//Avoid consumers with the same client ID in the cluster appearing at different nodes at the same time
Collection<Consumer> consumersSet = queue.getConsumers((c) -> (c instanceof ServerConsumer) && clientId.equals(((ServerConsumer) c).getConnectionClientID()));
for (Consumer consumer : consumersSet) {
try {
((ServerConsumer) consumer).close(false);
} catch (Exception e) {
log.error(e);
}
}
}
}
}

@Override
Expand Down Expand Up @@ -195,4 +242,17 @@ public void removeConnectedClient(String clientId) {
public MQTTConnection addConnectedClient(String clientId, MQTTConnection connection) {
return connectedClients.put(clientId, connection);
}

public MQTTSessionState getSessionState(String clientId) {
/* [MQTT-3.1.2-4] Attach an existing session if one exists otherwise create a new one. */
return sessionStates.computeIfAbsent(clientId, MQTTSessionState::new);
}

public MQTTSessionState removeSessionState(String clientId) {
return sessionStates.remove(clientId);
}

public Map<String, MQTTSessionState> getSessionStates() {
return new HashMap<>(sessionStates);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@

public class MQTTSession {

static Map<String, MQTTSessionState> SESSIONS = new ConcurrentHashMap<>();

private final String id = UUID.randomUUID().toString();

private final String identity;

private MQTTProtocolHandler protocolHandler;

private MQTTSubscriptionManager subscriptionManager;
Expand Down Expand Up @@ -72,6 +73,8 @@ public MQTTSession(MQTTProtocolHandler protocolHandler,
this.protocolManager = protocolManager;
this.wildcardConfiguration = wildcardConfiguration;

identity = protocolHandler.getServer().getIdentity();

this.connection = connection;

mqttConnectionManager = new MQTTConnectionManager(this);
Expand Down Expand Up @@ -108,7 +111,7 @@ synchronized void stop() throws Exception {

if (isClean()) {
clean();
SESSIONS.remove(connection.getClientID());
protocolManager.removeSessionState(connection.getClientID());
}
}
stopped = true;
Expand Down Expand Up @@ -201,7 +204,4 @@ public CoreMessageObjectPools getCoreMessageObjectPools() {
return coreMessageObjectPools;
}

public static Map<String, MQTTSessionState> getSessions() {
return new HashMap<>(SESSIONS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.activemq.artemis.core.remoting.impl;

import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand All @@ -43,4 +44,8 @@ public void updateInterceptors(List<BaseInterceptor> incomingInterceptors,
}
}

public Map<String, ProtocolManager> getProtocolMap() {
return Collections.unmodifiableMap(protocolMap);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;

import java.util.function.Predicate;
import org.apache.activemq.artemis.api.core.ActiveMQException;
import org.apache.activemq.artemis.api.core.Message;
import org.apache.activemq.artemis.api.core.RoutingType;
Expand Down Expand Up @@ -287,6 +289,8 @@ int moveReferences(int flushLimit,

Collection<Consumer> getConsumers();

Set<Consumer> getConsumers(Predicate<Consumer> predicate);

Map<SimpleString, Consumer> getGroups();

void resetGroup(SimpleString groupID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,14 @@ private synchronized void doConsumerCreated(final ClientMessage message) throws
// Need to propagate the consumer add
TypedProperties props = new TypedProperties();

SimpleString protocolName = message.getSimpleStringProperty(ManagementHelper.HDR_PROTOCOL_NAME);
if (protocolName != null)
props.putSimpleStringProperty(ManagementHelper.HDR_PROTOCOL_NAME, protocolName);

SimpleString clientId = message.getSimpleStringProperty(ManagementHelper.HDR_CLIENT_ID);
if (clientId != null)
props.putSimpleStringProperty(ManagementHelper.HDR_CLIENT_ID, clientId);

props.putSimpleStringProperty(ManagementHelper.HDR_ADDRESS, binding.getAddress());

props.putSimpleStringProperty(ManagementHelper.HDR_CLUSTER_NAME, clusterName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;

import java.util.function.Predicate;
import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration;
import org.apache.activemq.artemis.api.core.ActiveMQException;
import org.apache.activemq.artemis.api.core.ActiveMQNullRefException;
Expand Down Expand Up @@ -1204,6 +1205,18 @@ public Set<Consumer> getConsumers() {
return consumersSet;
}

@Override
public Set<Consumer> getConsumers(Predicate<Consumer> predicate) {
Set<Consumer> consumersSet = new HashSet<>();
for (ConsumerHolder<? extends Consumer> consumerHolder : consumers) {
if (predicate.test(consumerHolder.consumer)) {
consumersSet.add(consumerHolder.consumer);
}
}
return consumersSet;
}


@Override
public synchronized Map<SimpleString, Consumer> getGroups() {
return new HashMap<>(groups);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.apache.activemq.artemis.core.server.management.Notification;
import org.apache.activemq.artemis.core.transaction.Transaction;
import org.apache.activemq.artemis.core.transaction.impl.TransactionImpl;
import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection;
import org.apache.activemq.artemis.spi.core.protocol.SessionCallback;
import org.apache.activemq.artemis.spi.core.remoting.ReadyListener;
import org.apache.activemq.artemis.utils.FutureLatch;
Expand Down Expand Up @@ -1544,4 +1545,8 @@ public String getConnectionLocalAddress() {
public String getConnectionRemoteAddress() {
return this.session.getRemotingConnection().getTransportConnection().getRemoteAddress();
}

public RemotingConnection getRemotingConnection() {
return this.session.getRemotingConnection();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,16 @@ public ServerConsumer createConsumer(final long consumerID,
props.putSimpleStringProperty(ManagementHelper.HDR_FILTERSTRING, filterString);
}

String protocolName = remotingConnection.getProtocolName();
if (protocolName != null) {
props.putSimpleStringProperty(ManagementHelper.HDR_PROTOCOL_NAME, SimpleString.toSimpleString(protocolName));
}

String clientId = remotingConnection.getClientID();
if (clientId != null) {
props.putSimpleStringProperty(ManagementHelper.HDR_CLIENT_ID, SimpleString.toSimpleString(clientId));
}

Notification notification = new Notification(null, CoreNotificationType.CONSUMER_CREATED, props);

if (logger.isDebugEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.concurrent.atomic.AtomicInteger;

import io.netty.buffer.ByteBuf;
import java.util.function.Predicate;
import org.apache.activemq.artemis.api.core.ActiveMQBuffer;
import org.apache.activemq.artemis.api.core.ActiveMQException;
import org.apache.activemq.artemis.api.core.ActiveMQPropertyConversionException;
Expand Down Expand Up @@ -1301,6 +1302,11 @@ public Collection<Consumer> getConsumers() {
return null;
}

@Override
public Set<Consumer> getConsumers(Predicate<Consumer> predicate) {
return null;
}

@Override
public Map<SimpleString, Consumer> getGroups() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ public class MQTTTest extends MQTTTestSupport {
@Override
@Before
public void setUp() throws Exception {
Field sessions = MQTTSession.class.getDeclaredField("SESSIONS");
sessions.setAccessible(true);
sessions.set(null, new ConcurrentHashMap<>());
super.setUp();
}

Expand Down Expand Up @@ -1100,7 +1097,7 @@ public void testCleanSessionForSubscriptions() throws Exception {
notClean.publish(TOPIC, TOPIC.getBytes(), QoS.EXACTLY_ONCE, false);
notClean.disconnect();

assertEquals(1, MQTTSession.getSessions().size());
assertEquals(1, getSessions().size());

// MUST receive message from existing subscription from previous not clean session
notClean = mqttNotClean.blockingConnection();
Expand All @@ -1112,7 +1109,7 @@ public void testCleanSessionForSubscriptions() throws Exception {
notClean.publish(TOPIC, TOPIC.getBytes(), QoS.EXACTLY_ONCE, false);
notClean.disconnect();

assertEquals(1, MQTTSession.getSessions().size());
assertEquals(1, getSessions().size());

// MUST NOT receive message from previous not clean session as existing subscription should be gone
final MQTT mqttClean = createMQTTConnection(CLIENTID, true);
Expand All @@ -1124,7 +1121,7 @@ public void testCleanSessionForSubscriptions() throws Exception {
clean.publish(TOPIC, TOPIC.getBytes(), QoS.EXACTLY_ONCE, false);
clean.disconnect();

assertEquals(0, MQTTSession.getSessions().size());
assertEquals(0, getSessions().size());

// MUST NOT receive message from previous clean session as existing subscription should be gone
notClean = mqttNotClean.blockingConnection();
Expand All @@ -1133,7 +1130,7 @@ public void testCleanSessionForSubscriptions() throws Exception {
assertNull(msg);
notClean.disconnect();

assertEquals(1, MQTTSession.getSessions().size());
assertEquals(1, getSessions().size());
}

@Test(timeout = 60 * 1000)
Expand All @@ -1147,7 +1144,7 @@ public void testCleanSessionForMessages() throws Exception {
notClean.publish(TOPIC, TOPIC.getBytes(), QoS.EXACTLY_ONCE, false);
notClean.disconnect();

assertEquals(1, MQTTSession.getSessions().size());
assertEquals(1, getSessions().size());

// MUST NOT receive message from previous not clean session even when creating a new subscription
final MQTT mqttClean = createMQTTConnection(CLIENTID, true);
Expand All @@ -1159,7 +1156,7 @@ public void testCleanSessionForMessages() throws Exception {
clean.publish(TOPIC, TOPIC.getBytes(), QoS.EXACTLY_ONCE, false);
clean.disconnect();

assertEquals(0, MQTTSession.getSessions().size());
assertEquals(0, getSessions().size());
}

@Test(timeout = 60 * 1000)
Expand Down