Skip to content

Commit

Permalink
tweak types
Browse files Browse the repository at this point in the history
  • Loading branch information
walterra committed May 7, 2024
1 parent da2152b commit b13bec0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 49 deletions.
1 change: 1 addition & 0 deletions x-pack/packages/ml/response_stream/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
export {
streamFactory,
type StreamFactoryReturnType,
type StreamResponseWithHeaders,
type UncompressedResponseStream,
} from './stream_factory';
17 changes: 9 additions & 8 deletions x-pack/packages/ml/response_stream/server/stream_factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ type StreamType<T extends StreamTypeUnion> = T extends string
? T
: never;

// Fallback to never is there for backwards compatibility.
export interface StreamFactoryReturnType<T extends StreamTypeUnion = never> {
export interface StreamResponseWithHeaders {
body: zlib.Gzip | UncompressedResponseStream;
headers?: ResponseHeaders;
}

export interface StreamFactoryReturnType<T extends StreamTypeUnion> {
DELIMITER: string;
end: () => void;
push: (d: StreamType<T>, drain?: boolean) => void;
responseWithHeaders: {
body: zlib.Gzip | UncompressedResponseStream;
headers?: ResponseHeaders;
};
responseWithHeaders: StreamResponseWithHeaders;
}

/**
Expand Down Expand Up @@ -133,7 +134,7 @@ export function streamFactory<T extends StreamTypeUnion>(
function repeat() {
if (!tryToEnd) {
if (responseSizeSinceLastKeepAlive < FLUSH_PAYLOAD_SIZE) {
push({ flushPayload, type: 'flushPayload' } as unknown as StreamType<T>);
push({ flushPayload, type: 'flushPayload' } as StreamType<T>);
}
responseSizeSinceLastKeepAlive = 0;
setTimeout(repeat, FLUSH_KEEP_ALIVE_INTERVAL_MS);
Expand Down Expand Up @@ -211,7 +212,7 @@ export function streamFactory<T extends StreamTypeUnion>(
}
}

const responseWithHeaders: StreamFactoryReturnType<T>['responseWithHeaders'] = {
const responseWithHeaders: StreamResponseWithHeaders = {
body: stream,
headers: {
...(isCompressed ? { 'content-encoding': 'gzip' } : {}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { KibanaRequest, ResponseHeaders } from '@kbn/core-http-server';
import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain';
import type { AnalyticsServiceSetup } from '@kbn/core-analytics-server';
import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common';
import { StreamFactoryReturnType } from '@kbn/ml-response-stream/server';
import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
import { ResponseBody } from '../types';
import type { AssistantTool } from '../../../types';
Expand Down Expand Up @@ -51,7 +51,7 @@ export interface StaticReturnType {
headers: ResponseHeaders;
}
export type AgentExecutorResponse<T extends boolean> = T extends true
? StreamFactoryReturnType['responseWithHeaders']
? StreamResponseWithHeaders
: StaticReturnType;

export type AgentExecutor<T extends boolean> = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import { IRouter, Logger } from '@kbn/core/server';
import { transformError } from '@kbn/securitysolution-es-utils';
import { getRequestAbortedSignal } from '@kbn/data-plugin/server';
import { StreamFactoryReturnType } from '@kbn/ml-response-stream/server';
import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';

import { schema } from '@kbn/config-schema';
import {
Expand Down Expand Up @@ -323,42 +323,41 @@ export const postActionsConnectorExecuteRoute = (
page: 1,
});

const result: StreamFactoryReturnType['responseWithHeaders'] | StaticReturnType =
await callAgentExecutor({
abortSignal,
alertsIndexPattern: request.body.alertsIndexPattern,
anonymizationFields: anonymizationFieldsRes
? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data)
: undefined,
actions,
isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false,
assistantTools,
connectorId,
elserId,
esClient,
isStream:
// TODO implement llmClass for bedrock streaming
// tracked here: https://github.com/elastic/security-team/issues/7363
request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai',
llmType: getLlmType(actionTypeId),
kbResource: ESQL_RESOURCE,
langChainMessages,
logger,
onNewReplacements,
onLlmResponse,
request,
replacements: request.body.replacements,
size: request.body.size,
telemetry,
traceOptions: {
const result: StreamResponseWithHeaders | StaticReturnType = await callAgentExecutor({
abortSignal,
alertsIndexPattern: request.body.alertsIndexPattern,
anonymizationFields: anonymizationFieldsRes
? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data)
: undefined,
actions,
isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false,
assistantTools,
connectorId,
elserId,
esClient,
isStream:
// TODO implement llmClass for bedrock streaming
// tracked here: https://github.com/elastic/security-team/issues/7363
request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai',
llmType: getLlmType(actionTypeId),
kbResource: ESQL_RESOURCE,
langChainMessages,
logger,
onNewReplacements,
onLlmResponse,
request,
replacements: request.body.replacements,
size: request.body.size,
telemetry,
traceOptions: {
projectName: langSmithProject,
tracers: getLangSmithTracer({
apiKey: langSmithApiKey,
projectName: langSmithProject,
tracers: getLangSmithTracer({
apiKey: langSmithApiKey,
projectName: langSmithProject,
logger,
}),
},
});
logger,
}),
},
});

telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, {
actionTypeId,
Expand All @@ -371,9 +370,7 @@ export const postActionsConnectorExecuteRoute = (
request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai',
});

return response.ok<
StreamFactoryReturnType['responseWithHeaders']['body'] | StaticReturnType['body']
>(result);
return response.ok<StreamResponseWithHeaders['body'] | StaticReturnType['body']>(result);
} catch (err) {
logger.error(err);
const error = transformError(err);
Expand Down

0 comments on commit b13bec0

Please sign in to comment.