Skip to content

Commit

Permalink
more engine tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nbonamy committed May 15, 2024
1 parent 7bb637e commit c226454
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 25 deletions.
2 changes: 1 addition & 1 deletion build/build_number.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
187
188
10 changes: 5 additions & 5 deletions src/services/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@ export const canProcessFormat = (engine: string, model: string, format: string)
}

export const loadAllModels = async () => {
await loadModels('openai')
await loadModels('ollama')
await loadModels('mistralai')
await loadModels('anthropic')
await loadModels('groq')
for (const engine in availableEngines) {
await loadModels(engine)
}
}

export const loadModels = async (engine: string) => {
Expand All @@ -74,6 +72,8 @@ export const loadModels = async (engine: string) => {
await loadMistralAIModels()
} else if (engine === 'anthropic') {
await loadAnthropicModels()
} else if (engine === 'google') {
await loadGoogleModels()
} else if (engine === 'groq') {
await loadGroqModels()
}
Expand Down
17 changes: 16 additions & 1 deletion tests/unit/engine_ready.test.ts → tests/unit/engine.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@

import { beforeEach, expect, test } from 'vitest'
import { isEngineReady } from '../../src/services/llm'
import { isEngineReady, igniteEngine } from '../../src/services/llm'
import { store } from '../../src/services/store'
import defaults from '../../defaults/settings.json'
import OpenAI from '../../src/services/openai'
import Ollama from '../../src/services/ollama'
import MistralAI from '../../src/services/mistralai'
import Anthropic from '../../src/services/anthropic'
import Google from '../../src/services/google'
import Groq from '../../src/services/groq'

const model = [{ id: 'llava:latest', name: 'llava:latest', meta: {} }]

Expand Down Expand Up @@ -59,3 +65,12 @@ test('Google Configuration', () => {
store.config.engines.google.apiKey = '123'
expect(isEngineReady('google')).toBe(true)
})

test('Ignite Engine', async () => {
expect(await igniteEngine('openai', store.config)).toBeInstanceOf(OpenAI)
expect(await igniteEngine('ollama', store.config)).toBeInstanceOf(Ollama)
expect(await igniteEngine('mistralai', store.config)).toBeInstanceOf(MistralAI)
expect(await igniteEngine('anthropic', store.config)).toBeInstanceOf(Anthropic)
expect(await igniteEngine('google', store.config)).toBeInstanceOf(Google)
expect(await igniteEngine('groq', store.config)).toBeInstanceOf(Groq)
})
18 changes: 13 additions & 5 deletions tests/unit/engine_anthropic.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import defaults from '../../defaults/settings.json'
import Message from '../../src/models/message'
import Anthropic from '../../src/services/anthropic'
import * as _Anthropic from '@anthropic-ai/sdk'
import { loadAnthropicModels } from '../../src/services/llm'
import { Model } from '../../src/types/config.d'

vi.mock('@anthropic-ai/sdk', async() => {
const Anthropic = vi.fn()
Expand Down Expand Up @@ -44,14 +46,20 @@ beforeEach(() => {
store.config.engines.anthropic.apiKey = '123'
})

test('Anthropic Basic', async () => {
const anthropic = new Anthropic(store.config)
expect(anthropic.getName()).toBe('anthropic')
expect(await anthropic.getModels()).toStrictEqual([
test('Anthropic Load Models', async () => {
expect(await loadAnthropicModels()).toBe(true)
const models = store.config.engines.anthropic.models.chat
expect(models.map((m: Model) => { return { id: m.id, name: m.name }})).toStrictEqual([
{ id: 'claude-3-haiku-20240307', name: 'Claude 3 Haiku' },
{ id: 'claude-3-sonnet-20240229', name: 'Claude 3 Sonnet' },
{ id: 'claude-3-opus-20240229', name: 'Claude 3 Opus' },
{ id: 'claude-3-sonnet-20240229', name: 'Claude 3 Sonnet' },
])
expect(store.config.engines.anthropic.model.chat).toStrictEqual(models[0].id)
})

test('Anthropic Basic', async () => {
const anthropic = new Anthropic(store.config)
expect(anthropic.getName()).toBe('anthropic')
expect(anthropic.getVisionModels()).toStrictEqual([])
expect(anthropic.isVisionModel('claude-3-haiku-20240307')).toBe(true)
expect(anthropic.isVisionModel('claude-3-sonnet-20240229')).toBe(true)
Expand Down
19 changes: 13 additions & 6 deletions tests/unit/engine_google.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import { store } from '../../src/services/store'
import defaults from '../../defaults/settings.json'
import Message from '../../src/models/message'
import Google from '../../src/services/google'
import { loadGoogleModels } from '../../src/services/llm'
import { EnhancedGenerateContentResponse } from '@google/generative-ai'
import { Model } from '../../src/types/config.d'

vi.mock('@google/generative-ai', async() => {
return {
Expand Down Expand Up @@ -41,15 +43,20 @@ beforeEach(() => {
store.config.engines.google.apiKey = '123'
})

test('Google Basic', async () => {
const google = new Google(store.config)
expect(google.getName()).toBe('google')
expect(await google.getModels()).toStrictEqual([
test('Google Load Models', async () => {
expect(await loadGoogleModels()).toBe(true)
const models = store.config.engines.google.models.chat
expect(models.map((m: Model) => { return { id: m.id, name: m.name }})).toStrictEqual([
{ id: 'models/gemini-1.5-pro-latest', name: 'Gemini 1.5 Pro' },
{ id: 'gemini-1.5-flash-latest', name: 'Gemini 1.5 Flash' },
{ id: 'models/gemini-pro', name: 'Gemini 1.0 Pro' },
])
//expect(_Google.default.prototype.models.list).toHaveBeenCalled()
])
expect(store.config.engines.google.model.chat).toStrictEqual(models[0].id)
})

test('Google Basic', async () => {
const google = new Google(store.config)
expect(google.getName()).toBe('google')
expect(google.isVisionModel('models/gemini-pro')).toBe(false)
expect(google.isVisionModel('gemini-1.5-flash-latest')).toBe(true)
expect(google.isVisionModel('models/gemini-1.5-pro-latest')).toBe(true)
Expand Down
18 changes: 16 additions & 2 deletions tests/unit/engine_mistralai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@ import defaults from '../../defaults/settings.json'
import Message from '../../src/models/message'
import MistralAI from '../../src/services/mistralai'
import MistralClient from '../../src/vendor/mistralai'
import { loadMistralAIModels } from '../../src/services/llm'
import { Model } from '../../src/types/config.d'

vi.mock('../../src/vendor/mistralai', async() => {
const MistralClient = vi.fn()
MistralClient.prototype.apiKey = '123'
MistralClient.prototype.listModels = vi.fn(() => {
return { data: [{ id: 'model', name: 'model' }] }
return { data: [
{ id: 'model2', name: 'model2' },
{ id: 'model1', name: 'model1' },
] }
})
MistralClient.prototype.chat = vi.fn(() => {
return { choices: [ { message: { content: 'response' } } ] }
Expand All @@ -31,10 +36,19 @@ beforeEach(() => {
store.config.engines.mistralai.apiKey = '123'
})

test('MistralAI Load Models', async () => {
expect(await loadMistralAIModels()).toBe(true)
const models = store.config.engines.mistralai.models.chat
expect(models.map((m: Model) => { return { id: m.id, name: m.name }})).toStrictEqual([
{ id: 'model1', name: 'model1' },
{ id: 'model2', name: 'model2' },
])
expect(store.config.engines.mistralai.model.chat).toStrictEqual(models[0].id)
})

test('MistralAI Basic', async () => {
const mistralai = new MistralAI(store.config)
expect(mistralai.getName()).toBe('mistralai')
expect(await mistralai.getModels()).toStrictEqual([{ id: 'model', name: 'model' }])
expect(mistralai.isVisionModel('mistral-medium')).toBe(false)
expect(mistralai.isVisionModel('mistral-large')).toBe(false)
expect(mistralai.getRountingModel()).toBeNull()
Expand Down
19 changes: 17 additions & 2 deletions tests/unit/engine_ollama.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@ import defaults from '../../defaults/settings.json'
import Message from '../../src/models/message'
import Ollama from '../../src/services/ollama'
import * as _ollama from 'ollama/dist/browser.mjs'
import { loadOllamaModels } from '../../src/services/llm'
import { Model } from '../../src/types/config.d'

vi.mock('ollama/browser', async() => {
return { default : {
list: vi.fn(() => {
return { models: [{ id: 'model', name: 'model' }] }
return { models: [
{ model: 'model2', name: 'model2' },
{ model: 'model1', name: 'model1' },
] }
}),
chat: vi.fn((opts) => {
if (opts.stream) {
Expand All @@ -33,10 +38,20 @@ beforeEach(() => {
store.config.engines.ollama.apiKey = '123'
})

test('Ollama Load Models', async () => {
expect(await loadOllamaModels()).toBe(true)
const models = store.config.engines.ollama.models.chat
console.log(models)
expect(models.map((m: Model) => { return { id: m.id, name: m.name }})).toStrictEqual([
{ id: 'model1', name: 'model1' },
{ id: 'model2', name: 'model2' },
])
expect(store.config.engines.ollama.model.chat).toStrictEqual(models[0].name)
})

test('Ollama Basic', async () => {
const ollama = new Ollama(store.config)
expect(ollama.getName()).toBe('ollama')
expect(await ollama.getModels()).toStrictEqual([{ id: 'model', name: 'model' }])
expect(ollama.isVisionModel('llava:latest')).toBe(true)
expect(ollama.isVisionModel('llama2:latest')).toBe(false)
expect(ollama.getRountingModel()).toBeNull()
Expand Down
33 changes: 30 additions & 3 deletions tests/unit/engine_openai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@ import Message from '../../src/models/message'
import OpenAI from '../../src/services/openai'
import * as _OpenAI from 'openai'
import { ChatCompletionChunk } from 'openai/resources'
import { loadOpenAIModels } from '../../src/services/llm'
import { Model } from '../../src/types/config.d'

vi.mock('openai', async() => {
const OpenAI = vi.fn()
OpenAI.prototype.apiKey = '123'
OpenAI.prototype.models = {
list: vi.fn(() => {
return { data: [{ id: 'model', name: 'model' }] }
return { data: [
{ id: 'gpt-model2', name: 'model2' },
{ id: 'gpt-model1', name: 'model1' },
{ id: 'dall-e-model2', name: 'model2' },
{ id: 'dall-e-model1', name: 'model1' },
] }
})
}
OpenAI.prototype.chat = {
Expand Down Expand Up @@ -47,11 +54,31 @@ beforeEach(() => {
store.config.engines.openai.apiKey = '123'
})

test('OpenAI Load Chat Models', async () => {
expect(await loadOpenAIModels()).toBe(true)
const models = store.config.engines.openai.models.chat
expect(_OpenAI.default.prototype.models.list).toHaveBeenCalled()
expect(models.map((m: Model) => { return { id: m.id, name: m.name }})).toStrictEqual([
{ id: 'gpt-model1', name: 'gpt-model1' },
{ id: 'gpt-model2', name: 'gpt-model2' },
])
expect(store.config.engines.openai.model.chat).toStrictEqual(models[0].id)
})

test('OpenAI Load Image Models', async () => {
expect(await loadOpenAIModels()).toBe(true)
const models = store.config.engines.openai.models.image
expect(_OpenAI.default.prototype.models.list).toHaveBeenCalled()
expect(models.map((m: Model) => { return { id: m.id, name: m.name }})).toStrictEqual([
{ id: 'dall-e-model1', name: 'dall-e-model1' },
{ id: 'dall-e-model2', name: 'dall-e-model2' },
])
expect(store.config.engines.openai.model.image).toStrictEqual(models[0].id)
})

test('OpenAI Basic', async () => {
const openAI = new OpenAI(store.config)
expect(openAI.getName()).toBe('openai')
expect(await openAI.getModels()).toStrictEqual([{ id: 'model', name: 'model' }])
expect(_OpenAI.default.prototype.models.list).toHaveBeenCalled()
expect(openAI.isVisionModel('gpt-3.5')).toBe(false)
expect(openAI.isVisionModel('gpt-3.5-turbo')).toBe(false)
expect(openAI.isVisionModel('gpt-4')).toBe(false)
Expand Down

0 comments on commit c226454

Please sign in to comment.