Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 58 additions & 118 deletions internal/analysis/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package analysis

import (
"math/big"
"math/rand"
"slices"
"strings"

Expand Down Expand Up @@ -87,6 +88,7 @@ var Builtins = map[string]FnCallResolution{
type Diagnostic struct {
Range parser.Range
Kind DiagnosticKind
Id int32
}

type CheckResult struct {
Expand Down Expand Up @@ -149,30 +151,30 @@ func newCheckResult(program parser.Program) CheckResult {
}

func (res *CheckResult) check() {
for _, varDecl := range res.Program.Vars {
if varDecl.Type != nil {
res.checkVarType(*varDecl.Type)
}
if res.Program.Vars != nil {
for _, varDecl := range res.Program.Vars.Declarations {
if varDecl.Type != nil {
res.checkVarType(*varDecl.Type)
}

if varDecl.Name != nil {
res.checkDuplicateVars(*varDecl.Name, varDecl)
}
if varDecl.Name != nil {
res.checkDuplicateVars(*varDecl.Name, varDecl)
}

if varDecl.Origin != nil {
res.checkVarOrigin(*varDecl.Origin, varDecl)
if varDecl.Origin != nil {
res.checkVarOrigin(*varDecl.Origin, varDecl)
}
}
}

for _, statement := range res.Program.Statements {
res.unboundedAccountInSend = nil
res.checkStatement(statement)
}

// after static AST traversal is complete, check for unused vars
for name, rng := range res.unusedVars {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: rng,
Kind: &UnusedVar{Name: name},
})
res.pushDiagnostic(rng, UnusedVar{Name: name})
}
}

Expand Down Expand Up @@ -214,19 +216,12 @@ func CheckSource(source string) CheckResult {
result := parser.Parse(source)
res := newCheckResult(result.Value)
for _, parserError := range result.Errors {
res.Diagnostics = append(res.Diagnostics, parsingErrorToDiagnostic(parserError))
res.pushDiagnostic(parserError.Range, Parsing{Description: parserError.Msg})
}
res.check()
return res
}

func parsingErrorToDiagnostic(parserError parser.ParserError) Diagnostic {
return Diagnostic{
Range: parserError.Range,
Kind: &Parsing{Description: parserError.Msg},
}
}

func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) {
resolution, resolved := res.fnCallResolution[fnCall.Caller]

Expand All @@ -244,12 +239,9 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) {

if actualArgs < expectedArgs {
// Too few args
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: fnCall.Range,
Kind: &BadArity{
Expected: expectedArgs,
Actual: actualArgs,
},
res.pushDiagnostic(fnCall.Range, BadArity{
Expected: expectedArgs,
Actual: actualArgs,
})
} else if actualArgs > expectedArgs {
// Too many args
Expand All @@ -262,12 +254,9 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) {
End: lastIllegalArg.GetRange().End,
}

res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: rng,
Kind: &BadArity{
Expected: expectedArgs,
Actual: actualArgs,
},
res.pushDiagnostic(rng, BadArity{
Expected: expectedArgs,
Actual: actualArgs,
})
}

Expand All @@ -287,11 +276,8 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) {
res.checkExpression(arg, TypeAny)
}

res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: fnCall.Caller.Range,
Kind: &UnknownFunction{
Name: fnCall.Caller.Name,
},
res.pushDiagnostic(fnCall.Caller.Range, UnknownFunction{
Name: fnCall.Caller.Name,
})
}
}
Expand All @@ -302,20 +288,14 @@ func isTypeAllowed(typeName string) bool {

func (res *CheckResult) checkVarType(typeDecl parser.TypeDecl) {
if !isTypeAllowed(typeDecl.Name) {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: typeDecl.Range,
Kind: &InvalidType{Name: typeDecl.Name},
})
res.pushDiagnostic(typeDecl.Range, InvalidType{Name: typeDecl.Name})
}
}

func (res *CheckResult) checkDuplicateVars(variableName parser.Variable, decl parser.VarDeclaration) {
// check there aren't duplicate variables
if _, ok := res.declaredVars[variableName.Name]; ok {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: variableName.Range,
Kind: &DuplicateVariable{Name: variableName.Name},
})
res.pushDiagnostic(variableName.Range, DuplicateVariable{Name: variableName.Name})
} else {
res.declaredVars[variableName.Name] = decl
res.unusedVars[variableName.Name] = variableName.Range
Expand All @@ -337,20 +317,17 @@ func (res *CheckResult) checkVarOrigin(fnCall parser.FnCall, decl parser.VarDecl
}

func (res *CheckResult) checkExpression(lit parser.ValueExpr, requiredType string) {
actualType := res.checkTypeOf(lit)
actualType := res.checkTypeOf(lit, requiredType)
res.assertHasType(lit, requiredType, actualType)
}

func (res *CheckResult) checkTypeOf(lit parser.ValueExpr) string {
func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) string {
switch lit := lit.(type) {
case *parser.Variable:
if varDeclaration, ok := res.declaredVars[lit.Name]; ok {
res.varResolution[lit] = varDeclaration
} else {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: lit.Range,
Kind: &UnboundVariable{Name: lit.Name},
})
res.pushDiagnostic(lit.Range, UnboundVariable{Name: lit.Name, Type: typeHint})
}
delete(res.unusedVars, lit.Name)

Expand Down Expand Up @@ -408,19 +385,16 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr) string {
}

func (res *CheckResult) checkInfixOverload(bin *parser.BinaryInfix, allowed []string) string {
leftType := res.checkTypeOf(bin.Left)
leftType := res.checkTypeOf(bin.Left, allowed[0])

if leftType == TypeAny || slices.Contains(allowed, leftType) {
res.checkExpression(bin.Right, leftType)
return leftType
}

res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: bin.Left.GetRange(),
Kind: &TypeMismatch{
Expected: strings.Join(allowed, "|"),
Got: leftType,
},
res.pushDiagnostic(bin.Left.GetRange(), TypeMismatch{
Expected: strings.Join(allowed, "|"),
Got: leftType,
})
return TypeAny
}
Expand All @@ -430,14 +404,10 @@ func (res *CheckResult) assertHasType(lit parser.ValueExpr, requiredType string,
return
}

res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: lit.GetRange(),
Kind: &TypeMismatch{
Expected: requiredType,
Got: actualType,
},
res.pushDiagnostic(lit.GetRange(), TypeMismatch{
Expected: requiredType,
Got: actualType,
})

}

func (res *CheckResult) checkSentValue(sentValue parser.SentValue) {
Expand All @@ -455,30 +425,21 @@ func (res *CheckResult) checkSource(source parser.Source) {
}

if res.unboundedAccountInSend != nil {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: source.GetRange(),
Kind: &UnboundedAccountIsNotLast{},
})
res.pushDiagnostic(source.GetRange(), UnboundedAccountIsNotLast{})
}

switch source := source.(type) {
case *parser.SourceAccount:
res.checkExpression(source.ValueExpr, TypeAccount)
if account, ok := source.ValueExpr.(*parser.AccountInterpLiteral); ok {
if account.IsWorld() && res.unboundedSend {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: source.GetRange(),
Kind: &InvalidUnboundedAccount{},
})
res.pushDiagnostic(source.GetRange(), InvalidUnboundedAccount{})
} else if account.IsWorld() {
res.unboundedAccountInSend = account
}

if _, emptied := res.emptiedAccount[account.String()]; emptied && !account.IsWorld() {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Kind: &EmptiedAccount{Name: account.String()},
Range: account.Range,
})
res.pushDiagnostic(account.Range, EmptiedAccount{Name: account.String()})
}

res.emptiedAccount[account.String()] = struct{}{}
Expand All @@ -497,10 +458,7 @@ func (res *CheckResult) checkSource(source parser.Source) {
}

if res.unboundedSend {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: source.Address.GetRange(),
Kind: &InvalidUnboundedAccount{},
})
res.pushDiagnostic(source.Address.GetRange(), InvalidUnboundedAccount{})
}

res.checkExpression(source.Address, TypeAccount)
Expand Down Expand Up @@ -528,10 +486,7 @@ func (res *CheckResult) checkSource(source parser.Source) {

case *parser.SourceAllotment:
if res.unboundedSend {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Kind: &NoAllotmentInSendAll{},
Range: source.Range,
})
res.pushDiagnostic(source.Range, NoAllotmentInSendAll{})
}

var remainingAllotment *parser.RemainingAllotment = nil
Expand All @@ -555,10 +510,7 @@ func (res *CheckResult) checkSource(source parser.Source) {
if isLast {
remainingAllotment = allotment
} else {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: source.Range,
Kind: &RemainingIsNotLast{},
})
res.pushDiagnostic(source.Range, RemainingIsNotLast{})
}
}

Expand Down Expand Up @@ -637,10 +589,7 @@ func (res *CheckResult) tryEvaluatingPortionExpr(expr parser.ValueExpr) *big.Rat
}

if right.Cmp(big.NewInt(0)) == 0 {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Kind: &DivByZero{},
Range: expr.Range,
})
res.pushDiagnostic(expr.Range, DivByZero{})
return nil
}

Expand Down Expand Up @@ -708,10 +657,7 @@ func (res *CheckResult) checkDestination(destination parser.Destination) {
if isLast {
remainingAllotment = allotment
} else {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: destination.Range,
Kind: &RemainingIsNotLast{},
})
res.pushDiagnostic(destination.Range, RemainingIsNotLast{})
}
}

Expand Down Expand Up @@ -746,36 +692,22 @@ func (res *CheckResult) checkHasBadAllotmentSum(

if cmp == -1 && len(variableLiterals) == 1 {
var value big.Rat
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: variableLiterals[0].GetRange(),
Kind: &FixedPortionVariable{
Value: *value.Sub(big.NewRat(1, 1), &sum),
},
res.pushDiagnostic(variableLiterals[0].GetRange(), FixedPortionVariable{
Value: *value.Sub(big.NewRat(1, 1), &sum),
})
} else {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: rng,
Kind: &BadAllotmentSum{
Sum: sum,
},
})
res.pushDiagnostic(rng, BadAllotmentSum{Sum: sum})
}

// sum == 1
case 0:
for _, varLit := range variableLiterals {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: varLit.GetRange(),
Kind: &FixedPortionVariable{
Value: *big.NewRat(0, 1),
},
res.pushDiagnostic(varLit.GetRange(), FixedPortionVariable{
Value: *big.NewRat(0, 1),
})
}
if remaining != nil {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: remaining.Range,
Kind: &RedundantRemaining{},
})
res.pushDiagnostic(remaining.Range, RedundantRemaining{})
}
}
}
Expand Down Expand Up @@ -820,3 +752,11 @@ func (res *CheckResult) enterCappedSource() func() {
exitCloneUnboundedSend()
}
}

func (res *CheckResult) pushDiagnostic(rng parser.Range, kind DiagnosticKind) {
res.Diagnostics = append(res.Diagnostics, Diagnostic{
Range: rng,
Kind: kind,
Id: rand.Int31(),
})
}
Loading