Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate TX context in Kotlin coroutines #1274

Merged
merged 2 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data-tx/build.gradle
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
plugins {
id "io.micronaut.build.internal.data-kotlin"
id "io.micronaut.build.internal.data-module"
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2017-2022 original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.micronaut.transaction.interceptor;

import io.micronaut.aop.kotlin.KotlinInterceptedMethod;
import io.micronaut.context.annotation.Requires;
import io.micronaut.core.annotation.Internal;
import io.micronaut.transaction.support.TransactionSynchronizationManager;
import jakarta.inject.Singleton;
import kotlin.coroutines.CoroutineContext;

/**
* Helper to setup Kotlin coroutine context.
*
* @author Denis Stepanov
* @since 3.3
*/
@Internal
@Singleton
@Requires(classes = kotlin.coroutines.CoroutineContext.class)
final class CoroutineTxHelper {

public void setupCoroutineContext(KotlinInterceptedMethod kotlinInterceptedMethod) {
CoroutineContext existingContext = kotlinInterceptedMethod.getCoroutineContext();
kotlinInterceptedMethod.updateCoroutineContext(existingContext.plus(new TxSynchronousContext(TransactionSynchronizationManager.copyState())));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.micronaut.aop.InterceptedMethod;
import io.micronaut.aop.MethodInterceptor;
import io.micronaut.aop.MethodInvocationContext;
import io.micronaut.aop.kotlin.KotlinInterceptedMethod;
import io.micronaut.context.BeanLocator;
import io.micronaut.context.exceptions.ConfigurationException;
import io.micronaut.core.annotation.AnnotationValue;
Expand All @@ -35,13 +36,16 @@
import io.micronaut.transaction.exceptions.TransactionSystemException;
import io.micronaut.transaction.reactive.ReactiveTransactionOperations;
import io.micronaut.transaction.reactive.ReactiveTransactionStatus;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;

/**
Expand Down Expand Up @@ -74,14 +78,27 @@ public String toString() {

@NonNull
private final BeanLocator beanLocator;
private final CoroutineTxHelper coroutineTxHelper;

/**
* Default constructor.
*
* @param beanLocator The bean locator.
*/
public TransactionalInterceptor(@NonNull BeanLocator beanLocator) {
this(beanLocator, null);
}

/**
* Default constructor.
*
* @param beanLocator The bean locator.
* @param coroutineTxHelper The coroutine helper
*/
@Inject
public TransactionalInterceptor(@NonNull BeanLocator beanLocator, @Nullable CoroutineTxHelper coroutineTxHelper) {
this.beanLocator = beanLocator;
this.coroutineTxHelper = coroutineTxHelper;
}

@Override
Expand All @@ -92,6 +109,7 @@ public int getOrder() {
@Override
public Object intercept(MethodInvocationContext<Object, Object> context) {
InterceptedMethod interceptedMethod = InterceptedMethod.of(context);
boolean isKotlinSuspended = interceptedMethod instanceof KotlinInterceptedMethod;
try {
boolean isReactive = interceptedMethod.resultType() == InterceptedMethod.ResultType.PUBLISHER;
boolean isAsync = interceptedMethod.resultType() == InterceptedMethod.ResultType.COMPLETION_STAGE;
Expand All @@ -100,9 +118,10 @@ public Object intercept(MethodInvocationContext<Object, Object> context) {
.computeIfAbsent(context.getExecutableMethod(), executableMethod -> {
final String qualifier = executableMethod.stringValue(TransactionalAdvice.class).orElse(null);

if (isReactive || isAsync) {
ReactiveTransactionOperations<?> reactiveTransactionOperations
= beanLocator.findBean(ReactiveTransactionOperations.class, qualifier != null ? Qualifiers.byName(qualifier) : null).orElse(null);
ReactiveTransactionOperations<?> reactiveTransactionOperations
= beanLocator.findBean(ReactiveTransactionOperations.class, qualifier != null ? Qualifiers.byName(qualifier) : null).orElse(null);

if ((isReactive || isAsync) && !(isKotlinSuspended && reactiveTransactionOperations == null)) {
if (isReactive && reactiveTransactionOperations == null) {
throw new ConfigurationException("No reactive transaction management has been configured. Ensure you have correctly configured a reactive capable transaction manager");
} else {
Expand Down Expand Up @@ -133,7 +152,37 @@ public Object intercept(MethodInvocationContext<Object, Object> context) {
if (transactionInvocation.reactiveTransactionOperations != null) {
return interceptedMethod.handleResult(interceptedMethod.interceptResult());
} else {
throw new ConfigurationException("Async return type doesn't support transactional execution.");
if (isKotlinSuspended) {
final SynchronousTransactionManager<?> transactionManager = transactionInvocation.transactionManager;
final TransactionInfo transactionInfo = createTransactionIfNecessary(
transactionManager,
definition,
context.getExecutableMethod()
);
KotlinInterceptedMethod kotlinInterceptedMethod = (KotlinInterceptedMethod) interceptedMethod;
if (coroutineTxHelper != null) {
coroutineTxHelper.setupCoroutineContext(kotlinInterceptedMethod);
}
CompletionStage<?> result = interceptedMethod.interceptResultAsCompletionStage();
CompletableFuture newResult = new CompletableFuture();
result.whenComplete((o, throwable) -> {
if (throwable == null) {
commitTransactionAfterReturning(transactionInfo);
newResult.complete(o);
} else {
try {
completeTransactionAfterThrowing(transactionInfo, throwable);
} catch (Exception e) {
// Ignore rethrow
}
newResult.completeExceptionally(throwable);
}
cleanupTransactionInfo(transactionInfo);
});
return interceptedMethod.handleResult(newResult);
} else {
throw new ConfigurationException("Async return type doesn't support transactional execution.");
}
}
case SYNCHRONOUS:
final SynchronousTransactionManager<?> transactionManager = transactionInvocation.transactionManager;
Expand Down Expand Up @@ -171,11 +220,11 @@ private static TransactionInfo currentTransactionInfo() throws NoTransactionExce
* Return the transaction status of the current method invocation.
* Mainly intended for code that wants to set the current transaction
* rollback-only but not throw an application exception.
* @throws NoTransactionException if the transaction info cannot be found,
* because the method was invoked outside an AOP invocation context
*
* @return The current status
* @param <T> The connection type
* @return The current status
* @throws NoTransactionException if the transaction info cannot be found,
* because the method was invoked outside an AOP invocation context
*/
public static <T> TransactionStatus<T> currentTransactionStatus() throws NoTransactionException {
TransactionInfo info = currentTransactionInfo();
Expand All @@ -190,10 +239,11 @@ public static <T> TransactionStatus<T> currentTransactionStatus() throws NoTrans
* Create a transaction if necessary based on the given TransactionAttribute.
* <p>Allows callers to perform custom TransactionAttribute lookups through
* the TransactionAttributeSource.
* @param tm The transaction manager
* @param txAttr the TransactionAttribute (may be {@code null})
*
* @param tm The transaction manager
* @param txAttr the TransactionAttribute (may be {@code null})
* @param executableMethod the method that is being executed
* (used for monitoring and logging purposes)
* (used for monitoring and logging purposes)
* @return a TransactionInfo object, whether or not a transaction was created.
* The {@code hasTransaction()} method on TransactionInfo can be used to
* tell if there was a transaction created.
Expand All @@ -211,11 +261,12 @@ protected TransactionInfo createTransactionIfNecessary(@NonNull SynchronousTrans

/**
* Prepare a TransactionInfo for the given attribute and status object.
* @param tm The transaction manager
* @param txAttr the TransactionAttribute (may be {@code null})
*
* @param tm The transaction manager
* @param txAttr the TransactionAttribute (may be {@code null})
* @param executableMethod the fully qualified method name
* (used for monitoring and logging purposes)
* @param status the TransactionStatus for the current transaction
* (used for monitoring and logging purposes)
* @param status the TransactionStatus for the current transaction
* @return the prepared TransactionInfo object
*/
protected TransactionInfo prepareTransactionInfo(@NonNull SynchronousTransactionManager tm,
Expand All @@ -241,6 +292,7 @@ protected TransactionInfo prepareTransactionInfo(@NonNull SynchronousTransaction
/**
* Execute after successful completion of call, but not after an exception was handled.
* Do nothing if we didn't create a transaction.
*
* @param txInfo information about the current transaction
*/
protected void commitTransactionAfterReturning(@NonNull TransactionInfo txInfo) {
Expand All @@ -253,8 +305,9 @@ protected void commitTransactionAfterReturning(@NonNull TransactionInfo txInfo)
/**
* Handle a throwable, completing the transaction.
* We may commit or roll back, depending on the configuration.
*
* @param txInfo information about the current transaction
* @param ex throwable encountered
* @param ex throwable encountered
*/
protected void completeTransactionAfterThrowing(@NonNull TransactionInfo txInfo, Throwable ex) {
if (LOG.isTraceEnabled()) {
Expand Down Expand Up @@ -291,6 +344,7 @@ protected void completeTransactionAfterThrowing(@NonNull TransactionInfo txInfo,
/**
* Reset the TransactionInfo ThreadLocal.
* <p>Call this in all cases: exception or normal return!
*
* @param txInfo information about the current transaction (may be {@code null})
*/
protected void cleanupTransactionInfo(@Nullable TransactionInfo txInfo) {
Expand Down Expand Up @@ -331,8 +385,10 @@ protected TransactionAttribute resolveTransactionDefinition(
* @param <C> connection type
*/
private static final class TransactionInvocation<C> {
final @Nullable SynchronousTransactionManager<C> transactionManager;
final @Nullable ReactiveTransactionOperations<C> reactiveTransactionOperations;
final @Nullable
SynchronousTransactionManager<C> transactionManager;
final @Nullable
ReactiveTransactionOperations<C> reactiveTransactionOperations;
final TransactionAttribute definition;

TransactionInvocation(
Expand Down Expand Up @@ -365,13 +421,14 @@ protected static final class TransactionInfo<T> {

/**
* Constructs a new transaction info.
* @param transactionManager The transaction manager
*
* @param transactionManager The transaction manager
* @param transactionAttribute The transaction attribute
* @param executableMethod The joint point identification
* @param executableMethod The joint point identification
*/
protected TransactionInfo(@NonNull SynchronousTransactionManager<T> transactionManager,
@NonNull TransactionAttribute transactionAttribute,
@NonNull ExecutableMethod<Object, Object> executableMethod) {
@NonNull TransactionAttribute transactionAttribute,
@NonNull ExecutableMethod<Object, Object> executableMethod) {

this.transactionManager = transactionManager;
this.transactionAttribute = transactionAttribute;
Expand All @@ -397,6 +454,7 @@ public String getJoinpointIdentification() {

/**
* Create a new status.
*
* @param status The status.
*/
public void newTransactionStatus(@NonNull TransactionStatus<T> status) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.micronaut.transaction.support;

import io.micronaut.core.annotation.Internal;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.order.OrderUtil;
import io.micronaut.transaction.TransactionDefinition;
Expand Down Expand Up @@ -497,4 +498,70 @@ public static void clear() {
ACTUAL_TRANSACTION_ACTIVE.remove();
}

/**
* Copy existing state.
*
* @return The state
* @since 3.3
*/
@Internal
public static State copyState() {
return new CopyState(RESOURCES.get(),
SYNCHRONIZATIONS.get(),
CURRENT_TRANSACTION_NAME.get(),
CURRENT_TRANSACTION_READ_ONLY.get(),
CURRENT_TRANSACTION_ISOLATION_LEVEL.get(),
ACTUAL_TRANSACTION_ACTIVE.get());
}

/**
* Restore the state.
* @param state The state
* @since 3.3
*/
@Internal
public static void restoreState(State state) {
if (state instanceof CopyState) {
CopyState copyState = (CopyState) state;
RESOURCES.set(copyState.resources == null ? null : new HashMap<>(copyState.resources));
SYNCHRONIZATIONS.set(copyState.synchronizations == null ? null : new LinkedHashSet<>(copyState.synchronizations));
CURRENT_TRANSACTION_NAME.set(copyState.currentTransactionName);
CURRENT_TRANSACTION_READ_ONLY.set(copyState.currentTransactionReadOnlyStatus);
CURRENT_TRANSACTION_ISOLATION_LEVEL.set(copyState.currentTransactionIsolationLevel);
ACTUAL_TRANSACTION_ACTIVE.set(copyState.actualTransactionActive);
} else {
throw new IllegalStateException("Unknown state: " + state);
}
}

/**
* The synchronization state.
*/
@Internal
public interface State {
}

private static final class CopyState implements State {
private final Map<Object, Object> resources;
private final Set<TransactionSynchronization> synchronizations;
private final String currentTransactionName;
private final Boolean currentTransactionReadOnlyStatus;
private final TransactionDefinition.Isolation currentTransactionIsolationLevel;
private final Boolean actualTransactionActive;

private CopyState(Map<Object, Object> resources,
Set<TransactionSynchronization> synchronizations,
String currentTransactionName,
Boolean currentTransactionReadOnlyStatus,
TransactionDefinition.Isolation currentTransactionIsolationLevel,
Boolean actualTransactionActive) {
this.resources = resources == null ? null : new HashMap<>(resources);
this.synchronizations = synchronizations == null ? null : new LinkedHashSet<>(synchronizations);
this.currentTransactionName = currentTransactionName;
this.currentTransactionReadOnlyStatus = currentTransactionReadOnlyStatus;
this.currentTransactionIsolationLevel = currentTransactionIsolationLevel;
this.actualTransactionActive = actualTransactionActive;
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2017-2022 original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.micronaut.transaction.interceptor

import io.micronaut.core.annotation.Internal
import io.micronaut.transaction.support.TransactionSynchronizationManager
import kotlinx.coroutines.ThreadContextElement
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext

@Internal
class TxSynchronousContext(
private val state: TransactionSynchronizationManager.State
) : ThreadContextElement<TransactionSynchronizationManager.State>, AbstractCoroutineContextElement(Key) {

companion object Key : CoroutineContext.Key<TxSynchronousContext>

override fun restoreThreadContext(context: CoroutineContext, oldState: TransactionSynchronizationManager.State) {
TransactionSynchronizationManager.restoreState(oldState)
}

override fun updateThreadContext(context: CoroutineContext): TransactionSynchronizationManager.State {
val copyState = TransactionSynchronizationManager.copyState()
TransactionSynchronizationManager.restoreState(state)
return copyState
}

}
Loading