diff --git a/x-pack/packages/ml/aiops_components/src/progress_controls/progress_controls.tsx b/x-pack/packages/ml/aiops_components/src/progress_controls/progress_controls.tsx index 8fff372570505c..03ebe7b57ed031 100644 --- a/x-pack/packages/ml/aiops_components/src/progress_controls/progress_controls.tsx +++ b/x-pack/packages/ml/aiops_components/src/progress_controls/progress_controls.tsx @@ -63,6 +63,8 @@ export const ProgressControls: FC> = (pr runAnalysisDisabled = false, } = props; + const progressOutput = Math.round(progress * 100); + const { euiTheme } = useEuiTheme(); const runningProgressBarStyles = useAnimatedProgressBarBackground(euiTheme.colors.success); const analysisCompleteStyle = { display: 'none' }; @@ -147,7 +149,7 @@ export const ProgressControls: FC> = (pr data-test-subj="aiopsProgressTitleMessage" id="xpack.aiops.progressTitle" defaultMessage="Progress: {progress}% — {progressMessage}" - values={{ progress: Math.round(progress * 100), progressMessage }} + values={{ progress: progressOutput, progressMessage }} /> @@ -156,7 +158,7 @@ export const ProgressControls: FC> = (pr aria-label={i18n.translate('xpack.aiops.progressAriaLabel', { defaultMessage: 'Progress', })} - value={Math.round(progress * 100)} + value={progressOutput} max={100} size="m" /> diff --git a/x-pack/packages/ml/aiops_log_rate_analysis/api/stream_reducer.ts b/x-pack/packages/ml/aiops_log_rate_analysis/api/stream_reducer.ts index ca6148c133ccad..c4bde0b90c0fd9 100644 --- a/x-pack/packages/ml/aiops_log_rate_analysis/api/stream_reducer.ts +++ b/x-pack/packages/ml/aiops_log_rate_analysis/api/stream_reducer.ts @@ -34,12 +34,8 @@ export const initialState: StreamState = { export function streamReducer( state: StreamState, - action: AiopsLogRateAnalysisApiAction | AiopsLogRateAnalysisApiAction[] + action: AiopsLogRateAnalysisApiAction ): StreamState { - if (Array.isArray(action)) { - return action.reduce(streamReducer, state); - } - switch (action.type) { case API_ACTION_NAME.ADD_SIGNIFICANT_ITEMS: return { ...state, significantItems: [...state.significantItems, ...action.payload] }; diff --git a/x-pack/packages/ml/response_stream/client/fetch_stream.ts b/x-pack/packages/ml/response_stream/client/fetch_stream.ts index 7c4ad7789a3b67..6c45d8c99ce732 100644 --- a/x-pack/packages/ml/response_stream/client/fetch_stream.ts +++ b/x-pack/packages/ml/response_stream/client/fetch_stream.ts @@ -44,7 +44,7 @@ export async function* fetchStream body?: B, ndjson = true, headers?: HttpFetchOptions['headers'] -): AsyncGenerator<[GeneratorError, ReducerAction | Array> | undefined]> { +): AsyncGenerator<[GeneratorError, ReducerAction | undefined]> { let stream: Readonly | undefined; try { @@ -112,7 +112,9 @@ export async function* fetchStream : parts ) as Array>; - yield [null, actions]; + for (const action of actions) { + yield [null, action]; + } } catch (error) { if (error.name !== 'AbortError') { yield [error.toString(), undefined]; diff --git a/x-pack/packages/ml/response_stream/client/index.ts b/x-pack/packages/ml/response_stream/client/index.ts index 750442161a5699..a8b02cecd9cf6b 100644 --- a/x-pack/packages/ml/response_stream/client/index.ts +++ b/x-pack/packages/ml/response_stream/client/index.ts @@ -5,4 +5,5 @@ * 2.0. */ +export { fetchStream } from './fetch_stream'; export { useFetchStream } from './use_fetch_stream'; diff --git a/x-pack/packages/ml/response_stream/client/string_reducer.ts b/x-pack/packages/ml/response_stream/client/string_reducer.ts index f77b31e1fed1e0..d14990947fd059 100644 --- a/x-pack/packages/ml/response_stream/client/string_reducer.ts +++ b/x-pack/packages/ml/response_stream/client/string_reducer.ts @@ -7,20 +7,14 @@ import type { Reducer, ReducerAction, ReducerState } from 'react'; -type StringReducerPayload = string | string[] | undefined; +type StringReducerPayload = string | undefined; export type StringReducer = Reducer; /** * The `stringReducer` is provided to handle plain string based streams with `streamFactory()`. * * @param state - The current state, being the string fetched so far. - * @param payload — The state update can be a plain string, an array of strings or `undefined`. - * * An array of strings will be joined without a delimiter and added to the current string. - * In combination with `useFetchStream`'s buffering this allows to do bulk updates - * within the reducer without triggering a React/DOM update on every stream chunk. - * * `undefined` can be used to reset the state to an empty string, for example, when a - * UI has the option to trigger a refetch of a stream. - * + * @param payload — The state update can be a plain string to be added or `undefined` to reset the state. * @returns The updated state, a string that combines the previous string and the payload. */ export function stringReducer( @@ -31,5 +25,5 @@ export function stringReducer( return ''; } - return `${state}${Array.isArray(payload) ? payload.join('') : payload}`; + return `${state}${payload}`; } diff --git a/x-pack/packages/ml/response_stream/client/use_fetch_stream.ts b/x-pack/packages/ml/response_stream/client/use_fetch_stream.ts index 309d53e8dd4bda..55950c6ee8f774 100644 --- a/x-pack/packages/ml/response_stream/client/use_fetch_stream.ts +++ b/x-pack/packages/ml/response_stream/client/use_fetch_stream.ts @@ -7,14 +7,12 @@ import { useEffect, - useReducer, useRef, useState, type Reducer, - type ReducerAction, type ReducerState, + type ReducerAction, } from 'react'; -import useThrottle from 'react-use/lib/useThrottle'; import type { HttpSetup, HttpFetchOptions } from '@kbn/core/public'; import { isPopulatedObject } from '@kbn/ml-is-populated-object'; @@ -22,6 +20,8 @@ import { isPopulatedObject } from '@kbn/ml-is-populated-object'; import { fetchStream } from './fetch_stream'; import { stringReducer, type StringReducer } from './string_reducer'; +const DATA_THROTTLE_MS = 100; + // This pattern with a dual ternary allows us to default to StringReducer // and if a custom reducer is supplied fall back to that one instead. // The complexity in here allows us to create a simpler API surface where @@ -57,6 +57,7 @@ function isReducerOptions(arg: unknown): arg is CustomReducer { * @param apiVersion Optional API version. * @param body Optional API request body. * @param customReducer Optional custom reducer and initial state. + * @param headers Optional headers. * @returns An object with streaming data and methods to act on the stream. */ export function useFetchStream>( @@ -75,11 +76,41 @@ export function useFetchStream>( ? customReducer : ({ reducer: stringReducer, initialState: '' } as FetchStreamCustomReducer); - const [data, dispatch] = useReducer( - reducerWithFallback.reducer, - reducerWithFallback.initialState - ); - const dataThrottled = useThrottle(data, 100); + // We used `useReducer` in previous iterations of this hook, but it caused + // a lot of unnecessary re-renders even in combination with `useThrottle`. + // We're now using `dataRef` to allow updates outside of the render cycle. + // When the stream is running, we'll update `data` with the `dataRef` value + // periodically. + const [data, setData] = useState(reducerWithFallback.initialState); + const dataRef = useRef(reducerWithFallback.initialState); + + // This effect is used to throttle the data updates while the stream is running. + // It will update the `data` state with the current `dataRef` value every 100ms. + useEffect(() => { + // We cannot check against `isRunning` in the `setTimeout` callback, because + // we would check against a stale value. Instead, we use a mutable + // object to keep track of the current state of the effect. + const effectState = { isActive: true }; + + if (isRunning) { + setData(dataRef.current); + + function updateData() { + setTimeout(() => { + setData(dataRef.current); + if (effectState.isActive) { + updateData(); + } + }, DATA_THROTTLE_MS); + } + + updateData(); + } + + return () => { + effectState.isActive = false; + }; + }, [isRunning]); const abortCtrl = useRef(new AbortController()); @@ -99,7 +130,7 @@ export function useFetchStream>( abortCtrl.current = new AbortController(); - for await (const [fetchStreamError, actions] of fetchStream>( + for await (const [fetchStreamError, action] of fetchStream>( http, endpoint, apiVersion, @@ -110,14 +141,26 @@ export function useFetchStream>( )) { if (fetchStreamError !== null) { addError(fetchStreamError); - } else if (Array.isArray(actions) && actions.length > 0) { - dispatch(actions as ReducerAction>); + } else if (action) { + dataRef.current = reducerWithFallback.reducer(dataRef.current, action) as ReducerState< + CustomReducer + >; } } setIsRunning(false); }; + // This custom dispatch function allows us to update the `dataRef` value and will + // then trigger an update of `data` right away as we don't want to have the + // throttling in place for these types of updates. + const dispatch = (action: ReducerAction['reducer']>) => { + dataRef.current = reducerWithFallback.reducer(dataRef.current, action) as ReducerState< + CustomReducer + >; + setData(dataRef.current); + }; + const cancel = () => { abortCtrl.current.abort(); setIsCancelled(true); @@ -131,10 +174,10 @@ export function useFetchStream>( return { cancel, - // To avoid a race condition where the stream already ended but `useThrottle` would - // yet have to trigger another update within the throttling interval, we'll return + // To avoid a race condition where the stream already ended but the throttling would + // yet have to trigger another update within the interval, we'll return // the unthrottled data once the stream is complete. - data: isRunning ? dataThrottled : data, + data: isRunning ? data : dataRef.current, dispatch, errors, isCancelled, diff --git a/x-pack/packages/ml/response_stream/server/index.ts b/x-pack/packages/ml/response_stream/server/index.ts index 2beb7223026174..ae337aa99898f2 100644 --- a/x-pack/packages/ml/response_stream/server/index.ts +++ b/x-pack/packages/ml/response_stream/server/index.ts @@ -8,5 +8,6 @@ export { streamFactory, type StreamFactoryReturnType, + type StreamResponseWithHeaders, type UncompressedResponseStream, } from './stream_factory'; diff --git a/x-pack/packages/ml/response_stream/server/stream_factory.ts b/x-pack/packages/ml/response_stream/server/stream_factory.ts index 29c570c2fb564d..779e3457c7dac3 100644 --- a/x-pack/packages/ml/response_stream/server/stream_factory.ts +++ b/x-pack/packages/ml/response_stream/server/stream_factory.ts @@ -26,35 +26,25 @@ export class UncompressedResponseStream extends Stream.PassThrough {} const DELIMITER = '\n'; -type StreamType = 'string' | 'ndjson'; +type StreamTypeUnion = string | object; +type StreamType = T extends string + ? string + : T extends object + ? T + : never; + +export interface StreamResponseWithHeaders { + body: zlib.Gzip | UncompressedResponseStream; + headers?: ResponseHeaders; +} -export interface StreamFactoryReturnType { +export interface StreamFactoryReturnType { DELIMITER: string; end: () => void; - push: (d: T, drain?: boolean) => void; - responseWithHeaders: { - body: zlib.Gzip | UncompressedResponseStream; - headers?: ResponseHeaders; - }; + push: (d: StreamType, drain?: boolean) => void; + responseWithHeaders: StreamResponseWithHeaders; } -/** - * Overload to set up a string based response stream with support - * for gzip compression depending on provided request headers. - * - * @param headers - Request headers. - * @param logger - Kibana logger. - * @param compressOverride - Optional flag to override header based compression setting. - * @param flushFix - Adds an attribute with a random string payload to overcome buffer flushing with certain proxy configurations. - * - * @returns An object with stream attributes and methods. - */ -export function streamFactory( - headers: Headers, - logger: Logger, - compressOverride?: boolean, - flushFix?: boolean -): StreamFactoryReturnType; /** * Sets up a response stream with support for gzip compression depending on provided * request headers. Any non-string data pushed to the stream will be streamed as NDJSON. @@ -66,13 +56,13 @@ export function streamFactory( * * @returns An object with stream attributes and methods. */ -export function streamFactory( +export function streamFactory( headers: Headers, logger: Logger, compressOverride: boolean = true, flushFix: boolean = false ): StreamFactoryReturnType { - let streamType: StreamType; + let streamType: 'string' | 'ndjson'; const isCompressed = compressOverride && acceptCompression(headers); const flushPayload = flushFix ? crypto.randomBytes(FLUSH_PAYLOAD_SIZE).toString('hex') @@ -82,7 +72,7 @@ export function streamFactory( const stream = isCompressed ? zlib.createGzip() : new UncompressedResponseStream(); // If waiting for draining of the stream, items will be added to this buffer. - const backPressureBuffer: T[] = []; + const backPressureBuffer: Array> = []; // Flag will be set when the "drain" listener is active so we can avoid setting multiple listeners. let waitForDrain = false; @@ -120,7 +110,7 @@ export function streamFactory( } } - function push(d: T, drain = false) { + function push(d: StreamType, drain = false) { logDebugMessage( `Push to stream. Current backPressure buffer size: ${backPressureBuffer.length}, drain flag: ${drain}` ); @@ -144,7 +134,7 @@ export function streamFactory( function repeat() { if (!tryToEnd) { if (responseSizeSinceLastKeepAlive < FLUSH_PAYLOAD_SIZE) { - push({ flushPayload } as unknown as T); + push({ flushPayload, type: 'flushPayload' } as StreamType); } responseSizeSinceLastKeepAlive = 0; setTimeout(repeat, FLUSH_KEEP_ALIVE_INTERVAL_MS); @@ -222,7 +212,7 @@ export function streamFactory( } } - const responseWithHeaders: StreamFactoryReturnType['responseWithHeaders'] = { + const responseWithHeaders: StreamResponseWithHeaders = { body: stream, headers: { ...(isCompressed ? { 'content-encoding': 'gzip' } : {}), diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 0e7925df202811..ce2833fa144804 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -13,7 +13,7 @@ import { KibanaRequest, ResponseHeaders } from '@kbn/core-http-server'; import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import type { AnalyticsServiceSetup } from '@kbn/core-analytics-server'; import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common'; -import { StreamFactoryReturnType } from '@kbn/ml-response-stream/server'; +import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen'; import { ResponseBody } from '../types'; import type { AssistantTool } from '../../../types'; @@ -51,7 +51,7 @@ export interface StaticReturnType { headers: ResponseHeaders; } export type AgentExecutorResponse = T extends true - ? StreamFactoryReturnType['responseWithHeaders'] + ? StreamResponseWithHeaders : StaticReturnType; export type AgentExecutor = ( diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index e579996eb13d19..e76b7cd3377835 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -8,7 +8,7 @@ import { IRouter, Logger } from '@kbn/core/server'; import { transformError } from '@kbn/securitysolution-es-utils'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { StreamFactoryReturnType } from '@kbn/ml-response-stream/server'; +import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { schema } from '@kbn/config-schema'; import { @@ -323,42 +323,41 @@ export const postActionsConnectorExecuteRoute = ( page: 1, }); - const result: StreamFactoryReturnType['responseWithHeaders'] | StaticReturnType = - await callAgentExecutor({ - abortSignal, - alertsIndexPattern: request.body.alertsIndexPattern, - anonymizationFields: anonymizationFieldsRes - ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) - : undefined, - actions, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, - assistantTools, - connectorId, - elserId, - esClient, - isStream: - // TODO implement llmClass for bedrock streaming - // tracked here: https://github.com/elastic/security-team/issues/7363 - request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', - llmType: getLlmType(actionTypeId), - kbResource: ESQL_RESOURCE, - langChainMessages, - logger, - onNewReplacements, - onLlmResponse, - request, - replacements: request.body.replacements, - size: request.body.size, - telemetry, - traceOptions: { + const result: StreamResponseWithHeaders | StaticReturnType = await callAgentExecutor({ + abortSignal, + alertsIndexPattern: request.body.alertsIndexPattern, + anonymizationFields: anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined, + actions, + isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, + assistantTools, + connectorId, + elserId, + esClient, + isStream: + // TODO implement llmClass for bedrock streaming + // tracked here: https://github.com/elastic/security-team/issues/7363 + request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', + llmType: getLlmType(actionTypeId), + kbResource: ESQL_RESOURCE, + langChainMessages, + logger, + onNewReplacements, + onLlmResponse, + request, + replacements: request.body.replacements, + size: request.body.size, + telemetry, + traceOptions: { + projectName: langSmithProject, + tracers: getLangSmithTracer({ + apiKey: langSmithApiKey, projectName: langSmithProject, - tracers: getLangSmithTracer({ - apiKey: langSmithApiKey, - projectName: langSmithProject, - logger, - }), - }, - }); + logger, + }), + }, + }); telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { actionTypeId, @@ -371,9 +370,7 @@ export const postActionsConnectorExecuteRoute = ( request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', }); - return response.ok< - StreamFactoryReturnType['responseWithHeaders']['body'] | StaticReturnType['body'] - >(result); + return response.ok(result); } catch (err) { logger.error(err); const error = transformError(err);