Skip to content

Commit

Permalink
gateway: fix Pulsar topics with non default tenant (LangStream#761)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed May 23, 2024
1 parent 9e621df commit ea87260
Show file tree
Hide file tree
Showing 23 changed files with 525 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@

import ai.langstream.api.model.Gateway;
import ai.langstream.api.model.StreamingCluster;
import ai.langstream.api.model.TopicDefinition;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicOffsetPosition;
import ai.langstream.api.runner.topics.TopicReadResult;
import ai.langstream.api.runner.topics.TopicReader;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.runtime.StreamingClusterRuntime;
import ai.langstream.api.runtime.Topic;
import ai.langstream.apigateway.api.ConsumePushMessage;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
Expand All @@ -42,44 +45,14 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ConsumeGateway implements AutoCloseable {

protected static final ObjectMapper mapper = new ObjectMapper();

@Getter
public static class ProduceException extends Exception {

private final ProduceResponse.Status status;

public ProduceException(String message, ProduceResponse.Status status) {
super(message);
this.status = status;
}
}

public static class ProduceGatewayRequestValidator
implements GatewayRequestHandler.GatewayRequestValidator {
@Override
public List<String> getAllRequiredParameters(Gateway gateway) {
return gateway.getParameters();
}

@Override
public void validateOptions(Map<String, String> options) {
for (Map.Entry<String, String> option : options.entrySet()) {
switch (option.getKey()) {
default -> throw new IllegalArgumentException(
"Unknown option " + option.getKey());
}
}
}
}

private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry;
private final ClusterRuntimeRegistry clusterRuntimeRegistry;

private volatile TopicConnectionsRuntime topicConnectionsRuntime;

Expand All @@ -90,8 +63,11 @@ public void validateOptions(Map<String, String> options) {
private AuthenticatedGatewayRequestContext requestContext;
private List<Function<Record, Boolean>> filters;

public ConsumeGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) {
public ConsumeGateway(
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
ClusterRuntimeRegistry clusterRuntimeRegistry) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.clusterRuntimeRegistry = clusterRuntimeRegistry;
}

public void setup(
Expand Down Expand Up @@ -126,9 +102,16 @@ public void setup(
default -> TopicOffsetPosition.absolute(
Base64.getDecoder().decode(positionParameter));
};
TopicDefinition topicDefinition = requestContext.application().resolveTopic(topic);
StreamingClusterRuntime streamingClusterRuntime =
clusterRuntimeRegistry.getStreamingClusterRuntime(streamingCluster);
Topic topicImplementation =
streamingClusterRuntime.createTopicImplementation(
topicDefinition, streamingCluster);
final String resolvedTopicName = topicImplementation.topicName();
reader =
topicConnectionsRuntime.createReader(
streamingCluster, Map.of("topic", topic), position);
streamingCluster, Map.of("topic", resolvedTopicName), position);
reader.start();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@

import ai.langstream.api.model.Gateway;
import ai.langstream.api.model.StreamingCluster;
import ai.langstream.api.model.TopicDefinition;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.code.SimpleRecord;
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.runtime.StreamingClusterRuntime;
import ai.langstream.api.runtime.Topic;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
Expand Down Expand Up @@ -82,15 +86,18 @@ public void validateOptions(Map<String, String> options) {
}

private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry;
private final ClusterRuntimeRegistry clusterRuntimeRegistry;
private final TopicProducerCache topicProducerCache;
private TopicProducer producer;
private List<Header> commonHeaders;
private String logRef;

public ProduceGateway(
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
ClusterRuntimeRegistry clusterRuntimeRegistry,
TopicProducerCache topicProducerCache) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.clusterRuntimeRegistry = clusterRuntimeRegistry;
this.topicProducerCache = topicProducerCache;
}

Expand All @@ -116,17 +123,27 @@ public void start(
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

TopicDefinition topicDefinition = requestContext.application().resolveTopic(topic);
StreamingClusterRuntime streamingClusterRuntime =
clusterRuntimeRegistry.getStreamingClusterRuntime(streamingCluster);
Topic topicImplementation =
streamingClusterRuntime.createTopicImplementation(
topicDefinition, streamingCluster);
final String resolvedTopicName = topicImplementation.topicName();

// we need to cache the producer per topic and per config, since an application update could
// change the configuration
final TopicProducerCache.Key key =
new TopicProducerCache.Key(
requestContext.tenant(),
requestContext.applicationId(),
requestContext.gateway().getId(),
topic,
resolvedTopicName,
configString);
producer =
topicProducerCache.getOrCreate(key, () -> setupProducer(topic, streamingCluster));
topicProducerCache.getOrCreate(
key, () -> setupProducer(resolvedTopicName, streamingCluster));
}

@AllArgsConstructor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.langstream.api.model.Gateway;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
Expand Down Expand Up @@ -78,6 +79,7 @@ public class GatewayResource {
protected static final ObjectMapper MAPPER = new ObjectMapper();
protected static final String SERVICE_REQUEST_ID_HEADER = "langstream-service-request-id";
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final ClusterRuntimeRegistry clusterRuntimeRegistry;
private final TopicProducerCache topicProducerCache;
private final ApplicationStore applicationStore;
private final GatewayRequestHandler gatewayRequestHandler;
Expand Down Expand Up @@ -121,6 +123,7 @@ ProduceResponse produce(
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
clusterRuntimeRegistry,
topicProducerCache)) {
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(
Expand Down Expand Up @@ -259,12 +262,14 @@ private CompletableFuture<ResponseEntity> handleServiceWithTopics(
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
clusterRuntimeRegistry,
topicProducerCache); ) {

final ConsumeGateway consumeGateway =
new ConsumeGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
.getTopicConnectionsRuntimeRegistry(),
clusterRuntimeRegistry);
completableFuture.thenRunAsync(
() -> {
if (consumeGateway != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.runner;

import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
@Slf4j
public class ClusterRuntimeRegistryBean {
@Bean
public ClusterRuntimeRegistry registry() {
return new ClusterRuntimeRegistry();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package ai.langstream.apigateway.websocket;

import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.TopicProducerCache;
Expand Down Expand Up @@ -49,6 +50,7 @@ public class WebSocketConfig implements WebSocketConfigurer {

private final ApplicationStore applicationStore;
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final ClusterRuntimeRegistry clusterRuntimeRegistry;
private final GatewayRequestHandler gatewayRequestHandler;
private final TopicProducerCache topicProducerCache;
private final ExecutorService consumeThreadPool =
Expand All @@ -64,19 +66,22 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
applicationStore,
consumeThreadPool,
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache),
CONSUME_PATH)
.addHandler(
new ProduceHandler(
applicationStore,
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache),
PRODUCE_PATH)
.addHandler(
new ChatHandler(
applicationStore,
consumeThreadPool,
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache),
CHAT_PATH)
.setAllowedOrigins("*")
Expand All @@ -93,5 +98,6 @@ public ServletServerContainerFactoryBean createWebSocketContainer() {
@PreDestroy
public void onDestroy() {
consumeThreadPool.shutdown();
clusterRuntimeRegistry.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
import ai.langstream.api.events.GatewayEventData;
import ai.langstream.api.model.Gateway;
import ai.langstream.api.model.StreamingCluster;
import ai.langstream.api.model.TopicDefinition;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.code.SimpleRecord;
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.runtime.StreamingClusterRuntime;
import ai.langstream.api.runtime.Topic;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.gateways.ConsumeGateway;
Expand All @@ -52,14 +56,17 @@ public abstract class AbstractHandler extends TextWebSocketHandler {
protected static final String ATTRIBUTE_PRODUCE_GATEWAY = "__produce_gateway";
protected static final String ATTRIBUTE_CONSUME_GATEWAY = "__consume_gateway";
protected final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry;
protected final ClusterRuntimeRegistry clusterRuntimeRegistry;
protected final ApplicationStore applicationStore;
private final TopicProducerCache topicProducerCache;

public AbstractHandler(
ApplicationStore applicationStore,
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
ClusterRuntimeRegistry clusterRuntimeRegistry,
TopicProducerCache topicProducerCache) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.clusterRuntimeRegistry = clusterRuntimeRegistry;
this.applicationStore = applicationStore;
this.topicProducerCache = topicProducerCache;
}
Expand Down Expand Up @@ -187,11 +194,20 @@ protected void sendEvent(EventRecord.Types type, AuthenticatedGatewayRequestCont

topicConnectionsRuntime.init(streamingCluster);

TopicDefinition topicDefinition =
context.application().resolveTopic(gateway.getEventsTopic());
StreamingClusterRuntime streamingClusterRuntime =
new ClusterRuntimeRegistry().getStreamingClusterRuntime(streamingCluster);
Topic topicImplementation =
streamingClusterRuntime.createTopicImplementation(
topicDefinition, streamingCluster);
final String resolvedTopicName = topicImplementation.topicName();

try (final TopicProducer producer =
topicConnectionsRuntime.createProducer(
"langstream-events",
streamingCluster,
Map.of("topic", gateway.getEventsTopic()))) {
Map.of("topic", resolvedTopicName))) {
producer.start();

final EventSources.GatewaySource source =
Expand Down Expand Up @@ -246,7 +262,8 @@ protected void setupReader(
List<Function<Record, Boolean>> filters,
AuthenticatedGatewayRequestContext context)
throws Exception {
final ConsumeGateway consumeGateway = new ConsumeGateway(topicConnectionsRuntimeRegistry);
final ConsumeGateway consumeGateway =
new ConsumeGateway(topicConnectionsRuntimeRegistry, clusterRuntimeRegistry);
try {
consumeGateway.setup(topic, filters, context);
} catch (Exception ex) {
Expand All @@ -261,7 +278,10 @@ protected void setupProducer(
String topic, List<Header> commonHeaders, AuthenticatedGatewayRequestContext context)
throws Exception {
final ProduceGateway produceGateway =
new ProduceGateway(topicConnectionsRuntimeRegistry, topicProducerCache);
new ProduceGateway(
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache);

try {
produceGateway.start(topic, commonHeaders, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.gateways.ConsumeGateway;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
Expand All @@ -47,8 +48,13 @@ public ChatHandler(
ApplicationStore applicationStore,
ExecutorService executor,
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
ClusterRuntimeRegistry clusterRuntimeRegistry,
TopicProducerCache topicProducerCache) {
super(applicationStore, topicConnectionsRuntimeRegistry, topicProducerCache);
super(
applicationStore,
topicConnectionsRuntimeRegistry,
clusterRuntimeRegistry,
topicProducerCache);
this.executor = executor;
}

Expand Down
Loading

0 comments on commit ea87260

Please sign in to comment.