Skip to content

Commit

Permalink
NXP-20250: Check for scroll timeout and prevent leak on bad usage
Browse files Browse the repository at this point in the history
  • Loading branch information
bdelbosc authored and nuxeojenkins committed Aug 25, 2016
1 parent e6516fc commit 5dc980b
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1067,15 +1067,14 @@ DocumentModel restoreToVersion(DocumentRef docRef, DocumentRef versionRef, boole
*/
IterableQueryResult queryAndFetch(String query, String queryType, boolean distinctDocuments, Object... params);


/**
* Executes the given query and returns the first batch of results, next batch must be requested
* within the {@code keepAliveSeconds} delay.
*
* @param query The NXQL query to execute
* @param batchSize The expected result batch size, note that more results can be returned when the backend don't
* implement properly this feature
* @param keepAliveSeconds The scroll context lifetime in seconds, use {@code -1} for a default value.
* @param keepAliveSeconds The scroll context lifetime in seconds
* @return A {@link ScrollResult} including the search results and a scroll id, to be passed to the subsequent
* calls to {@link #scroll(String)}
*
Expand All @@ -1086,6 +1085,8 @@ DocumentModel restoreToVersion(DocumentRef docRef, DocumentRef versionRef, boole
/**
* Get the next batch of result, the {@code scrollId} is part of the previous {@link ScrollResult} response.
*
* @throws NuxeoException when the {@code scrollId} is unknown or when the scroll operation has timed out
*
* @since 8.4
*/
ScrollResult scroll(String scrollId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public interface ScrollResult {
String getScrollId();

/**
* Returns the list of document ids.
* Returns the list of document ids
*/
List<String> getResultIds();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ public class MemRepository extends DBSRepositoryBase {

private static final Log log = LogFactory.getLog(MemRepository.class);

protected static final String NOSCROLL_ID = "noscroll";

// for debug
private final AtomicLong temporaryIdCounter = new AtomicLong(0);

Expand Down Expand Up @@ -371,13 +373,16 @@ public ScrollResult scroll(DBSExpressionEvaluator evaluator, int batchSize, int
ids.add(id);
}
}
return new ScrollResultImpl("noscroll", ids);
return new ScrollResultImpl(NOSCROLL_ID, ids);
}

@Override
public ScrollResult scroll(String scrollId) {
// Id are already in memory, theay are returned as a single batch
return ScrollResultImpl.emptyResult();
if (NOSCROLL_ID.equals(scrollId)) {
// Id are already in memory, they are returned as a single batch
return ScrollResultImpl.emptyResult();
}
throw new NuxeoException("Unknown or timed out scrollId");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,9 +883,7 @@ public PartialList<Map<String, Serializable>> queryAndFetch(DBSExpressionEvaluat

@Override
public ScrollResult scroll(DBSExpressionEvaluator evaluator, int batchSize, int keepAliveSeconds) {
if (keepAliveSeconds != -1 && log.isInfoEnabled()) {
log.info("scroll keepAlive is not supported, the default MongoDB cursor timeout is 10min");
}
checkForTimedoutScroll();
MongoDBQueryBuilder builder = new MongoDBQueryBuilder(this, evaluator.getExpression(),
evaluator.getSelectClause(), null, evaluator.pathResolver, evaluator.fulltextSearchDisabled);
builder.walk();
Expand All @@ -900,25 +898,37 @@ public ScrollResult scroll(DBSExpressionEvaluator evaluator, int batchSize, int

DBCursor cursor = coll.find(query, keys);
String scrollId = UUID.randomUUID().toString();
registerCursor(scrollId, cursor, batchSize);
registerCursor(scrollId, cursor, batchSize, keepAliveSeconds);
return scroll(scrollId);
}

protected void registerCursor(String scrollId, DBCursor cursor, int batchSize) {
cursorResults.put(scrollId, new CursorResult(cursor, batchSize));
synchronized protected void checkForTimedoutScroll() {
Set<String> scrollIds = new HashSet<String>(cursorResults.keySet());
scrollIds.stream().forEach(x -> cursorResults.get(x).timedOut(x));
}

protected void registerCursor(String scrollId, DBCursor cursor, int batchSize, int keepAliveSeconds) {
cursorResults.put(scrollId, new CursorResult(cursor, batchSize, keepAliveSeconds));
}

@Override
public ScrollResult scroll(String scrollId) {
CursorResult cursorResult = cursorResults.get(scrollId);
if (cursorResult == null) {
return emptyResult();
throw new NuxeoException("Unknown or timed out scrollId");
} else if (cursorResult.timedOut(scrollId)) {
throw new NuxeoException("Timed out scrollId");
}
cursorResult.touch();
List<String> ids = new ArrayList<>(cursorResult.batchSize);
synchronized (cursorResult) {
if (cursorResult.cursor == null || !cursorResult.cursor.hasNext()) {
unregisterCursor(scrollId);
return emptyResult();
}
while (ids.size() < cursorResult.batchSize) {
if (cursorResult.cursor == null || !cursorResult.cursor.hasNext()) {
unregisterCursor(scrollId);
cursorResult.close();
break;
} else {
DBObject ob = cursorResult.cursor.next();
Expand Down Expand Up @@ -1149,10 +1159,28 @@ protected class CursorResult {
// Note that MongoDB cursor automatically timeout after 10 minutes of inactivity by default.
protected DBCursor cursor;
protected final int batchSize;
protected long lastCallTimestamp;
protected final int keepAliveSeconds;

public CursorResult(DBCursor cursor, int batchSize) {
public CursorResult(DBCursor cursor, int batchSize, int keepAliveSeconds) {
this.cursor = cursor;
this.batchSize = batchSize;
this.keepAliveSeconds = keepAliveSeconds;
lastCallTimestamp = System.currentTimeMillis();
}

void touch() {
lastCallTimestamp = System.currentTimeMillis();
}

boolean timedOut(String scrollId) {
long now = System.currentTimeMillis();
if (now - lastCallTimestamp > (keepAliveSeconds * 1000)) {
log.warn("Scroll " + scrollId + " timed out");
unregisterCursor(scrollId);
return true;
}
return false;
}

public void close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ public class JDBCMapper extends JDBCRowMapper implements Mapper {

protected boolean clusteringEnabled;

protected static final String NOSCROLL_ID = "noscroll";

/**
* Creates a new Mapper.
*
Expand Down Expand Up @@ -933,11 +935,16 @@ public ScrollResult scroll(String query, int batchSize, int keepAliveSeconds) {
if (!dialect.supportsScroll()) {
return defaultScroll(query);
}
checkForTimedoutScroll();
return scrollSearch(query, batchSize, keepAliveSeconds);
}

synchronized protected void checkForTimedoutScroll() {
Set<String> scrollIds = new HashSet<String>(cursorResults.keySet());
scrollIds.stream().forEach(x -> cursorResults.get(x).timedOut(x));
}

protected ScrollResult scrollSearch(String query, int batchSize, int keepAliveSeconds) {
// TODO: handle keepAlive, at least check for timeout cursor at registerCluster to prevent leak or DOS
QueryMaker queryMaker = findQueryMaker("NXQL");
QueryFilter queryFilter = new QueryFilter(null, null, null, null, Collections.emptyList(), 0, 0);
QueryMaker.Query q = queryMaker.buildQuery(sqlInfo, model, pathResolver, query, queryFilter, null);
Expand All @@ -961,37 +968,55 @@ protected ScrollResult scrollSearch(String query, int batchSize, int keepAliveSe
}
ResultSet rs = ps.executeQuery();
String scrollId = UUID.randomUUID().toString();
registerCursor(scrollId, ps, rs, batchSize);
registerCursor(scrollId, ps, rs, batchSize, keepAliveSeconds);
return scroll(scrollId);
} catch (SQLException e) {
throw new NuxeoException("Error on query", e);
}
}

protected class CursorResult {
protected PreparedStatement preparedStatement;
protected ResultSet resultSet;
protected int batchSize;
protected final int keepAliveSeconds;
protected final PreparedStatement preparedStatement;
protected final ResultSet resultSet;
protected final int batchSize;
protected long lastCallTimestamp;

CursorResult(PreparedStatement preparedStatement, ResultSet resultSet, int batchSize) {
CursorResult(PreparedStatement preparedStatement, ResultSet resultSet, int batchSize, int keepAliveSeconds) {
this.preparedStatement = preparedStatement;
this.resultSet = resultSet;
this.batchSize = batchSize;
this.keepAliveSeconds = keepAliveSeconds;
lastCallTimestamp = System.currentTimeMillis();
}

boolean timedOut(String scrollId) {
long now = System.currentTimeMillis();
if (now - lastCallTimestamp > (keepAliveSeconds * 1000)) {
log.warn("Scroll " + scrollId + " timed out");
unregisterCursor(scrollId);
return true;
}
return false;
}

void close() throws SQLException {
void touch() {
lastCallTimestamp = System.currentTimeMillis();
}

synchronized void close() throws SQLException {
if (resultSet != null) {
resultSet.close();
}
if (preparedStatement != null) {
preparedStatement.close();
}
}

}

protected void registerCursor(String scrollId, PreparedStatement ps, ResultSet rs, int batchSize) {
cursorResults.put(scrollId, new CursorResult(ps, rs, batchSize));
protected void registerCursor(String scrollId, PreparedStatement ps, ResultSet rs, int batchSize,
int keepAliveSeconds) {
cursorResults.put(scrollId, new CursorResult(ps, rs, batchSize, keepAliveSeconds));
}

protected void unregisterCursor(String scrollId) {
Expand All @@ -1000,12 +1025,12 @@ protected void unregisterCursor(String scrollId) {
try {
cursor.close();
} catch (SQLException e) {
throw new NuxeoException("Failed to close cursor");
log.error("Failed to close cursor for scroll: " + scrollId, e);
// do not propagate exception on cleaning
}
}
}


protected ScrollResult defaultScroll(String query) {
// the database has no proper support for cursor just return everything in one batch
QueryMaker queryMaker = findQueryMaker("NXQL");
Expand All @@ -1019,19 +1044,22 @@ protected ScrollResult defaultScroll(String query) {
} catch (SQLException e) {
throw new NuxeoException("Invalid scroll query: " + query, e);
}
return new ScrollResultImpl("noscroll", ids);
return new ScrollResultImpl(NOSCROLL_ID, ids);
}

@Override
public ScrollResult scroll(String scrollId) {
if (!dialect.supportsScroll()) {
if (NOSCROLL_ID.equals(scrollId) || !dialect.supportsScroll()) {
// there is only one batch in this case
return emptyResult();
}
CursorResult cursorResult = cursorResults.get(scrollId);
if (cursorResult == null) {
return emptyResult();
throw new NuxeoException("Unknown or timed out scrollId");
} else if (cursorResult.timedOut(scrollId)) {
throw new NuxeoException("Timed out scrollId");
}
cursorResult.touch();
List<String> ids = new ArrayList<>(cursorResult.batchSize);
synchronized (cursorResult) {
try {
Expand All @@ -1043,12 +1071,12 @@ public ScrollResult scroll(String scrollId) {
if (cursorResult.resultSet.next()) {
ids.add(cursorResult.resultSet.getString(1));
} else {
unregisterCursor(scrollId);
cursorResult.close();
break;
}
}
} catch (SQLException e) {
throw new NuxeoException("Error in scroll", e);
throw new NuxeoException("Error during scroll", e);
}
}
return new ScrollResultImpl(scrollId, ids);
Expand Down
Loading

0 comments on commit 5dc980b

Please sign in to comment.