diff --git a/src/common/types/function.ts b/src/common/types/function.ts index 876bf0f38..41f11bb48 100644 --- a/src/common/types/function.ts +++ b/src/common/types/function.ts @@ -605,4 +605,16 @@ export class FunctionInstance { // we say that types A is assignable to type B if they are not disjoint return !isDisjointWith(iType, this.definition.convertInput(inputId, type)); } + + /** + * Returns a new function instance with the given input assigned to the given type. + */ + withInput(inputId: InputId, type: NonNeverType): FunctionInstance { + return FunctionInstance.fromPartialInputs(this.definition, (id) => { + if (id === inputId) { + return type; + } + return this.inputs.get(id); + }); + } } diff --git a/src/renderer/contexts/GlobalNodeState.tsx b/src/renderer/contexts/GlobalNodeState.tsx index 2abbd298d..287703403 100644 --- a/src/renderer/contexts/GlobalNodeState.tsx +++ b/src/renderer/contexts/GlobalNodeState.tsx @@ -861,7 +861,7 @@ export const GlobalProvider = memo( const outputType = sourceFn.outputs.get(sourceHandleId); if (outputType !== undefined && !targetFn.canAssign(targetHandleId, outputType)) { - const schema = schemata.get(targetNode.data.schemaId); + const { schema } = targetFn.definition; const input = schema.inputs.find((i) => i.id === targetHandleId)!; const inputType = withoutNull( targetFn.definition.inputDefaults.get(targetHandleId)! @@ -884,6 +884,28 @@ export const GlobalProvider = memo( ); } + if ( + outputType !== undefined && + targetFn.inputErrors.length === 0 && + targetFn.outputErrors.length === 0 + ) { + const assignedFn = targetFn.withInput(targetHandleId, outputType); + if (assignedFn.outputErrors.length > 0) { + // the assigned caused output error + const errorId = assignedFn.outputErrors[0].outputId; + + const { schema } = targetFn.definition; + const output = schema.outputs.find((o) => o.id === errorId)!; + + if (output.neverReason) { + return invalid( + `Connection would cause the following error: ${output.neverReason}` + ); + } + return invalid(`Connection would cause an output error.`); + } + } + const checkTargetChildren = (parentNode: Node): boolean => { const targetChildren = getOutgoers(parentNode, getNodes(), getEdges()); if (!targetChildren.length) { @@ -908,7 +930,7 @@ export const GlobalProvider = memo( return VALID; }, - [typeState.functions, getNode, getNodes, getEdges, schemata] + [typeState.functions, getNode, getNodes, getEdges] ); const [inputDataChanges, addInputDataChanges] = useChangeCounter();