Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: type-safe invoke implementation #26

Merged
merged 5 commits into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/desktop/src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ fn main() {

Ok(())
})
// NOTE: New cmd should be added to invoke/_shared.ts
.invoke_handler(tauri::generate_handler![
config::get_config,
path::read_directory,
Expand Down
6 changes: 3 additions & 3 deletions apps/desktop/src/features/inference-server/model-config.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import {
} from "@localai/ui/select"
import { modelTypeList } from "@models/index"
import { TrashIcon } from "@radix-ui/react-icons"
import { invoke } from "@tauri-apps/api/tauri"
import { useState } from "react"

import { InvokeCommand, invoke } from "~features/invoke"
import { DownloadProgress } from "~features/model-downloader/download-progress"
import { DownloadState } from "~features/model-downloader/use-model-download"
import { useGlobal } from "~providers/global"
Expand All @@ -28,7 +28,7 @@ const TestModelButton = () => {
isSpinning={isTesting}
onClick={async () => {
setIsTesting(true)
await invoke("test_model", {
await invoke(InvokeCommand.TestModel, {
...model,
modelType
})
Expand Down Expand Up @@ -70,7 +70,7 @@ export const ModelConfig = () => {
)}
Icon={TrashIcon}
onClick={async () => {
await invoke("delete_model_file", {
await invoke(InvokeCommand.DeleteModelFile, {
path: model.path
})

Expand Down
17 changes: 6 additions & 11 deletions apps/desktop/src/features/inference-server/model-digest.tsx
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import { cn } from "@localai/theme/utils"
import { SpinnerButton } from "@localai/ui/button"
import { CrossCircledIcon, ReloadIcon } from "@radix-ui/react-icons"
import { invoke } from "@tauri-apps/api/tauri"
import { ShoppingCodeCheck } from "iconoir-react"
import { useEffect, useState } from "react"

import { InitState, useInit } from "~features/inference-server/use-init"
import { InvokeCommand, invoke } from "~features/invoke"
import type { ModelIntegrity } from "~features/invoke/model-integrity"
import type { ModelMetadata } from "~features/model-downloader/model-file"
import { DownloadState } from "~features/model-downloader/use-model-download"
import { useModel } from "~providers/model"

type ModelDigest = {
md5: string
sha256: string
blake3: string
}

export const getTruncatedHash = (hashValue: string) =>
`${hashValue.slice(0, 4)}...${hashValue.slice(-7)}`

Expand All @@ -31,13 +26,13 @@ const HashDisplay = ({ hashType = "", hashValue = "", truncated = false }) => {
}

export const getCachedIntegrity = async (path: string) =>
invoke<ModelDigest>("get_cached_integrity", {
invoke(InvokeCommand.GetCachedIntegrity, {
path
}).catch<ModelDigest>(() => null)
}).catch<ModelIntegrity>(() => null)

export function ModelDigest({ model }: { model: ModelMetadata }) {
const { downloadState } = useModel()
const [digestHash, setDigestHash] = useState<ModelDigest>(null)
const [digestHash, setDigestHash] = useState<ModelIntegrity>(null)
const [isCalculating, setIsCalculating] = useState(false)
const [showDetail, setShowDetail] = useState(false)
const { initState } = useInit(async () => {
Expand All @@ -55,7 +50,7 @@ export function ModelDigest({ model }: { model: ModelMetadata }) {
setDigestHash(null)
setIsCalculating(true)
try {
const resp = await invoke<ModelDigest>("compute_model_integrity", {
const resp = await invoke(InvokeCommand.ComputeModelIntegrity, {
path: model.path
})
setDigestHash(resp)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import { invoke } from "@tauri-apps/api/tauri"
import { useReducer } from "react"

import { useInit } from "~features/inference-server/use-init"
import { InvokeCommand, invoke } from "~features/invoke"
import type { ModelMetadata } from "~features/model-downloader/model-file"

export type ModelStats = {
loadCount: number
}

type LaunchCountAction = { type: "initialize" | "increment"; payload?: number }

function launchCounter(state: number, action: LaunchCountAction) {
Expand All @@ -24,7 +20,7 @@ export const useModelStats = (model: ModelMetadata) => {
const [launchCount, dispatch] = useReducer(launchCounter, 0)

useInit(async () => {
const resp = await invoke<ModelStats>("get_model_stats", {
const resp = await invoke(InvokeCommand.GetModelStats, {
path: model.path
}).catch<null>(() => null)
if (resp) {
Expand Down
6 changes: 3 additions & 3 deletions apps/desktop/src/features/inference-server/use-model-type.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import { ModelType } from "@models/index"
import { invoke } from "@tauri-apps/api/tauri"
import { useState } from "react"

import { useInit } from "~features/inference-server/use-init"
import { InvokeCommand, invoke } from "~features/invoke"
import type { ModelMetadata } from "~features/model-downloader/model-file"

export function setModelType(model: ModelMetadata, modelType: ModelType) {
return invoke("set_model_type", {
return invoke(InvokeCommand.SetModelType, {
path: model.path,
modelType
})
}

export async function getModelType(model: ModelMetadata) {
return invoke<ModelType>("get_model_type", {
return invoke(InvokeCommand.GetModelType, {
path: model.path
}).catch<null>(() => null)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import { useCallback, useMemo, useState } from "react"

import { useInit } from "~features/inference-server/use-init"
import type {
DirectoryState,
ModelMetadata
} from "~features/model-downloader/model-file"
import { InvokeCommand, invoke } from "~features/invoke"
import type { ModelMetadata } from "~features/model-downloader/model-file"

export const useModelsDirectory = () => {
const [modelsDirectory, setModelsDirectory] = useState("")
Expand All @@ -13,8 +11,7 @@ export const useModelsDirectory = () => {

useInit(async () => {
// get the models directory saved in config
const { invoke } = await import("@tauri-apps/api/tauri")
const resp = await invoke<DirectoryState>("initialize_models_dir")
const resp = await invoke(InvokeCommand.InitializeModelsDir)
if (!resp) {
return
}
Expand All @@ -25,8 +22,7 @@ export const useModelsDirectory = () => {
const updateModelsDirectory = useCallback(
async (dir = modelsDirectory) => {
setIsRefreshing(true)
const { invoke } = await import("@tauri-apps/api/tauri")
const resp = await invoke<DirectoryState>("update_models_dir", {
const resp = await invoke(InvokeCommand.UpdateModelsDir, {
dir
})
setModelsDirectory(resp.path)
Expand Down
3 changes: 3 additions & 0 deletions apps/desktop/src/features/invoke/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Improvement on Tauri's invoke

- For each cmd name, pair it with a defined set of input/output pairs and look it up from the invoke call
36 changes: 36 additions & 0 deletions apps/desktop/src/features/invoke/_shared.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
export type InvokeIO<Input = Record<string, any>, Output = any> = {
input: Input
output: Output
}

// This should match up with the list of command in apps/desktop/src-tauri/src/main.rs
export enum InvokeCommand {
GetConfig = "get_config",
ReadDirectory = "read_directory",
WriteFile = "write_file",
ReadFile = "read_file",
AppendThreadContent = "append_thread_content",
ReadThreadFile = "read_thread_file",
InitializeThreadsDir = "initialize_threads_dir",
UpdateThreadsDir = "update_threads_dir",
DeleteThreadFile = "delete_thread_file",
CreateThreadFile = "create_thread_file",
RenameThreadFile = "rename_thread_file",
UpdateModelsDir = "update_models_dir",
InitializeModelsDir = "initialize_models_dir",
DeleteModelFile = "delete_model_file",
GetDownloadProgress = "get_download_progress",
StartDownload = "start_download",
PauseDownload = "pause_download",
ResumeDownload = "resume_download",
GetCachedIntegrity = "get_cached_integrity",
ComputeModelIntegrity = "compute_model_integrity",
GetModelStats = "get_model_stats",
StartServer = "start_server",
StopServer = "stop_server",
LoadModel = "load_model",
GetModelType = "get_model_type",
SetModelType = "set_model_type",
TestModel = "test_model",
OpenDirectory = "open_directory"
}
41 changes: 41 additions & 0 deletions apps/desktop/src/features/invoke/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import type { ModelType } from "@models/_shared"

import type { ModelDownloaderCommandMap } from "~features/invoke/model-downloader"
import type { ModelIntegrityCommandMap } from "~features/invoke/model-integrity"
import type { ModelStatsCommandMap } from "~features/invoke/model-stats"
import type { ThreadCommandMap } from "~features/invoke/thread"
import type { DirectoryState } from "~features/model-downloader/model-file"

import { InvokeCommand, type InvokeIO } from "./_shared"

export { InvokeCommand }

type InvokeCommandMap = {
[commands in InvokeCommand]: InvokeIO
} & {
[InvokeCommand.OpenDirectory]: InvokeIO<{ path: string }>
[InvokeCommand.GetConfig]: InvokeIO<{ key: string }, string>

[InvokeCommand.InitializeModelsDir]: InvokeIO<never, DirectoryState>
[InvokeCommand.UpdateModelsDir]: InvokeIO<{ dir: string }, DirectoryState>
[InvokeCommand.DeleteModelFile]: InvokeIO<{ path: string }>

[InvokeCommand.GetModelType]: InvokeIO<{ path: string }, ModelType>
[InvokeCommand.SetModelType]: InvokeIO<{ path: string; modelType: ModelType }>

[InvokeCommand.StartServer]: InvokeIO<{ port: number }, string>
[InvokeCommand.StopServer]: InvokeIO<never, string>
} & ModelIntegrityCommandMap &
ModelStatsCommandMap &
ModelDownloaderCommandMap &
ThreadCommandMap

export async function invoke<T extends InvokeCommand>(
cmd: T,
...[args]: InvokeCommandMap[T]["input"] extends never
? []
: [InvokeCommandMap[T]["input"]]
) {
const { invoke: _invoke } = await import("@tauri-apps/api/tauri")
return _invoke<InvokeCommandMap[T]["output"]>(cmd, args)
}
33 changes: 33 additions & 0 deletions apps/desktop/src/features/invoke/model-downloader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import type { ModelInfo } from "@models/_shared"

import type { InvokeCommand, InvokeIO } from "~features/invoke/_shared"

export enum DownloadState {
None = "none",
Idle = "idle",
Downloading = "downloading",
Validating = "validating",
Completed = "completed",
Errored = "errored"
}

export type ProgressData = {
eventId: string
progress: number
size: number
downloadState: DownloadState

digest?: String
error?: string
}

export type ModelDownloaderCommandMap = {
[InvokeCommand.StartDownload]: InvokeIO<
Pick<ModelInfo, "name" | "downloadUrl" | "modelType"> & {
digest: string
}
>
[InvokeCommand.GetDownloadProgress]: InvokeIO<{ path: string }, ProgressData>
[InvokeCommand.PauseDownload]: InvokeIO<{ path: string }, ProgressData>
[InvokeCommand.ResumeDownload]: InvokeIO<{ path: string }, ProgressData>
}
14 changes: 14 additions & 0 deletions apps/desktop/src/features/invoke/model-integrity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import type { InvokeCommand, InvokeIO } from "~features/invoke/_shared"

export type ModelIntegrity = {
sha256: string
blake3: string
}

export type ModelIntegrityCommandMap = {
[InvokeCommand.GetCachedIntegrity]: InvokeIO<{ path: string }, ModelIntegrity>
[InvokeCommand.ComputeModelIntegrity]: InvokeIO<
{ path: string },
ModelIntegrity
>
}
9 changes: 9 additions & 0 deletions apps/desktop/src/features/invoke/model-stats.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import type { InvokeCommand, InvokeIO } from "~features/invoke/_shared"

export type ModelStats = {
loadCount: number
}

export type ModelStatsCommandMap = {
[InvokeCommand.GetModelStats]: InvokeIO<{ path: string }, number>
}
24 changes: 24 additions & 0 deletions apps/desktop/src/features/invoke/thread.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import type { InvokeCommand, InvokeIO } from "~features/invoke/_shared"
import type { DirectoryState } from "~features/model-downloader/model-file"

export type ThreadCommandMap = {
[InvokeCommand.AppendThreadContent]: InvokeIO<
{
path: string
content: string
},
string
>
[InvokeCommand.ReadThreadFile]: InvokeIO<{
path: string
eventId: string
}>
[InvokeCommand.InitializeThreadsDir]: InvokeIO<never, DirectoryState>
[InvokeCommand.UpdateThreadsDir]: InvokeIO<{ dir: string }, DirectoryState>
[InvokeCommand.DeleteThreadFile]: InvokeIO<{ path: string }>
[InvokeCommand.RenameThreadFile]: InvokeIO<
{ path: string; newName: string },
string
>
[InvokeCommand.CreateThreadFile]: InvokeIO<never, string>
}
4 changes: 2 additions & 2 deletions apps/desktop/src/features/layout/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import { Button } from "@localai/ui/button"
import { AppLayout } from "@localai/ui/layouts/app"
import { DotsHorizontalIcon, OpenInNewWindowIcon } from "@radix-ui/react-icons"
import { open as dialogOpen } from "@tauri-apps/api/dialog"
import { invoke } from "@tauri-apps/api/tauri"
import { Home } from "iconoir-react"
import type { ReactNode } from "react"

import { InvokeCommand, invoke } from "~features/invoke"
import { NavButton } from "~features/layout/nav-button"
import { NewThreadButton } from "~features/thread/new-thread"
import { ChatSideBar } from "~features/thread/side-bar"
Expand Down Expand Up @@ -38,7 +38,7 @@ const TopBar = () => {
title="Open threads directory"
className="rounded-l-none"
onClick={() => {
invoke("open_directory", {
invoke(InvokeCommand.OpenDirectory, {
path: threadsDirectory
})
}}>
Expand Down
8 changes: 4 additions & 4 deletions apps/desktop/src/features/model-downloader/model-selector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import {
SelectValue
} from "@localai/ui/select"
import { DownloadIcon } from "@radix-ui/react-icons"
import { invoke } from "@tauri-apps/api/tauri"
import { useMemo, useState } from "react"
import Balancer from "react-wrap-balancer"

import { getTruncatedHash } from "~features/inference-server/model-digest"
import { InvokeCommand, invoke } from "~features/invoke"
import { toGB } from "~features/model-downloader/model-file"
import { useModelsApi } from "~features/model-downloader/use-models-api"
import { useGlobal } from "~providers/global"
Expand Down Expand Up @@ -111,11 +111,11 @@ export const ModelSelector = () => {
onClick={async () => {
setIsDownloading(true)
try {
await invoke("start_download", {
await invoke(InvokeCommand.StartDownload, {
name: selectedModel.name,
downloadUrl: selectedModel.downloadUrl,
digest: selectedModel.blake3,
modelType: selectedModel.modelType
modelType: selectedModel.modelType,
digest: selectedModel.blake3
})
} catch (error) {
alert(error)
Expand Down
Loading