Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fix RaftLock and RaftSemaphore idempotency bug
  • Loading branch information
metanet authored and mdogan committed Feb 1, 2019
1 parent 60cda09 commit f160d0b
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 75 deletions.
Expand Up @@ -57,7 +57,7 @@ AcquireResult acquire(LockEndpoint endpoint, long commitIndex, UUID invocationUi
// if acquire() is being retried
if (invocationUid.equals(invocationRefUids.get(endpoint))
|| (owner != null && owner.invocationUid().equals(invocationUid))) {
return AcquireResult.successful(owner.commitIndex());
return AcquireResult.acquired(owner.commitIndex());
}

invocationRefUids.remove(endpoint);
Expand All @@ -70,24 +70,24 @@ AcquireResult acquire(LockEndpoint endpoint, long commitIndex, UUID invocationUi
if (endpoint.equals(owner.endpoint())) {
invocationRefUids.put(endpoint, invocationUid);
lockCount++;
return AcquireResult.successful(owner.commitIndex());
return AcquireResult.acquired(owner.commitIndex());
}

Collection<LockInvocationKey> cancelledWaitKeys = cancelWaitKeys(endpoint);
Collection<LockInvocationKey> cancelledWaitKeys = cancelWaitKeys(endpoint, invocationUid);

if (wait) {
waitKeys.add(key);
}

return AcquireResult.failed(cancelledWaitKeys);
return AcquireResult.notAcquired(cancelledWaitKeys);
}

private Collection<LockInvocationKey> cancelWaitKeys(LockEndpoint endpoint) {
private Collection<LockInvocationKey> cancelWaitKeys(LockEndpoint endpoint, UUID invocationUid) {
List<LockInvocationKey> cancelled = new ArrayList<LockInvocationKey>(0);
Iterator<LockInvocationKey> it = waitKeys.iterator();
while (it.hasNext()) {
LockInvocationKey waitKey = it.next();
if (waitKey.endpoint().equals(endpoint)) {
if (waitKey.endpoint().equals(endpoint) && !waitKey.invocationUid().equals(invocationUid)) {
cancelled.add(waitKey);
it.remove();
}
Expand Down Expand Up @@ -116,18 +116,31 @@ private ReleaseResult release(LockEndpoint endpoint, int releaseCount, UUID invo

LockInvocationKey newOwner = waitKeys.poll();
if (newOwner != null) {
List<LockInvocationKey> keys = new ArrayList<LockInvocationKey>();
keys.add(newOwner);

Iterator<LockInvocationKey> iter = waitKeys.iterator();
while (iter.hasNext()) {
LockInvocationKey key = iter.next();
if (newOwner.invocationUid().equals(key.invocationUid())) {
assert newOwner.endpoint().equals(key.endpoint());
keys.add(key);
iter.remove();
}
}

owner = newOwner;
lockCount = 1;

return ReleaseResult.successful(Collections.singleton(newOwner));
return ReleaseResult.successful(keys);
} else {
owner = null;
}

return ReleaseResult.SUCCESSFUL;
}

return ReleaseResult.failed(cancelWaitKeys(endpoint));
return ReleaseResult.failed(cancelWaitKeys(endpoint, invocationUid));
}

ReleaseResult forceRelease(long expectedFence, UUID invocationUid) {
Expand Down Expand Up @@ -237,11 +250,11 @@ public String toString() {

static class AcquireResult {

private static AcquireResult successful(long fence) {
private static AcquireResult acquired(long fence) {
return new AcquireResult(fence, Collections.<LockInvocationKey>emptyList());
}

private static AcquireResult failed(Collection<LockInvocationKey> cancelled) {
private static AcquireResult notAcquired(Collection<LockInvocationKey> cancelled) {
return new AcquireResult(INVALID_FENCE, cancelled);
}

Expand Down
Expand Up @@ -84,7 +84,7 @@ AcquireResult acquire(SemaphoreInvocationKey key, boolean wait) {
return new AcquireResult(key.permits(), Collections.<SemaphoreInvocationKey>emptyList());
}

Collection<SemaphoreInvocationKey> cancelled = cancelWaitKeys(key.sessionId(), key.threadId());
Collection<SemaphoreInvocationKey> cancelled = cancelWaitKeys(key.sessionId(), key.threadId(), key.invocationUid());

if (!isAvailable(key.permits())) {
if (wait) {
Expand Down Expand Up @@ -123,7 +123,7 @@ ReleaseResult release(long sessionId, long threadId, UUID invocationUid, int per
if (sessionId != NO_SESSION_ID) {
SessionState state = sessionStates.get(sessionId);
if (state == null) {
return ReleaseResult.failed(cancelWaitKeys(sessionId, threadId));
return ReleaseResult.failed(cancelWaitKeys(sessionId, threadId, invocationUid));
}

if (state.invocationRefUids.containsKey(Tuple2.of(threadId, invocationUid))) {
Expand All @@ -132,7 +132,7 @@ ReleaseResult release(long sessionId, long threadId, UUID invocationUid, int per
}

if (state.acquiredPermits < permits) {
return ReleaseResult.failed(cancelWaitKeys(sessionId, threadId));
return ReleaseResult.failed(cancelWaitKeys(sessionId, threadId, invocationUid));
}

state.acquiredPermits -= permits;
Expand All @@ -142,18 +142,19 @@ ReleaseResult release(long sessionId, long threadId, UUID invocationUid, int per
available += permits;

// order is important...
Collection<SemaphoreInvocationKey> cancelled = cancelWaitKeys(sessionId, threadId);
Collection<SemaphoreInvocationKey> cancelled = cancelWaitKeys(sessionId, threadId, invocationUid);
Collection<SemaphoreInvocationKey> acquired = assignPermitsToWaitKeys();

return ReleaseResult.successful(acquired, cancelled);
}

private Collection<SemaphoreInvocationKey> cancelWaitKeys(long sessionId, long threadId) {
private Collection<SemaphoreInvocationKey> cancelWaitKeys(long sessionId, long threadId, UUID invocationUid) {
List<SemaphoreInvocationKey> cancelled = new ArrayList<SemaphoreInvocationKey>(0);
Iterator<SemaphoreInvocationKey> iter = waitKeys.iterator();
while (iter.hasNext()) {
SemaphoreInvocationKey waitKey = iter.next();
if (waitKey.sessionId() == sessionId && waitKey.threadId() == threadId) {
if (waitKey.sessionId() == sessionId && waitKey.threadId() == threadId
&& !waitKey.invocationUid().equals(invocationUid)) {
cancelled.add(waitKey);
iter.remove();
}
Expand All @@ -164,16 +165,20 @@ private Collection<SemaphoreInvocationKey> cancelWaitKeys(long sessionId, long t

private Collection<SemaphoreInvocationKey> assignPermitsToWaitKeys() {
List<SemaphoreInvocationKey> assigned = new ArrayList<SemaphoreInvocationKey>();
Set<UUID> assignedInvocationUids = new HashSet<UUID>();
Iterator<SemaphoreInvocationKey> iterator = waitKeys.iterator();
while (iterator.hasNext()) {
SemaphoreInvocationKey key = iterator.next();
if (key.permits() > available) {
break;
if (assignedInvocationUids.contains(key.invocationUid())) {
iterator.remove();
assigned.add(key);
} else if (key.permits() <= available) {
iterator.remove();
if (assignedInvocationUids.add(key.invocationUid())) {
assigned.add(key);
assignPermitsToInvocation(key.sessionId(), key.threadId(), key.invocationUid(), key.permits());
}
}

iterator.remove();
assigned.add(key);
assignPermitsToInvocation(key.sessionId(), key.threadId(), key.invocationUid(), key.permits());
}

return assigned;
Expand All @@ -188,7 +193,7 @@ AcquireResult drain(long sessionId, long threadId, UUID invocationUid) {
}
}

Collection<SemaphoreInvocationKey> cancelled = cancelWaitKeys(sessionId, threadId);
Collection<SemaphoreInvocationKey> cancelled = cancelWaitKeys(sessionId, threadId, invocationUid);

int drained = available;
if (drained > 0) {
Expand All @@ -204,7 +209,7 @@ ReleaseResult change(long sessionId, long threadId, UUID invocationUid, int perm
return ReleaseResult.failed(Collections.<SemaphoreInvocationKey>emptyList());
}

Collection<SemaphoreInvocationKey> cancelled = cancelWaitKeys(sessionId, threadId);
Collection<SemaphoreInvocationKey> cancelled = cancelWaitKeys(sessionId, threadId, invocationUid);

if (sessionId != NO_SESSION_ID) {
SessionState state = sessionStates.get(sessionId);
Expand Down
Expand Up @@ -60,7 +60,7 @@ private AbstractSessionManager getSessionManager() {
}

@Test
public void testLockCancelsPendingLockRequest() {
public void testRetriedLockDoesNotCancelPendingLockRequest() {
lockByOtherThread(lock);

// there is a session id now
Expand All @@ -70,8 +70,7 @@ public void testLockCancelsPendingLockRequest() {
RaftInvocationManager invocationManager = getRaftInvocationManager(lockInstance);
UUID invUid = newUnsecureUUID();

InternalCompletableFuture<Object> f = invocationManager
.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid, MINUTES.toMillis(5)));
invocationManager.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid, MINUTES.toMillis(5)));

assertTrueEventually(new AssertTask() {
@Override
Expand All @@ -85,6 +84,43 @@ public void run() {

invocationManager.invoke(groupId, new LockOp(name, sessionId, getThreadId(), invUid));

assertTrueAllTheTime(new AssertTask() {
@Override
public void run() {
RaftLockService service = getNodeEngineImpl(lockInstance).getService(RaftLockService.SERVICE_NAME);
LockRegistry registry = service.getRegistryOrNull(groupId);
assertEquals(1, registry.getWaitTimeouts().size());
}
}, 10);
}

@Test(timeout = 30000)
public void testNewLockCancelsPendingLockRequest() {
lockByOtherThread(lock);

// there is a session id now

final RaftGroupId groupId = lock.getGroupId();
long sessionId = getSessionManager().getSession(groupId);
RaftInvocationManager invocationManager = getRaftInvocationManager(lockInstance);
UUID invUid1 = newUnsecureUUID();
UUID invUid2 = newUnsecureUUID();

InternalCompletableFuture<Object> f = invocationManager
.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid1, MINUTES.toMillis(5)));

assertTrueEventually(new AssertTask() {
@Override
public void run() {
RaftLockService service = getNodeEngineImpl(lockInstance).getService(RaftLockService.SERVICE_NAME);
LockRegistry registry = service.getRegistryOrNull(groupId);
assertNotNull(registry);
assertEquals(1, registry.getWaitTimeouts().size());
}
});

invocationManager.invoke(groupId, new LockOp(name, sessionId, getThreadId(), invUid2));

try {
f.join();
fail();
Expand All @@ -93,7 +129,7 @@ public void run() {
}

@Test
public void testTryLockWithTimeoutCancelsPendingLockRequest() {
public void testRetriedTryLockWithTimeoutDoesNotCancelPendingLockRequest() {
lockByOtherThread(lock);

// there is a session id now
Expand All @@ -103,8 +139,7 @@ public void testTryLockWithTimeoutCancelsPendingLockRequest() {
RaftInvocationManager invocationManager = getRaftInvocationManager(lockInstance);
UUID invUid = newUnsecureUUID();

InternalCompletableFuture<Object> f = invocationManager
.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid, MINUTES.toMillis(5)));
invocationManager.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid, MINUTES.toMillis(5)));


assertTrueEventually(new AssertTask() {
Expand All @@ -119,6 +154,43 @@ public void run() {

invocationManager.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid, MINUTES.toMillis(5)));

assertTrueEventually(new AssertTask() {
@Override
public void run() {
RaftLockService service = getNodeEngineImpl(lockInstance).getService(RaftLockService.SERVICE_NAME);
LockRegistry registry = service.getRegistryOrNull(groupId);
assertEquals(2, registry.getWaitTimeouts().size());
}
});
}

@Test(timeout = 30000)
public void testNewTryLockWithTimeoutCancelsPendingLockRequest() {
lockByOtherThread(lock);

// there is a session id now

final RaftGroupId groupId = lock.getGroupId();
long sessionId = getSessionManager().getSession(groupId);
RaftInvocationManager invocationManager = getRaftInvocationManager(lockInstance);
UUID invUid1 = newUnsecureUUID();
UUID invUid2 = newUnsecureUUID();

InternalCompletableFuture<Object> f = invocationManager
.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid1, MINUTES.toMillis(5)));

assertTrueEventually(new AssertTask() {
@Override
public void run() {
RaftLockService service = getNodeEngineImpl(lockInstance).getService(RaftLockService.SERVICE_NAME);
LockRegistry registry = service.getRegistryOrNull(groupId);
assertNotNull(registry);
assertEquals(1, registry.getWaitTimeouts().size());
}
});

invocationManager.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid2, MINUTES.toMillis(5)));

try {
f.join();
fail();
Expand All @@ -127,7 +199,7 @@ public void run() {
}

@Test
public void testTryLockWithoutTimeoutCancelsPendingLockRequest() {
public void testRetriedTryLockWithoutTimeoutDoesNotCancelPendingLockRequest() {
lockByOtherThread(lock);

// there is a session id now
Expand All @@ -137,8 +209,7 @@ public void testTryLockWithoutTimeoutCancelsPendingLockRequest() {
RaftInvocationManager invocationManager = getRaftInvocationManager(lockInstance);
UUID invUid = newUnsecureUUID();

InternalCompletableFuture<Object> f = invocationManager
.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid, MINUTES.toMillis(5)));
invocationManager.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid, MINUTES.toMillis(5)));

assertTrueEventually(new AssertTask() {
@Override
Expand All @@ -152,15 +223,18 @@ public void run() {

invocationManager.invoke(groupId, new TryLockOp(name, sessionId, getThreadId(), invUid, 0));

try {
f.join();
fail();
} catch (WaitKeyCancelledException ignored) {
}
assertTrueAllTheTime(new AssertTask() {
@Override
public void run() {
RaftLockService service = getNodeEngineImpl(lockInstance).getService(RaftLockService.SERVICE_NAME);
LockRegistry registry = service.getRegistryOrNull(groupId);
assertEquals(1, registry.getWaitTimeouts().size());
}
}, 10);
}

@Test
public void testUnlockCancelsPendingLockRequest() {
@Test(timeout = 30000)
public void testNewUnlockCancelsPendingLockRequest() {
lockByOtherThread(lock);

// there is a session id now
Expand Down

0 comments on commit f160d0b

Please sign in to comment.