diff --git a/x-pack/packages/ml/response_stream/server/index.ts b/x-pack/packages/ml/response_stream/server/index.ts index 2beb7223026174..ae337aa99898f2 100644 --- a/x-pack/packages/ml/response_stream/server/index.ts +++ b/x-pack/packages/ml/response_stream/server/index.ts @@ -8,5 +8,6 @@ export { streamFactory, type StreamFactoryReturnType, + type StreamResponseWithHeaders, type UncompressedResponseStream, } from './stream_factory'; diff --git a/x-pack/packages/ml/response_stream/server/stream_factory.ts b/x-pack/packages/ml/response_stream/server/stream_factory.ts index 79895525ba40ce..779e3457c7dac3 100644 --- a/x-pack/packages/ml/response_stream/server/stream_factory.ts +++ b/x-pack/packages/ml/response_stream/server/stream_factory.ts @@ -33,15 +33,16 @@ type StreamType = T extends string ? T : never; -// Fallback to never is there for backwards compatibility. -export interface StreamFactoryReturnType { +export interface StreamResponseWithHeaders { + body: zlib.Gzip | UncompressedResponseStream; + headers?: ResponseHeaders; +} + +export interface StreamFactoryReturnType { DELIMITER: string; end: () => void; push: (d: StreamType, drain?: boolean) => void; - responseWithHeaders: { - body: zlib.Gzip | UncompressedResponseStream; - headers?: ResponseHeaders; - }; + responseWithHeaders: StreamResponseWithHeaders; } /** @@ -133,7 +134,7 @@ export function streamFactory( function repeat() { if (!tryToEnd) { if (responseSizeSinceLastKeepAlive < FLUSH_PAYLOAD_SIZE) { - push({ flushPayload, type: 'flushPayload' } as unknown as StreamType); + push({ flushPayload, type: 'flushPayload' } as StreamType); } responseSizeSinceLastKeepAlive = 0; setTimeout(repeat, FLUSH_KEEP_ALIVE_INTERVAL_MS); @@ -211,7 +212,7 @@ export function streamFactory( } } - const responseWithHeaders: StreamFactoryReturnType['responseWithHeaders'] = { + const responseWithHeaders: StreamResponseWithHeaders = { body: stream, headers: { ...(isCompressed ? { 'content-encoding': 'gzip' } : {}), diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 0e7925df202811..ce2833fa144804 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -13,7 +13,7 @@ import { KibanaRequest, ResponseHeaders } from '@kbn/core-http-server'; import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import type { AnalyticsServiceSetup } from '@kbn/core-analytics-server'; import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common'; -import { StreamFactoryReturnType } from '@kbn/ml-response-stream/server'; +import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen'; import { ResponseBody } from '../types'; import type { AssistantTool } from '../../../types'; @@ -51,7 +51,7 @@ export interface StaticReturnType { headers: ResponseHeaders; } export type AgentExecutorResponse = T extends true - ? StreamFactoryReturnType['responseWithHeaders'] + ? StreamResponseWithHeaders : StaticReturnType; export type AgentExecutor = ( diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index e579996eb13d19..e76b7cd3377835 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -8,7 +8,7 @@ import { IRouter, Logger } from '@kbn/core/server'; import { transformError } from '@kbn/securitysolution-es-utils'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { StreamFactoryReturnType } from '@kbn/ml-response-stream/server'; +import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { schema } from '@kbn/config-schema'; import { @@ -323,42 +323,41 @@ export const postActionsConnectorExecuteRoute = ( page: 1, }); - const result: StreamFactoryReturnType['responseWithHeaders'] | StaticReturnType = - await callAgentExecutor({ - abortSignal, - alertsIndexPattern: request.body.alertsIndexPattern, - anonymizationFields: anonymizationFieldsRes - ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) - : undefined, - actions, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, - assistantTools, - connectorId, - elserId, - esClient, - isStream: - // TODO implement llmClass for bedrock streaming - // tracked here: https://github.com/elastic/security-team/issues/7363 - request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', - llmType: getLlmType(actionTypeId), - kbResource: ESQL_RESOURCE, - langChainMessages, - logger, - onNewReplacements, - onLlmResponse, - request, - replacements: request.body.replacements, - size: request.body.size, - telemetry, - traceOptions: { + const result: StreamResponseWithHeaders | StaticReturnType = await callAgentExecutor({ + abortSignal, + alertsIndexPattern: request.body.alertsIndexPattern, + anonymizationFields: anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined, + actions, + isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, + assistantTools, + connectorId, + elserId, + esClient, + isStream: + // TODO implement llmClass for bedrock streaming + // tracked here: https://github.com/elastic/security-team/issues/7363 + request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', + llmType: getLlmType(actionTypeId), + kbResource: ESQL_RESOURCE, + langChainMessages, + logger, + onNewReplacements, + onLlmResponse, + request, + replacements: request.body.replacements, + size: request.body.size, + telemetry, + traceOptions: { + projectName: langSmithProject, + tracers: getLangSmithTracer({ + apiKey: langSmithApiKey, projectName: langSmithProject, - tracers: getLangSmithTracer({ - apiKey: langSmithApiKey, - projectName: langSmithProject, - logger, - }), - }, - }); + logger, + }), + }, + }); telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { actionTypeId, @@ -371,9 +370,7 @@ export const postActionsConnectorExecuteRoute = ( request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', }); - return response.ok< - StreamFactoryReturnType['responseWithHeaders']['body'] | StaticReturnType['body'] - >(result); + return response.ok(result); } catch (err) { logger.error(err); const error = transformError(err);