Skip to content

Commit

Permalink
Make reference count check atomic with release (#876)
Browse files Browse the repository at this point in the history
JAVA-4490

Co-authored-by: Valentin Kovalenko <valentin.kovalenko@mongodb.com>
  • Loading branch information
jyemin and stIncMale committed Feb 10, 2022
1 parent dd291c3 commit 058d9f4
Show file tree
Hide file tree
Showing 25 changed files with 131 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.mongodb.connection.ClusterType;
import com.mongodb.connection.ServerDescription;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.internal.binding.AbstractReferenceCounted;
import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding;
import com.mongodb.internal.binding.AsyncConnectionSource;
import com.mongodb.internal.binding.AsyncReadWriteBinding;
Expand All @@ -36,15 +37,15 @@
import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.connection.ClusterType.LOAD_BALANCED;

public class ClientSessionBinding implements AsyncReadWriteBinding {
public class ClientSessionBinding extends AbstractReferenceCounted implements AsyncReadWriteBinding {
private final AsyncClusterAwareReadWriteBinding wrapped;
private final AsyncClientSession session;
private final boolean ownsSession;
private final ClientSessionContext sessionContext;

public ClientSessionBinding(final AsyncClientSession session, final boolean ownsSession,
final AsyncClusterAwareReadWriteBinding wrapped) {
this.wrapped = notNull("wrapped", (wrapped));
this.wrapped = notNull("wrapped", wrapped).retain();
this.ownsSession = ownsSession;
this.session = notNull("session", session);
this.sessionContext = new AsyncClientSessionContext(session);
Expand Down Expand Up @@ -113,14 +114,9 @@ private void getPinnedConnectionSource(final boolean isRead, final SingleResultC
}
}

@Override
public int getCount() {
return wrapped.getCount();
}

@Override
public AsyncReadWriteBinding retain() {
wrapped.retain();
super.retain();
return this;
}

Expand All @@ -131,15 +127,15 @@ public void getReadConnectionSource(final int minWireVersion, final ReadPreferen
}

@Override
public void release() {
wrapped.release();
closeSessionIfCountIsZero();
}

private void closeSessionIfCountIsZero() {
if (getCount() == 0 && ownsSession) {
session.close();
public int release() {
int count = super.release();
if (count == 0) {
wrapped.release();
if (ownsSession) {
session.close();
}
}
return count;
}

private boolean isConnectionSourcePinningRequired() {
Expand All @@ -152,6 +148,7 @@ private class SessionBindingAsyncConnectionSource implements AsyncConnectionSour

SessionBindingAsyncConnectionSource(final AsyncConnectionSource wrapped) {
this.wrapped = wrapped;
ClientSessionBinding.this.retain();
}

@Override
Expand Down Expand Up @@ -214,9 +211,12 @@ public int getCount() {
}

@Override
public void release() {
wrapped.release();
closeSessionIfCountIsZero();
public int release() {
int count = wrapped.release();
if (count == 0) {
ClientSessionBinding.this.release();
}
return count;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ public ReferenceCounted retain() {
}

@Override
public void release() {
if (referenceCount.decrementAndGet() < 0) {
public int release() {
int decrementedValue = referenceCount.decrementAndGet();
if (decrementedValue < 0) {
throw new IllegalStateException("Attempted to decrement the reference count below 0");
}
return decrementedValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ public interface AsyncClusterAwareReadWriteBinding extends AsyncReadWriteBinding
* @param callback the to be passed the connection source
*/
void getConnectionSource(ServerAddress serverAddress, SingleResultCallback<AsyncConnectionSource> callback);

@Override
AsyncClusterAwareReadWriteBinding retain();
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public AsyncClusterBinding(final Cluster cluster, final ReadPreference readPrefe
}

@Override
public AsyncReadWriteBinding retain() {
public AsyncClusterAwareReadWriteBinding retain() {
super.retain();
return this;
}
Expand Down Expand Up @@ -208,9 +208,10 @@ public AsyncConnectionSource retain() {
}

@Override
public void release() {
super.release();
public int release() {
int count = super.release();
AsyncClusterBinding.this.release();
return count;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ public ConnectionSource retain() {
}

@Override
public void release() {
super.release();
public int release() {
int count = super.release();
ClusterBinding.this.release();
return count;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

package com.mongodb.internal.binding;

import com.mongodb.internal.VisibleForTesting;

import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE;

/**
* An interface for reference-counted objects.
* <p>
Expand All @@ -39,9 +43,14 @@ public interface ReferenceCounted {
/**
* Gets the current reference count.
*
* <p>
* This method should only be used for testing. Production code should prefer using the count returned from {@link #release()}
* </p>
*
* @return the current count, which must be greater than or equal to 0.
* Returns 1 for a newly created object.
*/
@VisibleForTesting(otherwise = PRIVATE)
int getCount();

/**
Expand All @@ -54,6 +63,7 @@ public interface ReferenceCounted {
/**
* Release a reference to this object.
* @throws java.lang.IllegalStateException if the reference count is already 0
* @return the reference count after the release
*/
void release();
int release();
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,12 @@ public ConnectionSource retain() {
}

@Override
public void release() {
super.release();
if (super.getCount() == 0) {
public int release() {
int count = super.release();
if (count == 0) {
SingleServerBinding.this.release();
}
return count;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ public boolean isConnectionPinningRequired() {
}

@Override
public void release() {
super.release();
if (getCount() == 0) {
public int release() {
int count = super.release();
if (count == 0) {
if (pinnedConnection != null) {
pinnedConnection.release();
}
}
return count;
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ public ReferenceCounted retain() {
}

@Override
public void release() {
if (referenceCount.decrementAndGet() < 0) {
public int release() {
int decrementedValue = referenceCount.decrementAndGet();
if (decrementedValue < 0) {
throw new IllegalStateException("Attempted to decrement the reference count below 0");
}
return decrementedValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,12 @@ public int getCount() {
}

@Override
public void release() {
wrapped.release();
if (getCount() == 0) {
public int release() {
int count = wrapped.release();
if (count == 0) {
server.operationEnd();
}
return count;
}

@Override
Expand Down Expand Up @@ -401,11 +402,12 @@ public int getCount() {
}

@Override
public void release() {
wrapped.release();
if (getCount() == 0) {
public int release() {
int count = wrapped.release();
if (count == 0) {
server.operationEnd();
}
return count;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

import java.util.List;

import static com.mongodb.assertions.Assertions.isTrue;
import static com.mongodb.connection.ServerType.SHARD_ROUTER;
import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback;

Expand All @@ -61,16 +60,16 @@ public DefaultServerConnection retain() {
}

@Override
public void release() {
super.release();
if (getCount() == 0) {
public int release() {
int count = super.release();
if (count == 0) {
wrapped.close();
}
return count;
}

@Override
public ConnectionDescription getDescription() {
isTrue("open", getCount() > 0);
return wrapped.getDescription();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public int getCount() {
}

@Override
public void release() {
wrapped.release();
public int release() {
return wrapped.release();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ public AsyncReadWriteBinding retain() {
}

@Override
public void release() {
wrapped.release();
public int release() {
return wrapped.release();
}

private class SessionBindingAsyncConnectionSource implements AsyncConnectionSource {
Expand Down Expand Up @@ -168,8 +168,8 @@ public AsyncConnectionSource retain() {
}

@Override
public void release() {
wrapped.release();
public int release() {
return wrapped.release();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,13 @@ public void getWriteConnectionSource(final SingleResultCallback<AsyncConnectionS
}

@Override
public void release() {
super.release();
if (getCount() == 0) {
public int release() {
int count = super.release();
if (count == 0) {
readConnection.release();
writeConnection.release();
}
return count;
}

private final class SingleAsyncConnectionSource extends AbstractReferenceCounted implements AsyncConnectionSource {
Expand Down Expand Up @@ -259,11 +260,12 @@ public AsyncConnectionSource retain() {
}

@Override
public void release() {
super.release();
if (super.getCount() == 0) {
public int release() {
int count = super.release();
if (count == 0) {
AsyncSingleConnectionBinding.this.release();
}
return count;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ public ReadWriteBinding retain() {
}

@Override
public void release() {
wrapped.release();
public int release() {
return wrapped.release();
}

@Override
Expand Down Expand Up @@ -136,8 +136,8 @@ public int getCount() {
}

@Override
public void release() {
wrapped.release();
public int release() {
return wrapped.release();
}
}

Expand Down

0 comments on commit 058d9f4

Please sign in to comment.