Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursively check if assignment target expression is valid, don't copy returned values #288

Merged
merged 3 commits into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 0 additions & 12 deletions runtime/ast/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ type Expression interface {
AcceptExp(ExpressionVisitor) Repr
}

// TargetExpression

type TargetExpression interface {
isTargetExpression()
}

// BoolExpression

type BoolExpression struct {
Expand Down Expand Up @@ -285,8 +279,6 @@ type IdentifierExpression struct {

func (*IdentifierExpression) isExpression() {}

func (*IdentifierExpression) isTargetExpression() {}

func (*IdentifierExpression) isIfStatementTest() {}

func (e *IdentifierExpression) Accept(visitor Visitor) Repr {
Expand Down Expand Up @@ -385,8 +377,6 @@ type MemberExpression struct {

func (*MemberExpression) isExpression() {}

func (*MemberExpression) isTargetExpression() {}

func (*MemberExpression) isIfStatementTest() {}

func (*MemberExpression) isAccessExpression() {}
Expand Down Expand Up @@ -436,8 +426,6 @@ type IndexExpression struct {

func (*IndexExpression) isExpression() {}

func (*IndexExpression) isTargetExpression() {}

func (*IndexExpression) isIfStatementTest() {}

func (*IndexExpression) isAccessExpression() {}
Expand Down
2 changes: 1 addition & 1 deletion runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ func (interpreter *Interpreter) VisitReturnStatement(statement *ast.ReturnStatem
valueType := interpreter.Checker.Elaboration.ReturnStatementValueTypes[statement]
returnType := interpreter.Checker.Elaboration.ReturnStatementReturnTypes[statement]

value = interpreter.copyAndConvert(value, valueType, returnType)
value = interpreter.convertAndBox(value, valueType, returnType)

return functionReturn{value}
})
Expand Down
18 changes: 17 additions & 1 deletion runtime/sema/check_assignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (checker *Checker) visitAssignmentValueType(
// Check the target is valid (e.g. identifier expression,
// indexing expression, or member access expression)

if _, ok := targetExpression.(ast.TargetExpression); !ok {
if !IsValidAssignmentTargetExpression(targetExpression) {
checker.report(
&InvalidAssignmentTargetError{
Range: ast.NewRangeFromPositioned(targetExpression),
Expand Down Expand Up @@ -370,3 +370,19 @@ func (checker *Checker) visitMemberExpressionAssignment(

return member.TypeAnnotation.Type
}

func IsValidAssignmentTargetExpression(expression ast.Expression) bool {
switch expression := expression.(type) {
case *ast.IdentifierExpression:
return true

case *ast.IndexExpression:
return IsValidAssignmentTargetExpression(expression.TargetExpression)

case *ast.MemberExpression:
return IsValidAssignmentTargetExpression(expression.Expression)
turbolent marked this conversation as resolved.
Show resolved Hide resolved

default:
return false
}
}
4 changes: 2 additions & 2 deletions runtime/sema/check_swap.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (checker *Checker) VisitSwapStatement(swap *ast.SwapStatement) ast.Repr {

checkRight := true

if _, leftIsTarget := swap.Left.(ast.TargetExpression); !leftIsTarget {
if !IsValidAssignmentTargetExpression(swap.Left) {
checker.report(
&InvalidSwapExpressionError{
Side: common.OperandSideLeft,
Expand All @@ -63,7 +63,7 @@ func (checker *Checker) VisitSwapStatement(swap *ast.SwapStatement) ast.Repr {
}
}

if _, rightIsTarget := swap.Right.(ast.TargetExpression); !rightIsTarget {
if !IsValidAssignmentTargetExpression(swap.Right) {
checker.report(
&InvalidSwapExpressionError{
Side: common.OperandSideRight,
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/check_variable_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (checker *Checker) visitVariableDeclaration(declaration *ast.VariableDeclar
// The first expression must be a target expression (e.g. identifier expression,
// indexing expression, or member access expression)

if _, firstIsTarget := declaration.Value.(ast.TargetExpression); !firstIsTarget {
if !IsValidAssignmentTargetExpression(declaration.Value) {
checker.report(
&InvalidAssignmentTargetError{
Range: ast.NewRangeFromPositioned(declaration.Value),
Expand Down
2 changes: 1 addition & 1 deletion runtime/sema/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2007,7 +2007,7 @@ type InvalidAssignmentTargetError struct {
}

func (e *InvalidAssignmentTargetError) Error() string {
return "cannot assign to expression"
return "cannot assign to unassignable expression"
}

func (*InvalidAssignmentTargetError) isSemanticError() {}
Expand Down
137 changes: 129 additions & 8 deletions runtime/tests/checker/assignment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,136 @@ func TestCheckInvalidAssignmentTargetExpression(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
fun f() {}
t.Run("function invocation result", func(t *testing.T) {

fun test() {
f() = 2
}
`)
t.Parallel()

errs := ExpectCheckerErrors(t, err, 1)
_, err := ParseAndCheck(t, `
fun f() {}

fun test() {
f() = 2
}
`)

errs := ExpectCheckerErrors(t, err, 1)

assert.IsType(t, &sema.InvalidAssignmentTargetError{}, errs[0])
})

t.Run("index into function invocation result", func(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
fun f(): [Int] {
return [1]
}

fun test() {
f()[0] = 2
}
`)

errs := ExpectCheckerErrors(t, err, 1)

assert.IsType(t, &sema.InvalidAssignmentTargetError{}, errs[0])
})

t.Run("assess member of function invocation result", func(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
struct S {
var x: Int

init() {
self.x = 1
}
}

let s = S()

fun f(): S {
return s
}

fun test() {
f().x = 2
}
`)

errs := ExpectCheckerErrors(t, err, 1)

assert.IsType(t, &sema.InvalidAssignmentTargetError{}, errs[0])
})

t.Run("index into identifier", func(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
let xs = [1]

fun test() {
xs[0] = 2
}
`)

require.NoError(t, err)
})

t.Run("access member of identifier", func(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
struct S {
var x: Int

init() {
self.x = 1
}
}

let s = S()

fun test() {
s.x = 2
}
`)

require.NoError(t, err)
})

t.Run("index into array literal", func(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
fun test() {
[1][0] = 2
}
`)

errs := ExpectCheckerErrors(t, err, 1)

assert.IsType(t, &sema.InvalidAssignmentTargetError{}, errs[0])
})

t.Run("index into dictionary literal", func(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
fun test() {
{"a": 1}["a"] = 2
}
`)

errs := ExpectCheckerErrors(t, err, 1)

assert.IsType(t, &sema.InvalidAssignmentTargetError{}, errs[0])
assert.IsType(t, &sema.InvalidAssignmentTargetError{}, errs[0])
})
turbolent marked this conversation as resolved.
Show resolved Hide resolved
}
48 changes: 48 additions & 0 deletions runtime/tests/interpreter/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8127,3 +8127,51 @@ func TestInterpretNestedDestroy(t *testing.T) {
logs,
)
}

// TestInterpretInternalAssignment ensures that a modification of an "internal" value
// is not possible, because the value that is assigned into is a copy
//
func TestInterpretInternalAssignment(t *testing.T) {

t.Parallel()

inter := parseCheckAndInterpret(t, `
struct S {
priv let xs: {String: Int}

init() {
self.xs = {"a": 1}
}

fun getXS(): {String: Int} {
return self.xs
}
}

fun test(): [{String: Int}] {
let s = S()
let xs = s.getXS()
xs["b"] = 2
return [xs, s.getXS()]
}
`)

value, err := inter.Invoke("test")
require.NoError(t, err)

assert.Equal(t,
interpreter.NewArrayValueUnownedNonCopying(
interpreter.NewDictionaryValueUnownedNonCopying(
interpreter.NewStringValue("a"),
interpreter.NewIntValueFromInt64(1),
interpreter.NewStringValue("b"),
interpreter.NewIntValueFromInt64(2),
),
interpreter.NewDictionaryValueUnownedNonCopying(
interpreter.NewStringValue("a"),
interpreter.NewIntValueFromInt64(1),
),
),
value,
)
}