Skip to content

Commit

Permalink
Merge pull request #1360 from onflow/supun/number-supertypes
Browse files Browse the repository at this point in the history
Disallow arithmetic, comparison and bitwise operations on number supertypes
  • Loading branch information
SupunS committed Feb 3, 2022
2 parents d3124c5 + 4d806f5 commit 2b4b36a
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 276 deletions.
36 changes: 22 additions & 14 deletions runtime/sema/check_binary_expression.go
Expand Up @@ -130,9 +130,9 @@ func (checker *Checker) VisitBinaryExpression(expression *ast.BinaryExpression)

case BinaryOperationKindEquality:
return checker.checkBinaryExpressionEquality(
expression, operation, operationKind,
expression, operation,
leftType, rightType,
leftIsInvalid, rightIsInvalid, anyInvalid,
anyInvalid,
)

default:
Expand Down Expand Up @@ -163,16 +163,16 @@ func (checker *Checker) VisitBinaryExpression(expression *ast.BinaryExpression)
switch operationKind {
case BinaryOperationKindBooleanLogic:
return checker.checkBinaryExpressionBooleanLogic(
expression, operation, operationKind,
expression, operation,
leftType, rightType,
leftIsInvalid, rightIsInvalid, anyInvalid,
)

case BinaryOperationKindNilCoalescing:
resultType := checker.checkBinaryExpressionNilCoalescing(
expression, operation, operationKind,
expression, operation,
leftType, rightType,
leftIsInvalid, rightIsInvalid, anyInvalid,
leftIsInvalid, rightIsInvalid,
)

checker.Elaboration.BinaryExpressionResultTypes[expression] = resultType
Expand Down Expand Up @@ -255,12 +255,23 @@ func (checker *Checker) checkBinaryExpressionArithmeticOrNonEqualityComparisonOr
}
}

// check both types are equal
shouldReportInvalidOperands := func(leftType, rightType Type) bool {
// If errors are already reported, then avoid reporting them again.
if reportedInvalidOperands || anyInvalid {
return false
}

// Both types should be equal.
if !leftType.Equal(rightType) {
return true
}

if !reportedInvalidOperands &&
!anyInvalid &&
!leftType.Equal(rightType) {
// Arithmetic, bitwise and non-equality comparison operators
// are not supported for numeric supertypes.
return isNumericSuperType(leftType)
}

if shouldReportInvalidOperands(leftType, rightType) {
checker.report(
&InvalidBinaryOperandsError{
Operation: operation,
Expand Down Expand Up @@ -288,9 +299,8 @@ func (checker *Checker) checkBinaryExpressionArithmeticOrNonEqualityComparisonOr
func (checker *Checker) checkBinaryExpressionEquality(
expression *ast.BinaryExpression,
operation ast.Operation,
operationKind BinaryOperationKind,
leftType, rightType Type,
leftIsInvalid, rightIsInvalid, anyInvalid bool,
anyInvalid bool,
) (resultType Type) {

resultType = BoolType
Expand Down Expand Up @@ -319,7 +329,6 @@ func (checker *Checker) checkBinaryExpressionEquality(
func (checker *Checker) checkBinaryExpressionBooleanLogic(
expression *ast.BinaryExpression,
operation ast.Operation,
operationKind BinaryOperationKind,
leftType, rightType Type,
leftIsInvalid, rightIsInvalid, anyInvalid bool,
) Type {
Expand Down Expand Up @@ -371,9 +380,8 @@ func (checker *Checker) checkBinaryExpressionBooleanLogic(
func (checker *Checker) checkBinaryExpressionNilCoalescing(
expression *ast.BinaryExpression,
operation ast.Operation,
operationKind BinaryOperationKind,
leftType, rightType Type,
leftIsInvalid, rightIsInvalid, anyInvalid bool,
leftIsInvalid, rightIsInvalid bool,
) Type {
leftOptional, leftIsOptional := leftType.(*OptionalType)

Expand Down
53 changes: 47 additions & 6 deletions runtime/sema/type.go
Expand Up @@ -753,6 +753,7 @@ type IntegerRangedType interface {
Type
MinInt() *big.Int
MaxInt() *big.Int
IsSuperType() bool
}

type FractionalRangedType interface {
Expand Down Expand Up @@ -856,6 +857,7 @@ type NumericType struct {
supportsSaturatingDivide bool
memberResolvers map[string]MemberResolver
memberResolversOnce sync.Once
isSuperType bool
}

var _ IntegerRangedType = &NumericType{}
Expand Down Expand Up @@ -1003,6 +1005,15 @@ func (t *NumericType) initializeMemberResolvers() {
})
}

func (t *NumericType) AsSuperType() *NumericType {
t.isSuperType = true
return t
}

func (t *NumericType) IsSuperType() bool {
return t.isSuperType
}

// FixedPointNumericType represents all the types in the fixed-point range.
//
type FixedPointNumericType struct {
Expand All @@ -1019,6 +1030,7 @@ type FixedPointNumericType struct {
supportsSaturatingDivide bool
memberResolvers map[string]MemberResolver
memberResolversOnce sync.Once
isSuperType bool
}

var _ FractionalRangedType = &FixedPointNumericType{}
Expand Down Expand Up @@ -1195,25 +1207,38 @@ func (t *FixedPointNumericType) initializeMemberResolvers() {
})
}

func (t *FixedPointNumericType) AsSuperType() *FixedPointNumericType {
t.isSuperType = true
return t
}

func (t *FixedPointNumericType) IsSuperType() bool {
return t.isSuperType
}

// Numeric types

var (

// NumberType represents the super-type of all number types
NumberType = NewNumericType(NumberTypeName).
WithTag(NumberTypeTag)
WithTag(NumberTypeTag).
AsSuperType()

// SignedNumberType represents the super-type of all signed number types
SignedNumberType = NewNumericType(SignedNumberTypeName).
WithTag(SignedNumberTypeTag)
WithTag(SignedNumberTypeTag).
AsSuperType()

// IntegerType represents the super-type of all integer types
IntegerType = NewNumericType(IntegerTypeName).
WithTag(IntegerTypeTag)
WithTag(IntegerTypeTag).
AsSuperType()

// SignedIntegerType represents the super-type of all signed integer types
SignedIntegerType = NewNumericType(SignedIntegerTypeName).
WithTag(SignedIntegerTypeTag)
WithTag(SignedIntegerTypeTag).
AsSuperType()

// IntType represents the arbitrary-precision integer type `Int`
IntType = NewNumericType(IntTypeName).
Expand Down Expand Up @@ -1359,11 +1384,13 @@ var (

// FixedPointType represents the super-type of all fixed-point types
FixedPointType = NewNumericType(FixedPointTypeName).
WithTag(FixedPointTypeTag)
WithTag(FixedPointTypeTag).
AsSuperType()

// SignedFixedPointType represents the super-type of all signed fixed-point types
SignedFixedPointType = NewNumericType(SignedFixedPointTypeName).
WithTag(SignedFixedPointTypeTag)
WithTag(SignedFixedPointTypeTag).
AsSuperType()

// Fix64Type represents the 64-bit signed decimal fixed-point type `Fix64`
// which has a scale of Fix64Scale, and checks for overflow and underflow
Expand Down Expand Up @@ -4614,6 +4641,8 @@ func (t *ReferenceType) Resolve(_ *TypeParameterTypeOrderedMap) Type {
// AddressType represents the address type
type AddressType struct{}

var _ IntegerRangedType = &AddressType{}

func (*AddressType) IsType() {}

func (t *AddressType) Tag() TypeTag {
Expand Down Expand Up @@ -4680,6 +4709,10 @@ func (*AddressType) MaxInt() *big.Int {
return AddressTypeMaxIntBig
}

func (*AddressType) IsSuperType() bool {
return false
}

func (*AddressType) Unify(_ Type, _ *TypeParameterTypeOrderedMap, _ func(err error), _ ast.Range) bool {
return false
}
Expand Down Expand Up @@ -6264,3 +6297,11 @@ func getFieldNames(members []*Member) []string {

return fields
}

func isNumericSuperType(typ Type) bool {
if numberType, ok := typ.(IntegerRangedType); ok {
return numberType.IsSuperType()
}

return false
}
9 changes: 8 additions & 1 deletion runtime/tests/checker/integer_test.go
Expand Up @@ -527,7 +527,14 @@ func TestCheckIntegerLiteralArguments(t *testing.T) {
),
)

require.NoError(t, err)
switch ty {
case sema.IntegerType,
sema.SignedIntegerType:
errs := ExpectCheckerErrors(t, err, 1)
assert.IsType(t, &sema.InvalidBinaryOperandsError{}, errs[0])
default:
require.NoError(t, err)
}
})
}
}
Expand Down
84 changes: 84 additions & 0 deletions runtime/tests/checker/operations_test.go
Expand Up @@ -466,3 +466,87 @@ func TestCheckInvalidCompositeEquality(t *testing.T) {
test(compositeKind)
}
}

func TestCheckNumericSuperTypeBinaryOperations(t *testing.T) {

t.Parallel()

supertypes := []sema.Type{
sema.NumberType,
sema.SignedNumberType,
sema.IntegerType,
sema.SignedIntegerType,
sema.FixedPointType,
sema.SignedFixedPointType,
}

t.Run("non saturating operations", func(t *testing.T) {

t.Parallel()

operations := []ast.Operation{
ast.OperationPlus,
ast.OperationMinus,
ast.OperationMul,
ast.OperationDiv,
ast.OperationMod,
ast.OperationBitwiseAnd,
ast.OperationBitwiseOr,
ast.OperationBitwiseXor,
ast.OperationBitwiseRightShift,
ast.OperationBitwiseLeftShift,
ast.OperationLess,
ast.OperationLessEqual,
ast.OperationGreater,
ast.OperationGreaterEqual,
}

for _, supertype := range supertypes {
for _, op := range operations {
t.Run(fmt.Sprintf("%s,%s", supertype.String(), op.String()), func(t *testing.T) {
code := fmt.Sprintf(`
fun test(a: %[1]s, b: %[1]s): AnyStruct {
return a %[2]s b
}`,
supertype.String(),
op.Symbol(),
)

_, err := ParseAndCheck(t, code)
errs := ExpectCheckerErrors(t, err, 1)
assert.IsType(t, &sema.InvalidBinaryOperandsError{}, errs[0])
})
}
}
})

t.Run("saturating operations", func(t *testing.T) {
t.Parallel()

saturatingFunctions := []string{
sema.NumericTypeSaturatingAddFunctionName,
sema.NumericTypeSaturatingSubtractFunctionName,
sema.NumericTypeSaturatingMultiplyFunctionName,
sema.NumericTypeSaturatingMultiplyFunctionName,
sema.NumericTypeSaturatingDivideFunctionName,
}

for _, supertype := range supertypes {
for _, saturatingFunc := range saturatingFunctions {
t.Run(fmt.Sprintf("%s,%s", supertype.String(), saturatingFunc), func(t *testing.T) {
code := fmt.Sprintf(`
fun test(a: %[1]s, b: %[1]s): AnyStruct {
return a.%[2]s(b)
}`,
supertype.String(),
saturatingFunc,
)

_, err := ParseAndCheck(t, code)
errs := ExpectCheckerErrors(t, err, 1)
assert.IsType(t, &sema.NotDeclaredMemberError{}, errs[0])
})
}
}
})
}

0 comments on commit 2b4b36a

Please sign in to comment.