Skip to content

Commit

Permalink
Merge pull request #240 from tristantarrant/saslmerge
Browse files Browse the repository at this point in the history
JGRP-1967 SASL: correctly handle merge requests
  • Loading branch information
belaban committed Oct 6, 2015
2 parents ddb7aa6 + e2465d9 commit cd75779
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 35 deletions.
39 changes: 29 additions & 10 deletions src/org/jgroups/protocols/SASL.java
Expand Up @@ -218,13 +218,13 @@ public Object up(Event evt) {
Message msg = (Message) evt.getArg();
SaslHeader saslHeader = (SaslHeader) msg.getHeader(SASL_ID);
GmsHeader gmsHeader = (GmsHeader) msg.getHeader(GMS_ID);
if (needsAuthentication(gmsHeader)) {
Address remoteAddress = msg.getSrc();
if (needsAuthentication(gmsHeader, remoteAddress)) {
if (saslHeader == null)
throw new IllegalStateException("Found GMS join or merge request but no SASL header");
if (!serverChallenge(gmsHeader, saslHeader, msg))
return null; // failed auth, don't pass up
} else if (saslHeader != null) {
Address remoteAddress = msg.getSrc();
SaslContext saslContext = sasl_context.get(remoteAddress);
if (saslContext == null) {
throw new IllegalStateException(String.format(
Expand All @@ -251,7 +251,7 @@ public Object up(Event evt) {
} catch (SaslException e) {
disposeContext(remoteAddress);
if (log.isWarnEnabled()) {
log.warn("failed to validate CHALLENGE from " + remoteAddress + ", token", e);
log.warn(getAddress() + ": failed to validate CHALLENGE from " + remoteAddress + ", token", e);
}
}
break;
Expand Down Expand Up @@ -299,7 +299,8 @@ public void up(MessageBatch batch) {
for (Message msg : batch) {
// If we have a join or merge request --> authenticate, else pass up
GmsHeader gmsHeader = (GmsHeader) msg.getHeader(GMS_ID);
if (needsAuthentication(gmsHeader)) {
Address remoteAddress = msg.getSrc();
if (needsAuthentication(gmsHeader, remoteAddress)) {
SaslHeader saslHeader = (SaslHeader) msg.getHeader(id);
if (saslHeader == null) {
log.warn("Found GMS join or merge request but no SASL header");
Expand All @@ -323,10 +324,11 @@ public Object down(Event evt) {
case Event.MSG:
Message msg = (Message) evt.getArg();
GmsHeader hdr = (GmsHeader) msg.getHeader(GMS_ID);
if (needsAuthentication(hdr)) {
Address remoteAddress = msg.getDest();
if (needsAuthentication(hdr, remoteAddress)) {
// We are a client who needs to authenticate
SaslClientContext ctx = null;
Address remoteAddress = msg.getDest();

try {
ctx = new SaslClientContext(saslClientFactory, mech, server_name != null ? server_name : remoteAddress.toString(), client_callback_handler, sasl_props, client_subject);
sasl_context.put(remoteAddress, ctx);
Expand All @@ -344,10 +346,27 @@ public Object down(Event evt) {
return down_prot.down(evt);
}

protected static boolean needsAuthentication(GmsHeader hdr) {
return (hdr != null)
&& (hdr.getType() == GmsHeader.JOIN_REQ || hdr.getType() == GmsHeader.JOIN_REQ_WITH_STATE_TRANSFER || hdr
.getType() == GmsHeader.MERGE_REQ);
private boolean isSelf(Address remoteAddress) {
return remoteAddress.equals(local_addr);
}

private boolean needsAuthentication(GmsHeader hdr, Address remoteAddress) {
if (hdr != null) {
switch (hdr.getType()) {
case GMS.GmsHeader.JOIN_REQ:
case GMS.GmsHeader.JOIN_REQ_WITH_STATE_TRANSFER:
return true;
case GMS.GmsHeader.MERGE_REQ:
return !isSelf(remoteAddress);
case GMS.GmsHeader.JOIN_RSP:
case GMS.GmsHeader.MERGE_RSP:
return false;
default:
return false;
}
} else {
return false;
}
}

protected boolean serverChallenge(GmsHeader gmsHeader, SaslHeader saslHeader, Message msg) {
Expand Down
144 changes: 119 additions & 25 deletions tests/junit-functional/org/jgroups/protocols/SASLTest.java
@@ -1,44 +1,52 @@
package org.jgroups.protocols;

import static org.testng.AssertJUnit.assertTrue;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;

import org.jgroups.Address;
import org.jgroups.Event;
import org.jgroups.Global;
import org.jgroups.JChannel;
import org.jgroups.Membership;
import org.jgroups.View;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.protocols.pbcast.NAKACK2;
import org.jgroups.protocols.pbcast.STABLE;
import org.jgroups.stack.Protocol;
import org.jgroups.stack.ProtocolStack;
import org.jgroups.util.Util;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.Test;

import javax.security.auth.callback.*;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import java.io.IOException;

import static org.testng.AssertJUnit.assertTrue;

@Test(groups = Global.FUNCTIONAL, singleThreaded = true)
public class SASLTest {
private static final String REALM = "MyRealm";
private JChannel a;
private JChannel b;

private static JChannel createChannel(String channelName,String mech,String username) throws Exception {
private static JChannel createChannel(String channelName, String mech, String username) throws Exception {
SASL sasl = new SASL();
sasl.setMech(mech);
sasl.setClientCallbackHandler(new MyCallbackHandler(username));
sasl.setServerCallbackHandler(new MyCallbackHandler(username));
sasl.setTimeout(5000);
sasl.sasl_props.put("com.sun.security.sasl.digest.realm", REALM);
return new JChannel(
new Protocol[] {
new SHARED_LOOPBACK(),
new PING(),
new NAKACK2(),
new UNICAST3(),
new STABLE(),
sasl,
new GMS() }
).name(channelName);
sasl.setLevel("trace");
GMS gms = new GMS();
gms.setJoinTimeout(3000);
return new JChannel(new Protocol[] { new SHARED_LOOPBACK(), new PING(), new MERGE3(), new NAKACK2(),
new UNICAST3(), new STABLE(), sasl, gms }).name(channelName);
}

public void testSASLDigestMD5() throws Exception {
Expand All @@ -49,8 +57,7 @@ public void testSASLDigestMD5() throws Exception {
assertTrue(b.isConnected());
}


@Test(expectedExceptions=SecurityException.class)
@Test(expectedExceptions = SecurityException.class)
public void testSASLDigestMD5Failure() throws Throwable {
a = createChannel("A", "DIGEST-MD5", "jack");
b = createChannel("B", "DIGEST-MD5", "jill");
Expand All @@ -63,12 +70,98 @@ public void testSASLDigestMD5Failure() throws Throwable {
}
}

public void testSASLDigestMD5Merge() throws Exception {
a = createChannel("A", "DIGEST-MD5", "jack");
b = createChannel("B", "DIGEST-MD5", "jack");
a.connect("SaslTest");
b.connect("SaslTest");
assertTrue(b.isConnected());
print(a, b);
createPartitions(a, b);
print(a, b);
assertTrue(checkViewSize(1, a, b));
dropDiscard(a, b);
mergePartitions(a, b);
for(int i = 0; i < 10 && !checkViewSize(2, a, b); i++) {
Util.sleep(500);
}
assertTrue(viewContains(a.getView(), a, b));
assertTrue(viewContains(b.getView(), a, b));
}

private boolean viewContains(View view, JChannel... channels) {
boolean b = true;
for (JChannel ch : channels) {
b = b && view.containsMember(ch.getAddress());
}
return b;
}

private void dropDiscard(JChannel... channels) {
for (JChannel ch : channels) {
ch.getProtocolStack().removeProtocol(DISCARD.class);
}
}

private boolean checkViewSize(int expectedSize, JChannel... channels) {
boolean b = true;
for (JChannel ch : channels) {
b = b && ch.getView().size() == expectedSize;
}
return b;
}

@AfterMethod
public void cleanup() {
a.close();
b.close();
}

private static void createPartitions(JChannel... channels) throws Exception {
for (JChannel ch : channels) {
DISCARD discard = new DISCARD();
discard.setDiscardAll(true);
ch.getProtocolStack().insertProtocol(discard, ProtocolStack.ABOVE, TP.class);
}

for (JChannel ch : channels) {
View view = View.create(ch.getAddress(), 10, ch.getAddress());
GMS gms = (GMS) ch.getProtocolStack().findProtocol(GMS.class);
gms.installView(view);
}
}

private static void mergePartitions(JChannel... channels) throws Exception {
Membership membership = new Membership();
for (JChannel ch : channels) {
membership.add(ch.getAddress());
}
membership.sort();
Address leaderAddress = membership.elementAt(0);
JChannel leader = findChannelByAddress(leaderAddress, channels);
GMS gms = (GMS) leader.getProtocolStack().findProtocol(GMS.class);
gms.setLevel("trace");
Map<Address, View> views = new HashMap<Address, View>();
for (JChannel ch : channels) {
views.put(ch.getAddress(), ch.getView());
}
gms.up(new Event(Event.MERGE, views));
}

private static JChannel findChannelByAddress(Address address, JChannel... channels) {
for (JChannel ch : channels) {
if (ch.getAddress().equals(address)) {
return ch;
}
}
return null;
}

private static void print(JChannel... channels) {
for (JChannel ch : channels) {
System.out.println(ch.getName() + ": " + ch.getView());
}
}

public static class MyCallbackHandler implements CallbackHandler {
final private String password;
Expand All @@ -79,16 +172,17 @@ public MyCallbackHandler(String password) {

@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
for(Callback callback : callbacks) {
for (Callback callback : callbacks) {
if (callback instanceof NameCallback) {
NameCallback nameCallback = (NameCallback)callback;
NameCallback nameCallback = (NameCallback) callback;
nameCallback.setName("user");
} else if (callback instanceof PasswordCallback) {
PasswordCallback passwordCallback = (PasswordCallback)callback;
PasswordCallback passwordCallback = (PasswordCallback) callback;
passwordCallback.setPassword(password.toCharArray());
} else if (callback instanceof AuthorizeCallback) {
AuthorizeCallback authorizeCallback = (AuthorizeCallback)callback;
authorizeCallback.setAuthorized(authorizeCallback.getAuthenticationID().equals(authorizeCallback.getAuthorizationID()));
AuthorizeCallback authorizeCallback = (AuthorizeCallback) callback;
authorizeCallback.setAuthorized(
authorizeCallback.getAuthenticationID().equals(authorizeCallback.getAuthorizationID()));
} else if (callback instanceof RealmCallback) {
RealmCallback realmCallback = (RealmCallback) callback;
realmCallback.setText(REALM);
Expand Down

0 comments on commit cd75779

Please sign in to comment.