Skip to content

Commit

Permalink
Gemini: support Pro 1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Apr 10, 2024
1 parent e3290e1 commit 5dc9c8f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 63 deletions.
69 changes: 69 additions & 0 deletions src/modules/llms/server/gemini/gemini.models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import type { GeminiModelSchema } from './gemini.wiretypes';
import type { ModelDescriptionSchema } from '../llm.server.types';
import { LLM_IF_OAI_Chat, LLM_IF_OAI_Vision } from '../../store-llms';


const filterUnallowedNames = ['Legacy'];
const filterUnallowedInterfaces: GeminiModelSchema['supportedGenerationMethods'] = ['generateAnswer', 'embedContent', 'embedText'];

const geminiLinkModels = ['models/gemini-pro', 'models/gemini-pro-vision'];

// interfaces mapping
const geminiChatInterfaces: GeminiModelSchema['supportedGenerationMethods'] = ['generateContent'];
const geminiVisionNames = ['-vision'];


export function geminiFilterModels(geminiModel: GeminiModelSchema): boolean {
const isAllowed = !filterUnallowedNames.some(name => geminiModel.displayName.includes(name));
const isSupported = !filterUnallowedInterfaces.some(iface => geminiModel.supportedGenerationMethods.includes(iface));
return isAllowed && isSupported;
}

export function geminiSortModels(a: ModelDescriptionSchema, b: ModelDescriptionSchema): number {
// hidden to the bottom, then names descending
if (a.hidden && !b.hidden) return 1;
if (!a.hidden && b.hidden) return -1;
return b.label.localeCompare(a.label);
}

export function geminiModelToModelDescription(geminiModel: GeminiModelSchema, allModels: GeminiModelSchema[]): ModelDescriptionSchema {
const { description, displayName, name: modelId, supportedGenerationMethods } = geminiModel;

// handle symlinks
const isSymlink = geminiLinkModels.includes(modelId);
const symlinked = isSymlink ? allModels.find(m => m.displayName === displayName && m.name !== modelId) : null;
const label = isSymlink ? `🔗 ${displayName.replace('1.0', '')}${symlinked ? symlinked.name : '?'}` : displayName;

// handle hidden models
const hasChatInterfaces = supportedGenerationMethods.some(iface => geminiChatInterfaces.includes(iface));
const hidden = isSymlink || !hasChatInterfaces;

// context window
const { inputTokenLimit, outputTokenLimit } = geminiModel;
const contextWindow = inputTokenLimit + outputTokenLimit;

// description
const { version, topK, topP, temperature } = geminiModel;
const descriptionLong = description + ` (Version: ${version}, Defaults: temperature=${temperature}, topP=${topP}, topK=${topK}, interfaces=[${supportedGenerationMethods.join(',')}])`;

const interfaces: ModelDescriptionSchema['interfaces'] = [];
if (hasChatInterfaces) {
interfaces.push(LLM_IF_OAI_Chat);
if (geminiVisionNames.some(name => modelId.includes(name)))
interfaces.push(LLM_IF_OAI_Vision);
}

return {
id: modelId,
label,
// created: ...
// updated: ...
description: descriptionLong,
contextWindow: contextWindow,
maxCompletionTokens: outputTokenLimit,
// pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined,
// rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined,
interfaces,
hidden,
};
}
47 changes: 8 additions & 39 deletions src/modules/llms/server/gemini/gemini.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server';
import { fetchJsonOrTRPCError } from '~/server/api/trpc.router.fetchers';

import { fixupHost } from '~/common/util/urlUtils';

import { LLM_IF_OAI_Chat, LLM_IF_OAI_Vision } from '../../store-llms';
import { llmsChatGenerateOutputSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types';
import { llmsChatGenerateOutputSchema, llmsListModelsOutputSchema } from '../llm.server.types';

import { OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router';

import { GeminiBlockSafetyLevel, geminiBlockSafetyLevelSchema, GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes';
import { geminiFilterModels, geminiModelToModelDescription, geminiSortModels } from '~/modules/llms/server/gemini/gemini.models';


// Default hosts
Expand Down Expand Up @@ -146,43 +145,13 @@ export const llmGeminiRouter = createTRPCRouter({
// as the List API already all the info on all the models

// map to our output schema
const models = detailedModels
.filter(geminiFilterModels)
.map(geminiModel => geminiModelToModelDescription(geminiModel, detailedModels))
.sort(geminiSortModels);

return {
models: detailedModels.map((geminiModel) => {
const { description, displayName, inputTokenLimit, name, outputTokenLimit, supportedGenerationMethods } = geminiModel;

const isSymlink = ['models/gemini-pro', 'models/gemini-pro-vision'].includes(name);
const symlinked = isSymlink ? detailedModels.find(m => m.displayName === displayName && m.name !== name) : null;

const contextWindow = inputTokenLimit + outputTokenLimit;
const hidden = !supportedGenerationMethods.includes('generateContent') || isSymlink;

const { version, topK, topP, temperature } = geminiModel;
const descriptionLong = description + ` (Version: ${version}, Defaults: temperature=${temperature}, topP=${topP}, topK=${topK}, interfaces=[${supportedGenerationMethods.join(',')}])`;

// const isGeminiPro = name.includes('gemini-pro');
const isGeminiProVision = name.includes('gemini-pro-vision');

const interfaces: ModelDescriptionSchema['interfaces'] = [];
if (supportedGenerationMethods.includes('generateContent')) {
interfaces.push(LLM_IF_OAI_Chat);
if (isGeminiProVision)
interfaces.push(LLM_IF_OAI_Vision);
}

return {
id: name,
label: isSymlink ? `🔗 ${displayName.replace('1.0', '')}${symlinked ? symlinked.name : '?'}` : displayName,
// created: ...
// updated: ...
description: descriptionLong,
contextWindow: contextWindow,
maxCompletionTokens: outputTokenLimit,
// pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined,
// rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined,
interfaces: supportedGenerationMethods.includes('generateContent') ? [LLM_IF_OAI_Chat] : [],
hidden,
} satisfies ModelDescriptionSchema;
}),
models: models,
};
}),

Expand Down
51 changes: 27 additions & 24 deletions src/modules/llms/server/gemini/gemini.wiretypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,34 @@ export const geminiModelsStreamGenerateContentPath = '/v1beta/{model=models/*}:s

// models.list = /v1beta/models

const geminiModelSchema = z.object({
name: z.string(),
version: z.string(),
displayName: z.string(),
description: z.string(),
inputTokenLimit: z.number().int().min(1),
outputTokenLimit: z.number().int().min(1),
supportedGenerationMethods: z.array(z.enum([
'countMessageTokens',
'countTextTokens',
'countTokens',
'createTunedModel',
'createTunedTextModel',
'embedContent',
'embedText',
'generateAnswer',
'generateContent',
'generateMessage',
'generateText',
])),
temperature: z.number().optional(),
topP: z.number().optional(),
topK: z.number().optional(),
});
export type GeminiModelSchema = z.infer<typeof geminiModelSchema>;

export const geminiModelsListOutputSchema = z.object({
models: z.array(z.object({
name: z.string(),
version: z.string(),
displayName: z.string(),
description: z.string(),
inputTokenLimit: z.number().int().min(1),
outputTokenLimit: z.number().int().min(1),
supportedGenerationMethods: z.array(z.enum([
'countMessageTokens',
'countTextTokens',
'countTokens',
'createTunedModel',
'createTunedTextModel',
'embedContent',
'embedText',
'generateAnswer',
'generateContent',
'generateMessage',
'generateText',
])),
temperature: z.number().optional(),
topP: z.number().optional(),
topK: z.number().optional(),
})),
models: z.array(geminiModelSchema),
});


Expand Down

0 comments on commit 5dc9c8f

Please sign in to comment.