Skip to content

Commit

Permalink
🚑 feat: 优先使用本地模型文件,提高使用体验
Browse files Browse the repository at this point in the history
  • Loading branch information
rdmclin2 committed Jun 3, 2024
1 parent 83b35ee commit a564806
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 50 deletions.
10 changes: 3 additions & 7 deletions src/app/chat/ViewerMode/index.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
'use client';

import classNames from 'classnames';
import { isEqual } from 'lodash-es';
import React, { memo } from 'react';
import { Flexbox } from 'react-layout-kit';

Expand All @@ -11,7 +10,7 @@ import Alert from '@/features/Alert';
import ChatDialog from '@/features/ChatDialog';
import MessageInput from '@/features/ChatInput/MessageInput';
import { useGlobalStore } from '@/store/global';
import { sessionSelectors, useSessionStore } from '@/store/session';
import { useSessionStore } from '@/store/session';

import { useStyles } from './style';

Expand All @@ -21,15 +20,12 @@ export default memo(() => {
s.setChatDialog,
]);
const { styles } = useStyles();
const [currentAgent] = useSessionStore((s) => [sessionSelectors.currentAgent(s), isEqual]);
const activeId = useSessionStore((s) => s.activeId);

return (
<Flexbox flex={1} style={{ position: 'relative' }}>
<div className={styles.viewer}>
<AgentViewer
height={`calc(100vh - ${HEADER_HEIGHT}px)`}
modelUrl={currentAgent?.meta.model}
/>
<AgentViewer height={`calc(100vh - ${HEADER_HEIGHT}px)`} agentId={activeId} />
</div>
{showChatDialog ? (
<ChatDialog className={classNames(styles.dialog, styles.content)} setOpen={setChatDialog} />
Expand Down
8 changes: 4 additions & 4 deletions src/features/AgentViewer/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ import ToolBar from './ToolBar';
import { useStyles } from './style';

interface Props {
agentId: string;
className?: string;
height?: number | string;
modelUrl?: string;
style?: React.CSSProperties;
width?: number | string;
}

function AgentViewer(props: Props) {
const { className, style, height, modelUrl, width } = props;
const { className, style, height, agentId, width } = props;
const { styles } = useStyles();
const ref = useRef<HTMLDivElement>(null);
const viewer = useGlobalStore((s) => s.viewer);

const { loading, loadVrm } = useLoadVrm(viewer);

useEffect(() => {
loadVrm(modelUrl);
}, [modelUrl]);
loadVrm(agentId);
}, [agentId]);

const canvasRef = useCallback(
(canvas: HTMLCanvasElement) => {
Expand Down
10 changes: 6 additions & 4 deletions src/features/MarketInfo/SubscribeButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Flexbox } from 'react-layout-kit';
import { agentSelectors, useAgentStore } from '@/store/agent';
import { Agent } from '@/types/agent';
import { fetchWithProgress } from '@/utils/fetch';
import { generateRemoteModelKey } from '@/utils/model';
import { getModelPathByAgentId } from '@/utils/model';
import { setItem } from '@/utils/storage';

interface SubscribeButtonProps {
Expand All @@ -31,8 +31,9 @@ const SubscribeButton = (props: SubscribeButtonProps) => {
disabled={downloading}
onClick={async () => {
if (isSubscribed) {
removeLocalAgent(agent.agentId);
message.success('已取消订阅');
removeLocalAgent(agent.agentId).then(() => {
message.success('已取消订阅');
});
} else {
if (agent.meta.model) {
setDownloading(true);
Expand All @@ -43,7 +44,8 @@ const SubscribeButton = (props: SubscribeButtonProps) => {
setPercent((loaded / total) * 100);
},
});
await setItem(generateRemoteModelKey(agent.agentId), blob);
const modelKey = getModelPathByAgentId(agent.agentId);
await setItem(modelKey, blob);
} catch (e) {
console.error(e);
message.error('下载失败');
Expand Down
38 changes: 21 additions & 17 deletions src/hooks/useLoadVrm.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
import { useState } from 'react';

import { Viewer } from '@/features/vrmViewer/viewer';
import { isModelKey } from '@/utils/model';
import { agentSelectors, useAgentStore } from '@/store/agent';
import { getModelPathByAgentId } from '@/utils/model';
import storage from '@/utils/storage';

export const useLoadVrm = (viewer: Viewer) => {
const [loading, setLoading] = useState(false);
const getAgentModelById = useAgentStore((s) => agentSelectors.getAgentModelById(s));

const loadVrm = async (url?: string) => {
let vrmUrl = url;
if (url && isModelKey(url)) {
const blob = await storage.getItem(url);
if (blob) {
vrmUrl = window.URL.createObjectURL(blob as Blob);
} else {
vrmUrl = undefined;
}
}
if (vrmUrl) {
setLoading(true);
viewer.loadVrm(vrmUrl).finally(() => {
setLoading(false);
});
} else {
const loadVrm = async (agentId: string) => {
// 获取模型路径
let vrmUrl = getAgentModelById(agentId);
// 如果没有模型路径,卸载模型
if (!vrmUrl) {
viewer.unloadVRM();
return;
}

// 根据 AgentId 获取本地模型数据
const blob = await storage.getItem(getModelPathByAgentId(agentId));

if (blob) {
vrmUrl = window.URL.createObjectURL(blob as Blob);
}

setLoading(true);
viewer.loadVrm(vrmUrl).finally(() => {
setLoading(false);
});
};

return { loading, loadVrm };
Expand Down
18 changes: 12 additions & 6 deletions src/panels/RolePanel/RoleEdit/Model/ViewerWithUpload/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,31 @@ import EmptyGuide from '@/components/EmptyGuide';
import { ROLE_VIEWER_HEIGHT, ROLE_VIEWER_WIDTH } from '@/constants/common';
import AgentViewer from '@/features/AgentViewer';
import { agentSelectors, useAgentStore } from '@/store/agent';
import { generateLocalModelKey } from '@/utils/model';
import { useGlobalStore } from '@/store/global';
import { getModelPathByAgentId } from '@/utils/model';
import storage from '@/utils/storage';

interface ViewerWithUploadProps {
style?: CSSProperties;
}

const ViewerWithUpload = memo<ViewerWithUploadProps>(({ style }) => {
const [currentAgentModel, updateAgentConfig] = useAgentStore((s) => [
const viewer = useGlobalStore((s) => s.viewer);

const [currentAgentId, currentAgentModel, updateAgentConfig] = useAgentStore((s) => [
agentSelectors.currentAgentId(s),
agentSelectors.currentAgentModel(s),
s.updateAgentConfig,
]);

const handleUploadAvatar = (file: any) => {
const { name, size } = file;
const blob = new Blob([file], { type: 'application/octet-stream' });
const modelKey = generateLocalModelKey(name, size);
const modelKey = getModelPathByAgentId(currentAgentId!);

storage.setItem(modelKey, blob).then(() => {
updateAgentConfig({ meta: { model: modelKey } });
const vrmUrl = window.URL.createObjectURL(blob as Blob);
viewer.loadVrm(vrmUrl);
});
};

Expand All @@ -36,10 +42,10 @@ const ViewerWithUpload = memo<ViewerWithUploadProps>(({ style }) => {
style={style}
openFileDialogOnClick={!currentAgentModel}
>
{currentAgentModel ? (
{currentAgentModel && currentAgentId ? (
<AgentViewer
height={ROLE_VIEWER_HEIGHT}
modelUrl={currentAgentModel}
agentId={currentAgentId}
width={ROLE_VIEWER_WIDTH}
/>
) : (
Expand Down
7 changes: 5 additions & 2 deletions src/store/agent/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import { Agent, AgentMeta, GenderEnum } from '@/types/agent';
import { TouchAction, TouchAreaEnum } from '@/types/touch';
import { TTS } from '@/types/tts';
import { mergeWithUndefined } from '@/utils/common';
import { getModelPathByAgentId } from '@/utils/model';
import storage from '@/utils/storage';

import { initialState } from './initialState';
import { agentSelectors } from './selectors/agent';
Expand Down Expand Up @@ -77,7 +79,7 @@ export interface AgentStore {
* 移除本地角色
* @param agentId
*/
removeLocalAgent: (agentId: string) => void;
removeLocalAgent: (agentId: string) => Promise<void>;
/**
* 删除触摸配置
*/
Expand Down Expand Up @@ -272,7 +274,7 @@ const createAgentStore: StateCreator<AgentStore, [['zustand/devtools', never]]>
});
set({ localAgentList: newList });
},
removeLocalAgent: (agentId) => {
removeLocalAgent: async (agentId) => {
const { localAgentList } = get();
const newList = produce(localAgentList, (draft) => {
const index = draft.findIndex((item) => item.agentId === agentId);
Expand All @@ -281,6 +283,7 @@ const createAgentStore: StateCreator<AgentStore, [['zustand/devtools', never]]>
draft.splice(index, 1);
}
});
await storage.removeItem(getModelPathByAgentId(agentId));
set({ currentIdentifier: LOBE_VIDOL_DEFAULT_AGENT_ID, localAgentList: newList });
},
});
Expand Down
8 changes: 8 additions & 0 deletions src/store/agent/selectors/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ const currentAgentId = (s: AgentStore): string | undefined => {
return currentAgent.agentId;
};

const getAgentModelById = (s: AgentStore) => {
return (id: string): string | undefined => {
const agent = s.getAgentById(id);
return agent?.meta.model;
};
};

const isDefaultAgent = (s: AgentStore) => {
return (id: string): boolean => {
const agent = s.getAgentById(id);
Expand All @@ -88,6 +95,7 @@ export const agentSelectors = {
currentAgentTTS,
currentAgentTouch,
filterAgentListIds,
getAgentModelById,
agentListIds,
isDefaultAgent,
showSideBar,
Expand Down
3 changes: 0 additions & 3 deletions src/store/session/selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ const getAgentById = (s: SessionStore) => {
const currentAgent = (s: SessionStore): Agent | undefined => {
const { activeId } = s;
const agentStore = useAgentStore.getState();
if (activeId === LOBE_VIDOL_DEFAULT_AGENT_ID) {
return agentStore.defaultAgent;
}
return agentStore.getAgentById(activeId);
};

Expand Down
10 changes: 3 additions & 7 deletions src/utils/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@ import storage from '@/utils/storage';

const MODEL_SCHEMA = 'model';

export const generateLocalModelKey = (name: string, size: number) => {
return `${MODEL_SCHEMA}://${name}-${size}`;
};

export const generateRemoteModelKey = (id: string) => {
export const getModelPathByAgentId = (id: string) => {
return `${MODEL_SCHEMA}://${id}`;
};

export const isModelKey = (key: string) => {
export const isModelPath = (key: string) => {
return key.startsWith(MODEL_SCHEMA);
};

export const checkLocalModel = async (agentId: string) => {
const key = generateRemoteModelKey(agentId);
const key = getModelPathByAgentId(agentId);
const model = await storage.getItem(key);
return !!model;
};

0 comments on commit a564806

Please sign in to comment.