From 2f56f2e422f1e5a1b73503c2a77265abae00cb64 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Wed, 20 Mar 2024 20:12:53 +0100 Subject: [PATCH] Automatically remove unused nodes with side effects --- src/common/nodes/optimize.ts | 84 ++++++++++++++++++++++ src/main/cli/run.ts | 15 ++-- src/renderer/contexts/ExecutionContext.tsx | 34 ++++----- 3 files changed, 105 insertions(+), 28 deletions(-) create mode 100644 src/common/nodes/optimize.ts diff --git a/src/common/nodes/optimize.ts b/src/common/nodes/optimize.ts new file mode 100644 index 000000000..9eb4d5647 --- /dev/null +++ b/src/common/nodes/optimize.ts @@ -0,0 +1,84 @@ +import { Edge, Node } from 'reactflow'; +import { EdgeData, NodeData } from '../common-types'; +import { SchemaMap } from '../SchemaMap'; +import { getEffectivelyDisabledNodes } from './disabled'; +import { getNodesWithSideEffects } from './sideEffect'; + +const trimEdges = ( + nodes: Iterable>, + edges: readonly Edge[] +): Edge[] => { + const nodeIds = new Set(); + for (const n of nodes) { + nodeIds.add(n.id); + } + + return edges.filter((e) => nodeIds.has(e.source) && nodeIds.has(e.target)); +}; + +const removeUnusedSideEffectNodes = ( + nodes: readonly Node[], + edges: readonly Edge[], + schemata: SchemaMap +): Node[] => { + // eslint-disable-next-line no-param-reassign + edges = trimEdges(nodes, edges); + + const connectedNodes = new Set([...edges.map((e) => e.source), ...edges.map((e) => e.target)]); + + return nodes.filter((n) => { + if (connectedNodes.has(n.id)) { + // the node isn't unused + return true; + } + + const schema = schemata.get(n.data.schemaId); + if (!schema.hasSideEffects) { + // we only care about nodes with side effects + return true; + } + + // if all inputs don't require connections, that's fine too + const requireConnection = schema.inputs.some((i) => i.kind === 'generic' && !i.optional); + if (!requireConnection) { + return true; + } + + // the is unused, has side effects, and requires connections + return false; + }); +}; + +interface OptimizedChain { + nodes: Node[]; + edges: Edge[]; + report: { + /** How many effectively disabled nodes were removed. */ + removedDisabled: number; + /** How many side-effect-free nodes were removed. */ + removedSideEffectFree: number; + }; +} + +export const optimizeChain = ( + unoptimizedNodes: readonly Node[], + unoptimizedEdges: readonly Edge[], + schemata: SchemaMap +): OptimizedChain => { + // remove disabled nodes + const disabledNodes = new Set(getEffectivelyDisabledNodes(unoptimizedNodes, unoptimizedEdges)); + const enabledNodes = unoptimizedNodes.filter((n) => !disabledNodes.has(n)); + + // remove nodes without side effects + let withEffect = getNodesWithSideEffects(enabledNodes, unoptimizedEdges, schemata); + withEffect = removeUnusedSideEffectNodes(withEffect, unoptimizedEdges, schemata); + + return { + nodes: withEffect, + edges: trimEdges(withEffect, unoptimizedEdges), + report: { + removedDisabled: disabledNodes.size, + removedSideEffectFree: enabledNodes.length - withEffect.length, + }, + }; +}; diff --git a/src/main/cli/run.ts b/src/main/cli/run.ts index 5c16928c1..136d2a987 100644 --- a/src/main/cli/run.ts +++ b/src/main/cli/run.ts @@ -7,10 +7,9 @@ import { applyOverrides, readOverrideFile } from '../../common/input-override'; import { log } from '../../common/log'; import { checkNodeValidity } from '../../common/nodes/checkNodeValidity'; import { getConnectedInputs } from '../../common/nodes/connectedInputs'; -import { getEffectivelyDisabledNodes } from '../../common/nodes/disabled'; import { ChainLineage } from '../../common/nodes/lineage'; +import { optimizeChain } from '../../common/nodes/optimize'; import { parseFunctionDefinitions } from '../../common/nodes/parseFunctionDefinitions'; -import { getNodesWithSideEffects } from '../../common/nodes/sideEffect'; import { toBackendJson } from '../../common/nodes/toBackendJson'; import { TypeState } from '../../common/nodes/TypeState'; import { SaveFile } from '../../common/SaveFile'; @@ -214,20 +213,14 @@ export const runChainInCli = async (args: RunArguments) => { applyOverrides(saveFile.nodes, saveFile.edges, schemata, overrideFile); } - const disabledNodes = new Set( - getEffectivelyDisabledNodes(saveFile.nodes, saveFile.edges).map((n) => n.id) - ); - const nodesToOptimize = saveFile.nodes.filter((n) => !disabledNodes.has(n.id)); - const nodes = getNodesWithSideEffects(nodesToOptimize, saveFile.edges, schemata); - const nodesById = new Map(nodes.map((n) => [n.id, n])); - const edges = saveFile.edges.filter((e) => nodesById.has(e.source) && nodesById.has(e.target)); + const { nodes, edges, report } = optimizeChain(saveFile.nodes, saveFile.edges, schemata); // show an error if there are no nodes to run if (nodes.length === 0) { let message; - if (nodesToOptimize.length > 0) { + if (report.removedSideEffectFree > 0) { message = 'There are no nodes that have an effect. Try to view or output images/files.'; - } else if (disabledNodes.size > 0) { + } else if (report.removedDisabled > 0) { message = 'All nodes are disabled. There are no nodes to run.'; } else { message = 'There are no nodes to run.'; diff --git a/src/renderer/contexts/ExecutionContext.tsx b/src/renderer/contexts/ExecutionContext.tsx index 1bbed3d45..efd8b4ad1 100644 --- a/src/renderer/contexts/ExecutionContext.tsx +++ b/src/renderer/contexts/ExecutionContext.tsx @@ -9,8 +9,7 @@ import { log } from '../../common/log'; import { checkFeatures } from '../../common/nodes/checkFeatures'; import { checkNodeValidity } from '../../common/nodes/checkNodeValidity'; import { getConnectedInputs } from '../../common/nodes/connectedInputs'; -import { getEffectivelyDisabledNodes } from '../../common/nodes/disabled'; -import { getNodesWithSideEffects } from '../../common/nodes/sideEffect'; +import { optimizeChain } from '../../common/nodes/optimize'; import { toBackendJson } from '../../common/nodes/toBackendJson'; import { ipcRenderer } from '../../common/safeIpc'; import { getChainnerScope } from '../../common/types/chainner-scope'; @@ -139,8 +138,16 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> setManualOutputType, clearManualOutputTypes, } = useContext(GlobalContext); - const { schemata, url, backend, ownsBackend, restartingRef, features, featureStates } = - useContext(BackendContext); + const { + schemata, + url, + backend, + ownsBackend, + restartingRef, + features, + featureStates, + categories, + } = useContext(BackendContext); const { packageSettings } = useSettings(); const { sendAlert, sendToast } = useContext(AlertBoxContext); @@ -345,24 +352,15 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> }, [status, nodeChanges, edgeChanges, sendToast]); const runNodes = useCallback(async () => { - const allNodes = getNodes(); - const allEdges = getEdges(); - - const disabledNodes = new Set( - getEffectivelyDisabledNodes(allNodes, allEdges).map((n) => n.id) - ); - const nodesToOptimize = allNodes.filter((n) => !disabledNodes.has(n.id)); - const nodes = getNodesWithSideEffects(nodesToOptimize, allEdges, schemata); - const nodeIds = new Set(nodes.map((n) => n.id)); - const edges = allEdges.filter((e) => nodeIds.has(e.source) && nodeIds.has(e.target)); + const { nodes, edges, report } = optimizeChain(getNodes(), getEdges(), schemata); // show an error if there are no nodes to run if (nodes.length === 0) { let message; - if (nodesToOptimize.length > 0) { + if (report.removedSideEffectFree > 0) { message = 'There are no nodes that have an effect. Try to view or output images/files.'; - } else if (disabledNodes.size > 0) { + } else if (report.removedDisabled > 0) { message = 'All nodes are disabled. There are no nodes to run.'; } else { message = 'There are no nodes to run.'; @@ -375,7 +373,8 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> const invalidNodes = nodes.flatMap((node) => { const functionInstance = typeStateRef.current.functions.get(node.data.id); const schema = schemata.get(node.data.schemaId); - const { category, name } = schema; + const { name } = schema; + const category = categories.get(schema.category)?.name ?? schema.category; const validity = bothValid( checkFeatures(schema.features, features, featureStates), @@ -448,6 +447,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> getNodes, getEdges, schemata, + categories, sendAlert, typeStateRef, chainLineageRef,