Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly show mixed iterator-non-iterator inputs/outputs #2481

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def __init__(
)
self.length_type: navi.ExpressionJson = length_type

def to_dict(self):
return {
"inputs": self.inputs,
"lengthType": self.length_type,
}


class IteratorOutputInfo:
def __init__(
Expand All @@ -103,6 +109,12 @@ def __init__(
)
self.length_type: navi.ExpressionJson = length_type

def to_dict(self):
return {
"outputs": self.outputs,
"lengthType": self.length_type,
}


@dataclass(frozen=True)
class NodeData:
Expand Down
2 changes: 2 additions & 0 deletions backend/src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ async def nodes(_request: Request):
"groupLayout": [
g.to_dict() if isinstance(g, Group) else g for g in node.group_layout
],
"iteratorInputs": [x.to_dict() for x in node.iterator_inputs],
"iteratorOutputs": [x.to_dict() for x in node.iterator_outputs],
"description": node.description,
"seeAlso": node.see_also,
"icon": node.icon,
Expand Down
2 changes: 2 additions & 0 deletions src/common/SchemaMap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ const BLANK_SCHEMA: NodeSchema = {
inputs: [],
outputs: [],
groupLayout: [],
iteratorInputs: [],
iteratorOutputs: [],
icon: '',
category: '' as CategoryId,
nodeGroup: '' as NodeGroupId,
Expand Down
11 changes: 11 additions & 0 deletions src/common/common-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,15 @@ export type OutputHeight = Readonly<Record<OutputId, number>>;
export type OutputTypes = Readonly<Partial<Record<OutputId, ExpressionJson | null>>>;
export type GroupState = Readonly<Record<GroupId, unknown>>;

export interface IteratorInputInfo {
readonly inputs: readonly InputId[];
readonly lengthType: ExpressionJson;
}
export interface IteratorOutputInfo {
readonly outputs: readonly OutputId[];
readonly lengthType: ExpressionJson;
}

export interface NodeSchema {
readonly name: string;
readonly category: CategoryId;
Expand All @@ -270,6 +279,8 @@ export interface NodeSchema {
readonly inputs: readonly Input[];
readonly outputs: readonly Output[];
readonly groupLayout: readonly (InputId | Group)[];
readonly iteratorInputs: readonly IteratorInputInfo[];
readonly iteratorOutputs: readonly IteratorOutputInfo[];
readonly schemaId: SchemaId;
readonly hasSideEffects: boolean;
readonly deprecated: boolean;
Expand Down
167 changes: 167 additions & 0 deletions src/common/nodes/lineage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import { Edge, Node } from 'reactflow';
import { EdgeData, NodeData, NodeSchema } from '../common-types';
import { SchemaMap } from '../SchemaMap';
import {
EMPTY_ARRAY,
ParsedSourceHandle,
ParsedTargetHandle,
assertNever,
groupBy,
parseSourceHandle,
stringifyTargetHandle,
} from '../util';

/**
* Represents the iterator lineage of an output.
*
* Note: this class only provides a minimal interface to enable future extensions.
*/
export class Lineage {
private readonly sourceNode: string;

private constructor(sourceNode: string) {
this.sourceNode = sourceNode;
}

equals(other: Lineage): boolean {
return this.sourceNode === other.sourceNode;
}

static fromSourceNode(nodeId: string): Lineage {
return new Lineage(nodeId);
}
}

export class ChainLineage {
readonly schemata: SchemaMap;

private readonly nodeSchemata: ReadonlyMap<string, NodeSchema>;

private readonly byTargetNode: ReadonlyMap<string, readonly Edge<EdgeData>[]>;

private readonly byTargetHandle: ReadonlyMap<string, Edge<EdgeData>>;

private readonly nodeLineageCache = new Map<string, Lineage | null>();

constructor(
schemata: SchemaMap,
nodes: readonly Node<NodeData>[],
edges: readonly Edge<EdgeData>[]
) {
this.schemata = schemata;
this.nodeSchemata = new Map(nodes.map((n) => [n.id, schemata.get(n.data.schemaId)]));

this.byTargetHandle = new Map(edges.map((e) => [e.targetHandle!, e] as const));
this.byTargetNode = groupBy(edges, (e) => e.target);
}

static readonly EMPTY: ChainLineage = new ChainLineage(SchemaMap.EMPTY, [], []);

getEdgeByTarget(handle: ParsedTargetHandle): Edge<EdgeData> | undefined {
return this.byTargetHandle.get(stringifyTargetHandle(handle));
}

/**
* Returns the single lineage (if any) of all iterated inputs of the given node.
*
* Note: regular nodes are auto-iterated, so their lineage is that of the first iterated input (if any).
*
* Note: the input lineage of collector nodes is `null` if there are no connected iterated inputs (invalid chain).
*/
getInputLineage(nodeId: string): Lineage | null {
const schema = this.nodeSchemata.get(nodeId);
if (!schema) return null;

switch (schema.nodeType) {
case 'newIterator': {
// iterator source nodes do not support iterated inputs
return null;
}
case 'regularNode': {
// regular nodes are auto-iterated, so their lineage is that of the first iterated input
let lineage = this.nodeLineageCache.get(nodeId);
if (lineage === undefined) {
lineage = null;

const edges = this.byTargetNode.get(nodeId) ?? EMPTY_ARRAY;
for (const edge of edges) {
const inputLineage = this.getOutputLineage(
parseSourceHandle(edge.sourceHandle!)
);
if (inputLineage !== null) {
lineage = inputLineage;
break;
}
}

this.nodeLineageCache.set(nodeId, lineage);
}
return lineage;
}
case 'collector': {
// collectors already return non-iterator outputs
let lineage = this.nodeLineageCache.get(nodeId);
if (lineage === undefined) {
lineage = null;

if (schema.iteratorInputs.length !== 1) {
throw new Error(
`Collector nodes should have exactly 1 iterator input info (${schema.schemaId})`
);
}
const info = schema.iteratorInputs[0];

for (const inputId of info.inputs) {
const edge = this.getEdgeByTarget({ nodeId, inputId });
// eslint-disable-next-line no-continue
if (!edge) continue;

const handle = parseSourceHandle(edge.sourceHandle!);
const inputLineage = this.getOutputLineage(handle);
if (inputLineage !== null) {
lineage = inputLineage;
break;
}
}

this.nodeLineageCache.set(nodeId, lineage);
}
return lineage;
}
default:
return assertNever(schema.nodeType);
}
}

/**
* Returns the lineage of the given specific output.
*/
getOutputLineage({ nodeId, outputId }: ParsedSourceHandle): Lineage | null {
const schema = this.nodeSchemata.get(nodeId);
if (!schema) return null;

switch (schema.nodeType) {
case 'regularNode': {
// for regular nodes, the lineage of all outputs is equal to
// the lineage of the first iterated input (if any).
return this.getInputLineage(nodeId);
}
case 'newIterator': {
// iterator source nodes create a new lineage
if (schema.iteratorOutputs.length !== 1) {
throw new Error(
`Iterator nodes should have exactly 1 iterator output info (${schema.schemaId})`
);
}
const info = schema.iteratorOutputs[0];
return info.outputs.includes(outputId) ? Lineage.fromSourceNode(nodeId) : null;
}
case 'collector': {
// collectors already return non-iterator outputs
return null;
}
default:
return assertNever(schema.nodeType);
}
}
}
17 changes: 6 additions & 11 deletions src/renderer/components/Handle.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { Box, Tooltip, chakra } from '@chakra-ui/react';
import React, { memo } from 'react';
import { Connection, Position, Handle as RFHandle } from 'reactflow';
import { useContext } from 'use-context-selector';
import { NodeType } from '../../common/common-types';
import { Validity } from '../../common/Validity';
import { FakeNodeContext } from '../contexts/FakeExampleContext';
import { noContextMenu } from '../hooks/useContextMenu';
Expand All @@ -15,7 +14,7 @@ interface HandleElementProps {
isValidConnection: (connection: Readonly<Connection>) => boolean;
validity: Validity;
id: string;
nodeType: NodeType;
isIterated: boolean;
}

// Had to do this garbage to prevent chakra from clashing the position prop
Expand All @@ -26,16 +25,12 @@ const HandleElement = memo(
validity,
type,
id,
nodeType,
isIterated,
...props
}: React.PropsWithChildren<HandleElementProps>) => {
const { isFake } = useContext(FakeNodeContext);

const isIterator = nodeType === 'newIterator';
const isCollector = nodeType === 'collector';

const squaredHandle =
(isIterator && type === 'output') || (isCollector && type === 'input');
const squaredHandle = isIterated;

return (
<Tooltip
Expand Down Expand Up @@ -101,7 +96,7 @@ export interface HandleProps {
isValidConnection: (connection: Readonly<Connection>) => boolean;
handleColors: readonly string[];
connectedColor: string | undefined;
nodeType: NodeType;
isIterated: boolean;
}

const getBackground = (colors: readonly string[]): string => {
Expand All @@ -125,7 +120,7 @@ export const Handle = memo(
isValidConnection,
handleColors,
connectedColor,
nodeType,
isIterated,
}: HandleProps) => {
const isConnected = !!connectedColor;

Expand Down Expand Up @@ -154,8 +149,8 @@ export const Handle = memo(
as={HandleElement}
className={`${type}-handle`}
id={id}
isIterated={isIterated}
isValidConnection={isValidConnection}
nodeType={nodeType}
sx={{
width: '16px',
height: '16px',
Expand Down
9 changes: 9 additions & 0 deletions src/renderer/components/NodeDocumentation/NodeExample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ export const NodeExample = memo(({ accentColor, selectedSchema }: NodeExamplePro
functionInstance: typeInfo.instance,
});

const { iteratedInputs, iteratedOutputs } = useMemo(() => {
return {
iteratedInputs: new Set(selectedSchema.iteratorInputs.flatMap((i) => i.inputs)),
iteratedOutputs: new Set(selectedSchema.iteratorOutputs.flatMap((i) => i.outputs)),
};
}, [selectedSchema]);

return (
<Center key={selectedSchema.schemaId}>
<FakeNodeProvider isFake>
Expand Down Expand Up @@ -187,6 +194,8 @@ export const NodeExample = memo(({ accentColor, selectedSchema }: NodeExamplePro
isLocked: false,
connectedInputs: EMPTY_SET,
connectedOutputs: EMPTY_SET,
iteratedInputs,
iteratedOutputs,
type: typeInfo,
testCondition: (condition: Condition): boolean =>
testInputConditionTypeInfo(condition, inputData, typeInfo),
Expand Down
8 changes: 4 additions & 4 deletions src/renderer/components/inputs/InputContainer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Box, Center, HStack, Text, Tooltip } from '@chakra-ui/react';
import React, { memo, useCallback, useMemo } from 'react';
import { Connection, Node, useReactFlow } from 'reactflow';
import { useContext } from 'use-context-selector';
import { InputId, NodeData, NodeType } from '../../../common/common-types';
import { InputId, NodeData } from '../../../common/common-types';
import { parseSourceHandle, parseTargetHandle, stringifyTargetHandle } from '../../../common/util';
import { VALID, invalid } from '../../../common/Validity';
import { BackendContext } from '../../contexts/BackendContext';
Expand All @@ -19,7 +19,7 @@ export interface InputHandleProps {
id: string;
inputId: InputId;
connectableType: Type;
nodeType: NodeType;
isIterated: boolean;
}

export const InputHandle = memo(
Expand All @@ -28,7 +28,7 @@ export const InputHandle = memo(
id,
inputId,
connectableType,
nodeType,
isIterated,
}: React.PropsWithChildren<InputHandleProps>) => {
const { isValidConnection, edgeChanges, useConnectingFrom, typeState } =
useContext(GlobalVolatileContext);
Expand Down Expand Up @@ -115,8 +115,8 @@ export const InputHandle = memo(
}
handleColors={handleColors}
id={targetHandle}
isIterated={isIterated}
isValidConnection={isValidConnectionForRf}
nodeType={nodeType}
type="input"
validity={validity}
/>
Expand Down
7 changes: 4 additions & 3 deletions src/renderer/components/inputs/SchemaInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ export const SchemaInput = memo(({ input, nodeState, afterInput }: SingleInputPr
setWidth,
isLocked,
connectedInputs,
iteratedInputs,
iteratedOutputs,
type,
schema,
} = nodeState;

const functionDefinition = useContextSelector(BackendContext, (c) =>
Expand Down Expand Up @@ -128,7 +129,7 @@ export const SchemaInput = memo(({ input, nodeState, afterInput }: SingleInputPr
connectableType={connectableType}
id={nodeId}
inputId={inputId}
nodeType={schema.nodeType}
isIterated={iteratedInputs.has(inputId)}
>
{inputElement}
</InputHandle>
Expand All @@ -145,7 +146,7 @@ export const SchemaInput = memo(({ input, nodeState, afterInput }: SingleInputPr
}
id={nodeId}
isConnected={nodeState.connectedOutputs.has(fused.outputId)}
nodeType={schema.nodeType}
isIterated={iteratedOutputs.has(fused.outputId)}
outputId={fused.outputId}
type={outputType}
/>
Expand Down