Skip to content

Commit

Permalink
Add support for custom transaction timeout into transactional endpoint.
Browse files Browse the repository at this point in the history
Allow propagation of custom timeout if specified in transactional service endpoint.

Extract header extraction from service into custom utility class.
  • Loading branch information
MishaDemianenko committed Sep 7, 2016
1 parent 02f7339 commit 4a05588
Show file tree
Hide file tree
Showing 15 changed files with 398 additions and 100 deletions.
Expand Up @@ -103,6 +103,12 @@ public void check()
throw new UnsupportedOperationException( "fake test class" );
}

@Override
public void check()
{
throw new UnsupportedOperationException( "fake test class" );
}

@Override
public TxStateHolder stateView()
{
Expand Down
Expand Up @@ -39,10 +39,10 @@
import org.neo4j.logging.LogProvider;
import org.neo4j.server.rest.web.ServerQuerySession;

import static org.neo4j.server.web.HttpHeaderUtils.getTransactionTimeout;

public class CypherExecutor extends LifecycleAdapter
{
static final String MAX_EXECUTION_TIME_HEADER = "max-execution-time";

private final Database database;
private ExecutionEngine executionEngine;
private GraphDatabaseQueryService service;
Expand Down Expand Up @@ -89,7 +89,7 @@ public QuerySession createSession( String query, Map<String, Object> parameters,

private InternalTransaction getInternalTransaction( HttpServletRequest request )
{
long customTimeout = getTransactionTimeLimit( request );
long customTimeout = getTransactionTimeout( request, log );
return customTimeout > 0 ? beginCustomTransaction( customTimeout ) : beginDefaultTransaction();
}

Expand All @@ -102,22 +102,4 @@ private InternalTransaction beginDefaultTransaction()
{
return service.beginTransaction( KernelTransaction.Type.implicit, AccessMode.Static.FULL );
}

private long getTransactionTimeLimit( HttpServletRequest request )
{
String headerValue = request.getHeader( MAX_EXECUTION_TIME_HEADER );
if ( headerValue != null )
{
try
{
return Long.parseLong( headerValue );
}
catch ( NumberFormatException e )
{
log.error( String.format( "Fail to parse `%s` header with value: '%s'. Should be a positive number.",
MAX_EXECUTION_TIME_HEADER, headerValue), e );
}
}
return -1;
}
}
Expand Up @@ -68,10 +68,11 @@ public TransactionFacade( TransitionalPeriodTransactionMessContainer kernel, Que
this.logProvider = logProvider;
}

public TransactionHandle newTransactionHandle( TransactionUriScheme uriScheme, boolean implicitTransaction, AccessMode mode )
throws TransactionLifecycleException
public TransactionHandle newTransactionHandle( TransactionUriScheme uriScheme, boolean implicitTransaction,
AccessMode mode, long customTransactionTimeout ) throws TransactionLifecycleException
{
return new TransactionHandle( kernel, engine, queryService, registry, uriScheme, implicitTransaction, mode, logProvider );
return new TransactionHandle( kernel, engine, queryService, registry, uriScheme, implicitTransaction, mode,
customTransactionTimeout, logProvider );
}

public TransactionHandle findTransactionHandle( long txId ) throws TransactionLifecycleException
Expand Down
Expand Up @@ -74,14 +74,15 @@ public class TransactionHandle implements TransactionTerminationHandle
private final TransactionUriScheme uriScheme;
private final Type type;
private final AccessMode mode;
private long customTransactionTimeout;
private final Log log;
private final long id;
private TransitionalTxManagementKernelTransaction context;
private GraphDatabaseQueryService queryService;

TransactionHandle( TransitionalPeriodTransactionMessContainer txManagerFacade, QueryExecutionEngine engine,
GraphDatabaseQueryService queryService, TransactionRegistry registry, TransactionUriScheme uriScheme,
boolean implicitTransaction, AccessMode mode, LogProvider logProvider )
boolean implicitTransaction, AccessMode mode, long customTransactionTimeout, LogProvider logProvider )
{
this.txManagerFacade = txManagerFacade;
this.engine = engine;
Expand All @@ -90,6 +91,7 @@ public class TransactionHandle implements TransactionTerminationHandle
this.uriScheme = uriScheme;
this.type = implicitTransaction ? Type.implicit : Type.explicit;
this.mode = mode;
this.customTransactionTimeout = customTransactionTimeout;
this.log = logProvider.getLog( getClass() );
this.id = registry.begin( this );
}
Expand Down Expand Up @@ -206,7 +208,7 @@ private void ensureActiveTransaction() throws InternalBeginTransactionError
{
try
{
context = txManagerFacade.newTransaction( type, mode );
context = txManagerFacade.newTransaction( type, mode, customTransactionTimeout );
}
catch ( RuntimeException e )
{
Expand Down Expand Up @@ -314,7 +316,8 @@ private void executeStatements( StatementDeserializer statements, ExecutionResul
}

hasPrevious = true;
QuerySession querySession = txManagerFacade.create( statement.statement(), statement.parameters(), queryService, type, mode, request );
QuerySession querySession = txManagerFacade.create( statement.statement(), statement.parameters(),
queryService, type, mode, customTransactionTimeout, request );
Result result = safelyExecute( statement, hasPeriodicCommit, querySession );
output.statementResult( result, statement.includeStats(), statement.resultDataContents() );
output.notifications( result.getNotifications() );
Expand Down
Expand Up @@ -47,28 +47,25 @@ public TransitionalPeriodTransactionMessContainer( GraphDatabaseFacade db )
this.txBridge = db.getDependencyResolver().resolveDependency( ThreadToStatementContextBridge.class );
}

public TransitionalTxManagementKernelTransaction newTransaction( Type type, AccessMode mode )
public TransitionalTxManagementKernelTransaction newTransaction( Type type, AccessMode mode,
long customTransactionTimeout )
{
return new TransitionalTxManagementKernelTransaction( db, type, mode, txBridge );
return new TransitionalTxManagementKernelTransaction( db, type, mode, customTransactionTimeout, txBridge );
}

public ThreadToStatementContextBridge getBridge()
{
return txBridge;
}

public QuerySession create(
String query,
Map<String, Object> parameters,
GraphDatabaseQueryService service,
Type type,
AccessMode mode,
HttpServletRequest request )

public QuerySession create( String query, Map<String, Object> parameters, GraphDatabaseQueryService service,
Type type, AccessMode mode, long customTransactionTimeout, HttpServletRequest request )
{
InternalTransaction transaction = db.beginTransaction( type, mode );
TransactionalContext context = new Neo4jTransactionalContext(
service, transaction, txBridge.get(), query, parameters, locker
);
InternalTransaction transaction = customTransactionTimeout > 0 ? db.beginTransaction( type, mode, customTransactionTimeout ) :
db.beginTransaction( type, mode);
TransactionalContext context = new Neo4jTransactionalContext( service, transaction, txBridge.get(), query,
parameters, locker );
return new ServerQuerySession( request, context );
}
}
Expand Up @@ -31,19 +31,21 @@ class TransitionalTxManagementKernelTransaction
private final GraphDatabaseFacade db;
private final KernelTransaction.Type type;
private final AccessMode mode;
private long customTransactionTimeout;
private final ThreadToStatementContextBridge bridge;

private InternalTransaction tx;
private KernelTransaction suspendedTransaction;

TransitionalTxManagementKernelTransaction( GraphDatabaseFacade db, KernelTransaction.Type type,
AccessMode mode, ThreadToStatementContextBridge bridge )
AccessMode mode, long customTransactionTimeout, ThreadToStatementContextBridge bridge )
{
this.db = db;
this.type = type;
this.mode = mode;
this.customTransactionTimeout = customTransactionTimeout;
this.bridge = bridge;
this.tx = db.beginTransaction( type, mode );
this.tx = startTransaction();
}

void suspendSinceTransactionsAreStillThreadBound()
Expand Down Expand Up @@ -108,6 +110,12 @@ void closeTransactionForPeriodicCommit()

void reopenAfterPeriodicCommit()
{
tx = db.beginTransaction( type, mode );
tx = startTransaction();
}

private InternalTransaction startTransaction()
{
return customTransactionTimeout > 0 ? db.beginTransaction( type, mode, customTransactionTimeout ) :
db.beginTransaction( type, mode );
}
}
Expand Up @@ -39,13 +39,15 @@
import javax.ws.rs.core.UriInfo;

import org.neo4j.kernel.api.security.AccessMode;
import org.neo4j.logging.Log;
import org.neo4j.server.rest.dbms.AuthorizedRequestWrapper;
import org.neo4j.server.rest.transactional.ExecutionResultSerializer;
import org.neo4j.server.rest.transactional.TransactionFacade;
import org.neo4j.server.rest.transactional.TransactionHandle;
import org.neo4j.server.rest.transactional.TransactionTerminationHandle;
import org.neo4j.server.rest.transactional.error.Neo4jError;
import org.neo4j.server.rest.transactional.error.TransactionLifecycleException;
import org.neo4j.server.web.HttpHeaderUtils;
import org.neo4j.udc.UsageData;

import static org.neo4j.udc.UsageDataKeys.Features.http_tx_endpoint;
Expand All @@ -61,12 +63,15 @@ public class TransactionalService
private final TransactionFacade facade;
private final UsageData usage;
private final TransactionUriScheme uriScheme;
private Log log;

public TransactionalService( @Context TransactionFacade facade, @Context UriInfo uriInfo, @Context UsageData usage )
public TransactionalService( @Context TransactionFacade facade, @Context UriInfo uriInfo, @Context UsageData usage,
@Context Log log )
{
this.facade = facade;
this.usage = usage;
this.uriScheme = new TransactionUriBuilder( uriInfo );
this.log = log;
}

@POST
Expand All @@ -79,7 +84,8 @@ public Response executeStatementsInNewTransaction( final InputStream input, @Con
{
usage.get( features ).flag( http_tx_endpoint );
AccessMode accessMode = AuthorizedRequestWrapper.getAccessModeFromHttpServletRequest( request );
TransactionHandle transactionHandle = facade.newTransactionHandle( uriScheme, false, accessMode );
long customTransactionTimeout = HttpHeaderUtils.getTransactionTimeout( request, log );
TransactionHandle transactionHandle = facade.newTransactionHandle( uriScheme, false, accessMode, customTransactionTimeout );
return createdResponse( transactionHandle, executeStatements( input, transactionHandle, uriInfo.getBaseUri(), request ) );
}
catch ( TransactionLifecycleException e )
Expand Down Expand Up @@ -137,7 +143,8 @@ public Response commitNewTransaction( final InputStream input, @Context final Ur
try
{
AccessMode accessMode = AuthorizedRequestWrapper.getAccessModeFromHttpServletRequest( request );
transactionHandle = facade.newTransactionHandle( uriScheme, true, accessMode );
long customTransactionTimeout = HttpHeaderUtils.getTransactionTimeout( request, log );
transactionHandle = facade.newTransactionHandle( uriScheme, true, accessMode, customTransactionTimeout );
}
catch ( TransactionLifecycleException e )
{
Expand Down
Expand Up @@ -19,14 +19,19 @@
*/
package org.neo4j.server.web;

import javax.ws.rs.core.MediaType;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.MediaType;

import org.neo4j.logging.Log;

public class HttpHeaderUtils {

public static final String MAX_EXECUTION_TIME_HEADER = "max-execution-time";

public static final Map<String, String> CHARSET = Collections.singletonMap("charset", StandardCharsets.UTF_8.name());

public static MediaType mediaTypeWithCharsetUtf8(String mediaType)
Expand All @@ -49,4 +54,30 @@ public static MediaType mediaTypeWithCharsetUtf8(MediaType mediaType)
paramsWithCharset.putAll(CHARSET);
return new MediaType(mediaType.getType(), mediaType.getSubtype(), paramsWithCharset);
}

/**
* Retrieve custom transaction timeout in milliseconds from numeric {@link #MAX_EXECUTION_TIME_HEADER} request
* header.
* If header is not set returns -1.
* @param request http request
* @param errorLog errors log for header parsing errors
* @return custom timeout if header set, -1 otherwise or when value is not a valid number.
*/
public static long getTransactionTimeout( HttpServletRequest request, Log errorLog )
{
String headerValue = request.getHeader( MAX_EXECUTION_TIME_HEADER );
if ( headerValue != null )
{
try
{
return Long.parseLong( headerValue );
}
catch ( NumberFormatException e )
{
errorLog.error( String.format( "Fail to parse `%s` header with value: '%s'. Should be a positive number.",
MAX_EXECUTION_TIME_HEADER, headerValue), e );
}
}
return -1;
}
}
Expand Up @@ -36,6 +36,7 @@
import org.neo4j.kernel.impl.factory.GraphDatabaseFacade;
import org.neo4j.kernel.impl.query.QueryExecutionEngine;
import org.neo4j.logging.AssertableLogProvider;
import org.neo4j.server.web.HttpHeaderUtils;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -79,7 +80,7 @@ public void startDefaultTransaction() throws Throwable
@Test
public void startTransactionWithCustomTimeout() throws Throwable
{
when( request.getHeader( CypherExecutor.MAX_EXECUTION_TIME_HEADER ) )
when( request.getHeader( HttpHeaderUtils.MAX_EXECUTION_TIME_HEADER ) )
.thenReturn( String.valueOf( CUSTOM_TRANSACTION_TIMEOUT ) );

CypherExecutor cypherExecutor = new CypherExecutor( database, logProvider );
Expand All @@ -95,7 +96,7 @@ public void startTransactionWithCustomTimeout() throws Throwable
@Test
public void startDefaultTransactionWhenHeaderHasIncorrectValue() throws Throwable
{
when( request.getHeader( CypherExecutor.MAX_EXECUTION_TIME_HEADER ) )
when( request.getHeader( HttpHeaderUtils.MAX_EXECUTION_TIME_HEADER ) )
.thenReturn( "not a number" );

CypherExecutor cypherExecutor = new CypherExecutor( database, logProvider );
Expand All @@ -111,7 +112,7 @@ public void startDefaultTransactionWhenHeaderHasIncorrectValue() throws Throwabl
@Test
public void startDefaultTransactionIfTimeoutIsNegative() throws Throwable
{
when( request.getHeader( CypherExecutor.MAX_EXECUTION_TIME_HEADER ) )
when( request.getHeader( HttpHeaderUtils.MAX_EXECUTION_TIME_HEADER ) )
.thenReturn( "-2" );

CypherExecutor cypherExecutor = new CypherExecutor( database, logProvider );
Expand All @@ -123,21 +124,6 @@ public void startDefaultTransactionIfTimeoutIsNegative() throws Throwable
logProvider.assertNoLoggingOccurred();
}

@Test
public void startDefaultTransactionIfExecutionGuardDisabled() throws Throwable
{
when( request.getHeader( CypherExecutor.MAX_EXECUTION_TIME_HEADER ) )
.thenReturn( String.valueOf( CUSTOM_TRANSACTION_TIMEOUT ) );

CypherExecutor cypherExecutor = new CypherExecutor( database, logProvider );
cypherExecutor.start();

cypherExecutor.createSession( request );

verify( databaseQueryService ).beginTransaction( KernelTransaction.Type.implicit, AccessMode.Static.FULL );
logProvider.assertNoLoggingOccurred();
}

private void initLogProvider()
{
logProvider = new AssertableLogProvider( true );
Expand All @@ -157,16 +143,20 @@ private void setUpMocks()

InternalTransaction transaction = new TopLevelTransaction( kernelTransaction, () -> statement );

AccessMode.Static accessMode = AccessMode.Static.FULL;
KernelTransaction.Type type = KernelTransaction.Type.implicit;
when( kernelTransaction.mode() ).thenReturn( accessMode );
when( kernelTransaction.transactionType() ).thenReturn( type );
when( database.getGraph() ).thenReturn( databaseFacade );
when( databaseFacade.getDependencyResolver() ).thenReturn( dependencyResolver );
when( dependencyResolver.resolveDependency( QueryExecutionEngine.class ) ).thenReturn( executionEngine );
when( dependencyResolver.resolveDependency( ThreadToStatementContextBridge.class ) ).thenReturn(
statementBridge );
when( dependencyResolver.resolveDependency( GraphDatabaseQueryService.class ) ).thenReturn(
databaseQueryService );
when( databaseQueryService.beginTransaction( KernelTransaction.Type.implicit, AccessMode.Static.FULL ) )
when( databaseQueryService.beginTransaction( type, accessMode ) )
.thenReturn( transaction );
when( databaseQueryService.beginTransaction( KernelTransaction.Type.implicit, AccessMode.Static.FULL,
when( databaseQueryService.beginTransaction( type, accessMode,
CUSTOM_TRANSACTION_TIMEOUT ) ).thenReturn( transaction );
when( databaseQueryService.getDependencyResolver() ).thenReturn( dependencyResolver );
}
Expand Down

0 comments on commit 4a05588

Please sign in to comment.