Skip to content

Commit

Permalink
add new exception GatewayWebsocketSessionAbortedException which is us…
Browse files Browse the repository at this point in the history
…ed for terminating the websocket session;

extend Connect with an optional killSwitch;
shutdown websocket streams in service requests done phase of CoordinatedShutdown;

Signed-off-by: Stefan Maute <stefan.maute@bosch.io>
  • Loading branch information
Stefan Maute committed Sep 15, 2022
1 parent 7174e8c commit 46cb671
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright (c) 2022 Contributors to the Eclipse Foundation
*
* See the NOTICE file(s) distributed with this work for additional
* information regarding copyright ownership.
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0
*
* SPDX-License-Identifier: EPL-2.0
*/
package org.eclipse.ditto.gateway.api;

import java.net.URI;

import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;

import org.eclipse.ditto.base.model.common.HttpStatus;
import org.eclipse.ditto.base.model.exceptions.DittoRuntimeException;
import org.eclipse.ditto.base.model.exceptions.DittoRuntimeExceptionBuilder;
import org.eclipse.ditto.base.model.headers.DittoHeaders;
import org.eclipse.ditto.base.model.json.JsonParsableException;
import org.eclipse.ditto.json.JsonObject;

/**
* Error response for websocket sessions that are aborted due to service restart.
*
* @since 3.0.0
*/
@JsonParsableException(errorCode = GatewayWebsocketSessionAbortedException.ERROR_CODE)
public class GatewayWebsocketSessionAbortedException extends DittoRuntimeException implements GatewayException {

/**
* Error code of this exception.
*/
public static final String ERROR_CODE = ERROR_CODE_PREFIX + "streaming.session.aborted";

/**
* Error message of this exception.
*/
public static final String MESSAGE = "The web socket session is aborted due to a service restart.";

private static final HttpStatus STATUS_CODE = HttpStatus.INTERNAL_SERVER_ERROR;

private GatewayWebsocketSessionAbortedException(final DittoHeaders dittoHeaders,
@Nullable final String message,
@Nullable final String description,
@Nullable final Throwable cause,
@Nullable final URI href) {

super(ERROR_CODE, STATUS_CODE, dittoHeaders, message, description, cause, href);
}

/**
* Create a {@code GatewayWebsocketSessionAbortedException}.
*
* @param dittoHeaders the Ditto headers.
* @return the exception.
*/
public static GatewayWebsocketSessionAbortedException of(final DittoHeaders dittoHeaders) {
return new Builder()
.message(MESSAGE)
.description("Please try again later.")
.dittoHeaders(dittoHeaders)
.build();
}

/**
* Constructs a new {@code SubscriptionAbortedException} object with the exception message extracted from the
* given JSON object.
*
* @param jsonObject the JSON to read the {@link org.eclipse.ditto.base.model.exceptions.DittoRuntimeException.JsonFields#MESSAGE} field from.
* @param dittoHeaders the headers of the command which resulted in this exception.
* @return the new SubscriptionAbortedException.
* @throws NullPointerException if any argument is {@code null}.
* @throws org.eclipse.ditto.json.JsonMissingFieldException if this JsonObject did not contain an error message.
* @throws org.eclipse.ditto.json.JsonParseException if the passed in {@code jsonObject} was not in the expected
* format.
*/
public static GatewayWebsocketSessionAbortedException fromJson(final JsonObject jsonObject, final DittoHeaders dittoHeaders) {
return DittoRuntimeException.fromJson(jsonObject, dittoHeaders, new Builder());
}

@Override
public DittoRuntimeException setDittoHeaders(final DittoHeaders dittoHeaders) {
return new Builder()
.message(getMessage())
.description(getDescription().orElse(null))
.cause(getCause())
.href(getHref().orElse(null))
.dittoHeaders(dittoHeaders)
.build();
}

/**
* A mutable builder with a fluent API for a {@link GatewayWebsocketSessionAbortedException}.
*/
@NotThreadSafe
public static final class Builder extends DittoRuntimeExceptionBuilder<GatewayWebsocketSessionAbortedException> {

private Builder() {}

@Override
protected GatewayWebsocketSessionAbortedException doBuild(final DittoHeaders dittoHeaders,
@Nullable final String message,
@Nullable final String description,
@Nullable final Throwable cause,
@Nullable final URI href) {
return new GatewayWebsocketSessionAbortedException(dittoHeaders, message, description, cause, href);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ private Route createSseRoute(final RequestContext ctx, final CompletionStage<Dit
final var authorizationContext = dittoHeaders.getAuthorizationContext();
final var connect = new Connect(withQueue.getSourceQueue(), connectionCorrelationId,
STREAMING_TYPE_SSE, jsonSchemaVersion, null, Set.of(),
authorizationContext);
authorizationContext, null);
final var startStreaming =
StartStreaming.getBuilder(StreamingType.EVENTS, connectionCorrelationId,
authorizationContext)
Expand Down Expand Up @@ -420,7 +420,7 @@ private Route createMessagesSseRoute(final RequestContext ctx,
final var connect =
new Connect(withQueue.getSourceQueue(), connectionCorrelationId,
STREAMING_TYPE_SSE, jsonSchemaVersion, null, Set.of(),
authorizationContext);
authorizationContext, null);
final String resourcePathRqlStatement;
if (INBOX_OUTBOX_WITH_SUBJECT_PATTERN.matcher(messagePath).matches()) {
resourcePathRqlStatement = String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.eclipse.ditto.base.model.signals.Signal;
import org.eclipse.ditto.base.model.signals.acks.Acknowledgement;
import org.eclipse.ditto.gateway.api.GatewayInternalErrorException;
import org.eclipse.ditto.gateway.api.GatewayWebsocketSessionAbortedException;
import org.eclipse.ditto.gateway.api.GatewayWebsocketSessionClosedException;
import org.eclipse.ditto.gateway.api.GatewayWebsocketSessionExpiredException;
import org.eclipse.ditto.gateway.service.endpoints.routes.AbstractRoute;
Expand Down Expand Up @@ -116,7 +117,9 @@
import akka.stream.FanOutShape2;
import akka.stream.FlowShape;
import akka.stream.Graph;
import akka.stream.KillSwitches;
import akka.stream.Materializer;
import akka.stream.SharedKillSwitch;
import akka.stream.SinkShape;
import akka.stream.UniformFanInShape;
import akka.stream.javadsl.Flow;
Expand Down Expand Up @@ -166,6 +169,8 @@ public final class WebSocketRoute implements WebSocketRouteBuilder {
.tag(DIRECTION, "dropped");
private static final String MDC_CONNECTION_CORRELATION_ID = "connection-correlation-id";

private final SharedKillSwitch wsKillSwitch = KillSwitches.shared(WebSocketRoute.class.getSimpleName());

private final ActorRef streamingActor;
private final StreamingConfig streamingConfig;
private final Materializer materializer;
Expand All @@ -178,6 +183,7 @@ public final class WebSocketRoute implements WebSocketRouteBuilder {
private HeaderTranslator headerTranslator;
private WebSocketConfigProvider webSocketConfigProvider;


private WebSocketRoute(final ActorSystem actorSystem,
final ActorRef streamingActor,
final StreamingConfig streamingConfig,
Expand Down Expand Up @@ -308,15 +314,14 @@ private CompletionStage<HttpResponse> createWebSocket(final WebSocketUpgrade upg
.thenApply(websocketConfig -> {
final Pair<Connect, Flow<DittoRuntimeException, Message, NotUsed>> outgoing =
createOutgoing(version, connectionCorrelationId, authContext, dittoHeaders, adapter,
request,
websocketConfig, signalEnrichmentFacade, logger);
request, websocketConfig, signalEnrichmentFacade, logger);

final Flow<Message, DittoRuntimeException, NotUsed> incoming =
createIncoming(version, connectionCorrelationId, authContext, dittoHeaders, adapter,
request,
websocketConfig, outgoing.first(), logger);
request, websocketConfig, outgoing.first(), logger);

return upgradeToWebSocket.handleMessagesWith(incoming.via(outgoing.second()));
return upgradeToWebSocket.handleMessagesWith(
incoming.via(wsKillSwitch.flow()).via(outgoing.second()));
}));
}

Expand Down Expand Up @@ -408,8 +413,8 @@ private Flow<Message, DittoRuntimeException, NotUsed> createIncoming(final JsonS
Patterns.ask(streamingActor, connect, LOCAL_ASK_TIMEOUT)
.thenApply(result -> Source.repeat((ActorRef) result))
);

final var noOpStreamControlMessage = NoOp.getInstance();

return setAckRequestThenMergeLeftAndRight.zipWith(sessionActorSource, Pair::create)
.to(Sink.foreach(pair -> {
final var actorRef = pair.second();
Expand All @@ -420,7 +425,6 @@ private Flow<Message, DittoRuntimeException, NotUsed> createIncoming(final JsonS
}));
}


private Flow<Message, String, NotUsed> getStrictifyFlow(final HttpRequest request, final Logger logger) {
return Flow.<Message>create()
.via(Flow.fromFunction(msg -> {
Expand All @@ -444,7 +448,6 @@ private Flow<Message, String, NotUsed> getStrictifyFlow(final HttpRequest reques
}))
.withAttributes(Attributes.createLogLevels(Logging.DebugLevel(), Logging.DebugLevel(),
Logging.WarningLevel()));

}

private Graph<FanOutShape2<String, Either<StreamControlMessage, Signal<?>>, DittoRuntimeException>, NotUsed>
Expand Down Expand Up @@ -561,11 +564,17 @@ private Pair<Connect, Flow<DittoRuntimeException, Message, NotUsed>> createOutgo
withQueue -> {
webSocketSupervisor.supervise(withQueue.getSupervisedStream(), connectionCorrelationId,
additionalHeaders);
return new Connect(withQueue.getSourceQueue(), connectionCorrelationId, STREAMING_TYPE_WS, version,
optJsonWebToken.map(JsonWebToken::getExpirationTime).orElse(null),
readDeclaredAcknowledgementLabels(additionalHeaders), connectionAuthContext);
return new Connect(withQueue.getSourceQueue(), connectionCorrelationId, STREAMING_TYPE_WS,
version, optJsonWebToken.map(JsonWebToken::getExpirationTime).orElse(null),
readDeclaredAcknowledgementLabels(additionalHeaders), connectionAuthContext,
wsKillSwitch);
})
.recoverWithRetries(1, new PFBuilder<Throwable, Source<SessionedJsonifiable, NotUsed>>()
.match(GatewayWebsocketSessionAbortedException.class,
ex -> {
logger.info("WebSocket connection aborted because of service restart!");
return Source.empty();
})
.match(GatewayWebsocketSessionExpiredException.class,
ex -> {
logger.info("WebSocket connection terminated because JWT expired!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.eclipse.ditto.edge.service.acknowledgements.things.ThingModifyCommandAckRequestSetter;
import org.eclipse.ditto.edge.service.placeholders.EntityIdPlaceholder;
import org.eclipse.ditto.gateway.api.GatewayInternalErrorException;
import org.eclipse.ditto.gateway.api.GatewayWebsocketSessionAbortedException;
import org.eclipse.ditto.gateway.api.GatewayWebsocketSessionClosedException;
import org.eclipse.ditto.gateway.api.GatewayWebsocketSessionExpiredException;
import org.eclipse.ditto.gateway.service.security.authentication.jwt.JwtAuthenticationResultProvider;
Expand Down Expand Up @@ -83,10 +84,14 @@
import akka.Done;
import akka.actor.AbstractActorWithTimers;
import akka.actor.ActorRef;
import akka.actor.Cancellable;
import akka.actor.CoordinatedShutdown;
import akka.actor.Props;
import akka.actor.Terminated;
import akka.japi.pf.PFBuilder;
import akka.japi.pf.ReceiveBuilder;
import akka.pattern.Patterns;
import akka.stream.KillSwitch;
import akka.stream.javadsl.SourceQueueWithComplete;
import scala.PartialFunction;

Expand All @@ -95,6 +100,12 @@
*/
final class StreamingSessionActor extends AbstractActorWithTimers {

/**
* Ask-timeout in shutdown tasks. Its duration should be long enough for session requests to succeed but
* ultimately does not matter because each shutdown phase has its own timeout.
*/
private static final Duration SHUTDOWN_ASK_TIMEOUT = Duration.ofMinutes(2L);

/**
* Maximum lifetime of an expiring session.
* If a session is established with JWT lasting more than this duration, the session will persist forever.
Expand All @@ -116,9 +127,11 @@ final class StreamingSessionActor extends AbstractActorWithTimers {
private final AcknowledgementAggregatorActorStarter ackregatorStarter;
private final Set<AcknowledgementLabel> declaredAcks;
private final ThreadSafeDittoLoggingAdapter logger;

private AuthorizationContext authorizationContext;

private Cancellable cancellableShutdownTask;
@Nullable private final KillSwitch killSwitch;

@SuppressWarnings("unused")
private StreamingSessionActor(final Connect connect,
final DittoProtocolSub dittoProtocolSub,
Expand All @@ -140,6 +153,7 @@ private StreamingSessionActor(final Connect connect,
this.jwtAuthenticationResultProvider = jwtAuthenticationResultProvider;
outstandingSubscriptionAcks = EnumSet.noneOf(StreamingType.class);
authorizationContext = connect.getConnectionAuthContext();
killSwitch = connect.getKillSwitch().orElse(null);
streamingSessions = new EnumMap<>(StreamingType.class);
ackregatorStarter = AcknowledgementAggregatorActorStarter.of(getContext(),
streamingConfig.getAcknowledgementConfig(),
Expand Down Expand Up @@ -197,14 +211,24 @@ static Props props(final Connect connect,

@Override
public void preStart() {
declareAcknowledgementLabels(declaredAcks);

final var coordinatedShutdown = CoordinatedShutdown.get(getContext().getSystem());
final var serviceRequestsDoneTask = "service-requests-done-streaming-session-actor" ;
cancellableShutdownTask = coordinatedShutdown.addCancellableTask(CoordinatedShutdown.PhaseServiceRequestsDone(),
serviceRequestsDoneTask,
() -> Patterns.ask(getSelf(), Control.SERVICE_REQUESTS_DONE, SHUTDOWN_ASK_TIMEOUT)
.thenApply(reply -> Done.done())
);

eventAndResponsePublisher.watchCompletion()
.whenComplete((done, error) -> getSelf().tell(Control.TERMINATED, getSelf()));
declareAcknowledgementLabels(declaredAcks);
}

@Override
public void postStop() {
logger.info("Closing <{}> streaming session.", type);
cancellableShutdownTask.cancel();
cancelSessionTimeout();
eventAndResponsePublisher.complete();
}
Expand Down Expand Up @@ -383,6 +407,7 @@ private Receive createSelfTerminationBehavior() {
.match(Terminated.class, this::handleTerminated)
.matchEquals(Control.TERMINATED, this::handleTerminated)
.matchEquals(Control.SESSION_TERMINATION, this::handleSessionTermination)
.matchEquals(Control.SERVICE_REQUESTS_DONE, this::serviceRequestsDone)
.build();
}

Expand Down Expand Up @@ -745,6 +770,14 @@ private static Optional<DittoHeaderInvalidException> checkForAcksWithoutResponse
}
}

private void serviceRequestsDone(final Control serviceRequestsDone) {
logger.info("{}: abort ongoing websocket session", serviceRequestsDone);
if (killSwitch != null) {
killSwitch.abort(GatewayWebsocketSessionAbortedException.of(DittoHeaders.empty()));
}
getSender().tell(Done.getInstance(), ActorRef.noSender());
}

/**
* Messages to self to perform an outstanding acknowledgement if not already acknowledged.
*/
Expand Down Expand Up @@ -778,10 +811,11 @@ private ConfirmUnsubscription(final StreamingType streamingType) {

}

private enum Control {
public enum Control {
TERMINATED,
SESSION_TERMINATION,
RESUBSCRIBE
RESUBSCRIBE,
SERVICE_REQUESTS_DONE
}

}
Loading

0 comments on commit 46cb671

Please sign in to comment.