Skip to content

Commit

Permalink
chore: load, unload model and inference synchronously
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Mar 25, 2024
1 parent 1ad794c commit 9551996
Show file tree
Hide file tree
Showing 16 changed files with 226 additions and 173 deletions.
46 changes: 0 additions & 46 deletions core/src/extensions/ai-engines/RemoteOAIEngine.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
import { Model, ModelEvent } from '../../types'
import { MessageRequest, Model, ModelEvent } from '../../types'
import { EngineManager } from './EngineManager'

/**
* Base AIEngine
Expand All @@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension {
// The inference engine
abstract provider: string
// The model folder
modelFolder: string = 'models'

/**
* On extension load, subscribe to events.
*/
override onLoad() {
this.registerEngine()

events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))

this.prePopulateModels()
}

/**
* Defines models
*/
models(): Promise<Model[]> {
return Promise.resolve([])
}

/**
* On extension load, subscribe to events.
* Registers AI Engines
*/
onLoad() {
this.prePopulateModels()
registerEngine() {
EngineManager.instance()?.register(this)
}

/**
* Loads the model.
*/
async loadModel(model: Model): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}

/*
* Inference request
*/
inference(data: MessageRequest) {}

/**
* Stop inference
*/
stopInference() {}

/**
* Pre-populate models to App Data Folder
*/
prePopulateModels(): Promise<void> {
const modelFolder = 'models'
return this.models().then((models) => {
const prePoluateOperations = models.map((model) =>
getJanDataFolderPath()
.then((janDataFolder) =>
// Attempt to create the model folder
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs
.mkdir(path)
.catch()
Expand Down
34 changes: 34 additions & 0 deletions core/src/extensions/engines/EngineManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { log } from '../../core'
import { AIEngine } from './AIEngine'

/**
* Manages the registration and retrieval of inference engines.
*/
export class EngineManager {
public engines = new Map<string, AIEngine>()

/**
* Registers an engine.
* @param engine - The engine to register.
*/
register<T extends AIEngine>(engine: T) {
this.engines.set(engine.provider, engine)
}

/**
* Retrieves a engine by provider.
* @param provider - The name of the engine to retrieve.
* @returns The engine, if found.
*/
get<T extends AIEngine>(provider: string): T | undefined {
return this.engines.get(provider) as T | undefined
}

static instance(): EngineManager | undefined {
return window.core?.engineManager as EngineManager
}
}

/**
* The singleton instance of the ExtensionManager.
*/
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
Expand All @@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* Load the model.
*/
async loadModel(model: Model) {
override async loadModel(model: Model): Promise<void> {
if (model.engine.toString() !== this.provider) return

const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
Expand All @@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
)

if (res?.error) {
events.emit(ModelEvent.OnModelFail, {
...model,
error: res.error,
})
return
events.emit(ModelEvent.OnModelFail, { error: res.error })
return Promise.reject(res.error)
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine?.toString() !== this.provider) return
this.loadedModel = undefined
override async unloadModel(model?: Model): Promise<void> {
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()

executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
this.loadedModel = undefined
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
Expand All @@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension unload
*/
onUnload(): void {}
override onUnload(): void {}

/*
* Inference request
*/
inference(data: MessageRequest) {
override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return

const timestamp = Date.now()
Expand Down Expand Up @@ -114,7 +114,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* Stops the inference.
*/
stopInference() {
override stopInference() {
this.isCancelled = true
this.controller?.abort()
}
Expand Down
26 changes: 26 additions & 0 deletions core/src/extensions/engines/RemoteOAIEngine.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { OAIEngine } from './OAIEngine'

/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
override onLoad() {
super.onLoad()
}

/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ export * from './AIEngine'
export * from './OAIEngine'
export * from './LocalOAIEngine'
export * from './RemoteOAIEngine'
export * from './EngineManager'
2 changes: 1 addition & 1 deletion core/src/extensions/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
/**
* Base AI Engines.
*/
export * from './ai-engines'
export * from './engines'
7 changes: 3 additions & 4 deletions extensions/inference-nitro-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
return super.loadModel(model)
}

override unloadModel(model: Model): void {
super.unloadModel(model)

if (model.engine && model.engine !== this.provider) return
override async unloadModel(model?: Model) {
if (model?.engine && model.engine !== this.provider) return

// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
return super.unloadModel(model)
}
}
Loading

0 comments on commit 9551996

Please sign in to comment.