diff --git a/cortex-js/src/domain/abstracts/engine.abstract.ts b/cortex-js/src/domain/abstracts/engine.abstract.ts index d203218fa..92335fc63 100644 --- a/cortex-js/src/domain/abstracts/engine.abstract.ts +++ b/cortex-js/src/domain/abstracts/engine.abstract.ts @@ -6,6 +6,10 @@ import { Extension } from './extension.abstract'; export abstract class EngineExtension extends Extension { abstract onLoad(): void; + transformPayload?: Function; + + transformResponse?: Function; + abstract inference( dto: any, headers: Record, @@ -17,4 +21,5 @@ export abstract class EngineExtension extends Extension { ): Promise {} async unloadModel(modelId: string): Promise {} + } diff --git a/cortex-js/src/domain/abstracts/oai.abstract.ts b/cortex-js/src/domain/abstracts/oai.abstract.ts index 7718b8bfd..02d215cf2 100644 --- a/cortex-js/src/domain/abstracts/oai.abstract.ts +++ b/cortex-js/src/domain/abstracts/oai.abstract.ts @@ -1,7 +1,8 @@ import { HttpService } from '@nestjs/axios'; import { EngineExtension } from './engine.abstract'; -import stream from 'stream'; +import stream, { Transform } from 'stream'; import { firstValueFrom } from 'rxjs'; +import _ from 'lodash'; export abstract class OAIEngineExtension extends EngineExtension { abstract apiUrl: string; @@ -17,22 +18,47 @@ export abstract class OAIEngineExtension extends EngineExtension { createChatDto: any, headers: Record, ): Promise { - const { stream } = createChatDto; + const payload = this.transformPayload ? this.transformPayload(createChatDto) : createChatDto; + const { stream: isStream } = payload; + const additionalHeaders = _.omit(headers, ['content-type', 'authorization']); const response = await firstValueFrom( - this.httpService.post(this.apiUrl, createChatDto, { + this.httpService.post(this.apiUrl, payload, { headers: { 'Content-Type': headers['content-type'] ?? 'application/json', Authorization: this.apiKey ? `Bearer ${this.apiKey}` : headers['authorization'], + ...additionalHeaders, }, - responseType: stream ? 'stream' : 'json', + responseType: isStream ? 'stream' : 'json', }), ); + if (!response) { throw new Error('No response'); } - - return response.data; + if(!this.transformResponse) { + return response.data; + } + if (isStream) { + const transformResponse = this.transformResponse.bind(this); + const lineStream = new Transform({ + transform(chunk, encoding, callback) { + const lines = chunk.toString().split('\n'); + const transformedLines = []; + for (const line of lines) { + if (line.trim().length > 0) { + const transformedLine = transformResponse(line); + if (transformedLine) { + transformedLines.push(transformedLine); + } + } + } + callback(null, transformedLines.join('')); + } + }); + return response.data.pipe(lineStream); + } + return this.transformResponse(response.data); } } diff --git a/cortex-js/src/extensions/anthropic.engine.ts b/cortex-js/src/extensions/anthropic.engine.ts new file mode 100644 index 000000000..7ce37aff7 --- /dev/null +++ b/cortex-js/src/extensions/anthropic.engine.ts @@ -0,0 +1,97 @@ +import stream from 'stream'; +import { HttpService } from '@nestjs/axios'; +import { OAIEngineExtension } from '../domain/abstracts/oai.abstract'; +import { ConfigsUsecases } from '@/usecases/configs/configs.usecase'; +import { EventEmitter2 } from '@nestjs/event-emitter'; +import _ from 'lodash'; + +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class AnthropicEngineExtension extends OAIEngineExtension { + apiUrl = 'https://api.anthropic.com/v1/messages'; + name = 'anthropic'; + productName = 'Anthropic Inference Engine'; + description = 'This extension enables Anthropic chat completion API calls'; + version = '0.0.1'; + apiKey?: string; + + constructor( + protected readonly httpService: HttpService, + protected readonly configsUsecases: ConfigsUsecases, + protected readonly eventEmmitter: EventEmitter2, + ) { + super(httpService); + + eventEmmitter.on('config.updated', async (data) => { + if (data.group === this.name) { + this.apiKey = data.value; + } + }); + } + + async onLoad() { + const configs = (await this.configsUsecases.getGroupConfigs( + this.name, + )) as unknown as { apiKey: string }; + this.apiKey = configs?.apiKey; + if (!configs?.apiKey) + await this.configsUsecases.saveConfig('apiKey', '', this.name); + } + + override async inference(dto: any, headers: Record): Promise { + headers['x-api-key'] = this.apiKey as string + headers['Content-Type'] = 'application/json' + headers['anthropic-version'] = '2023-06-01' + return super.inference(dto, headers) + } + + transformPayload = (data: any): any => { + return _.pick(data, ['messages', 'model', 'stream', 'max_tokens']); + } + + transformResponse = (data: any): string => { + // handling stream response + if (typeof data === 'string' && data.trim().length === 0) { + return ''; + } + if (typeof data === 'string' && data.startsWith('event: ')) { + return '' + } + if (typeof data === 'string' && data.startsWith('data: ')) { + data = data.replace('data: ', ''); + const parsedData = JSON.parse(data); + if (parsedData.type !== 'content_block_delta') { + return '' + } + const text = parsedData.delta?.text; + //convert to have this format data.choices[0]?.delta?.content + return JSON.stringify({ + choices: [ + { + delta: { + content: text + } + } + ] + }) + } + // non-stream response + if (data.content && data.content.length > 0 && data.content[0].text) { + return JSON.stringify({ + choices: [ + { + delta: { + content: data.content[0].text, + }, + }, + ], + }); + } + + console.error('Invalid response format:', data); + return ''; + } +} diff --git a/cortex-js/src/extensions/extensions.module.ts b/cortex-js/src/extensions/extensions.module.ts index ad8430be6..22f806e7f 100644 --- a/cortex-js/src/extensions/extensions.module.ts +++ b/cortex-js/src/extensions/extensions.module.ts @@ -6,6 +6,7 @@ import { HttpModule, HttpService } from '@nestjs/axios'; import { ConfigsUsecases } from '@/usecases/configs/configs.usecase'; import { ConfigsModule } from '@/usecases/configs/configs.module'; import { EventEmitter2, EventEmitterModule } from '@nestjs/event-emitter'; +import AnthropicEngineExtension from './anthropic.engine'; const provider = { provide: 'EXTENSIONS_PROVIDER', @@ -18,6 +19,7 @@ const provider = { new OpenAIEngineExtension(httpService, configUsecases, eventEmitter), new GroqEngineExtension(httpService, configUsecases, eventEmitter), new MistralEngineExtension(httpService, configUsecases, eventEmitter), + new AnthropicEngineExtension(httpService, configUsecases, eventEmitter), ], }; diff --git a/cortex-js/src/infrastructure/commanders/chat.command.ts b/cortex-js/src/infrastructure/commanders/chat.command.ts index fb4ba0427..9c39287a5 100644 --- a/cortex-js/src/infrastructure/commanders/chat.command.ts +++ b/cortex-js/src/infrastructure/commanders/chat.command.ts @@ -49,7 +49,6 @@ export class ChatCommand extends CommandRunner { async run(passedParams: string[], options: ChatOptions): Promise { let modelId = passedParams[0]; - const checkingSpinner = ora('Checking model...').start(); // First attempt to get message from input or options // Extract input from 1 to end of array let message = options.message ?? passedParams.slice(1).join(' '); @@ -68,11 +67,9 @@ export class ChatCommand extends CommandRunner { } else if (models.length > 0) { modelId = await this.modelInquiry(models); } else { - checkingSpinner.fail('Model ID is required'); exit(1); } } - checkingSpinner.succeed(`Model found`); if (!message) options.attach = true; const result = await this.chatCliUsecases.chat( diff --git a/cortex-js/src/infrastructure/commanders/types/engine.interface.ts b/cortex-js/src/infrastructure/commanders/types/engine.interface.ts index 91a08d919..93101ad11 100644 --- a/cortex-js/src/infrastructure/commanders/types/engine.interface.ts +++ b/cortex-js/src/infrastructure/commanders/types/engine.interface.ts @@ -7,4 +7,5 @@ export enum Engines { groq = 'groq', mistral = 'mistral', openai = 'openai', + anthropic = 'anthropic', } diff --git a/cortex-js/src/usecases/chat/chat.usecases.ts b/cortex-js/src/usecases/chat/chat.usecases.ts index 1ce580f42..44b3289f2 100644 --- a/cortex-js/src/usecases/chat/chat.usecases.ts +++ b/cortex-js/src/usecases/chat/chat.usecases.ts @@ -35,7 +35,6 @@ export class ChatUsecases { const engine = (await this.extensionRepository.findOne( model!.engine ?? Engines.llamaCPP, )) as EngineExtension | undefined; - if (engine == null) { throw new Error(`No engine found with name: ${model.engine}`); }