diff --git a/conf/scripts/ASYM_ENCRYPT_BlockTest/testASYM_ENCRYPT_NotBlockingJoin.btm b/conf/scripts/ASYM_ENCRYPT_BlockTest/testASYM_ENCRYPT_NotBlockingJoin.btm new file mode 100644 index 00000000000..83090b83a79 --- /dev/null +++ b/conf/scripts/ASYM_ENCRYPT_BlockTest/testASYM_ENCRYPT_NotBlockingJoin.btm @@ -0,0 +1,16 @@ + + +## Send unicast message to joiner *before* JOIN-RSP (https://issues.jboss.org/browse/JGRP-2131) +RULE InjectAdditionalUnicast +CLASS GMS +METHOD sendJoinResponse +HELPER org.jgroups.tests.helpers.SendUnicast +AT ENTRY +BIND gms=$0; + dest=$2; +IF TRUE + DO System.out.println("** sending unicast message to " + dest); + sendUnicast(gms, dest); +ENDRULE + + diff --git a/src/org/jgroups/protocols/ASYM_ENCRYPT.java b/src/org/jgroups/protocols/ASYM_ENCRYPT.java index 08206a8d0e6..879e5375f11 100644 --- a/src/org/jgroups/protocols/ASYM_ENCRYPT.java +++ b/src/org/jgroups/protocols/ASYM_ENCRYPT.java @@ -7,9 +7,8 @@ import org.jgroups.annotations.Property; import org.jgroups.conf.ClassConfigurator; import org.jgroups.protocols.pbcast.GMS; -import org.jgroups.util.AsciiString; -import org.jgroups.util.MessageBatch; -import org.jgroups.util.Util; +import org.jgroups.protocols.pbcast.JoinRsp; +import org.jgroups.util.*; import javax.crypto.Cipher; import javax.crypto.KeyGenerator; @@ -17,10 +16,14 @@ import javax.crypto.spec.SecretKeySpec; import java.security.*; import java.security.spec.X509EncodedKeySpec; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; /** * Encrypts and decrypts communication in JGroups by using a secret key distributed to all cluster members by the @@ -62,9 +65,11 @@ public class ASYM_ENCRYPT extends EncryptBase { protected volatile boolean is_key_server; protected KeyPair key_pair; // to store own's public/private Key protected Cipher asym_cipher; // decrypting cypher for secret key requests + protected final Lock queue_lock=new ReentrantLock(); + // queue all up msgs until the secret key has been received/created @ManagedAttribute(description="whether or not to queue received messages (until the secret key was received)") - protected volatile boolean queue_up_msgs=true; + protected boolean queue_up_msgs=true; // queues a bounded number of messages received during a null secret key (or fetching the key from a new coord) protected final BlockingQueue up_queue=new ArrayBlockingQueue<>(100); @@ -82,7 +87,7 @@ public class ASYM_ENCRYPT extends EncryptBase { @ManagedOperation(description="Triggers a request for the secret key to the current keyserver") public void sendKeyRequest() { if(key_server_addr == null) { - log.debug("%s: key server is currently not set", key_server_addr); + log.debug("%s: sending secret key request failed as the key server is currently not set", local_addr); return; } sendKeyRequest(key_server_addr); @@ -94,7 +99,7 @@ public void init() throws Exception { } public void stop() { - drainUpQueue(); + stopQueueing(); super.stop(); } @@ -110,8 +115,13 @@ public Object down(Event evt) { public Object up(Event evt) { if(evt.type() == Event.MSG) { Message msg=evt.arg(); - if(skip(msg)) + if(skip(msg)) { + GMS.GmsHeader hdr=(GMS.GmsHeader)msg.getHeader(GMS_ID); + Address key_server=getCoordinator(msg, hdr); + if(key_server != null) + sendKeyRequest(key_server); return up_prot.up(evt); + } } return super.up(evt); } @@ -122,6 +132,10 @@ public void up(MessageBatch batch) { try { up_prot.up(new Event(Event.MSG, msg)); batch.remove(msg); + GMS.GmsHeader hdr=(GMS.GmsHeader)msg.getHeader(GMS_ID); + Address key_server=getCoordinator(msg, hdr); + if(key_server != null) + sendKeyRequest(key_server); } catch(Throwable t) { log.error("failed passing up message from %s: %s, ex=%s", msg.src(), msg.printHeaders(), t); @@ -132,6 +146,32 @@ public void up(MessageBatch batch) { super.up(batch); // decrypt the rest of the messages in the batch (if any) } + /** Tries to find out if this is a JOIN_RSP or INSTALL_MERGE_VIEW message and returns the coordinator of the view */ + protected Address getCoordinator(Message msg, GMS.GmsHeader hdr) { + switch(hdr.getType()) { + case GMS.GmsHeader.JOIN_RSP: + try { + JoinRsp join_rsp=Util.streamableFromBuffer(JoinRsp.class, msg.getRawBuffer(), msg.getOffset(), msg.getLength()); + View new_view=join_rsp != null? join_rsp.getView() : null; + return new_view != null? new_view.getCoord() : null; + } + catch(Throwable t) { + log.error("%s: failed getting coordinator (keyserver) from JoinRsp: %s", local_addr, t); + } + break; + case GMS.GmsHeader.INSTALL_MERGE_VIEW: + try { + Tuple tuple=GMS._readViewAndDigest(msg.getRawBuffer(), msg.getOffset(), msg.getLength()); + View new_view=tuple != null? tuple.getVal1() : null; + return new_view != null? new_view.getCoord() : null; + } + catch(Throwable t) { + log.error("%s: failed getting coordinator (keyserver) from INSTALL_MERGE_VIEW: %s", local_addr, t); + } + break; + } + return null; + } /** Checks if a message needs to be encrypted/decrypted. Join and merge requests/responses don't need to be @@ -147,6 +187,8 @@ protected static boolean skip(Message msg) { case GMS.GmsHeader.MERGE_RSP: case GMS.GmsHeader.VIEW_ACK: case GMS.GmsHeader.INSTALL_MERGE_VIEW: + case GMS.GmsHeader.GET_DIGEST_REQ: + case GMS.GmsHeader.GET_DIGEST_RSP: return true; } return false; @@ -169,8 +211,7 @@ protected static boolean skip(Message msg) { } @Override protected boolean process(Message msg) { - if(queue_up_msgs || secret_key == null) { - up_queue.offer(msg); + if(enqueue(msg)) { log.trace("%s: queuing %s message from %s as secret key hasn't been retrieved from keyserver %s yet, hdrs: %s", local_addr, msg.dest() == null? "mcast" : "unicast", msg.src(), key_server_addr, msg.printHeaders()); if(last_key_request == 0 || System.currentTimeMillis() - last_key_request > 2000) { @@ -276,7 +317,7 @@ else if(left_mbrs) try { this.secret_key=createSecretKey(); initSymCiphers(sym_algorithm, secret_key); - drainUpQueue(); + stopQueueing(); } catch(Exception ex) { log.error("%s: failed creating secret key and initializing ciphers", local_addr, ex); @@ -286,9 +327,10 @@ else if(left_mbrs) /** If the keyserver changed, send a request for the secret key to the keyserver */ protected void handleNewKeyServer(Address newKeyServer, boolean merge_view, boolean left_mbrs) { if(keyServerChanged(newKeyServer) || merge_view || left_mbrs) { - secret_key=null; - sym_version=null; - queue_up_msgs=true; + // secret_key=null; + // sym_version=null; + // queue_up_msgs=true; + startQueueing(); key_server_addr=newKeyServer; is_key_server=false; log.debug("%s: sending request for secret key to the new keyserver %s", local_addr, key_server_addr); @@ -303,26 +345,26 @@ protected boolean keyServerChanged(Address newKeyServer) { protected void setKeys(SecretKey key, byte[] version) throws Exception { - if(Arrays.equals(this.sym_version, version)) - return; - - // System.out.printf("%s: ******** setting sym_version (%s) to %s\n", local_addr, - // Util.byteArrayToHexString(this.sym_version), Util.byteArrayToHexString(version)); - - Cipher decoding_cipher=secret_key != null? decoding_ciphers.take() : null; - // put the previous key into the map, keep the cipher: no leak, as we'll clear decoding_ciphers in initSymCiphers() - if(decoding_cipher != null) - key_map.put(new AsciiString(version), decoding_cipher); - secret_key=key; - initSymCiphers(key.getAlgorithm(), key); - sym_version=version; - drainUpQueue(); + synchronized(this) { + if(Arrays.equals(this.sym_version, version)) { + stopQueueing(); + return; + } + Cipher decoding_cipher=secret_key != null? decoding_ciphers.take() : null; + // put the previous key into the map, keep the cipher: no leak, as we'll clear decoding_ciphers in initSymCiphers() + if(decoding_cipher != null) + key_map.put(new AsciiString(version), decoding_cipher); + secret_key=key; + initSymCiphers(key.getAlgorithm(), key); + sym_version=version; + } + stopQueueing(); } protected void sendSecretKey(SecretKey secret_key, PublicKey public_key, Address source) throws Exception { byte[] encryptedKey=encryptSecretKey(secret_key, public_key); - Message newMsg=new Message(source, local_addr, encryptedKey) + Message newMsg=new Message(source, encryptedKey).src(local_addr) .putHeader(this.id, new EncryptHeader(EncryptHeader.SECRET_KEY_RSP, symVersion())); log.debug("%s: sending secret key to %s", local_addr, source); down_prot.down(new Event(Event.MSG,newMsg)); @@ -344,8 +386,10 @@ protected byte[] encryptSecretKey(SecretKey secret_key, PublicKey public_key) th /** send client's public key to server and request server's public key */ protected void sendKeyRequest(Address key_server) { - Message newMsg=new Message(key_server, local_addr, key_pair.getPublic().getEncoded()) - .putHeader(this.id,new EncryptHeader(EncryptHeader.SECRET_KEY_REQ, sym_version)); + if(key_server == null) + return; + Message newMsg=new Message(key_server, key_pair.getPublic().getEncoded()).src(local_addr) + .putHeader(this.id,new EncryptHeader(EncryptHeader.SECRET_KEY_REQ, null)); down_prot.down(new Event(Event.MSG,newMsg)); } @@ -373,12 +417,41 @@ protected SecretKeySpec decodeKey(byte[] encodedKey) throws Exception { } } - // doesn't have to be 100% correct: leftover messages wll be delivered later and will be discarded as dupes, as - // retransmission is likely to have kicked in before anyway - protected void drainUpQueue() { - queue_up_msgs=false; - Message queued_msg; - while((queued_msg=up_queue.poll()) != null) { + protected void startQueueing() { + queue_lock.lock(); + try { + queue_up_msgs=true; + } + finally { + queue_lock.unlock(); + } + } + + protected boolean enqueue(Message msg) { + queue_lock.lock(); + try { + return queue_up_msgs && up_queue.offer(msg); + } + finally { + queue_lock.unlock(); + } + } + + + // Drains the queued messages. Doesn't have to be 100% correct: leftover messages wll be delivered later and will + // be discarded as dupes, as retransmission is likely to have kicked in before, anyway + protected void stopQueueing() { + List sink=new ArrayList<>(up_queue.size()); + queue_lock.lock(); + try { + queue_up_msgs=false; + up_queue.drainTo(sink); + } + finally { + queue_lock.unlock(); + } + + for(Message queued_msg: sink) { try { Message decrypted_msg=decryptMessage(null, queued_msg.copy()); if(decrypted_msg != null) diff --git a/src/org/jgroups/protocols/EncryptBase.java b/src/org/jgroups/protocols/EncryptBase.java index 8ca665bf8c4..dc3cc068ef5 100644 --- a/src/org/jgroups/protocols/EncryptBase.java +++ b/src/org/jgroups/protocols/EncryptBase.java @@ -193,8 +193,8 @@ protected synchronized void initSymCiphers(String algorithm, SecretKey secret) t encoding_ciphers.clear(); decoding_ciphers.clear(); for(int i=0; i < cipher_pool_size; i++ ) { - encoding_ciphers.add(createCipher(Cipher.ENCRYPT_MODE, secret, algorithm)); - decoding_ciphers.add(createCipher(Cipher.DECRYPT_MODE, secret, algorithm)); + encoding_ciphers.offer(createCipher(Cipher.ENCRYPT_MODE, secret, algorithm)); + decoding_ciphers.offer(createCipher(Cipher.DECRYPT_MODE, secret, algorithm)); }; //set the version diff --git a/tests/byteman/org/jgroups/tests/byteman/ASYM_ENCRYPT_BlockTest.java b/tests/byteman/org/jgroups/tests/byteman/ASYM_ENCRYPT_BlockTest.java new file mode 100644 index 00000000000..e6840c4148f --- /dev/null +++ b/tests/byteman/org/jgroups/tests/byteman/ASYM_ENCRYPT_BlockTest.java @@ -0,0 +1,93 @@ +package org.jgroups.tests.byteman; + +import org.jboss.byteman.contrib.bmunit.BMNGRunner; +import org.jboss.byteman.contrib.bmunit.BMScript; +import org.jgroups.Global; +import org.jgroups.JChannel; +import org.jgroups.auth.MD5Token; +import org.jgroups.protocols.*; +import org.jgroups.protocols.pbcast.GMS; +import org.jgroups.protocols.pbcast.NAKACK2; +import org.jgroups.protocols.pbcast.STABLE; +import org.jgroups.stack.Protocol; +import org.jgroups.util.MyReceiver; +import org.jgroups.util.Util; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +/** + * Tests https://issues.jboss.org/browse/JGRP-2131 + * @author Bela Ban + * @since 3.6.12, 4.0.0 + */ +@Test(groups=Global.BYTEMAN,singleThreaded=true) +public class ASYM_ENCRYPT_BlockTest extends BMNGRunner { + protected JChannel a, b; + protected MyReceiver ra, rb; + + @BeforeMethod protected void setup() throws Exception { + a=create("A"); + b=create("B"); + a.setReceiver(ra=new MyReceiver()); + b.setReceiver(rb=new MyReceiver()); + a.connect(ASYM_ENCRYPT_BlockTest.class.getSimpleName()); + b.connect(ASYM_ENCRYPT_BlockTest.class.getSimpleName()); + Util.waitUntilAllChannelsHaveSameView(10000, 1000, a,b); + } + + @AfterMethod protected void tearDown() { + Util.close(b, a); + } + + @BMScript(dir="scripts/ASYM_ENCRYPT_BlockTest", value="testASYM_ENCRYPT_NotBlockingJoin") + public void testASYM_ENCRYPT_NotBlockingJoin() throws Exception { + a.send(b.getAddress(), "one"); + b.send(a.getAddress(), "two"); + + + for(int i=0; i < 10; i++) { + if(ra.size() >= 1 && rb.size() >= 1) // fail fast if size > 1 + break; + Util.sleep(1000); + } + + System.out.printf("A's messages:\n%s\nB's messages:\n%s\n", print(ra), print(rb)); + assert ra.size() == 1 : String.format("A has %d messages", ra.size()); + assert rb.size() >= 1 : String.format("B has %d messages", rb.size()); + boolean match=false; + for(Object obj: rb.list()) { + if(obj.equals("one")) { + match=true; + break; + } + } + assert match; + assert "two".equals(ra.list().get(0)); + } + + + protected JChannel create(String name) throws Exception { + Protocol[] protocols={ + new SHARED_LOOPBACK(), + new SHARED_LOOPBACK_PING(), + new ASYM_ENCRYPT().encryptEntireMessage(false).symKeylength(128) + .symAlgorithm("AES/ECB/PKCS5Padding").asymKeylength(512).asymAlgorithm("RSA"), + new NAKACK2(), + new UNICAST3(), + new STABLE(), + new AUTH().setAuthToken(new MD5Token("jdgservercluster", "MD5")), + new GMS().joinTimeout(5000), + new FRAG2().fragSize(8000) + }; + + return new JChannel(protocols).name(name); + } + + protected static String print(MyReceiver r) { + StringBuilder sb=new StringBuilder(); + for(Object str: r.list()) + sb.append(str + "\n"); + return sb.toString(); + } +} diff --git a/tests/byteman/org/jgroups/tests/helpers/SendUnicast.java b/tests/byteman/org/jgroups/tests/helpers/SendUnicast.java new file mode 100644 index 00000000000..e6bdc403442 --- /dev/null +++ b/tests/byteman/org/jgroups/tests/helpers/SendUnicast.java @@ -0,0 +1,18 @@ +package org.jgroups.tests.helpers; + +import org.jgroups.Address; +import org.jgroups.Event; +import org.jgroups.Message; +import org.jgroups.protocols.pbcast.GMS; + +/** + * @author Bela Ban + * @since 4.0 + */ +public class SendUnicast { + public void sendUnicast(GMS gms,Address dest) { + Message msg=new Message(dest, "sorry for the interruption :-)"); + gms.down(new Event(Event.MSG, msg)); + // System.out.printf("** injected message %s\n", msg.printHeaders()); + } +}