Skip to content

Commit

Permalink
add placeholder draft for cohere helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
ystoneman committed Jun 5, 2024
1 parent dec01a0 commit 998b643
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
10 changes: 3 additions & 7 deletions src/libs/agent-runtime/cohere/index.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
// sort-imports-ignore
// TODO: FOR COHERE See if this shims is needed for Cohere
// import '@anthropic-ai/sdk/shims/web';
import { CohereClient, CohereError, CohereTimeoutError } from "cohere-ai";
import { ClientOptions } from 'openai';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
// TODO: FOR COHERE Add cohere to types
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider, UserMessageContentPart } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { desensitizeUrl } from '../utils/desensitizeUrl';

// TODO: FOR COHERE create cohereHelpers
import { buildAnthropicMessages, buildAnthropicTools } from '../utils/anthropicHelpers';
import { buildCohereMessages, buildCohereTools } from '../utils/cohereHelpers';
import { StreamingResponse } from '../utils/response';

// TODO: FOR COHERE create stream util for cohere
Expand Down Expand Up @@ -127,10 +123,10 @@ export class LobeCohereAI implements LobeRuntimeAI {
max_tokens,
model,
temperature,
tools: buildAnthropicTools(tools),
tools: buildCohereTools(tools),
p,
message: typeof message === 'string' ? message : message.join(' '),
chatHistory: chatHistory.map((m) => ({ role: m.role, message: m.content })),
chat_history: chatHistory.map((m) => ({ role: m.role.toUpperCase(), message: m.content })),
};
}

Expand Down
60 changes: 60 additions & 0 deletions src/libs/agent-runtime/utils/cohereHelpers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { parseDataUri } from './uriParser';
import { OpenAIChatMessage, UserMessageContentPart } from '../types';
import Cohere from 'cohere-ai';

export const buildCohereBlock = (
content: UserMessageContentPart,
): { type: string; content?: string; data?: string; mime_type?: string } => {
switch (content.type) {
case 'text': {
return { type: 'text', content: content.text };
}

case 'image_url': {
const { mimeType, base64 } = parseDataUri(content.image_url.url);
return {
type: 'image',
data: base64,
mime_type: mimeType,
};
}

default: {
throw new Error(`Unsupported content type: ${content.type}`);
}
}
};

export const buildCohereMessage = (
message: OpenAIChatMessage,
): { role: string; content: string | object } => {
const content = message.content as string | UserMessageContentPart[];

switch (message.role) {
case 'system':
case 'user':
case 'assistant': {
return {
role: message.role,
content: typeof content === 'string' ? content : content.map(buildCohereBlock),
};
}

default: {
throw new Error(`Unsupported message role: ${message.role}`);
}
}
};

export const buildCohereMessages = (
oaiMessages: OpenAIChatMessage[],
): { role: string; content: string | object }[] => {
return oaiMessages.map(buildCohereMessage);
};

export const buildCohereTools = (tools?: OpenAI.ChatCompletionTool[]) =>
tools?.map(tool => ({
name: tool.function.name,
description: tool.function.description,
input_schema: tool.function.parameters,
}));

0 comments on commit 998b643

Please sign in to comment.