diff --git a/README.md b/README.md index 2c5934d..3537c59 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ - [ ] 拖拽图片自动上传并生成节点 - 编组 - [ ] 框选节点编组 - - [ ] 编组相关基础功能 + - [x] 编组相关基础功能 - [ ] 局部 Flow 转换为组件 - 节点 - [ ] 中继节点,支持一个到多个变量中继 diff --git a/src/components/NodeComponent/InputComponent.tsx b/src/components/NodeComponent/SdNode/InputComponent.tsx similarity index 100% rename from src/components/NodeComponent/InputComponent.tsx rename to src/components/NodeComponent/SdNode/InputComponent.tsx diff --git a/src/components/NodeComponent/NodeImgPreview.tsx b/src/components/NodeComponent/SdNode/NodeImgPreview.tsx similarity index 92% rename from src/components/NodeComponent/NodeImgPreview.tsx rename to src/components/NodeComponent/SdNode/NodeImgPreview.tsx index 57fa0a5..6a5eff7 100644 --- a/src/components/NodeComponent/NodeImgPreview.tsx +++ b/src/components/NodeComponent/SdNode/NodeImgPreview.tsx @@ -1,4 +1,4 @@ -import { ImagePreview } from '@/components/NodeComponent/index' +import { ImagePreview } from '@/components/NodeComponent' import { emptyImg } from '@/components/theme' import { getBackendUrl } from '@/config' import { Image } from 'antd' diff --git a/src/components/NodeComponent/NodeInputs.tsx b/src/components/NodeComponent/SdNode/NodeInputs.tsx similarity index 94% rename from src/components/NodeComponent/NodeInputs.tsx rename to src/components/NodeComponent/SdNode/NodeInputs.tsx index e1f1258..360799b 100644 --- a/src/components/NodeComponent/NodeInputs.tsx +++ b/src/components/NodeComponent/SdNode/NodeInputs.tsx @@ -1,7 +1,7 @@ import React from 'react' import { Position } from 'reactflow' +import { SpaceCol } from '../style' import NodeSlot from './NodeSlot' -import { SpaceCol } from './style' interface NodeInputsProps { data: { diff --git a/src/components/NodeComponent/NodeOutpus.tsx b/src/components/NodeComponent/SdNode/NodeOutpus.tsx similarity index 93% rename from src/components/NodeComponent/NodeOutpus.tsx rename to src/components/NodeComponent/SdNode/NodeOutpus.tsx index 0c84ecf..af8b5fb 100644 --- a/src/components/NodeComponent/NodeOutpus.tsx +++ b/src/components/NodeComponent/SdNode/NodeOutpus.tsx @@ -1,7 +1,7 @@ import React from 'react' import { Position } from 'reactflow' +import { SpaceCol } from '../style' import NodeSlot from './NodeSlot' -import { SpaceCol } from './style' interface NodeOutpusProps { data: string[] diff --git a/src/components/NodeComponent/NodeParams.tsx b/src/components/NodeComponent/SdNode/NodeParams.tsx similarity index 92% rename from src/components/NodeComponent/NodeParams.tsx rename to src/components/NodeComponent/SdNode/NodeParams.tsx index 1d4bd90..cc07621 100644 --- a/src/components/NodeComponent/NodeParams.tsx +++ b/src/components/NodeComponent/SdNode/NodeParams.tsx @@ -1,5 +1,5 @@ import { InputComponent } from '@/components' -import NodeSlot from '@/components/NodeComponent/NodeSlot' +import NodeSlot from '@/components/NodeComponent/SdNode/NodeSlot' import { Flow } from '@/types' import React from 'react' import { Position } from 'reactflow' diff --git a/src/components/NodeComponent/NodeSlot.tsx b/src/components/NodeComponent/SdNode/NodeSlot.tsx similarity index 97% rename from src/components/NodeComponent/NodeSlot.tsx rename to src/components/NodeComponent/SdNode/NodeSlot.tsx index 3979113..78459d6 100644 --- a/src/components/NodeComponent/NodeSlot.tsx +++ b/src/components/NodeComponent/SdNode/NodeSlot.tsx @@ -3,7 +3,7 @@ import { isArray, startCase } from 'lodash-es' import React from 'react' import { Handle, HandleType, Position } from 'reactflow' import { shallow } from 'zustand/shallow' -import { Slot } from './style' +import { Slot } from '../style' interface NodeSlotProps { label: string diff --git a/src/components/NodeComponent/PreviewNode.tsx b/src/components/NodeComponent/SdNode/PreviewNode.tsx similarity index 96% rename from src/components/NodeComponent/PreviewNode.tsx rename to src/components/NodeComponent/SdNode/PreviewNode.tsx index e487028..e1ef0fc 100644 --- a/src/components/NodeComponent/PreviewNode.tsx +++ b/src/components/NodeComponent/SdNode/PreviewNode.tsx @@ -3,7 +3,7 @@ import { checkInput } from '@/utils' import { startCase } from 'lodash-es' import React from 'react' import styled from 'styled-components' -import { NodeCard, SpaceCol, SpaceGrid } from './style' +import { NodeCard, SpaceCol, SpaceGrid } from '../style' const Slot = styled.div<{ isRequired: 1 | 0 }>` background: ${({ isRequired, theme }) => (isRequired ? theme.colorPrimary : theme.colorBorder)}; diff --git a/src/components/NodeComponent/SelectUploadInput.tsx b/src/components/NodeComponent/SdNode/SelectUploadInput.tsx similarity index 100% rename from src/components/NodeComponent/SelectUploadInput.tsx rename to src/components/NodeComponent/SdNode/SelectUploadInput.tsx diff --git a/src/components/NodeComponent/SliderInput.tsx b/src/components/NodeComponent/SdNode/SliderInput.tsx similarity index 100% rename from src/components/NodeComponent/SliderInput.tsx rename to src/components/NodeComponent/SdNode/SliderInput.tsx diff --git a/src/components/NodeComponent/SdNode/index.tsx b/src/components/NodeComponent/SdNode/index.tsx new file mode 100644 index 0000000..e1a827b --- /dev/null +++ b/src/components/NodeComponent/SdNode/index.tsx @@ -0,0 +1,62 @@ +import { SpaceGrid } from '@/components/NodeComponent/style' +import { useAppStore } from '@/store' +import { Widget } from '@/types' +import { checkInput } from '@/utils' +import React from 'react' +import { NodeProps } from 'reactflow' +import { shallow } from 'zustand/shallow' +import NodeImgPreview from './NodeImgPreview' +import NodeInputs from './NodeInputs' +import NodeOutpus from './NodeOutpus' +import NodeParams from './NodeParams' + +const SdNode: React.FC> = (node) => { + const { imagePreviews, inputImgPreviews } = useAppStore( + (st) => ({ + imagePreviews: st.graph?.[node.id]?.images + ?.map((image, index) => { + return { + image, + index, + } + }) + .filter(Boolean), + inputImgPreviews: [ + { + image: { + filename: st.onGetNodeFieldsData(node.id, 'image'), + type: 'input', + }, + index: 0, + }, + ].filter((i) => i.image.filename), + onPropChange: st.onPropChange, + }), + shallow + ) + + const params: any[] = [] + const inputs: any[] = [] + const outputs: any[] = node.data.output + + for (const [property, input] of Object.entries(node.data.input.required)) { + if (checkInput.isParameterOrList(input)) { + params.push({ name: property, type: input[0], input }) + } else { + inputs.push({ name: property, type: input[0] }) + } + } + + return ( + <> + + + + + + + + ) +} + +export default React.memo(SdNode) diff --git a/src/components/NodeComponent/index.tsx b/src/components/NodeComponent/index.tsx index d8ead41..b384b2f 100644 --- a/src/components/NodeComponent/index.tsx +++ b/src/components/NodeComponent/index.tsx @@ -2,19 +2,15 @@ import { ActionIcon, Input } from '@/components' import { ColorMenu, colorList } from '@/components/NodeComponent/ColorMenu' import { useAppStore } from '@/store' import { ImageItem, type Widget } from '@/types' -import { checkInput } from '@/utils' import { CopyOutlined, DeleteOutlined, EditOutlined, HighlightOutlined, MoreOutlined } from '@ant-design/icons' import { Dropdown, Progress, type MenuProps } from 'antd' import { useTheme } from 'antd-style' -import { mix } from 'polished' -import React, { useState } from 'react' +import { mix, rgba } from 'polished' +import React, { useEffect, useRef, useState } from 'react' import { NodeResizeControl, type NodeProps } from 'reactflow' import { shallow } from 'zustand/shallow' -import NodeImgPreview from './NodeImgPreview' -import NodeInputs from './NodeInputs' -import NodeOutpus from './NodeOutpus' -import NodeParams from './NodeParams' -import { NodeCard, SpaceGrid } from './style' +import SdNode from './SdNode' +import { GroupCard, NodeCard } from './style' export const NODE_IDENTIFIER = 'sdNode' @@ -24,26 +20,10 @@ export interface ImagePreview { } const NodeComponent: React.FC> = (node) => { - const { progressBar, imagePreviews, onDuplicateNode, onDeleteNode, onModifyChange, inputImgPreviews } = useAppStore( + const ref: any = useRef(null) + const { progressBar, onDuplicateNode, onDeleteNode, onModifyChange } = useAppStore( (st) => ({ progressBar: st.nodeInProgress?.id === node.id ? st.nodeInProgress.progress : undefined, - imagePreviews: st.graph?.[node.id]?.images - ?.map((image, index) => { - return { - image, - index, - } - }) - .filter(Boolean), - inputImgPreviews: [ - { - image: { - filename: st.onGetNodeFieldsData(node.id, 'image'), - type: 'input', - }, - index: 0, - }, - ].filter((i) => i.image.filename), onPropChange: st.onPropChange, onDuplicateNode: st.onDuplicateNode, onDeleteNode: st.onDeleteNode, @@ -54,21 +34,10 @@ const NodeComponent: React.FC> = (node) => { const theme = useTheme() const [nicknameInput, setNicknameInput] = useState(false) - const params: any[] = [] - const inputs: any[] = [] - const outputs: any[] = node.data.output const isInProgress = progressBar !== undefined const isSelected = node.selected const name = node.data?.nickname || node.data.name - - for (const [property, input] of Object.entries(node.data.input.required)) { - if (checkInput.isParameterOrList(input)) { - params.push({ name: property, type: input[0], input }) - } else { - inputs.push({ name: property, type: input[0] }) - } - } - + const isGroup = node.data.name === 'Group' const handleNickname = (e: any) => { const nickname = e.target.value onModifyChange(node.id, 'nickname', nickname) @@ -110,10 +79,26 @@ const NodeComponent: React.FC> = (node) => { }, ] + const StyledCard = isGroup ? GroupCard : NodeCard + let background + if (isGroup) { + background = node.data?.color ? rgba(node.data.color, 0.2) : theme.colorFill + } else { + background = node.data?.color ? mix(0.8, theme.colorBgContainer, node.data.color) : theme.colorBgContainer + } + + useEffect(() => { + if (isGroup) { + const parenet = ref.current?.parentNode + parenet.setAttribute('type', 'group') + } + }, []) + return ( - > = (node) => { ) } active={isInProgress || isSelected ? 1 : 0} - hoverable + hoverable={!isGroup} extra={ isInProgress ? progressBar > 0 && @@ -140,14 +125,9 @@ const NodeComponent: React.FC> = (node) => { ) } > - - - - - - - {isSelected && } - + + {isSelected && } + ) } diff --git a/src/components/NodeComponent/style.ts b/src/components/NodeComponent/style.ts index 50c6a4f..f7440e7 100644 --- a/src/components/NodeComponent/style.ts +++ b/src/components/NodeComponent/style.ts @@ -3,17 +3,42 @@ import { Position } from 'reactflow' import styled, { css } from 'styled-components' export const NodeCard = styled(Card)<{ active: 1 | 0 }>` - min-width: 240px; box-shadow: ${({ theme }) => theme.boxShadowTertiary}; + min-width: 80px; + min-height: 120px; ${({ active, theme }) => active ? css` outline: 2px solid ${theme.colorPrimary}; ` - : ''} + : ''}; .ant-card-head { background: ${({ theme }) => theme.colorFillQuaternary} !important; padding-right: 3px; + border-bottom: unset; + height: 15px; + } +` + +export const GroupCard = styled(Card)<{ active: 1 | 0 }>` + min-width: 80px; + min-height: 120px; + height: 100%; + ${({ active, theme }) => + active + ? css` + outline: 2px solid ${theme.colorPrimary}; + ` + : ''}; + .react-flow__resize-control { + pointer-events: all !important; + } + .ant-card-head { + background: ${({ theme }) => theme.colorFillQuaternary} !important; + padding-right: 3px; + border-bottom: unset; + height: 15px; + pointer-events: all !important; } ` diff --git a/src/components/index.ts b/src/components/index.ts index 2d91c9b..53519b2 100644 --- a/src/components/index.ts +++ b/src/components/index.ts @@ -6,5 +6,5 @@ export { default as WorkflowPageComponent } from './ControlPanelComponent/Workfl export * from './EditorComponent' export { default as Header } from './Header' export { NODE_IDENTIFIER, default as NodeComponent } from './NodeComponent' -export { default as InputComponent } from './NodeComponent/InputComponent' -export { default as PreviewNode } from './NodeComponent/PreviewNode' +export { default as InputComponent } from './NodeComponent/SdNode/InputComponent' +export { default as PreviewNode } from './NodeComponent/SdNode/PreviewNode' diff --git a/src/layouts/GlobalStyle.ts b/src/layouts/GlobalStyle.ts index f44399a..88c7f87 100644 --- a/src/layouts/GlobalStyle.ts +++ b/src/layouts/GlobalStyle.ts @@ -15,6 +15,11 @@ const GlobalStyle = createGlobalStyle` font-family: 'Hack', 'IBM Plex Mono', 'ui-monospace', 'Consolas', monospace !important; } + .react-flow__node[type="group"] { + z-index: -1 !important; + pointer-events: none !important; + } + .react-json-view { background: transparent !important; diff --git a/src/pages/FlowEditor.tsx b/src/pages/FlowEditor.tsx index 5c2986d..4928d09 100644 --- a/src/pages/FlowEditor.tsx +++ b/src/pages/FlowEditor.tsx @@ -3,6 +3,7 @@ import { useAppStore } from '@/store' import { getPostion, getPostionCenter } from '@/utils' import { Connection } from '@reactflow/core/dist/esm/types' import { Edge } from '@reactflow/core/dist/esm/types/edges' +import { NodeDragHandler } from '@reactflow/core/dist/esm/types/nodes' import { useTheme } from 'antd-style' import React, { useCallback, useEffect, useRef, useState } from 'react' import ReactFlow, { Background, BackgroundVariant, Controls, MiniMap } from 'reactflow' @@ -17,21 +18,32 @@ const FlowEditor: React.FC = () => { const reactFlowRef: any = useRef(null) const edgeUpdateSuccessful = useRef(true) const [reactFlowInstance, setReactFlowInstance] = useState(null) - const { nodes, edges, onNodesChange, onEdgesChange, onConnect, onInit, onAddNode, onCopyNode, onPasteNode } = - useAppStore( - (st) => ({ - nodes: st.nodes, - edges: st.edges, - onNodesChange: st.onNodesChange, - onEdgesChange: st.onEdgesChange, - onConnect: st.onConnect, - onInit: st.onInit, - onAddNode: st.onAddNode, - onCopyNode: st.onCopyNode, - onPasteNode: st.onPasteNode, - }), - shallow - ) + const { + nodes, + edges, + onNodesChange, + onEdgesChange, + onConnect, + onInit, + onAddNode, + onCopyNode, + onPasteNode, + onSetNodesGroup, + } = useAppStore( + (st) => ({ + nodes: st.nodes, + edges: st.edges, + onInit: st.onInit, + onNodesChange: st.onNodesChange, + onEdgesChange: st.onEdgesChange, + onConnect: st.onConnect, + onAddNode: st.onAddNode, + onCopyNode: st.onCopyNode, + onPasteNode: st.onPasteNode, + onSetNodesGroup: st.onSetNodesGroup, + }), + shallow + ) const onEdgeUpdateStart = useCallback(() => { edgeUpdateSuccessful.current = false @@ -79,6 +91,18 @@ const FlowEditor: React.FC = () => { [reactFlowInstance] ) + const onNodeDrag: NodeDragHandler = useCallback( + (_, node, nodes) => { + if (nodes.length > 2 || node.data.name !== 'Group') return + const intersections = reactFlowInstance + .getIntersectingNodes(node) + .filter((n: any) => n.data.name !== 'Group' && (n.parentNode === node.id || !n.parentNode)) + .map((n: any) => n.id) + onSetNodesGroup(intersections, node) + }, + [reactFlowInstance] + ) + const handleCopy = useCallback(() => { const copyData = onCopyNode() navigator.clipboard.writeText(JSON.stringify(copyData)) @@ -123,14 +147,13 @@ const FlowEditor: React.FC = () => { edges={edges} fitView snapToGrid - snapGrid={[24, 24]} + snapGrid={[20, 20]} minZoom={0.05} nodeTypes={nodeTypes} deleteKeyCode={['Delete', 'Backspace']} multiSelectionKeyCode={['Shift']} panOnScroll={!isWindows} zoomOnScroll={isWindows} - onlyRenderVisibleElements disableKeyboardA11y={true} onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} @@ -138,6 +161,7 @@ const FlowEditor: React.FC = () => { onEdgeUpdateStart={onEdgeUpdateStart} onEdgeUpdateEnd={onEdgeUpdateEnd} onConnect={onConnect} + onNodeDrag={onNodeDrag} onDrop={onDrop} onDragOver={onDragOver} onInit={(e: any) => { diff --git a/src/store/AppState.ts b/src/store/AppState.ts index 37ceaa1..4c598fc 100644 --- a/src/store/AppState.ts +++ b/src/store/AppState.ts @@ -47,6 +47,7 @@ export interface AppState { /****************************************************** *********************** Node ************************* ******************************************************/ + onSetNodesGroup: (childIds: NodeId[], groupNode: Node) => void onNodesChange: OnNodesChange onUpdateNodes: (id: string, data: any) => void onAddNode: (nodeItem: NodeItem) => void diff --git a/src/store/index.ts b/src/store/index.ts index 4481f65..23b9f7c 100644 --- a/src/store/index.ts +++ b/src/store/index.ts @@ -1,4 +1,5 @@ import { createPrompt, deleteFromQueue, getWidgetLibrary as getWidgets, sendPrompt } from '@/client' +import { Widget, WidgetKey } from '@/types' import { PersistedGraph, addConnection, @@ -21,6 +22,16 @@ import { devtools } from 'zustand/middleware' import { AppState } from './AppState' export * from './AppState' +const defaultWidgets: Record = { + Group: { + input: { required: {} }, + output: [], + output_name: [], + name: 'Group', + category: 'Utils', + }, +} + export const useAppStore = create()( devtools((set, get) => ({ /****************************************************** @@ -58,19 +69,36 @@ export const useAppStore = create()( onRefresh: async () => { const widgets = await getWidgets() - set({ widgets }, false, 'onRefresh') + set({ widgets: { ...widgets, ...defaultWidgets } }, false, 'onRefresh') }, onInit: async () => { setInterval(() => get().onPersistTemp(), 5000) const widgets = await getWidgets() - set({ widgets }, false, 'onInit') + set({ widgets: { ...widgets, ...defaultWidgets } }, false, 'onInit') get().onLoadWorkflow(retrieveTempWorkflow() ?? { data: {}, connections: [] }) }, /****************************************************** *********************** Node ************************* ******************************************************/ + onSetNodesGroup: (childIds, groupNode) => { + set((st) => ({ + nodes: st.nodes.map((n) => { + if (childIds.includes(n.id)) { + if (n.parentNode === groupNode.id) return n + n.parentNode = groupNode.id + n.position.x = n.position.x - groupNode.position.x + n.position.y = n.position.y - groupNode.position.y + } else if (n.parentNode === groupNode.id) { + n.parentNode = undefined + n.position.x = n.position.x + groupNode.position.x + n.position.y = n.position.y + groupNode.position.y + } + return n + }), + })) + }, onNodesChange: (changes) => { set((st) => ({ nodes: applyNodeChanges(changes, st.nodes) }), false, 'onNodesChange') @@ -185,6 +213,7 @@ export const useAppStore = create()( }, onPasteNode: (workflow, position) => { + const nodes = get().nodes const basePositon = getTopLeftPoint(Object.values(workflow.data).map((item) => item.position)) const idMap: { [id: string]: string } = {} // 存储原始节点 id 和新节点 id 的映射关系 const newWorkflow: PersistedGraph = { @@ -202,10 +231,27 @@ export const useAppStore = create()( }, key: uuid(), // 使用 uuid 生成新的唯一标识符 } + if (node.parentNode) { + if (!Object.keys(workflow.data).includes(node.parentNode)) { + newNode.parentNode = undefined + const groupNode = nodes.find((n) => n.id === node.parentNode) + if (groupNode) { + newNode.position.x = newNode.position.x + groupNode.position.x + newNode.position.y = newNode.position.y + groupNode.position.y + } + } else { + newNode.position = node.position + } + } newWorkflow.data[newNode.key] = newNode idMap[id] = newNode.key // 记录原始节点 id 和新节点 id 的映射关系 }) + Object.keys(newWorkflow.data).forEach((key) => { + const parentNodeId = newWorkflow.data[key]?.parentNode + if (parentNodeId) newWorkflow.data[key].parentNode = idMap[parentNodeId] + }) + // 更新 connection 中的 source 和 target workflow.connections.forEach((conn) => { const newConn = { @@ -231,6 +277,7 @@ export const useAppStore = create()( position: node.position, width: node.width, height: node.height, + parentNode: node.parentNode, key, }, true @@ -368,6 +415,7 @@ export const useAppStore = create()( width: node.width, height: node.height, key: key, + parentNode: node.parentNode, }) } else { console.warn(`Unknown widget ${node.value.widget}`) diff --git a/src/types/node.ts b/src/types/node.ts index 7a101d8..303829b 100644 --- a/src/types/node.ts +++ b/src/types/node.ts @@ -22,6 +22,7 @@ export interface NodeItem { key?: string width?: number height?: number + parentNode?: string } // 3. SDNode 相关类型定义 diff --git a/src/utils/node.ts b/src/utils/node.ts index dcc280b..a27713a 100644 --- a/src/utils/node.ts +++ b/src/utils/node.ts @@ -7,7 +7,7 @@ import { fromWidget } from './widget' // 用于添加、更新和获取节点的函数 export function addNode( state: AppState, - { widget, node, position, width, height, key }: NodeItem, + { widget, node, position, width, height, key, parentNode }: NodeItem, isCopy?: boolean ): AppState { const nextKey = key ? key : uuid() @@ -26,6 +26,7 @@ export function addNode( zIndex: maxZ + 1, width, height, + parentNode, style: { width, height, diff --git a/src/utils/persistence.ts b/src/utils/persistence.ts index 83821e3..93576bd 100644 --- a/src/utils/persistence.ts +++ b/src/utils/persistence.ts @@ -10,6 +10,7 @@ export interface PersistedNode { position: Position width?: number height?: number + parentNode?: string } export interface PersistedGraph { @@ -35,6 +36,7 @@ export function toPersisted(state: AppState): PersistedGraph { data[node.id] = { value, position: node.position } if (node.width) data[node.id].width = node.width if (node.height) data[node.id].height = node.height + if (node.parentNode) data[node.id].parentNode = node.parentNode } } return {