diff --git a/backend/src/api/api.py b/backend/src/api/api.py index 1b20837fe..0fb044e6f 100644 --- a/backend/src/api/api.py +++ b/backend/src/api/api.py @@ -89,6 +89,12 @@ def __init__( ) self.length_type: navi.ExpressionJson = length_type + def to_dict(self): + return { + "inputs": self.inputs, + "lengthType": self.length_type, + } + class IteratorOutputInfo: def __init__( @@ -103,6 +109,12 @@ def __init__( ) self.length_type: navi.ExpressionJson = length_type + def to_dict(self): + return { + "outputs": self.outputs, + "lengthType": self.length_type, + } + @dataclass(frozen=True) class NodeData: diff --git a/backend/src/server.py b/backend/src/server.py index 199bc10dc..94d3163ac 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -126,6 +126,8 @@ async def nodes(_request: Request): "groupLayout": [ g.to_dict() if isinstance(g, Group) else g for g in node.group_layout ], + "iteratorInputs": [x.to_dict() for x in node.iterator_inputs], + "iteratorOutputs": [x.to_dict() for x in node.iterator_outputs], "description": node.description, "seeAlso": node.see_also, "icon": node.icon, diff --git a/src/common/SchemaMap.ts b/src/common/SchemaMap.ts index 2c9f83a27..ea1092da8 100644 --- a/src/common/SchemaMap.ts +++ b/src/common/SchemaMap.ts @@ -13,6 +13,8 @@ const BLANK_SCHEMA: NodeSchema = { inputs: [], outputs: [], groupLayout: [], + iteratorInputs: [], + iteratorOutputs: [], icon: '', category: '' as CategoryId, nodeGroup: '' as NodeGroupId, diff --git a/src/common/common-types.ts b/src/common/common-types.ts index 222500db0..9ab87aaac 100644 --- a/src/common/common-types.ts +++ b/src/common/common-types.ts @@ -259,6 +259,15 @@ export type OutputHeight = Readonly>; export type OutputTypes = Readonly>>; export type GroupState = Readonly>; +export interface IteratorInputInfo { + readonly inputs: readonly InputId[]; + readonly lengthType: ExpressionJson; +} +export interface IteratorOutputInfo { + readonly outputs: readonly OutputId[]; + readonly lengthType: ExpressionJson; +} + export interface NodeSchema { readonly name: string; readonly category: CategoryId; @@ -270,6 +279,8 @@ export interface NodeSchema { readonly inputs: readonly Input[]; readonly outputs: readonly Output[]; readonly groupLayout: readonly (InputId | Group)[]; + readonly iteratorInputs: readonly IteratorInputInfo[]; + readonly iteratorOutputs: readonly IteratorOutputInfo[]; readonly schemaId: SchemaId; readonly hasSideEffects: boolean; readonly deprecated: boolean; diff --git a/src/common/nodes/lineage.ts b/src/common/nodes/lineage.ts new file mode 100644 index 000000000..0e1bac203 --- /dev/null +++ b/src/common/nodes/lineage.ts @@ -0,0 +1,167 @@ +import { Edge, Node } from 'reactflow'; +import { EdgeData, NodeData, NodeSchema } from '../common-types'; +import { SchemaMap } from '../SchemaMap'; +import { + EMPTY_ARRAY, + ParsedSourceHandle, + ParsedTargetHandle, + assertNever, + groupBy, + parseSourceHandle, + stringifyTargetHandle, +} from '../util'; + +/** + * Represents the iterator lineage of an output. + * + * Note: this class only provides a minimal interface to enable future extensions. + */ +export class Lineage { + private readonly sourceNode: string; + + private constructor(sourceNode: string) { + this.sourceNode = sourceNode; + } + + equals(other: Lineage): boolean { + return this.sourceNode === other.sourceNode; + } + + static fromSourceNode(nodeId: string): Lineage { + return new Lineage(nodeId); + } +} + +export class ChainLineage { + readonly schemata: SchemaMap; + + private readonly nodeSchemata: ReadonlyMap; + + private readonly byTargetNode: ReadonlyMap[]>; + + private readonly byTargetHandle: ReadonlyMap>; + + private readonly nodeLineageCache = new Map(); + + constructor( + schemata: SchemaMap, + nodes: readonly Node[], + edges: readonly Edge[] + ) { + this.schemata = schemata; + this.nodeSchemata = new Map(nodes.map((n) => [n.id, schemata.get(n.data.schemaId)])); + + this.byTargetHandle = new Map(edges.map((e) => [e.targetHandle!, e] as const)); + this.byTargetNode = groupBy(edges, (e) => e.target); + } + + static readonly EMPTY: ChainLineage = new ChainLineage(SchemaMap.EMPTY, [], []); + + getEdgeByTarget(handle: ParsedTargetHandle): Edge | undefined { + return this.byTargetHandle.get(stringifyTargetHandle(handle)); + } + + /** + * Returns the single lineage (if any) of all iterated inputs of the given node. + * + * Note: regular nodes are auto-iterated, so their lineage is that of the first iterated input (if any). + * + * Note: the input lineage of collector nodes is `null` if there are no connected iterated inputs (invalid chain). + */ + getInputLineage(nodeId: string): Lineage | null { + const schema = this.nodeSchemata.get(nodeId); + if (!schema) return null; + + switch (schema.nodeType) { + case 'newIterator': { + // iterator source nodes do not support iterated inputs + return null; + } + case 'regularNode': { + // regular nodes are auto-iterated, so their lineage is that of the first iterated input + let lineage = this.nodeLineageCache.get(nodeId); + if (lineage === undefined) { + lineage = null; + + const edges = this.byTargetNode.get(nodeId) ?? EMPTY_ARRAY; + for (const edge of edges) { + const inputLineage = this.getOutputLineage( + parseSourceHandle(edge.sourceHandle!) + ); + if (inputLineage !== null) { + lineage = inputLineage; + break; + } + } + + this.nodeLineageCache.set(nodeId, lineage); + } + return lineage; + } + case 'collector': { + // collectors already return non-iterator outputs + let lineage = this.nodeLineageCache.get(nodeId); + if (lineage === undefined) { + lineage = null; + + if (schema.iteratorInputs.length !== 1) { + throw new Error( + `Collector nodes should have exactly 1 iterator input info (${schema.schemaId})` + ); + } + const info = schema.iteratorInputs[0]; + + for (const inputId of info.inputs) { + const edge = this.getEdgeByTarget({ nodeId, inputId }); + // eslint-disable-next-line no-continue + if (!edge) continue; + + const handle = parseSourceHandle(edge.sourceHandle!); + const inputLineage = this.getOutputLineage(handle); + if (inputLineage !== null) { + lineage = inputLineage; + break; + } + } + + this.nodeLineageCache.set(nodeId, lineage); + } + return lineage; + } + default: + return assertNever(schema.nodeType); + } + } + + /** + * Returns the lineage of the given specific output. + */ + getOutputLineage({ nodeId, outputId }: ParsedSourceHandle): Lineage | null { + const schema = this.nodeSchemata.get(nodeId); + if (!schema) return null; + + switch (schema.nodeType) { + case 'regularNode': { + // for regular nodes, the lineage of all outputs is equal to + // the lineage of the first iterated input (if any). + return this.getInputLineage(nodeId); + } + case 'newIterator': { + // iterator source nodes create a new lineage + if (schema.iteratorOutputs.length !== 1) { + throw new Error( + `Iterator nodes should have exactly 1 iterator output info (${schema.schemaId})` + ); + } + const info = schema.iteratorOutputs[0]; + return info.outputs.includes(outputId) ? Lineage.fromSourceNode(nodeId) : null; + } + case 'collector': { + // collectors already return non-iterator outputs + return null; + } + default: + return assertNever(schema.nodeType); + } + } +} diff --git a/src/renderer/components/Handle.tsx b/src/renderer/components/Handle.tsx index 2000139d4..2f9819855 100644 --- a/src/renderer/components/Handle.tsx +++ b/src/renderer/components/Handle.tsx @@ -2,7 +2,6 @@ import { Box, Tooltip, chakra } from '@chakra-ui/react'; import React, { memo } from 'react'; import { Connection, Position, Handle as RFHandle } from 'reactflow'; import { useContext } from 'use-context-selector'; -import { NodeType } from '../../common/common-types'; import { Validity } from '../../common/Validity'; import { FakeNodeContext } from '../contexts/FakeExampleContext'; import { noContextMenu } from '../hooks/useContextMenu'; @@ -15,7 +14,7 @@ interface HandleElementProps { isValidConnection: (connection: Readonly) => boolean; validity: Validity; id: string; - nodeType: NodeType; + isIterated: boolean; } // Had to do this garbage to prevent chakra from clashing the position prop @@ -26,16 +25,12 @@ const HandleElement = memo( validity, type, id, - nodeType, + isIterated, ...props }: React.PropsWithChildren) => { const { isFake } = useContext(FakeNodeContext); - const isIterator = nodeType === 'newIterator'; - const isCollector = nodeType === 'collector'; - - const squaredHandle = - (isIterator && type === 'output') || (isCollector && type === 'input'); + const squaredHandle = isIterated; return ( ) => boolean; handleColors: readonly string[]; connectedColor: string | undefined; - nodeType: NodeType; + isIterated: boolean; } const getBackground = (colors: readonly string[]): string => { @@ -125,7 +120,7 @@ export const Handle = memo( isValidConnection, handleColors, connectedColor, - nodeType, + isIterated, }: HandleProps) => { const isConnected = !!connectedColor; @@ -154,8 +149,8 @@ export const Handle = memo( as={HandleElement} className={`${type}-handle`} id={id} + isIterated={isIterated} isValidConnection={isValidConnection} - nodeType={nodeType} sx={{ width: '16px', height: '16px', diff --git a/src/renderer/components/NodeDocumentation/NodeExample.tsx b/src/renderer/components/NodeDocumentation/NodeExample.tsx index 67a6f258f..98272ab3a 100644 --- a/src/renderer/components/NodeDocumentation/NodeExample.tsx +++ b/src/renderer/components/NodeDocumentation/NodeExample.tsx @@ -142,6 +142,13 @@ export const NodeExample = memo(({ accentColor, selectedSchema }: NodeExamplePro functionInstance: typeInfo.instance, }); + const { iteratedInputs, iteratedOutputs } = useMemo(() => { + return { + iteratedInputs: new Set(selectedSchema.iteratorInputs.flatMap((i) => i.inputs)), + iteratedOutputs: new Set(selectedSchema.iteratorOutputs.flatMap((i) => i.outputs)), + }; + }, [selectedSchema]); + return (
@@ -187,6 +194,8 @@ export const NodeExample = memo(({ accentColor, selectedSchema }: NodeExamplePro isLocked: false, connectedInputs: EMPTY_SET, connectedOutputs: EMPTY_SET, + iteratedInputs, + iteratedOutputs, type: typeInfo, testCondition: (condition: Condition): boolean => testInputConditionTypeInfo(condition, inputData, typeInfo), diff --git a/src/renderer/components/inputs/InputContainer.tsx b/src/renderer/components/inputs/InputContainer.tsx index cf7f7145b..3f7dc41aa 100644 --- a/src/renderer/components/inputs/InputContainer.tsx +++ b/src/renderer/components/inputs/InputContainer.tsx @@ -4,7 +4,7 @@ import { Box, Center, HStack, Text, Tooltip } from '@chakra-ui/react'; import React, { memo, useCallback, useMemo } from 'react'; import { Connection, Node, useReactFlow } from 'reactflow'; import { useContext } from 'use-context-selector'; -import { InputId, NodeData, NodeType } from '../../../common/common-types'; +import { InputId, NodeData } from '../../../common/common-types'; import { parseSourceHandle, parseTargetHandle, stringifyTargetHandle } from '../../../common/util'; import { VALID, invalid } from '../../../common/Validity'; import { BackendContext } from '../../contexts/BackendContext'; @@ -19,7 +19,7 @@ export interface InputHandleProps { id: string; inputId: InputId; connectableType: Type; - nodeType: NodeType; + isIterated: boolean; } export const InputHandle = memo( @@ -28,7 +28,7 @@ export const InputHandle = memo( id, inputId, connectableType, - nodeType, + isIterated, }: React.PropsWithChildren) => { const { isValidConnection, edgeChanges, useConnectingFrom, typeState } = useContext(GlobalVolatileContext); @@ -115,8 +115,8 @@ export const InputHandle = memo( } handleColors={handleColors} id={targetHandle} + isIterated={isIterated} isValidConnection={isValidConnectionForRf} - nodeType={nodeType} type="input" validity={validity} /> diff --git a/src/renderer/components/inputs/SchemaInput.tsx b/src/renderer/components/inputs/SchemaInput.tsx index 42d82cae1..5d3eba5b2 100644 --- a/src/renderer/components/inputs/SchemaInput.tsx +++ b/src/renderer/components/inputs/SchemaInput.tsx @@ -58,8 +58,9 @@ export const SchemaInput = memo(({ input, nodeState, afterInput }: SingleInputPr setWidth, isLocked, connectedInputs, + iteratedInputs, + iteratedOutputs, type, - schema, } = nodeState; const functionDefinition = useContextSelector(BackendContext, (c) => @@ -128,7 +129,7 @@ export const SchemaInput = memo(({ input, nodeState, afterInput }: SingleInputPr connectableType={connectableType} id={nodeId} inputId={inputId} - nodeType={schema.nodeType} + isIterated={iteratedInputs.has(inputId)} > {inputElement} @@ -145,7 +146,7 @@ export const SchemaInput = memo(({ input, nodeState, afterInput }: SingleInputPr } id={nodeId} isConnected={nodeState.connectedOutputs.has(fused.outputId)} - nodeType={schema.nodeType} + isIterated={iteratedOutputs.has(fused.outputId)} outputId={fused.outputId} type={outputType} /> diff --git a/src/renderer/components/node/NodeOutputs.tsx b/src/renderer/components/node/NodeOutputs.tsx index b5bc3c868..2e4b81aea 100644 --- a/src/renderer/components/node/NodeOutputs.tsx +++ b/src/renderer/components/node/NodeOutputs.tsx @@ -48,7 +48,16 @@ interface NodeOutputProps { } export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => { - const { id, schema, schemaId, outputHeight, setOutputHeight, nodeWidth, setWidth } = nodeState; + const { + id, + schema, + schemaId, + outputHeight, + setOutputHeight, + nodeWidth, + setWidth, + iteratedOutputs, + } = nodeState; const { functionDefinitions } = useContext(BackendContext); const { setManualOutputType } = useContext(GlobalContext); @@ -110,8 +119,8 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => { generic={OutputIsGeneric[output.kind]} id={id} isConnected={nodeState.connectedOutputs.has(output.id)} + isIterated={iteratedOutputs.has(output.id)} key={`${id}-${output.id}`} - nodeType={schema.nodeType} output={output} type={type} > diff --git a/src/renderer/components/outputs/OutputContainer.tsx b/src/renderer/components/outputs/OutputContainer.tsx index 6c9658a20..82a1d6694 100644 --- a/src/renderer/components/outputs/OutputContainer.tsx +++ b/src/renderer/components/outputs/OutputContainer.tsx @@ -3,7 +3,7 @@ import { Box, Center, HStack, Text } from '@chakra-ui/react'; import React, { memo, useCallback, useMemo } from 'react'; import { Connection } from 'reactflow'; import { useContext } from 'use-context-selector'; -import { NodeType, Output, OutputId } from '../../../common/common-types'; +import { Output, OutputId } from '../../../common/common-types'; import { stringifySourceHandle } from '../../../common/util'; import { VALID, invalid } from '../../../common/Validity'; import { GlobalVolatileContext } from '../../contexts/GlobalNodeState'; @@ -14,14 +14,14 @@ import { TypeTags } from '../TypeTag'; export interface OutputHandleProps { id: string; outputId: OutputId; - nodeType: NodeType; definitionType: Type; type: Type | undefined; isConnected: boolean; + isIterated: boolean; } export const OutputHandle = memo( - ({ id, outputId, nodeType, definitionType, type, isConnected }: OutputHandleProps) => { + ({ id, outputId, isIterated, definitionType, type, isConnected }: OutputHandleProps) => { const { isValidConnection, useConnectingFrom } = useContext(GlobalVolatileContext); const [connectingFrom] = useConnectingFrom; @@ -64,8 +64,8 @@ export const OutputHandle = memo( connectedColor={isConnected ? handleColors[0] : undefined} handleColors={handleColors} id={sourceHandle} + isIterated={isIterated} isValidConnection={isValidConnectionForRf} - nodeType={nodeType} type="output" validity={validity} /> @@ -81,7 +81,7 @@ interface OutputContainerProps { type: Type | undefined; generic: boolean; isConnected: boolean; - nodeType: NodeType; + isIterated: boolean; } export const OutputContainer = memo( @@ -93,7 +93,7 @@ export const OutputContainer = memo( type, generic, isConnected, - nodeType, + isIterated, }: React.PropsWithChildren) => { let contents = children; if (output.hasHandle) { @@ -104,7 +104,7 @@ export const OutputContainer = memo( definitionType={definitionType} id={id} isConnected={isConnected} - nodeType={nodeType} + isIterated={isIterated} outputId={output.id} type={type} /> diff --git a/src/renderer/contexts/GlobalNodeState.tsx b/src/renderer/contexts/GlobalNodeState.tsx index 98d2d4943..48b06d820 100644 --- a/src/renderer/contexts/GlobalNodeState.tsx +++ b/src/renderer/contexts/GlobalNodeState.tsx @@ -25,6 +25,7 @@ import { import { IdSet } from '../../common/IdSet'; import { log } from '../../common/log'; import { getEffectivelyDisabledNodes } from '../../common/nodes/disabled'; +import { ChainLineage } from '../../common/nodes/lineage'; import { TypeState } from '../../common/nodes/TypeState'; import { ipcRenderer } from '../../common/safeIpc'; import { ParsedSaveData, SaveData, openSaveFile } from '../../common/SaveFile'; @@ -102,6 +103,7 @@ interface GlobalVolatile { getConnected: (id: string) => readonly [IdSet, IdSet]; isValidConnection: (connection: Readonly) => Validity; effectivelyDisabledNodes: ReadonlySet; + chainLineage: ChainLineage; zoom: number; collidingEdge: string | undefined; collidingNode: string | undefined; @@ -326,6 +328,11 @@ export const GlobalProvider = memo( }); }, [edgeChanges, nodeChanges, getNodes, getEdges]); + const [chainLineage, setChainLineage] = useState(ChainLineage.EMPTY); + useEffect(() => { + setChainLineage(new ChainLineage(schemata, getNodes(), getEdges())); + }, [edgeChanges, getNodes, getEdges, schemata]); + const [savePath, setSavePathInternal] = useSessionStorage('save-path', null); const [openRecent, pushOpenPath, removeRecentPath] = useOpenRecent(); const setSavePath = useCallback( @@ -1301,6 +1308,7 @@ export const GlobalProvider = memo( typeState, getConnected, effectivelyDisabledNodes, + chainLineage, isValidConnection, zoom, collidingEdge, diff --git a/src/renderer/helpers/nodeState.ts b/src/renderer/helpers/nodeState.ts index a55bd3f9d..2870032b5 100644 --- a/src/renderer/helpers/nodeState.ts +++ b/src/renderer/helpers/nodeState.ts @@ -15,7 +15,7 @@ import { import { IdSet } from '../../common/IdSet'; import { testInputCondition } from '../../common/nodes/condition'; import { FunctionInstance } from '../../common/types/function'; -import { EMPTY_ARRAY, EMPTY_SET } from '../../common/util'; +import { EMPTY_ARRAY, EMPTY_SET, parseSourceHandle } from '../../common/util'; import { BackendContext } from '../contexts/BackendContext'; import { GlobalContext, GlobalVolatileContext } from '../contexts/GlobalNodeState'; import { useMemoObject } from '../hooks/useMemo'; @@ -73,6 +73,8 @@ export interface NodeState { readonly isLocked: boolean; readonly connectedInputs: ReadonlySet; readonly connectedOutputs: ReadonlySet; + readonly iteratedInputs: ReadonlySet; + readonly iteratedOutputs: ReadonlySet; readonly type: TypeInfo; readonly testCondition: (condition: Condition) => boolean; } @@ -111,6 +113,38 @@ export const useNodeStateFromData = (data: NodeData): NodeState => { return [IdSet.toSet(inputsSet), IdSet.toSet(outputsSet)]; }, [connectedString]); + const chainLineage = useContextSelector(GlobalVolatileContext, (c) => c.chainLineage); + const [iteratedInputs, iteratedOutputs] = useMemo(() => { + if (schema.nodeType === 'regularNode') { + // eslint-disable-next-line @typescript-eslint/no-shadow + const iteratedInputs = new Set(); + for (const input of schema.inputs) { + const edge = chainLineage.getEdgeByTarget({ nodeId: id, inputId: input.id }); + // eslint-disable-next-line no-continue + if (!edge) continue; + + const inputLineage = chainLineage.getOutputLineage( + parseSourceHandle(edge.sourceHandle!) + ); + if (inputLineage !== null) { + iteratedInputs.add(input.id); + } + } + + if (iteratedInputs.size > 0) { + // regular nodes are auto-iterated + return [iteratedInputs, new Set(schema.outputs.map((o) => o.id))]; + } + return [iteratedInputs, EMPTY_SET]; + } + + // iterators and collectors only have their defined iterated inputs/outputs + return [ + new Set(schema.iteratorInputs.flatMap((i) => i.inputs)), + new Set(schema.iteratorOutputs.flatMap((o) => o.outputs)), + ]; + }, [chainLineage, id, schema]); + const type = useTypeInfo(id); const testCondition = useCallback( @@ -133,6 +167,8 @@ export const useNodeStateFromData = (data: NodeData): NodeState => { isLocked: isLocked ?? false, connectedInputs, connectedOutputs, + iteratedInputs, + iteratedOutputs, type, testCondition, });