Skip to content

Commit

Permalink
Close auto commit transaction when execution and streaming fails
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-ince committed Mar 23, 2018
1 parent 2887607 commit 1198c65
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 47 deletions.
Expand Up @@ -236,40 +236,63 @@ else if ( ROLLBACK.matcher( statement ).matches() )
{ {
ctx.lastStatement = statement; ctx.lastStatement = statement;
} }
if ( spi.isPeriodicCommit( statement ) )
execute( ctx, spi, statement, params, spi.isPeriodicCommit( statement ) );

return AUTO_COMMIT;
}
}

void execute( MutableTransactionState ctx, SPI spi, String statement, MapValue params, boolean isPeriodicCommit )
throws KernelException
{
// only acquire a new transaction when the statement does not contain periodic commit
if ( !isPeriodicCommit )
{
ctx.currentTransaction = spi.beginTransaction( ctx.loginContext );
}

boolean failed = true;
try
{
BoltResultHandle resultHandle = spi.executeQuery( ctx.querySource, ctx.loginContext, statement, params );
startExecution( ctx, resultHandle );
failed = false;
}
finally
{
// if we acquired a transaction and a failure occurred, then simply close the transaction
if ( !isPeriodicCommit )
{ {
BoltResultHandle resultHandle = executeQuery( ctx, spi, statement, params, noop() ); if ( failed )
startExecution( ctx, resultHandle ); {
ctx.currentTransaction = null; // Periodic commit will change the current transaction, so closeTransaction( ctx, false );
// we can't trust this to point to the actual current transaction; }
return AUTO_COMMIT;
} }
else else
{ {
ctx.currentTransaction = spi.beginTransaction( ctx.loginContext ); // Periodic commit will change the current transaction, so
BoltResultHandle resultHandle = execute( ctx, spi, statement, params ); // we can't trust this to point to the actual current transaction;
startExecution( ctx, resultHandle ); ctx.currentTransaction = null;
return AUTO_COMMIT;
} }
} }
} }


/*
* In AUTO_COMMIT we must make sure to fail, close and set the current
* transaction to null.
*/
private BoltResultHandle execute( MutableTransactionState ctx, SPI spi, String statement, MapValue params )
{
return executeQuery( ctx, spi, statement, params, () -> closeTransaction( ctx, false ) );
}

@Override @Override
void streamResult( MutableTransactionState ctx, void streamResult( MutableTransactionState ctx,
ThrowingConsumer<BoltResult, Exception> resultConsumer ) throws Exception ThrowingConsumer<BoltResult, Exception> resultConsumer ) throws Exception
{ {
assert ctx.currentResult != null; assert ctx.currentResult != null;
consumeResult( ctx, resultConsumer );
closeTransaction( ctx, true ); boolean success = false;
try
{
success = consumeResult( ctx, resultConsumer );
}
finally
{
closeTransaction( ctx, success );
}
} }
}, },
EXPLICIT_TRANSACTION EXPLICIT_TRANSACTION
Expand Down Expand Up @@ -325,14 +348,7 @@ else if ( ROLLBACK.matcher( statement ).matches() )


private BoltResultHandle execute( MutableTransactionState ctx, SPI spi, String statement, MapValue params ) private BoltResultHandle execute( MutableTransactionState ctx, SPI spi, String statement, MapValue params )
{ {
return executeQuery( ctx, spi, statement, params, return executeQuery( ctx, spi, statement, params );
() ->
{
if ( ctx.currentTransaction != null )
{
ctx.currentTransaction.failure();
}
} );
} }


@Override @Override
Expand Down Expand Up @@ -400,7 +416,7 @@ void closeTransaction( MutableTransactionState ctx, boolean success ) throws Tra
} }
} }


void consumeResult( MutableTransactionState ctx, ThrowingConsumer<BoltResult,Exception> resultConsumer ) throws Exception boolean consumeResult( MutableTransactionState ctx, ThrowingConsumer<BoltResult,Exception> resultConsumer ) throws Exception
{ {
boolean success = false; boolean success = false;
try try
Expand All @@ -419,6 +435,7 @@ void consumeResult( MutableTransactionState ctx, ThrowingConsumer<BoltResult,Exc
ctx.currentResultHandle = null; ctx.currentResultHandle = null;
} }
} }
return success;
} }


void startExecution( MutableTransactionState ctx, BoltResultHandle resultHandle ) throws KernelException void startExecution( MutableTransactionState ctx, BoltResultHandle resultHandle ) throws KernelException
Expand All @@ -439,9 +456,9 @@ void startExecution( MutableTransactionState ctx, BoltResultHandle resultHandle
} }


private static BoltResultHandle executeQuery( MutableTransactionState ctx, SPI spi, String statement, private static BoltResultHandle executeQuery( MutableTransactionState ctx, SPI spi, String statement,
MapValue params, ThrowingAction<KernelException> onFail ) MapValue params )
{ {
return spi.executeQuery( ctx.querySource, ctx.loginContext, statement, params, onFail ); return spi.executeQuery( ctx.querySource, ctx.loginContext, statement, params );
} }


/** /**
Expand Down Expand Up @@ -509,9 +526,6 @@ interface SPI
boolean isPeriodicCommit( String query ); boolean isPeriodicCommit( String query );


BoltResultHandle executeQuery( BoltQuerySource querySource, BoltResultHandle executeQuery( BoltQuerySource querySource,
LoginContext loginContext, LoginContext loginContext, String statement, MapValue params );
String statement,
MapValue params,
ThrowingAction<KernelException> onFail );
} }
} }
Expand Up @@ -123,9 +123,7 @@ public boolean isPeriodicCommit( String query )


@Override @Override
public BoltResultHandle executeQuery( BoltQuerySource querySource, public BoltResultHandle executeQuery( BoltQuerySource querySource,
LoginContext loginContext, LoginContext loginContext, String statement, MapValue params )
String statement,
MapValue params, ThrowingAction<KernelException> onFail )
{ {
InternalTransaction internalTransaction = queryService.beginTransaction( implicit, loginContext ); InternalTransaction internalTransaction = queryService.beginTransaction( implicit, loginContext );
ClientConnectionInfo sourceDetails = new BoltConnectionInfo( querySource.principalName, ClientConnectionInfo sourceDetails = new BoltConnectionInfo( querySource.principalName,
Expand All @@ -151,13 +149,11 @@ public BoltResult start() throws KernelException
catch ( KernelException e ) catch ( KernelException e )
{ {
close( false ); close( false );
onFail.apply();
throw new QueryExecutionKernelException( e ); throw new QueryExecutionKernelException( e );
} }
catch ( Throwable e ) catch ( Throwable e )
{ {
close( false ); close( false );
onFail.apply();
throw e; throw e;
} }
} }
Expand Down
Expand Up @@ -95,6 +95,29 @@ public void describeTo( Description description )
}; };
} }


public static Matcher<RecordedBoltResponse> containsRecord( final Object... values )
{
return new BaseMatcher<RecordedBoltResponse>()
{
private AnyValue[] anyValues = Arrays.stream( values ).map( ValueUtils::of ).toArray( AnyValue[]::new );

@Override
public boolean matches( final Object item )
{

final RecordedBoltResponse response = (RecordedBoltResponse) item;
QueryResult.Record[] records = response.records();
return records.length > 0 && Arrays.equals( records[0].fields(), anyValues );
}

@Override
public void describeTo( Description description )
{
description.appendText( format( "with record %s", values ) );
}
};
}

public static Matcher<RecordedBoltResponse> succeededWithRecord( final Object... values ) public static Matcher<RecordedBoltResponse> succeededWithRecord( final Object... values )
{ {
return new BaseMatcher<RecordedBoltResponse>() return new BaseMatcher<RecordedBoltResponse>()
Expand Down
Expand Up @@ -422,7 +422,7 @@ public void shouldThrowDuringStreamResultIfPendingTerminationNoticeExists() thro
} }


@Test @Test
public void shouldCloseResultHandlesWhenExecutionFails() throws Exception public void shouldCloseResultAndTransactionHandlesWhenExecutionFails() throws Exception
{ {
KernelTransaction transaction = newTransaction(); KernelTransaction transaction = newTransaction();
TransactionStateMachine.BoltResultHandle resultHandle = newResultHandle( new RuntimeException( "some error" ) ); TransactionStateMachine.BoltResultHandle resultHandle = newResultHandle( new RuntimeException( "some error" ) );
Expand All @@ -442,10 +442,11 @@ public void shouldCloseResultHandlesWhenExecutionFails() throws Exception


assertNull( stateMachine.ctx.currentResultHandle ); assertNull( stateMachine.ctx.currentResultHandle );
assertNull( stateMachine.ctx.currentResult ); assertNull( stateMachine.ctx.currentResult );
assertNull( stateMachine.ctx.currentTransaction );
} }


@Test @Test
public void shouldCloseResultHandlesWhenConsumeFails() throws Exception public void shouldCloseResultAndTransactionHandlesWhenConsumeFails() throws Exception
{ {
KernelTransaction transaction = newTransaction(); KernelTransaction transaction = newTransaction();
TransactionStateMachineSPI stateMachineSPI = newTransactionStateMachineSPI( transaction ); TransactionStateMachineSPI stateMachineSPI = newTransactionStateMachineSPI( transaction );
Expand All @@ -472,6 +473,7 @@ public void shouldCloseResultHandlesWhenConsumeFails() throws Exception


assertNull( stateMachine.ctx.currentResultHandle ); assertNull( stateMachine.ctx.currentResultHandle );
assertNull( stateMachine.ctx.currentResult ); assertNull( stateMachine.ctx.currentResult );
assertNull( stateMachine.ctx.currentTransaction );
} }


@Test @Test
Expand Down Expand Up @@ -500,6 +502,7 @@ public void shouldCloseResultHandlesWhenExecutionFailsInExplicitTransaction() th


assertNull( stateMachine.ctx.currentResultHandle ); assertNull( stateMachine.ctx.currentResultHandle );
assertNull( stateMachine.ctx.currentResult ); assertNull( stateMachine.ctx.currentResult );
assertNotNull( stateMachine.ctx.currentTransaction );
} }


@Test @Test
Expand Down Expand Up @@ -535,6 +538,7 @@ public void shouldCloseResultHandlesWhenConsumeFailsInExplicitTransaction() thro


assertNull( stateMachine.ctx.currentResultHandle ); assertNull( stateMachine.ctx.currentResultHandle );
assertNull( stateMachine.ctx.currentResult ); assertNull( stateMachine.ctx.currentResult );
assertNotNull( stateMachine.ctx.currentTransaction );
} }


private static KernelTransaction newTransaction() private static KernelTransaction newTransaction()
Expand Down Expand Up @@ -571,8 +575,8 @@ private static TransactionStateMachineSPI newFailingTransactionStateMachineSPI(
TransactionStateMachineSPI stateMachineSPI = mock( TransactionStateMachineSPI.class ); TransactionStateMachineSPI stateMachineSPI = mock( TransactionStateMachineSPI.class );


when( stateMachineSPI.beginTransaction( any() ) ).thenReturn( mock( KernelTransaction.class ) ); when( stateMachineSPI.beginTransaction( any() ) ).thenReturn( mock( KernelTransaction.class ) );
when( stateMachineSPI.executeQuery( any(), any(), anyString(), any(), any() ) ).thenReturn( resultHandle ); when( stateMachineSPI.executeQuery( any(), any(), anyString(), any() ) ).thenReturn( resultHandle );
when( stateMachineSPI.executeQuery( any(), any(), eq( "FAIL" ), any(), any() ) ).thenThrow( new TransactionTerminatedException( failureStatus ) ); when( stateMachineSPI.executeQuery( any(), any(), eq( "FAIL" ), any() ) ).thenThrow( new TransactionTerminatedException( failureStatus ) );


return stateMachineSPI; return stateMachineSPI;
} }
Expand All @@ -583,7 +587,7 @@ private static TransactionStateMachineSPI newTransactionStateMachineSPI( KernelT
TransactionStateMachineSPI stateMachineSPI = mock( TransactionStateMachineSPI.class ); TransactionStateMachineSPI stateMachineSPI = mock( TransactionStateMachineSPI.class );


when( stateMachineSPI.beginTransaction( any() ) ).thenReturn( transaction ); when( stateMachineSPI.beginTransaction( any() ) ).thenReturn( transaction );
when( stateMachineSPI.executeQuery( any(), any(), anyString(), any(), any() ) ).thenReturn( resultHandle ); when( stateMachineSPI.executeQuery( any(), any(), anyString(), any() ) ).thenReturn( resultHandle );


return stateMachineSPI; return stateMachineSPI;
} }
Expand All @@ -594,7 +598,7 @@ private static TransactionStateMachineSPI newTransactionStateMachineSPI( KernelT
TransactionStateMachineSPI stateMachineSPI = mock( TransactionStateMachineSPI.class ); TransactionStateMachineSPI stateMachineSPI = mock( TransactionStateMachineSPI.class );


when( stateMachineSPI.beginTransaction( any() ) ).thenReturn( transaction ); when( stateMachineSPI.beginTransaction( any() ) ).thenReturn( transaction );
when( stateMachineSPI.executeQuery( any(), any(), anyString(), any(), any() ) ).thenReturn( resultHandle ); when( stateMachineSPI.executeQuery( any(), any(), anyString(), any() ) ).thenReturn( resultHandle );


return stateMachineSPI; return stateMachineSPI;
} }
Expand Down
Expand Up @@ -41,6 +41,7 @@
import org.neo4j.graphdb.Label; import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Transaction; import org.neo4j.graphdb.Transaction;
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.kernel.impl.util.ValueUtils; import org.neo4j.kernel.impl.util.ValueUtils;
import org.neo4j.test.Barrier; import org.neo4j.test.Barrier;
import org.neo4j.test.DoubleLatch; import org.neo4j.test.DoubleLatch;
Expand All @@ -50,11 +51,16 @@
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonMap; import static java.util.Collections.singletonMap;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.AllOf.allOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.neo4j.bolt.testing.BoltMatchers.containsRecord;
import static org.neo4j.bolt.testing.BoltMatchers.failedWithStatus;
import static org.neo4j.bolt.testing.BoltMatchers.succeeded; import static org.neo4j.bolt.testing.BoltMatchers.succeeded;
import static org.neo4j.bolt.testing.BoltMatchers.succeededWithMetadata; import static org.neo4j.bolt.testing.BoltMatchers.succeededWithMetadata;
import static org.neo4j.bolt.testing.BoltMatchers.succeededWithRecord; import static org.neo4j.bolt.testing.BoltMatchers.succeededWithRecord;
import static org.neo4j.bolt.testing.BoltMatchers.wasIgnored;
import static org.neo4j.bolt.testing.NullResponseHandler.nullResponseHandler; import static org.neo4j.bolt.testing.NullResponseHandler.nullResponseHandler;
import static org.neo4j.values.virtual.VirtualValues.EMPTY_MAP; import static org.neo4j.values.virtual.VirtualValues.EMPTY_MAP;


Expand Down Expand Up @@ -376,6 +382,54 @@ public void beginShouldNotOverwriteLastStatement() throws Throwable
assertThat( recorder.nextResponse(), succeededWithRecord( 1L ) ); assertThat( recorder.nextResponse(), succeededWithRecord( 1L ) );
} }


@Test
public void shouldCloseAutoCommitTransactionAndRespondToNextStatementWhenRunFails() throws Throwable
{
// Given
final BoltStateMachine machine = env.newMachine( boltChannel );
machine.init( USER_AGENT, emptyMap(), null );
BoltResponseRecorder recorder = new BoltResponseRecorder();

// When
machine.run( "INVALID QUERY", EMPTY_MAP, recorder );
machine.pullAll( recorder );
machine.ackFailure( recorder );
machine.run( "RETURN 2", EMPTY_MAP, recorder );
machine.pullAll( recorder );

// Then
assertThat( recorder.nextResponse(), failedWithStatus( Status.Statement.SyntaxError ) );
assertThat( recorder.nextResponse(), wasIgnored() );
assertThat( recorder.nextResponse(), succeeded() );
assertThat( recorder.nextResponse(), succeeded() );
assertThat( recorder.nextResponse(), succeededWithRecord( 2L ) );
assertEquals( recorder.responseCount(), 0 );
}

@Test
public void shouldCloseAutoCommitTransactionAndRespondToNextStatementWhenStreamingFails() throws Throwable
{
// Given
final BoltStateMachine machine = env.newMachine( boltChannel );
machine.init( USER_AGENT, emptyMap(), null );
BoltResponseRecorder recorder = new BoltResponseRecorder();

// When
machine.run( "UNWIND [1, 0] AS x RETURN 1 / x", EMPTY_MAP, recorder );
machine.pullAll( recorder );
machine.ackFailure( recorder );
machine.run( "RETURN 2", EMPTY_MAP, recorder );
machine.pullAll( recorder );

// Then
assertThat( recorder.nextResponse(), succeeded() );
assertThat( recorder.nextResponse(), allOf( containsRecord( 1L ), failedWithStatus( Status.Statement.ArithmeticError ) ) );
assertThat( recorder.nextResponse(), succeeded() );
assertThat( recorder.nextResponse(), succeeded() );
assertThat( recorder.nextResponse(), succeededWithRecord( 2L ) );
assertEquals( recorder.responseCount(), 0 );
}

public static Server createHttpServer( public static Server createHttpServer(
DoubleLatch latch, Barrier.Control innerBarrier, int firstBatchSize, int otherBatchSize ) DoubleLatch latch, Barrier.Control innerBarrier, int firstBatchSize, int otherBatchSize )
{ {
Expand Down

0 comments on commit 1198c65

Please sign in to comment.