diff --git a/runtime/sema/check_assignment.go b/runtime/sema/check_assignment.go index dedef05e5..e8a3dd03e 100644 --- a/runtime/sema/check_assignment.go +++ b/runtime/sema/check_assignment.go @@ -135,6 +135,34 @@ func (checker *Checker) checkAssignment( return } +func (checker *Checker) rootOfAccessChain(target ast.Expression) (baseVariable *Variable, accessChain []Type) { + var inAccessChain = true + + // seek the variable expression (if it exists) at the base of the access chain + for inAccessChain { + switch targetExp := target.(type) { + case *ast.IdentifierExpression: + baseVariable = checker.valueActivations.Find(targetExp.Identifier.Identifier) + if baseVariable != nil { + accessChain = append(accessChain, baseVariable.Type) + } + inAccessChain = false + case *ast.IndexExpression: + target = targetExp.TargetExpression + elementType := checker.Elaboration.IndexExpressionTypes(targetExp).IndexedType.ElementType(true) + accessChain = append(accessChain, elementType) + case *ast.MemberExpression: + target = targetExp.Expression + memberType, _, _, _ := checker.visitMember(targetExp, true) + accessChain = append(accessChain, memberType) + default: + inAccessChain = false + } + } + + return +} + // We have to prevent any writes to references, since we cannot know where the value // pointed to by the reference may have come from. Similarly, we can never safely assign // to a resource; because resources are moved instead of copied, we cannot currently @@ -162,31 +190,7 @@ func (checker *Checker) enforceViewAssignment(assignment ast.Statement, target a return } - var baseVariable *Variable - var accessChain = make([]Type, 0) - var inAccessChain = true - - // seek the variable expression (if it exists) at the base of the access chain - for inAccessChain { - switch targetExp := target.(type) { - case *ast.IdentifierExpression: - baseVariable = checker.valueActivations.Find(targetExp.Identifier.Identifier) - if baseVariable != nil { - accessChain = append(accessChain, baseVariable.Type) - } - inAccessChain = false - case *ast.IndexExpression: - target = targetExp.TargetExpression - elementType := checker.Elaboration.IndexExpressionTypes(targetExp).IndexedType.ElementType(true) - accessChain = append(accessChain, elementType) - case *ast.MemberExpression: - target = targetExp.Expression - memberType, _, _, _ := checker.visitMember(targetExp, true) - accessChain = append(accessChain, memberType) - default: - inAccessChain = false - } - } + baseVariable, accessChain := checker.rootOfAccessChain(target) // if the base of the access chain is not a variable, then we cannot make any static guarantees about // whether or not it is a local struct-kinded variable. E.g. in the case of `(b ? s1 : s2).x`, we can't @@ -312,7 +316,7 @@ func (checker *Checker) visitAssignmentValueType( func (checker *Checker) visitIdentifierExpressionAssignment( target *ast.IdentifierExpression, ) (targetType Type) { - identifier := target.Identifier.Identifier + identifier := target.Identifier // check identifier was declared before variable := checker.findAndCheckValueVariable(target, true) @@ -320,11 +324,15 @@ func (checker *Checker) visitIdentifierExpressionAssignment( return InvalidType } + if variable.Type.IsResourceType() { + checker.checkResourceVariableCapturingInFunction(variable, identifier) + } + // check identifier is not a constant if variable.IsConstant { checker.report( &AssignmentToConstantError{ - Name: identifier, + Name: identifier.Identifier, Range: ast.NewRangeFromPositioned(checker.memoryGauge, target), }, ) @@ -483,7 +491,24 @@ func (checker *Checker) visitMemberExpressionAssignment( reportAssignmentToConstant() } - return memberType + if memberType.IsResourceType() { + // if the member is a resource, check that it is not captured in a function, + // based off the activation depth of the root of the access chain, i.e. `a` in `a.b.c` + // we only want to make this check for transactions, as they are the only "resource-like" types + // (that can contain resources and must destroy them in their `execute` blocks), that are themselves + // not checked by the capturing logic, since they are not themselves resources. + baseVariable, _ := checker.rootOfAccessChain(target) + + if baseVariable == nil { + return + } + + if _, isTransaction := baseVariable.Type.(*TransactionType); isTransaction { + checker.checkResourceVariableCapturingInFunction(baseVariable, member.Identifier) + } + } + + return } func IsValidAssignmentTargetExpression(expression ast.Expression) bool { diff --git a/runtime/tests/checker/access_test.go b/runtime/tests/checker/access_test.go index 916a7258f..e91c640d3 100644 --- a/runtime/tests/checker/access_test.go +++ b/runtime/tests/checker/access_test.go @@ -1826,6 +1826,7 @@ func TestCheckAccessImportGlobalValueVariableDeclarationWithSecondValue(t *testi `) require.NoError(t, err) + // these capture x and y because they are created in a different file _, err = ParseAndCheckWithOptions(t, ` import x, y, createR from "imported" @@ -1849,7 +1850,9 @@ func TestCheckAccessImportGlobalValueVariableDeclarationWithSecondValue(t *testi }, ) - errs := RequireCheckerErrors(t, err, 7) + errs := RequireCheckerErrors(t, err, 9) + + // For `x` require.IsType(t, &sema.InvalidAccessError{}, errs[0]) assert.Equal(t, @@ -1859,23 +1862,29 @@ func TestCheckAccessImportGlobalValueVariableDeclarationWithSecondValue(t *testi require.IsType(t, &sema.ResourceCapturingError{}, errs[1]) - require.IsType(t, &sema.AssignmentToConstantError{}, errs[2]) + require.IsType(t, &sema.ResourceCapturingError{}, errs[2]) + + require.IsType(t, &sema.AssignmentToConstantError{}, errs[3]) assert.Equal(t, "x", - errs[2].(*sema.AssignmentToConstantError).Name, + errs[3].(*sema.AssignmentToConstantError).Name, ) - require.IsType(t, &sema.ResourceCapturingError{}, errs[3]) - require.IsType(t, &sema.ResourceCapturingError{}, errs[4]) - require.IsType(t, &sema.AssignmentToConstantError{}, errs[5]) + // For `y` + + require.IsType(t, &sema.ResourceCapturingError{}, errs[5]) + + require.IsType(t, &sema.ResourceCapturingError{}, errs[6]) + + require.IsType(t, &sema.AssignmentToConstantError{}, errs[7]) assert.Equal(t, "y", - errs[5].(*sema.AssignmentToConstantError).Name, + errs[7].(*sema.AssignmentToConstantError).Name, ) - require.IsType(t, &sema.ResourceCapturingError{}, errs[6]) + require.IsType(t, &sema.ResourceCapturingError{}, errs[8]) } func TestCheckContractNestedDeclarationPrivateAccess(t *testing.T) { diff --git a/runtime/tests/checker/resources_test.go b/runtime/tests/checker/resources_test.go index 99def38e6..db21d40bc 100644 --- a/runtime/tests/checker/resources_test.go +++ b/runtime/tests/checker/resources_test.go @@ -10055,3 +10055,144 @@ func TestCheckIndexingResourceLoss(t *testing.T) { assert.IsType(t, &sema.InvalidNestedResourceMoveError{}, errs[1]) }) } + +func TestCheckInvalidResourceCaptureOnLeft(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + var x: @AnyResource? <- nil + fun () { + x <-! [] + } + destroy x + } + `) + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.ResourceCapturingError{}, errs[0]) +} + +func TestCheckInvalidNestedResourceCapture(t *testing.T) { + + t.Parallel() + + t.Run("on right", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + transaction { + var x: @AnyResource? + prepare() { + self.x <- nil + } + execute { + fun() { + let y <- self.x + destroy y + } + } + } + `) + require.NoError(t, err) + }) + + t.Run("resource field on right", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R { + var x: @AnyResource? + init() { + self.x <- nil + } + fun foo() { + fun() { + let y <- self.x <- nil + destroy y + } + } + } + `) + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.ResourceCapturingError{}, errs[0]) + }) + + t.Run("on left", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + transaction { + var x: @AnyResource? + prepare() { + self.x <- nil + } + execute { + fun() { + self.x <-! nil + } + destroy self.x + } + } + `) + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.ResourceCapturingError{}, errs[0]) + }) + + t.Run("on left method scope", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + transaction { + var x: @AnyResource? + prepare() { + self.x <- nil + } + execute { + self.x <-! nil + destroy self.x + } + } + `) + require.NoError(t, err) + }) + + t.Run("contract self variable on left", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + contract C { + var x: @AnyResource? + init() { + self.x <- nil + } + fun foo() { + fun() { + self.x <-! nil + } + } + } + `) + require.NoError(t, err) + }) + + t.Run("contract self variable on left method scope", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + contract C { + var x: @AnyResource? + init() { + self.x <- nil + } + fun foo() { + self.x <-! nil + } + } + `) + require.NoError(t, err) + }) +} diff --git a/runtime/tests/checker/transactions_test.go b/runtime/tests/checker/transactions_test.go index 7262162c3..f3c512be2 100644 --- a/runtime/tests/checker/transactions_test.go +++ b/runtime/tests/checker/transactions_test.go @@ -523,3 +523,34 @@ func TestCheckInvalidTransactionSelfMoveIntoDictionaryLiteral(t *testing.T) { assert.IsType(t, &sema.InvalidMoveError{}, errs[0]) } + +func TestCheckInvalidTransactionResourceLoss(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + access(all) resource R{} + transaction { + var r: @R? + + prepare() { + self.r <- nil + } + + execute { + let writeback = fun() { + self.r <-! create R() + } + + var x <- self.r + + destroy x + writeback() + } + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.ResourceCapturingError{}, errs[0]) +}