diff --git a/compiler/packages/babel-plugin-react-compiler/src/HIR/HIR.ts b/compiler/packages/babel-plugin-react-compiler/src/HIR/HIR.ts index b94bffa040f..db9e2a0f5b2 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/HIR/HIR.ts +++ b/compiler/packages/babel-plugin-react-compiler/src/HIR/HIR.ts @@ -294,6 +294,13 @@ export type HIRFunction = { }; export type FunctionEffect = + | { + kind: 'ImmutableFunctionCall'; + loc: SourceLocation; + lvalue: IdentifierId; + callee: IdentifierId; + global: boolean; + } | { kind: 'GlobalMutation'; error: CompilerErrorDetailOptions; diff --git a/compiler/packages/babel-plugin-react-compiler/src/Inference/InferReferenceEffects.ts b/compiler/packages/babel-plugin-react-compiler/src/Inference/InferReferenceEffects.ts index aac092707dd..0d4e422c574 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/Inference/InferReferenceEffects.ts +++ b/compiler/packages/babel-plugin-react-compiler/src/Inference/InferReferenceEffects.ts @@ -31,6 +31,7 @@ import { isMutableEffect, isObjectType, isRefValueType, + isSetStateType, isUseRefType, } from '../HIR/HIR'; import {FunctionSignature} from '../HIR/ObjectShape'; @@ -251,6 +252,10 @@ export default function inferReferenceEffects( loc: eff.loc, }); } + case 'ImmutableFunctionCall': { + // Handled below + break; + } default: assertExhaustive( eff, @@ -258,9 +263,122 @@ export default function inferReferenceEffects( ); } }); - } else { - fn.effects = functionEffects; + + if (functionEffects.length > 0) { + const error = new CompilerError(); + let usedIdentifiers = new Set(); + let names = new Map(); + + function visitFunction(fn: HIRFunction): void { + for (const [, block] of fn.body.blocks) { + for (const instr of block.instructions) { + switch (instr.value.kind) { + case 'FunctionExpression': + case 'ObjectMethod': + names.set( + instr.lvalue.identifier.id, + instr.value.loweredFunc.func.id ?? '(anonymous function)', + ); + visitFunction(instr.value.loweredFunc.func); + break; + case 'LoadGlobal': + names.set(instr.lvalue.identifier.id, instr.value.binding.name); + break; + case 'LoadLocal': + case 'LoadContext': + names.set( + instr.lvalue.identifier.id, + instr.value.place.identifier.name?.value ?? + names.get(instr.value.place.identifier.id), + ); + break; + case 'StoreContext': + case 'StoreLocal': + names.set( + instr.lvalue.identifier.id, + instr.value.value.identifier.name?.value ?? + names.get(instr.value.value.identifier.id), + ); + names.set( + instr.value.lvalue.place.identifier.id, + instr.value.value.identifier.name?.value ?? + names.get(instr.value.value.identifier.id), + ); + break; + case 'PropertyLoad': + names.set( + instr.lvalue.identifier.id, + `${instr.value.object.identifier.name?.value ?? names.get(instr.value.object.identifier.id) ?? '(unknown)'}.${instr.value.property}`, + ); + break; + case 'ComputedLoad': + names.set( + instr.lvalue.identifier.id, + `${instr.value.object.identifier.name?.value ?? names.get(instr.value.object.identifier.id) ?? '(unknown)'}[...]`, + ); + break; + case 'Destructure': { + const destructuredName = + instr.value.value.identifier.name?.value ?? + names.get(instr.value.value.identifier.id); + const destructuredMsg = destructuredName + ? `(destructured from \`${destructuredName}\`)` + : '(destructured)'; + Array.from( + eachPatternOperand(instr.value.lvalue.pattern), + ).forEach(place => + names.set( + place.identifier.id, + `${place.identifier.name?.value ?? 'value'} ${destructuredMsg}`, + ), + ); + } + } + Array.from(eachInstructionOperand(instr)).forEach(operand => + usedIdentifiers.add(operand.identifier.id), + ); + } + for (const phi of block.phis) { + Array.from(phi.operands.values()).forEach(operand => + usedIdentifiers.add(operand.id), + ); + } + Array.from(eachTerminalOperand(block.terminal)).forEach(operand => + usedIdentifiers.add(operand.identifier.id), + ); + } + } + visitFunction(fn); + + const allowedNames = new Set(['invariant', 'recoverableViolation']); + + for (const effect of functionEffects) { + CompilerError.invariant(effect.kind === 'ImmutableFunctionCall', { + reason: + 'All effects other than ImmutableFunctionCall should have been handled earlier', + loc: null, + }); + if ( + !usedIdentifiers.has(effect.lvalue) && + (!effect.global || + !names.has(effect.callee) || + !allowedNames.has(names.get(effect.callee)!)) + ) { + const name = names.get(effect.callee) ?? '(unknown)'; + error.push({ + reason: `Function \'${name}\' is called with arguments that React Compiler expects to be immutable and its return value is ignored. This call is likely to perform unsafe side effects, which violates the rules of React.`, + loc: effect.loc, + severity: ErrorSeverity.InvalidReact, + }); + } + } + + if (error.hasErrors()) { + throw error; + } + } } + fn.effects = functionEffects; } // Maintains a mapping of top-level variables to the kind of value they hold @@ -433,11 +551,12 @@ class InferenceState { for (const effect of dependentEffects) { if ( effect.kind === 'GlobalMutation' || - effect.kind === 'ReactMutation' + effect.kind === 'ReactMutation' || + effect.kind === 'ImmutableFunctionCall' ) { // Known effects are always propagated upwards functionEffects.push(effect); - } else { + } else if (effect.kind === 'ContextMutation') { /** * Contextual effects need to be replayed against the current inference * state, which may know more about the value to which the effect applied. @@ -1416,6 +1535,32 @@ function inferBlock( } hasCaptureArgument ||= instrValue.callee.effect === Effect.Capture; + if ( + !isSetStateType(instrValue.callee.identifier) && + instrValue.callee.effect === Effect.Read && + signature?.hookKind == null + ) { + const allRead = instrValue.args.every(arg => { + switch (arg.kind) { + case 'Identifier': + return arg.effect === Effect.Read; + case 'Spread': + return arg.place.effect === Effect.Read; + default: + assertExhaustive(arg, 'Unexpected arg kind'); + } + }); + if (allRead) { + functionEffects.push({ + kind: 'ImmutableFunctionCall', + lvalue: instr.lvalue.identifier.id, + callee: instrValue.callee.identifier.id, + loc: instrValue.loc, + global: state.kind(instrValue.callee).kind === ValueKind.Global, + }); + } + } + state.initialize(instrValue, returnValueKind); state.define(instr.lvalue, instrValue); instr.lvalue.effect = hasCaptureArgument @@ -1544,6 +1689,33 @@ function inferBlock( } hasCaptureArgument ||= instrValue.receiver.effect === Effect.Capture; + if ( + !isSetStateType(instrValue.property.identifier) && + instrValue.receiver.effect === Effect.Read && + instrValue.property.effect === Effect.Read && + signature?.hookKind == null + ) { + const allRead = instrValue.args.every(arg => { + switch (arg.kind) { + case 'Identifier': + return arg.effect === Effect.Read; + case 'Spread': + return arg.place.effect === Effect.Read; + default: + assertExhaustive(arg, 'Unexpected arg kind'); + } + }); + if (allRead) { + functionEffects.push({ + kind: 'ImmutableFunctionCall', + lvalue: instr.lvalue.identifier.id, + callee: instrValue.property.identifier.id, + loc: instrValue.loc, + global: state.kind(instrValue.property).kind === ValueKind.Global, + }); + } + } + state.initialize(instrValue, returnValueKind); state.define(instr.lvalue, instrValue); instr.lvalue.effect = hasCaptureArgument @@ -2173,7 +2345,9 @@ function areArgumentsImmutableAndNonMutating( } function isEffectSafeOutsideRender(effect: FunctionEffect): boolean { - return effect.kind === 'GlobalMutation'; + return ( + effect.kind === 'GlobalMutation' || effect.kind === 'ImmutableFunctionCall' + ); } function getWriteErrorReason(abstractValue: AbstractValue): string {