diff --git a/internal/analysis/check.go b/internal/analysis/check.go index 277a1ba1..c7395cfd 100644 --- a/internal/analysis/check.go +++ b/internal/analysis/check.go @@ -2,6 +2,7 @@ package analysis import ( "math/big" + "math/rand" "slices" "strings" @@ -87,6 +88,7 @@ var Builtins = map[string]FnCallResolution{ type Diagnostic struct { Range parser.Range Kind DiagnosticKind + Id int32 } type CheckResult struct { @@ -149,19 +151,22 @@ func newCheckResult(program parser.Program) CheckResult { } func (res *CheckResult) check() { - for _, varDecl := range res.Program.Vars { - if varDecl.Type != nil { - res.checkVarType(*varDecl.Type) - } + if res.Program.Vars != nil { + for _, varDecl := range res.Program.Vars.Declarations { + if varDecl.Type != nil { + res.checkVarType(*varDecl.Type) + } - if varDecl.Name != nil { - res.checkDuplicateVars(*varDecl.Name, varDecl) - } + if varDecl.Name != nil { + res.checkDuplicateVars(*varDecl.Name, varDecl) + } - if varDecl.Origin != nil { - res.checkVarOrigin(*varDecl.Origin, varDecl) + if varDecl.Origin != nil { + res.checkVarOrigin(*varDecl.Origin, varDecl) + } } } + for _, statement := range res.Program.Statements { res.unboundedAccountInSend = nil res.checkStatement(statement) @@ -169,10 +174,7 @@ func (res *CheckResult) check() { // after static AST traversal is complete, check for unused vars for name, rng := range res.unusedVars { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: rng, - Kind: &UnusedVar{Name: name}, - }) + res.pushDiagnostic(rng, UnusedVar{Name: name}) } } @@ -214,19 +216,12 @@ func CheckSource(source string) CheckResult { result := parser.Parse(source) res := newCheckResult(result.Value) for _, parserError := range result.Errors { - res.Diagnostics = append(res.Diagnostics, parsingErrorToDiagnostic(parserError)) + res.pushDiagnostic(parserError.Range, Parsing{Description: parserError.Msg}) } res.check() return res } -func parsingErrorToDiagnostic(parserError parser.ParserError) Diagnostic { - return Diagnostic{ - Range: parserError.Range, - Kind: &Parsing{Description: parserError.Msg}, - } -} - func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) { resolution, resolved := res.fnCallResolution[fnCall.Caller] @@ -244,12 +239,9 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) { if actualArgs < expectedArgs { // Too few args - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: fnCall.Range, - Kind: &BadArity{ - Expected: expectedArgs, - Actual: actualArgs, - }, + res.pushDiagnostic(fnCall.Range, BadArity{ + Expected: expectedArgs, + Actual: actualArgs, }) } else if actualArgs > expectedArgs { // Too many args @@ -262,12 +254,9 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) { End: lastIllegalArg.GetRange().End, } - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: rng, - Kind: &BadArity{ - Expected: expectedArgs, - Actual: actualArgs, - }, + res.pushDiagnostic(rng, BadArity{ + Expected: expectedArgs, + Actual: actualArgs, }) } @@ -287,11 +276,8 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) { res.checkExpression(arg, TypeAny) } - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: fnCall.Caller.Range, - Kind: &UnknownFunction{ - Name: fnCall.Caller.Name, - }, + res.pushDiagnostic(fnCall.Caller.Range, UnknownFunction{ + Name: fnCall.Caller.Name, }) } } @@ -302,20 +288,14 @@ func isTypeAllowed(typeName string) bool { func (res *CheckResult) checkVarType(typeDecl parser.TypeDecl) { if !isTypeAllowed(typeDecl.Name) { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: typeDecl.Range, - Kind: &InvalidType{Name: typeDecl.Name}, - }) + res.pushDiagnostic(typeDecl.Range, InvalidType{Name: typeDecl.Name}) } } func (res *CheckResult) checkDuplicateVars(variableName parser.Variable, decl parser.VarDeclaration) { // check there aren't duplicate variables if _, ok := res.declaredVars[variableName.Name]; ok { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: variableName.Range, - Kind: &DuplicateVariable{Name: variableName.Name}, - }) + res.pushDiagnostic(variableName.Range, DuplicateVariable{Name: variableName.Name}) } else { res.declaredVars[variableName.Name] = decl res.unusedVars[variableName.Name] = variableName.Range @@ -337,20 +317,17 @@ func (res *CheckResult) checkVarOrigin(fnCall parser.FnCall, decl parser.VarDecl } func (res *CheckResult) checkExpression(lit parser.ValueExpr, requiredType string) { - actualType := res.checkTypeOf(lit) + actualType := res.checkTypeOf(lit, requiredType) res.assertHasType(lit, requiredType, actualType) } -func (res *CheckResult) checkTypeOf(lit parser.ValueExpr) string { +func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) string { switch lit := lit.(type) { case *parser.Variable: if varDeclaration, ok := res.declaredVars[lit.Name]; ok { res.varResolution[lit] = varDeclaration } else { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: lit.Range, - Kind: &UnboundVariable{Name: lit.Name}, - }) + res.pushDiagnostic(lit.Range, UnboundVariable{Name: lit.Name, Type: typeHint}) } delete(res.unusedVars, lit.Name) @@ -408,19 +385,16 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr) string { } func (res *CheckResult) checkInfixOverload(bin *parser.BinaryInfix, allowed []string) string { - leftType := res.checkTypeOf(bin.Left) + leftType := res.checkTypeOf(bin.Left, allowed[0]) if leftType == TypeAny || slices.Contains(allowed, leftType) { res.checkExpression(bin.Right, leftType) return leftType } - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: bin.Left.GetRange(), - Kind: &TypeMismatch{ - Expected: strings.Join(allowed, "|"), - Got: leftType, - }, + res.pushDiagnostic(bin.Left.GetRange(), TypeMismatch{ + Expected: strings.Join(allowed, "|"), + Got: leftType, }) return TypeAny } @@ -430,14 +404,10 @@ func (res *CheckResult) assertHasType(lit parser.ValueExpr, requiredType string, return } - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: lit.GetRange(), - Kind: &TypeMismatch{ - Expected: requiredType, - Got: actualType, - }, + res.pushDiagnostic(lit.GetRange(), TypeMismatch{ + Expected: requiredType, + Got: actualType, }) - } func (res *CheckResult) checkSentValue(sentValue parser.SentValue) { @@ -455,10 +425,7 @@ func (res *CheckResult) checkSource(source parser.Source) { } if res.unboundedAccountInSend != nil { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: source.GetRange(), - Kind: &UnboundedAccountIsNotLast{}, - }) + res.pushDiagnostic(source.GetRange(), UnboundedAccountIsNotLast{}) } switch source := source.(type) { @@ -466,19 +433,13 @@ func (res *CheckResult) checkSource(source parser.Source) { res.checkExpression(source.ValueExpr, TypeAccount) if account, ok := source.ValueExpr.(*parser.AccountInterpLiteral); ok { if account.IsWorld() && res.unboundedSend { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: source.GetRange(), - Kind: &InvalidUnboundedAccount{}, - }) + res.pushDiagnostic(source.GetRange(), InvalidUnboundedAccount{}) } else if account.IsWorld() { res.unboundedAccountInSend = account } if _, emptied := res.emptiedAccount[account.String()]; emptied && !account.IsWorld() { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Kind: &EmptiedAccount{Name: account.String()}, - Range: account.Range, - }) + res.pushDiagnostic(account.Range, EmptiedAccount{Name: account.String()}) } res.emptiedAccount[account.String()] = struct{}{} @@ -497,10 +458,7 @@ func (res *CheckResult) checkSource(source parser.Source) { } if res.unboundedSend { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: source.Address.GetRange(), - Kind: &InvalidUnboundedAccount{}, - }) + res.pushDiagnostic(source.Address.GetRange(), InvalidUnboundedAccount{}) } res.checkExpression(source.Address, TypeAccount) @@ -528,10 +486,7 @@ func (res *CheckResult) checkSource(source parser.Source) { case *parser.SourceAllotment: if res.unboundedSend { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Kind: &NoAllotmentInSendAll{}, - Range: source.Range, - }) + res.pushDiagnostic(source.Range, NoAllotmentInSendAll{}) } var remainingAllotment *parser.RemainingAllotment = nil @@ -555,10 +510,7 @@ func (res *CheckResult) checkSource(source parser.Source) { if isLast { remainingAllotment = allotment } else { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: source.Range, - Kind: &RemainingIsNotLast{}, - }) + res.pushDiagnostic(source.Range, RemainingIsNotLast{}) } } @@ -637,10 +589,7 @@ func (res *CheckResult) tryEvaluatingPortionExpr(expr parser.ValueExpr) *big.Rat } if right.Cmp(big.NewInt(0)) == 0 { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Kind: &DivByZero{}, - Range: expr.Range, - }) + res.pushDiagnostic(expr.Range, DivByZero{}) return nil } @@ -708,10 +657,7 @@ func (res *CheckResult) checkDestination(destination parser.Destination) { if isLast { remainingAllotment = allotment } else { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: destination.Range, - Kind: &RemainingIsNotLast{}, - }) + res.pushDiagnostic(destination.Range, RemainingIsNotLast{}) } } @@ -746,36 +692,22 @@ func (res *CheckResult) checkHasBadAllotmentSum( if cmp == -1 && len(variableLiterals) == 1 { var value big.Rat - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: variableLiterals[0].GetRange(), - Kind: &FixedPortionVariable{ - Value: *value.Sub(big.NewRat(1, 1), &sum), - }, + res.pushDiagnostic(variableLiterals[0].GetRange(), FixedPortionVariable{ + Value: *value.Sub(big.NewRat(1, 1), &sum), }) } else { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: rng, - Kind: &BadAllotmentSum{ - Sum: sum, - }, - }) + res.pushDiagnostic(rng, BadAllotmentSum{Sum: sum}) } // sum == 1 case 0: for _, varLit := range variableLiterals { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: varLit.GetRange(), - Kind: &FixedPortionVariable{ - Value: *big.NewRat(0, 1), - }, + res.pushDiagnostic(varLit.GetRange(), FixedPortionVariable{ + Value: *big.NewRat(0, 1), }) } if remaining != nil { - res.Diagnostics = append(res.Diagnostics, Diagnostic{ - Range: remaining.Range, - Kind: &RedundantRemaining{}, - }) + res.pushDiagnostic(remaining.Range, RedundantRemaining{}) } } } @@ -820,3 +752,11 @@ func (res *CheckResult) enterCappedSource() func() { exitCloneUnboundedSend() } } + +func (res *CheckResult) pushDiagnostic(rng parser.Range, kind DiagnosticKind) { + res.Diagnostics = append(res.Diagnostics, Diagnostic{ + Range: rng, + Kind: kind, + Id: rand.Int31(), + }) +} diff --git a/internal/analysis/check_test.go b/internal/analysis/check_test.go index 61e2edbb..0e3aca73 100644 --- a/internal/analysis/check_test.go +++ b/internal/analysis/check_test.go @@ -1,7 +1,6 @@ package analysis_test import ( - "math/big" "testing" "github.com/formancehq/numscript/internal/analysis" @@ -11,6 +10,14 @@ import ( "github.com/stretchr/testify/require" ) +func checkSource(input string) []analysis.Diagnostic { + res := analysis.CheckSource(input) + for i := range res.Diagnostics { + res.Diagnostics[i].Id = 0 + } + return res.Diagnostics +} + func TestInvalidType(t *testing.T) { t.Parallel() @@ -20,22 +27,12 @@ send [C 10] ( destination = $my_var )` - res := analysis.CheckSource(input) - require.Lenf(t, res.Diagnostics, 1, "xs: %#v", res.Diagnostics) - - d1 := res.Diagnostics[0] - assert.Equal(t, - parser.Range{ - Start: parser.Position{Character: 7}, - End: parser.Position{Character: 7 + len("invalid")}, + require.Equal(t, []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "invalid", 0), + Kind: analysis.InvalidType{Name: "invalid"}, }, - d1.Range, - ) - - assert.Equal(t, - &analysis.InvalidType{Name: "invalid"}, - d1.Kind, - ) + }, checkSource(input)) } func TestValidType(t *testing.T) { @@ -46,10 +43,8 @@ send [C 10] ( source = $my_var destination = $my_var )` - program := parser.Parse(input).Value - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 0) + require.Empty(t, checkSource(input)) } func TestDuplicateVariable(t *testing.T) { @@ -65,24 +60,12 @@ func TestDuplicateVariable(t *testing.T) { destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - parser.Range{ - Start: parser.Position{Line: 3, Character: 10}, - End: parser.Position{Line: 3, Character: 10 + len("$x")}, + require.Equal(t, []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$x", 1), + Kind: analysis.DuplicateVariable{Name: "x"}, }, - d1.Range, - ) - - assert.Equal(t, - &analysis.DuplicateVariable{Name: "x"}, - d1.Kind, - ) + }, checkSource(input)) } func TestUnboundVarInSaveAccount(t *testing.T) { @@ -90,24 +73,17 @@ func TestUnboundVarInSaveAccount(t *testing.T) { input := `save $unbound_mon from $unbound_acc` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 2) - - assert.Equal(t, - []analysis.Diagnostic{ - { - Kind: &analysis.UnboundVariable{Name: "unbound_mon"}, - Range: parser.RangeOfIndexed(input, "$unbound_mon", 0), - }, - { - Kind: &analysis.UnboundVariable{Name: "unbound_acc"}, - Range: parser.RangeOfIndexed(input, "$unbound_acc", 0), - }, + require.Equal(t, []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$unbound_mon", 0), + Kind: analysis.UnboundVariable{Name: "unbound_mon", Type: "monetary"}, }, - diagnostics, - ) + { + Range: parser.RangeOfIndexed(input, "$unbound_acc", 0), + Kind: analysis.UnboundVariable{Name: "unbound_acc", Type: "account"}, + }, + }, checkSource(input)) + } func TestUnboundVarInInfixOp(t *testing.T) { @@ -120,21 +96,18 @@ func TestUnboundVarInInfixOp(t *testing.T) { ) ` - parseResult := parser.Parse(input) - require.Empty(t, parseResult.Errors) - assert.Equal(t, []analysis.Diagnostic{ { - Kind: &analysis.UnboundVariable{Name: "unbound_mon1"}, + Kind: analysis.UnboundVariable{Name: "unbound_mon1", Type: analysis.TypeMonetary}, Range: parser.RangeOfIndexed(input, "$unbound_mon1", 0), }, { - Kind: &analysis.UnboundVariable{Name: "unbound_mon2"}, + Kind: analysis.UnboundVariable{Name: "unbound_mon2", Type: analysis.TypeMonetary}, Range: parser.RangeOfIndexed(input, "$unbound_mon2", 0), }, }, - analysis.CheckProgram(parseResult.Value).Diagnostics, + checkSource(input), ) } @@ -154,18 +127,18 @@ save $str from $n diagnostics := analysis.CheckProgram(program).Diagnostics require.Len(t, diagnostics, 2) - assert.Equal(t, + require.Equal(t, []analysis.Diagnostic{ { - Kind: &analysis.TypeMismatch{Expected: "monetary", Got: "string"}, + Kind: analysis.TypeMismatch{Expected: "monetary", Got: "string"}, Range: parser.RangeOfIndexed(input, "$str", 1), }, { - Kind: &analysis.TypeMismatch{Expected: "account", Got: "number"}, + Kind: analysis.TypeMismatch{Expected: "account", Got: "number"}, Range: parser.RangeOfIndexed(input, "$n", 1), }, }, - diagnostics, + checkSource(input), ) } @@ -177,23 +150,14 @@ func TestUnboundVarInSource(t *testing.T) { destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - parser.Range{ - Start: parser.Position{Line: 1, Character: 28}, - End: parser.Position{Line: 1, Character: 28 + len("$unbound_var")}, + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$unbound_var", 0), + Kind: analysis.UnboundVariable{Name: "unbound_var", Type: analysis.TypeAccount}, + }, }, - d1.Range, - ) - - assert.Equal(t, - &analysis.UnboundVariable{Name: "unbound_var"}, - d1.Kind, + checkSource(input), ) } @@ -205,23 +169,16 @@ func TestUnboundVarInSourceOneof(t *testing.T) { destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - assert.Equal(t, + require.Equal(t, []analysis.Diagnostic{ { Range: parser.RangeOfIndexed(input, "$unbound_var", 0), - Kind: &analysis.UnboundVariable{Name: "unbound_var"}, + Kind: analysis.UnboundVariable{Name: "unbound_var", Type: analysis.TypeAccount}, }, }, - diagnostics, + checkSource(input), ) - } - func TestUnboundVarInDest(t *testing.T) { t.Parallel() @@ -233,19 +190,14 @@ func TestUnboundVarInDest(t *testing.T) { } )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Lenf(t, diagnostics, 1, "expected len to be 1") - - d1 := diagnostics[0] - assert.Equal(t, - parser.RangeOfIndexed(input, "$unbound_var", 0), - d1.Range, - ) - assert.Equal(t, - &analysis.UnboundVariable{Name: "unbound_var"}, - d1.Kind, + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$unbound_var", 0), + Kind: analysis.UnboundVariable{Name: "unbound_var", Type: analysis.TypeAccount}, + }, + }, + checkSource(input), ) } @@ -260,10 +212,19 @@ func TestUnboundMany(t *testing.T) { destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 2) + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$unbound1", 0), + Kind: analysis.UnboundVariable{Name: "unbound1", Type: analysis.TypeAccount}, + }, + { + Range: parser.RangeOfIndexed(input, "$unbound2", 0), + Kind: analysis.UnboundVariable{Name: "unbound2", Type: analysis.TypeAccount}, + }, + }, + checkSource(input), + ) } func TestUnboundCurrenciesVars(t *testing.T) { @@ -276,36 +237,34 @@ func TestUnboundCurrenciesVars(t *testing.T) { destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 2) + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$unbound1", 0), + Kind: analysis.UnboundVariable{Name: "unbound1", Type: analysis.TypeMonetary}, + }, + { + Range: parser.RangeOfIndexed(input, "$unbound2", 0), + Kind: analysis.UnboundVariable{Name: "unbound2", Type: analysis.TypeMonetary}, + }, + }, + checkSource(input), + ) } -// TODO unbound vars in declr - func TestUnusedVarInSource(t *testing.T) { t.Parallel() input := `vars { monetary $unused_var }` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - parser.Range{ - Start: parser.Position{Character: 16}, - End: parser.Position{Character: 16 + len("$unused_var")}, + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$unused_var", 0), + Kind: analysis.UnusedVar{Name: "unused_var"}, + }, }, - d1.Range, - ) - - assert.Equal(t, - &analysis.UnusedVar{Name: "unused_var"}, - d1.Kind, + checkSource(input), ) } @@ -319,23 +278,17 @@ send [$a 100] ( destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "asset", - Got: "account", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$a", 1), + Kind: analysis.TypeMismatch{ + Expected: "asset", + Got: "account", + }, + }, }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "$a", 1), - d1.Range, + checkSource(input), ) } @@ -349,23 +302,17 @@ send [EUR/2 $n] ( destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "number", - Got: "account", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$n", 1), + Kind: analysis.TypeMismatch{ + Expected: "number", + Got: "account", + }, + }, }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "$n", 1), - d1.Range, + checkSource(input), ) } @@ -379,23 +326,17 @@ send [COIN 100] ( destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "monetary", - Got: "account", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$account", 1), + Kind: analysis.TypeMismatch{ + Expected: "monetary", + Got: "account", + }, + }, }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "$account", 1), - d1.Range, + checkSource(input), ) } @@ -409,23 +350,17 @@ send [COIN 100] ( destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "account", - Got: "portion", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$x", 1), + Kind: analysis.TypeMismatch{ + Expected: "account", + Got: "portion", + }, + }, }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "$x", 1), - d1.Range, + checkSource(input), ) } @@ -439,23 +374,17 @@ send [COIN 100] ( destination = $x )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "account", - Got: "portion", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$x", 1), + Kind: analysis.TypeMismatch{ + Expected: "account", + Got: "portion", + }, + }, }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "$x", 1), - d1.Range, + checkSource(input), ) } @@ -469,23 +398,17 @@ send [COIN 100] ( destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "account", - Got: "portion", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$x", 1), + Kind: analysis.TypeMismatch{ + Expected: "account", + Got: "portion", + }, + }, }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "$x", 1), - d1.Range, + checkSource(input), ) } @@ -499,1172 +422,72 @@ send [COIN 100] ( destination = @dest )` - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "monetary", - Got: "portion", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$x", 1), + Kind: analysis.TypeMismatch{ + Expected: "monetary", + Got: "portion", + }, + }, }, - d1.Kind, + checkSource(input), ) +} - assert.Equal(t, - parser.RangeOfIndexed(input, "$x", 1), - d1.Range, - ) -} - -func TestWrongTypeForSrcAllotmentPortion(t *testing.T) { - t.Parallel() - - input := `vars { string $p } - -send [COIN 100] ( - source = { - $p from @a - remaining from @b - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "portion", - Got: "string", - }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "$p", 1), - d1.Range, - ) -} - -func TestWrongTypeForDestAllotmentPortion(t *testing.T) { - t.Parallel() - - input := `vars { string $p } - -send [COIN 100] ( - source = @s - destination = { - $p to @a - remaining to @dest - } -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "portion", - Got: "string", - }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "$p", 1), - d1.Range, - ) -} - -func TestBadRemainingInSource(t *testing.T) { - t.Parallel() - - input := `send [COIN 100] ( - source = { - 1/2 from @a - remaining from @b - 1/2 from @c - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.RemainingIsNotLast{}, - d1.Kind, - ) - - assert.Equal(t, - parser.Range{ - Start: parser.Position{Line: 1, Character: 12}, - End: parser.Position{Line: 5, Character: 5}, - }, - d1.Range, - ) - -} - -func TestBadRemainingInDest(t *testing.T) { - t.Parallel() - - input := `send [COIN 100] ( - source = @a - destination = { - 1/2 to @a - remaining to @b - 1/2 to @c - } -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.RemainingIsNotLast{}, - d1.Kind, - ) - - assert.Equal(t, - parser.Range{ - Start: parser.Position{Line: 2, Character: 17}, - End: parser.Position{Line: 6, Character: 5}, - }, - d1.Range, - ) - -} - -func TestBadAllotmentSumInSourceLessThanOne(t *testing.T) { - t.Parallel() - - input := ` -send [COIN 100] ( - source = { - 1/3 from @s1 - 1/3 from @s2 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - end := *parser.PositionOfIndexed(input, "}", 0) - end.Character++ - - assert.Equal(t, []analysis.Diagnostic{ - { - Range: parser.Range{ - Start: *parser.PositionOfIndexed(input, "{", 0), - End: end, - }, - Kind: &analysis.BadAllotmentSum{ - Sum: *big.NewRat(2, 3), - }, - }, - }, analysis.CheckProgram(program).Diagnostics) - -} - -func TestBadAllotmentPerc(t *testing.T) { - t.Parallel() - - input := ` -send [COIN 100] ( - source = { - 25% from @s1 - 50% from @s2 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - end := *parser.PositionOfIndexed(input, "}", 0) - end.Character++ - - assert.Equal(t, []analysis.Diagnostic{ - { - Range: parser.Range{ - Start: *parser.PositionOfIndexed(input, "{", 0), - End: end, - }, - Kind: &analysis.BadAllotmentSum{ - Sum: *big.NewRat(75, 100), - }, - }, - }, analysis.CheckProgram(program).Diagnostics) - -} - -func TestBadAllotmentComplexExpr(t *testing.T) { - t.Parallel() - - // same test as the previous one, with nested expr - input := ` -send [COIN 100] ( - source = { - (10 - 9)/(2 + 1) from @s1 - ((1 + 1) - 1)/3 from @s2 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - end := *parser.PositionOfIndexed(input, "}", 0) - end.Character++ - - assert.Equal(t, []analysis.Diagnostic{ - { - Range: parser.Range{ - Start: *parser.PositionOfIndexed(input, "{", 0), - End: end, - }, - Kind: &analysis.BadAllotmentSum{ - Sum: *big.NewRat(2, 3), - }, - }, - }, analysis.CheckProgram(program).Diagnostics) - -} - -func TestDivByZero(t *testing.T) { - t.Parallel() - - input := `send [COIN 100] ( - source = { - 4/0 from @world - remaining kept - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Equal(t, []analysis.Diagnostic{ - { - Kind: &analysis.DivByZero{}, - Range: parser.RangeOfIndexed(input, "4/0", 0), - }, - }, diagnostics) -} - -func TestBadAllotmentSumInSourceMoreThanOne(t *testing.T) { - t.Parallel() - - input := `send [COIN 100] ( - source = { - 2/3 from @s1 - 2/3 from @s2 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Lenf(t, diagnostics, 1, "wrong diagnostics len") - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.BadAllotmentSum{ - Sum: *big.NewRat(4, 3), - }, - d1.Kind, - ) - - assert.Equal(t, - parser.Range{ - Start: parser.Position{Line: 1, Character: 12}, - End: parser.Position{Line: 4, Character: 5}, - }, - d1.Range, - ) - -} - -func TestBadAllotmentSumInDestinationLessThanOne(t *testing.T) { - t.Parallel() - - input := ` -send [COIN 100] ( - source = @src - destination = { - 1/3 to @d1 - 1/3 to @d2 - } -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Lenf(t, diagnostics, 1, "wrong diagnostics len") - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.BadAllotmentSum{ - Sum: *big.NewRat(2, 3), - }, - d1.Kind, - ) -} - -func TestNoAllotmentLt1ErrIfVariable(t *testing.T) { - t.Parallel() - - input := `vars { - portion $portion1 - portion $portion2 -} - -send [COIN 100] ( - source = { - 1/3 from @s1 - 1/3 from @s2 - $portion1 from @s3 - $portion2 from @s4 - } - destination = @d -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 0) -} - -func TestAllotmentGt1ErrIfVariable(t *testing.T) { - t.Parallel() - - input := `vars { - portion $portion1 - portion $portion2 -} - -send [COIN 100] ( - source = @src - destination = { - 2/3 to @d1 - 2/3 to @d2 - $portion1 to @d3 - $portion2 to @d4 - } -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - assert.IsType(t, diagnostics[0].Kind, &analysis.BadAllotmentSum{}) -} - -func TestAllotmentErrOnlyOneVar(t *testing.T) { - t.Parallel() - - input := `vars { portion $portion } - -send [COIN 100] ( - source = { - 2/3 from @s1 - $portion from @s2 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - assert.Equal(t, diagnostics[0], analysis.Diagnostic{ - Kind: &analysis.FixedPortionVariable{ - Value: *big.NewRat(1, 3), - }, - Range: parser.RangeOfIndexed(input, "$portion", 1), - }) -} - -func TestAllotmentErrWhenVarIsZero(t *testing.T) { - t.Parallel() - - input := `vars { - portion $portion1 - portion $portion2 -} - -send [COIN 100] ( - source = { - 2/3 from @s1 - 1/3 from @s2 - $portion1 from @s3 - $portion2 from @s4 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 2) - - assert.Equal(t, diagnostics[0], analysis.Diagnostic{ - Kind: &analysis.FixedPortionVariable{ - Value: *big.NewRat(0, 1), - }, - Range: parser.RangeOfIndexed(input, "$portion1", 1), - }) - - assert.Equal(t, diagnostics[1], analysis.Diagnostic{ - Kind: &analysis.FixedPortionVariable{ - Value: *big.NewRat(0, 1), - }, - Range: parser.RangeOfIndexed(input, "$portion2", 1), - }) -} - -func TestNoBadAllotmentWhenRemaining(t *testing.T) { - t.Parallel() - - input := `send [COIN 100] ( - source = { - 1/3 from @s1 - 1/3 from @s2 - remaining from @s3 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Lenf(t, diagnostics, 0, "wrong diagnostics len") -} - -func TestBadAllotmentWhenRemainingButGt1(t *testing.T) { - t.Parallel() - - input := `send [COIN 100] ( - source = { - 2/3 from @s1 - 2/3 from @s2 - remaining from @s3 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Lenf(t, diagnostics, 1, "wrong diagnostics len") - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.BadAllotmentSum{ - Sum: *big.NewRat(4, 3), - }, - d1.Kind, - ) -} - -func TestRedundantRemainingWhenSumIsOne(t *testing.T) { - t.Parallel() - - input := `send [COIN 100] ( - source = { - 2/3 from @s1 - 1/3 from @s2 - remaining from @s3 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Lenf(t, diagnostics, 1, "wrong diagnostics len") - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.RedundantRemaining{}, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "remaining", 0), - d1.Range, - ) -} - -func TestNoSingleAllotmentVariable(t *testing.T) { - t.Parallel() - - input := `vars { portion $allot } - -send [COIN 100] ( - source = { - $allot from @s1 - remaining from @s2 - } - destination = @dest -)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Lenf(t, diagnostics, 0, "wrong diagnostics len") -} - -func TestCheckNoUnboundFunctionCall(t *testing.T) { - t.Parallel() - - input := `invalid_fn_call()` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.UnknownFunction{Name: "invalid_fn_call"}, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "invalid_fn_call", 0), - d1.Range, - ) -} - -func TestAllowedFnCall(t *testing.T) { - t.Parallel() - - input := `set_tx_meta("for_cone", "true")` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 0) -} - -func TestCheckFnCallTypesWrongType(t *testing.T) { - t.Parallel() - - input := `set_tx_meta(@addr, 42)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "string", - Got: "account", - }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, "@addr", 0), - d1.Range, - ) -} - -func TestTooFewFnArgs(t *testing.T) { - t.Parallel() - - input := `set_tx_meta("arg")` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.BadArity{ - Expected: 2, - Actual: 1, - }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, `set_tx_meta("arg")`, 0), - d1.Range, - ) -} - -func TestTooManyFnArgs(t *testing.T) { - t.Parallel() - - input := `set_tx_meta("arg", "ok", 10, 20)` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.BadArity{ - Expected: 2, - Actual: 4, - }, - d1.Kind, - ) - - assert.Equal(t, - parser.RangeOfIndexed(input, `10, 20`, 0), - d1.Range, - ) -} - -func TestCheckTrailingCommaFnCall(t *testing.T) { - t.Parallel() - - input := `set_tx_meta("ciao", 42, 10, )` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) -} - -func TestCheckTypesOriginFn(t *testing.T) { - t.Parallel() - - input := ` - vars { - monetary $mon = meta(42, "str") - } - - send $mon ( - source = @s - destination = @d - ) - ` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "account", - Got: "number", - }, - d1.Kind, - ) -} - -func TestCheckReturnTypeOriginFn(t *testing.T) { - t.Parallel() - - input := ` - vars { - account $mon = balance(@account, EUR/2) - } - - send [EUR/2 100] ( - source = $mon - destination = @d - ) - ` - - program := parser.Parse(input).Value - - diagnostics := analysis.CheckProgram(program).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.TypeMismatch{ - Expected: "monetary", - Got: "account", - }, - d1.Kind, - ) -} - -func TestWorldOverdraft(t *testing.T) { - t.Parallel() - - input := ` - send [EUR/2 100] ( - source = { - @a - @world allowing unbounded overdraft - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.InvalidWorldOverdraft{}, - d1.Kind, - ) - - assert.Equal(t, d1.Range, parser.RangeOfIndexed(input, "@world", 0)) -} - -func TestForbidAllotmentInSendAll(t *testing.T) { - t.Parallel() - - input := ` - send [EUR/2 *] ( - source = { - 1/2 from @s1 - remaining from @s2 - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.NoAllotmentInSendAll{}, - d1.Kind, - ) -} - -func TestAllowAllotmentInCappedSendAll(t *testing.T) { - t.Parallel() - - input := ` - send [EUR/2 *] ( - source = { - max [EUR/2 10] from { - 1/2 from @s1 - remaining from @s2 - } - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) -} - -func TestDisallowAllotmentInCappedSendAllOutsideMax(t *testing.T) { - t.Parallel() - - input := ` - send [EUR/2 *] ( - source = { - max [EUR/2 10] from @a - { - 1/2 from @s1 - remaining from @s2 - } - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) - - d1 := diagnostics[0] - assert.Equal(t, - &analysis.NoAllotmentInSendAll{}, - d1.Kind, - ) -} - -func TestNoForbidAllotmentInSendAll(t *testing.T) { - t.Parallel() - - input := ` - send [EUR/2 *] ( - source = @a - destination = @dest - ) - - - send [EUR/2 100] ( - source = { - 1/2 from @s1 - remaining from @s2 - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) -} - -func TestForbidUnboundedSrcInSendAll(t *testing.T) { - t.Parallel() - - input := ` - send [GEM *] ( - source = { - @ok - @illegal allowing unbounded overdraft // <- err - } - destination = @b - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) - - require.Equal(t, - diagnostics[0].Kind, - &analysis.InvalidUnboundedAccount{}, - ) - - require.Equal(t, - diagnostics[0].Range, - parser.RangeOfIndexed(input, "@illegal", 0), - ) -} - -func TestAllowUnboundedSrcInSendAllWhenCapped(t *testing.T) { - t.Parallel() - - input := ` - send [GEM *] ( - source = max [GEM 100] from { - @ok - @illegal allowing unbounded overdraft - } - destination = @b - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) -} - -func TestForbidWorldSrcInSendAll(t *testing.T) { - t.Parallel() - - input := ` - send [EUR/2 *] ( - source = @world - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) -} - -func TestForbidEmptiedAccount(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 100] ( - source = { - @a - @b - @a // <- err - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) - - require.Equal(t, - diagnostics[0].Kind, - &analysis.EmptiedAccount{Name: "a"}, - ) - - require.Equal(t, - diagnostics[0].Range, - parser.RangeOfIndexed(input, "@a", 1), - ) -} - -func TestResetEmptiedAccount(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 100] ( - source = @a - destination = @dest - ) - - send [COIN 100] ( - source = @a - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) -} - -func TestEmptiedAccountInMax(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 100] ( - source = { - @emptied - max [COIN 10] from { - @a - @emptied // <- err - @b - } - @c - } - destination = @b - ) - - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) - - require.Equal(t, - diagnostics[0].Kind, - &analysis.EmptiedAccount{Name: "emptied"}, - ) - - require.Equal(t, - diagnostics[0].Range, - parser.RangeOfIndexed(input, "@emptied", 1), - ) -} - -func TestEmptiedAccountDoNotLeakMaxed(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 100] ( - source = { - max [COIN 10] from @emptied - @emptied - } - destination = @b - ) - - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) -} - -func TestDoNotEmptyAccountInMax(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 100] ( - source = { - @a - max [COIN 10] from { - @a1 - @emptied - @b1 - @emptied // <- err - } - } - destination = @b - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) - - require.Equal(t, - diagnostics[0].Kind, - &analysis.EmptiedAccount{Name: "emptied"}, - ) - - require.Equal(t, - diagnostics[0].Range, - parser.RangeOfIndexed(input, "@emptied", 1), - ) -} - -func TestDoNotEmitEmptiedAccountOnAllotment(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 100] ( - source = { - 1/2 from @emptied - 1/2 from @emptied - } - destination = @b - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) -} - -func TestDoNotAllowExprAfterWorld(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 100] ( - source = { - @world - @another - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) - - require.Equal(t, - diagnostics[0].Kind, - &analysis.UnboundedAccountIsNotLast{}, - ) - - require.Equal(t, - diagnostics[0].Range, - parser.RangeOfIndexed(input, "@another", 0), - ) -} - -func TestAllowWorldInNextExpr(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 1] ( - source = @world - destination = @dest - ) - - send [COIN 1] ( - source = @world - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) - -} - -func TestAllowWorldInMaxedExpr(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 10] ( - source = { - max [COIN 1] from @world - @x - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) - -} - -func TestDoNotAllowExprAfterWorldInsideMaxed(t *testing.T) { +func TestWrongTypeForSrcAllotmentPortion(t *testing.T) { t.Parallel() - input := ` - send [COIN 10] ( - source = max [COIN 1] from { - @world - @x - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) + input := `vars { string $p } - require.Equal(t, - diagnostics[0].Kind, - &analysis.UnboundedAccountIsNotLast{}, - ) +send [COIN 100] ( + source = { + $p from @a + remaining from @b + } + destination = @dest +)` require.Equal(t, - diagnostics[0].Range, - parser.RangeOfIndexed(input, "@x", 0), + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$p", 1), + Kind: analysis.TypeMismatch{ + Expected: "portion", + Got: "string", + }, + }, + }, + checkSource(input), ) } -func TestDoNotAllowExprAfterUnbounded(t *testing.T) { +func TestWrongTypeForDestAllotmentPortion(t *testing.T) { t.Parallel() - input := ` - send [COIN 100] ( - source = { - @unbounded allowing unbounded overdraft - @another - } - destination = @dest - ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Len(t, diagnostics, 1) + input := `vars { string $p } - require.Equal(t, - diagnostics[0].Kind, - &analysis.UnboundedAccountIsNotLast{}, - ) +send [COIN 100] ( + source = @s + destination = { + $p to @a + remaining to @dest + } +)` require.Equal(t, - diagnostics[0].Range, - parser.RangeOfIndexed(input, "@another", 0), - ) -} - -func TestAllowExprAfterBoundedOverdraft(t *testing.T) { - t.Parallel() - - input := ` - send [COIN 100] ( - source = { - @unbounded allowing overdraft up to [COIN 10] - @another - } - destination = @dest + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$p", 1), + Kind: analysis.TypeMismatch{ + Expected: "portion", + Got: "string", + }, + }, + }, + checkSource(input), ) - ` - - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) } func TestCheckPlus(t *testing.T) { @@ -1673,59 +496,64 @@ func TestCheckPlus(t *testing.T) { t.Run("error in number+portion", func(t *testing.T) { input := `set_tx_meta("k", 1 + 1/2)` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Equal(t, []analysis.Diagnostic{ - { - Range: parser.RangeOfIndexed(input, "1/2", 0), - Kind: &analysis.TypeMismatch{ - Expected: "number", - Got: "portion", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "1/2", 0), + Kind: analysis.TypeMismatch{ + Expected: "number", + Got: "portion", + }, }, }, - }, diagnostics) + checkSource(input), + ) }) t.Run("allow number+number", func(t *testing.T) { input := `set_tx_meta("k", 1 + 2)` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) + require.Empty(t, checkSource(input)) }) t.Run("allow monetary+monetary", func(t *testing.T) { input := `set_tx_meta("k", [EUR/2 10] + [EUR/2 20])` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) + require.Empty(t, checkSource(input)) }) t.Run("error when left side is invalid", func(t *testing.T) { input := `set_tx_meta("k", @acc + @acc)` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Equal(t, []analysis.Diagnostic{ - { - Range: parser.RangeOfIndexed(input, "@acc", 0), - Kind: &analysis.TypeMismatch{ - Expected: "number|monetary", - Got: "account", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "@acc", 0), + Kind: analysis.TypeMismatch{ + Expected: "number|monetary", + Got: "account", + }, }, }, - }, diagnostics) + checkSource(input), + ) }) t.Run("no type error when left side is any", func(t *testing.T) { input := `set_tx_meta("k", $unbound_var + @acc)` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Equal(t, []analysis.Diagnostic{ - { - Range: parser.RangeOfIndexed(input, "$unbound_var", 0), - Kind: &analysis.UnboundVariable{ - Name: "unbound_var", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$unbound_var", 0), + Kind: analysis.UnboundVariable{ + Name: "unbound_var", + Type: analysis.TypeNumber, + }, }, }, - }, diagnostics) + checkSource(input), + ) }) } @@ -1735,59 +563,64 @@ func TestCheckMinus(t *testing.T) { t.Run("error in number-portion", func(t *testing.T) { input := `set_tx_meta("k", 1 - 1/2)` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Equal(t, []analysis.Diagnostic{ - { - Range: parser.RangeOfIndexed(input, "1/2", 0), - Kind: &analysis.TypeMismatch{ - Expected: "number", - Got: "portion", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "1/2", 0), + Kind: analysis.TypeMismatch{ + Expected: "number", + Got: "portion", + }, }, }, - }, diagnostics) + checkSource(input), + ) }) t.Run("allow number-number", func(t *testing.T) { input := `set_tx_meta("k", 1 - 2)` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) + require.Empty(t, checkSource(input)) }) t.Run("allow monetary-monetary", func(t *testing.T) { input := `set_tx_meta("k", [EUR/2 10] - [EUR/2 20])` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Empty(t, diagnostics) + require.Empty(t, checkSource(input)) }) t.Run("error when left side is invalid", func(t *testing.T) { input := `set_tx_meta("k", @acc - @acc)` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Equal(t, []analysis.Diagnostic{ - { - Range: parser.RangeOfIndexed(input, "@acc", 0), - Kind: &analysis.TypeMismatch{ - Expected: "number|monetary", - Got: "account", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "@acc", 0), + Kind: analysis.TypeMismatch{ + Expected: "number|monetary", + Got: "account", + }, }, }, - }, diagnostics) + checkSource(input), + ) }) t.Run("no type error when left side is any", func(t *testing.T) { input := `set_tx_meta("k", $unbound_var - @acc)` - diagnostics := analysis.CheckSource(input).Diagnostics - require.Equal(t, []analysis.Diagnostic{ - { - Range: parser.RangeOfIndexed(input, "$unbound_var", 0), - Kind: &analysis.UnboundVariable{ - Name: "unbound_var", + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$unbound_var", 0), + Kind: analysis.UnboundVariable{ + Name: "unbound_var", + Type: analysis.TypeNumber, + }, }, }, - }, diagnostics) + checkSource(input), + ) }) } diff --git a/internal/analysis/diagnostic_kind.go b/internal/analysis/diagnostic_kind.go index de8162e8..5d694910 100644 --- a/internal/analysis/diagnostic_kind.go +++ b/internal/analysis/diagnostic_kind.go @@ -46,11 +46,11 @@ type Parsing struct { Description string } -func (e *Parsing) Message() string { +func (e Parsing) Message() string { return e.Description } -func (*Parsing) Severity() Severity { +func (Parsing) Severity() Severity { return ErrorSeverity } @@ -59,7 +59,7 @@ type InvalidType struct { } // TODO evaluate suggestion using Levenshtein distance -func (e *InvalidType) Message() string { +func (e InvalidType) Message() string { allowedTypeList := "" for index, t := range AllowedTypes { if index != 0 { @@ -71,7 +71,7 @@ func (e *InvalidType) Message() string { return fmt.Sprintf("'%s' is not a valid type. Allowed types are: %s", e.Name, allowedTypeList) } -func (*InvalidType) Severity() Severity { +func (InvalidType) Severity() Severity { return ErrorSeverity } @@ -79,24 +79,25 @@ type DuplicateVariable struct { Name string } -func (e *DuplicateVariable) Message() string { +func (e DuplicateVariable) Message() string { return fmt.Sprintf("A variable with the name '$%s' was already declared", e.Name) } -func (*DuplicateVariable) Severity() Severity { +func (DuplicateVariable) Severity() Severity { return ErrorSeverity } type UnboundVariable struct { Name string + Type string } // TODO evaluate suggestion using Levenshtein distance -func (e *UnboundVariable) Message() string { +func (e UnboundVariable) Message() string { return fmt.Sprintf("The variable '$%s' was not declared", e.Name) } -func (*UnboundVariable) Severity() Severity { +func (UnboundVariable) Severity() Severity { return ErrorSeverity } @@ -104,11 +105,11 @@ type UnusedVar struct { Name string } -func (e *UnusedVar) Message() string { +func (e UnusedVar) Message() string { return fmt.Sprintf("The variable '$%s' is never used", e.Name) } -func (*UnusedVar) Severity() Severity { +func (UnusedVar) Severity() Severity { return WarningSeverity } @@ -117,20 +118,20 @@ type TypeMismatch struct { Got string } -func (e *TypeMismatch) Message() string { +func (e TypeMismatch) Message() string { return fmt.Sprintf("Type mismatch (expected '%s', got '%s' instead)", e.Expected, e.Got) } -func (*TypeMismatch) Severity() Severity { +func (TypeMismatch) Severity() Severity { return ErrorSeverity } type RemainingIsNotLast struct{} -func (e *RemainingIsNotLast) Message() string { +func (e RemainingIsNotLast) Message() string { return "A 'remaining' clause should be the last in an allotment expression" } -func (*RemainingIsNotLast) Severity() Severity { +func (RemainingIsNotLast) Severity() Severity { return ErrorSeverity } @@ -138,7 +139,7 @@ type BadAllotmentSum struct { Sum big.Rat } -func (e *BadAllotmentSum) Message() string { +func (e BadAllotmentSum) Message() string { one := big.NewRat(1, 1) switch e.Sum.Cmp(one) { @@ -153,7 +154,7 @@ func (e *BadAllotmentSum) Message() string { panic(fmt.Sprintf("unreachable state: allotment=%s", e.Sum.String())) } -func (*BadAllotmentSum) Severity() Severity { +func (BadAllotmentSum) Severity() Severity { return ErrorSeverity } @@ -161,11 +162,11 @@ type DivByZero struct { Sum big.Rat } -func (e *DivByZero) Message() string { +func (e DivByZero) Message() string { return "Cannot divide by zero" } -func (*DivByZero) Severity() Severity { +func (DivByZero) Severity() Severity { return ErrorSeverity } @@ -173,19 +174,19 @@ type FixedPortionVariable struct { Value big.Rat } -func (e *FixedPortionVariable) Message() string { +func (e FixedPortionVariable) Message() string { return fmt.Sprintf("Using a variable expression can lead to a runtime error if the expression doesn't resolve to %s.\n\nConsider using a hard-coded value or adding a 'remaining' clause to prevent the error", e.Value.String()) } -func (*FixedPortionVariable) Severity() Severity { +func (FixedPortionVariable) Severity() Severity { return WarningSeverity } type RedundantRemaining struct{} -func (e *RedundantRemaining) Message() string { +func (e RedundantRemaining) Message() string { return "Redundant 'remaining' clause (allotment already sums to 1)" } -func (*RedundantRemaining) Severity() Severity { +func (RedundantRemaining) Severity() Severity { return WarningSeverity } @@ -193,7 +194,7 @@ type UnknownFunction struct { Name string } -func (e *UnknownFunction) Message() string { +func (e UnknownFunction) Message() string { res, exists := Builtins[e.Name] if exists { return fmt.Sprintf("You cannot use this function here (try to use it in a %s context)", res.ContextName()) @@ -202,7 +203,7 @@ func (e *UnknownFunction) Message() string { return fmt.Sprintf("The function '%s' does not exist", e.Name) } -func (*UnknownFunction) Severity() Severity { +func (UnknownFunction) Severity() Severity { return ErrorSeverity } @@ -211,41 +212,41 @@ type BadArity struct { Actual int } -func (e *BadArity) Message() string { +func (e BadArity) Message() string { return fmt.Sprintf("Wrong number of arguments (expected %d, got %d instead)", e.Expected, e.Actual) } -func (*BadArity) Severity() Severity { +func (BadArity) Severity() Severity { return ErrorSeverity } type InvalidWorldOverdraft struct{} -func (e *InvalidWorldOverdraft) Message() string { +func (e InvalidWorldOverdraft) Message() string { return "@world is already set to be ovedraft" } -func (*InvalidWorldOverdraft) Severity() Severity { +func (InvalidWorldOverdraft) Severity() Severity { return WarningSeverity } type NoAllotmentInSendAll struct{} -func (e *NoAllotmentInSendAll) Message() string { +func (e NoAllotmentInSendAll) Message() string { return "Cannot take all balance of an allotment source" } -func (*NoAllotmentInSendAll) Severity() Severity { +func (NoAllotmentInSendAll) Severity() Severity { return WarningSeverity } type InvalidUnboundedAccount struct{} -func (e *InvalidUnboundedAccount) Message() string { +func (e InvalidUnboundedAccount) Message() string { return "Cannot take all balance of an unbounded source" } -func (*InvalidUnboundedAccount) Severity() Severity { +func (InvalidUnboundedAccount) Severity() Severity { return ErrorSeverity } @@ -253,20 +254,20 @@ type EmptiedAccount struct { Name string } -func (e *EmptiedAccount) Message() string { +func (e EmptiedAccount) Message() string { return fmt.Sprintf("@%s is already empty at this point", e.Name) } -func (*EmptiedAccount) Severity() Severity { +func (EmptiedAccount) Severity() Severity { return WarningSeverity } type UnboundedAccountIsNotLast struct{} -func (e *UnboundedAccountIsNotLast) Message() string { +func (e UnboundedAccountIsNotLast) Message() string { return "Inorder sources after an unbounded overdraft are never reached" } -func (*UnboundedAccountIsNotLast) Severity() Severity { +func (UnboundedAccountIsNotLast) Severity() Severity { return WarningSeverity } diff --git a/internal/analysis/hover.go b/internal/analysis/hover.go index e50b6aa7..7713ecb8 100644 --- a/internal/analysis/hover.go +++ b/internal/analysis/hover.go @@ -30,10 +30,12 @@ type BuiltinFnHover struct { func (*BuiltinFnHover) hover() {} func HoverOn(program parser.Program, position parser.Position) Hover { - for _, varDecl := range program.Vars { - hover := hoverOnVar(varDecl, position) - if hover != nil { - return hover + if program.Vars != nil { + for _, varDecl := range program.Vars.Declarations { + hover := hoverOnVar(varDecl, position) + if hover != nil { + return hover + } } } diff --git a/internal/interpreter/interpreter.go b/internal/interpreter/interpreter.go index e73fa802..cb7352ee 100644 --- a/internal/interpreter/interpreter.go +++ b/internal/interpreter/interpreter.go @@ -205,9 +205,11 @@ func RunProgram( FeatureFlags: featureFlags, } - err := st.parseVars(program.Vars, vars) - if err != nil { - return nil, err + if program.Vars != nil { + err := st.parseVars(program.Vars.Declarations, vars) + if err != nil { + return nil, err + } } // preload balances before executing the script diff --git a/internal/lsp/code_actions.go b/internal/lsp/code_actions.go new file mode 100644 index 00000000..d335f1d5 --- /dev/null +++ b/internal/lsp/code_actions.go @@ -0,0 +1,42 @@ +package lsp + +import ( + "fmt" + + "github.com/formancehq/numscript/internal/analysis" + "github.com/formancehq/numscript/internal/parser" +) + +func CreateVar(diagnostic analysis.UnboundVariable, program parser.Program) TextEdit { + declarationLine := fmt.Sprintf("\n %s $%s\n", diagnostic.Type, diagnostic.Name) + + if program.Vars == nil || len(program.Vars.Declarations) == 0 { + var rng Range + text := fmt.Sprintf("vars {%s}", declarationLine) + + if program.Vars != nil { + rng = toLspRange(program.Vars.Range) + } else { + text += "\n\n" + } + + return TextEdit{ + NewText: text, + Range: rng, + } + } + + lastVarEnd := program.Vars.Declarations[len(program.Vars.Declarations)-1].End + + varsEndPosition := program.Vars.Range.End + varsEndPosition.Character-- + + return TextEdit{ + NewText: declarationLine, + Range: Range{ + Start: toLspPosition(lastVarEnd), + End: toLspPosition(varsEndPosition), + }, + } + +} diff --git a/internal/lsp/codeactions_test.go b/internal/lsp/codeactions_test.go new file mode 100644 index 00000000..b6f5c9e3 --- /dev/null +++ b/internal/lsp/codeactions_test.go @@ -0,0 +1,201 @@ +package lsp_test + +import ( + "strings" + "testing" + + "github.com/formancehq/numscript/internal/analysis" + lsp "github.com/formancehq/numscript/internal/lsp" + "github.com/formancehq/numscript/internal/parser" + "github.com/stretchr/testify/require" +) + +func performAction(t *testing.T, + initial string, + expected string, + toEdit func(kind analysis.DiagnosticKind, program parser.Program) lsp.TextEdit, +) { + res := analysis.CheckSource(initial) + require.Len(t, res.Diagnostics, 1) + + first := res.Diagnostics[0] + + finalStr := performEdit(initial, toEdit(first.Kind, res.Program)) + + require.Equal(t, expected, finalStr) +} + +func TestCreateVarWhenNoVarsBlock(t *testing.T) { + initial := `send [USD/2 100] ( + source = max $example from @a + destination = @b +) +` + + final := `vars { + monetary $example +} + +send [USD/2 100] ( + source = max $example from @a + destination = @b +) +` + + performAction(t, initial, final, func(kind analysis.DiagnosticKind, program parser.Program) lsp.TextEdit { + return lsp.CreateVar(kind.(analysis.UnboundVariable), program) + }) + +} + +func TestCreateVarWhenAlreadyExistingVars(t *testing.T) { + initial := `vars { + monetary $example +} + +send [USD/2 100] ( + source = max $example from $account + destination = @b +) +` + + final := `vars { + monetary $example + account $account +} + +send [USD/2 100] ( + source = max $example from $account + destination = @b +) +` + + performAction(t, initial, final, func(kind analysis.DiagnosticKind, program parser.Program) lsp.TextEdit { + return lsp.CreateVar(kind.(analysis.UnboundVariable), program) + }) + +} + +func TestCreateVarWhenAlreadyExistingVarsSameLine(t *testing.T) { + initial := `vars { account $account } + +send [USD/2 100] ( + source = max $example from $account + destination = @b +) +` + + final := `vars { account $account + monetary $example +} + +send [USD/2 100] ( + source = max $example from $account + destination = @b +) +` + + performAction(t, initial, final, func(kind analysis.DiagnosticKind, program parser.Program) lsp.TextEdit { + return lsp.CreateVar(kind.(analysis.UnboundVariable), program) + }) + +} + +func TestCreateVarWhenEmptyVarsBlock(t *testing.T) { + initial := `vars { +} + +send [USD/2 100] ( + source = max [USD/2 100] from $account + destination = @b +) +` + + final := `vars { + account $account +} + +send [USD/2 100] ( + source = max [USD/2 100] from $account + destination = @b +) +` + + performAction(t, initial, final, func(kind analysis.DiagnosticKind, program parser.Program) lsp.TextEdit { + return lsp.CreateVar(kind.(analysis.UnboundVariable), program) + }) + +} + +func TestPositionToOffset(t *testing.T) { + str := `abc +def +ghi` + + require.Equal(t, positionToOffset(strings.Split(str, "\n"), lsp.Position{ + Line: 1, + Character: 1, + }), 5) + +} + +func TestPerformEdit(t *testing.T) { + initial := `a +ins<>here +c +` + + require.Equal(t, `a +ins___here +c +`, performEdit(initial, lsp.TextEdit{ + Range: lsp.Range{ + Start: lsp.Position{Line: 1, Character: 3}, + End: lsp.Position{Line: 1, Character: 5}, + }, + NewText: "___", + })) + +} + +func TestPerformEdit2(t *testing.T) { + initial := `abc` + + require.Equal(t, `LINE1 +LINE2 + +abc`, performEdit(initial, lsp.TextEdit{ + // Empty range + NewText: `LINE1 +LINE2 + +`, + })) + +} + +func positionToOffset(lines []string, position lsp.Position) int { + // TODO: check indexes are 0-based + + offset := 0 + for _, line := range lines[0:position.Line] { + // +1 for the newline which was trimmed in lines + offset += len(line) + 1 + } + + offset += int(position.Character) + + return offset +} + +func performEdit(initial string, textEdit lsp.TextEdit) string { + lines := strings.Split(initial, "\n") + + startOffset := positionToOffset(lines, textEdit.Range.Start) + endOffset := positionToOffset(lines, textEdit.Range.End) + + before := initial[0:startOffset] + after := initial[endOffset:] + + return before + textEdit.NewText + after +} diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 437aef7d..20218068 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -3,6 +3,7 @@ package lsp import ( "encoding/json" "fmt" + "slices" "github.com/formancehq/numscript/internal/analysis" "github.com/formancehq/numscript/internal/parser" @@ -160,6 +161,42 @@ func (state *State) handleGetSymbols(params DocumentSymbolParams) []DocumentSymb return lspDocumentSymbols } +func (state *State) handleCodeAction(params CodeActionParams) []CodeAction { + doc, ok := state.documents.Get(params.TextDocument.URI) + if !ok { + return nil + } + + var actions []CodeAction + for _, d := range doc.CheckResult.Diagnostics { + index := slices.IndexFunc(params.Context.Diagnostics, func(lspDiagnostic Diagnostic) bool { + id, ok := lspDiagnostic.Data.(float64) + return ok && int32(id) == d.Id + }) + + var fixedDiagnostics []Diagnostic + if index != -1 { + fixedDiagnostics = append(fixedDiagnostics, params.Context.Diagnostics[index]) + } + + switch kind := d.Kind.(type) { + case analysis.UnboundVariable: + actions = append(actions, CodeAction{ + Title: "Create variable", + Kind: QuickFix, + Diagnostics: fixedDiagnostics, + Edit: WorkspaceEdit{ + Changes: map[string][]TextEdit{ + string(params.TextDocument.URI): {CreateVar(kind, doc.CheckResult.Program)}, + }, + }, + }) + } + } + + return actions +} + func Handle(r jsonrpc2.Request, state State) any { switch r.Method { case "initialize": @@ -172,6 +209,7 @@ func Handle(r jsonrpc2.Request, state State) any { HoverProvider: true, DefinitionProvider: true, DocumentSymbolProvider: true, + CodeActionProvider: true, }, // This is ugly. Is there a shortcut? ServerInfo: struct { @@ -211,6 +249,11 @@ func Handle(r jsonrpc2.Request, state State) any { json.Unmarshal([]byte(*r.Params), &p) return state.handleGetSymbols(p) + case "textDocument/codeAction": + var p CodeActionParams + json.Unmarshal([]byte(*r.Params), &p) + return state.handleCodeAction(p) + default: // Unhandled method // TODO should it panic? @@ -244,5 +287,6 @@ func toLspDiagnostic(d analysis.Diagnostic) Diagnostic { Range: toLspRange(d.Range), Severity: DiagnosticSeverity(d.Kind.Severity()), Message: d.Kind.Message(), + Data: d.Id, } } diff --git a/internal/parser/__snapshots__/parser_fault_tolerance_test.snap b/internal/parser/__snapshots__/parser_fault_tolerance_test.snap index 3302b777..c77f3b07 100755 --- a/internal/parser/__snapshots__/parser_fault_tolerance_test.snap +++ b/internal/parser/__snapshots__/parser_fault_tolerance_test.snap @@ -1,21 +1,27 @@ [TestFaultToleranceVarName - 1] parser.Program{ - Vars: { - { - Range: parser.Range{ - Start: parser.Position{Character:7, Line:0}, - End: parser.Position{Character:18, Line:0}, - }, - Name: (*parser.Variable)(nil), - Type: &parser.TypeDecl{ + Vars: &parser.VarDeclarations{ + Range: parser.Range{ + Start: parser.Position{}, + End: parser.Position{Character:21, Line:0}, + }, + Declarations: { + { Range: parser.Range{ Start: parser.Position{Character:7, Line:0}, - End: parser.Position{Character:15, Line:0}, + End: parser.Position{Character:18, Line:0}, }, - Name: "monetary", + Name: (*parser.Variable)(nil), + Type: &parser.TypeDecl{ + Range: parser.Range{ + Start: parser.Position{Character:7, Line:0}, + End: parser.Position{Character:15, Line:0}, + }, + Name: "monetary", + }, + Origin: (*parser.FnCall)(nil), }, - Origin: (*parser.FnCall)(nil), }, }, Statements: nil, @@ -24,7 +30,7 @@ parser.Program{ [TestFaultToleranceSend - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -41,7 +47,7 @@ parser.Program{ [TestFaultToleranceMonetary - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -58,7 +64,7 @@ parser.Program{ [TestFaultToleranceNoAddr - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -81,7 +87,7 @@ parser.Program{ [TestFaultToleranceInvalidDest - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -133,7 +139,7 @@ parser.Program{ [TestFaultToleranceInvalidSrcTk - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -192,7 +198,7 @@ parser.Program{ [TestFaultToleranceTrailingComma - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -223,7 +229,7 @@ parser.Program{ [TestFaultToleranceDestinationNoRemainingMispelledFrom - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -310,27 +316,33 @@ parser.Program{ [TestFaultToleranceIncompleteOrigin - 1] parser.Program{ - Vars: { - { - Range: parser.Range{ - Start: parser.Position{Character:1, Line:2}, - End: parser.Position{Character:11, Line:2}, - }, - Name: &parser.Variable{ - Range: parser.Range{ - Start: parser.Position{Character:7, Line:2}, - End: parser.Position{Character:9, Line:2}, - }, - Name: "a", - }, - Type: &parser.TypeDecl{ + Vars: &parser.VarDeclarations{ + Range: parser.Range{ + Start: parser.Position{Character:0, Line:1}, + End: parser.Position{Character:1, Line:3}, + }, + Declarations: { + { Range: parser.Range{ Start: parser.Position{Character:1, Line:2}, - End: parser.Position{Character:6, Line:2}, + End: parser.Position{Character:11, Line:2}, + }, + Name: &parser.Variable{ + Range: parser.Range{ + Start: parser.Position{Character:7, Line:2}, + End: parser.Position{Character:9, Line:2}, + }, + Name: "a", + }, + Type: &parser.TypeDecl{ + Range: parser.Range{ + Start: parser.Position{Character:1, Line:2}, + End: parser.Position{Character:6, Line:2}, + }, + Name: "asset", }, - Name: "asset", + Origin: (*parser.FnCall)(nil), }, - Origin: (*parser.FnCall)(nil), }, }, Statements: nil, @@ -339,7 +351,7 @@ parser.Program{ [TestFaultToleranceIncompleteSave - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SaveStatement{ Range: parser.Range{ @@ -355,7 +367,7 @@ parser.Program{ [TestFaultToleranceIncompleteSave2 - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SaveStatement{ Range: parser.Range{ @@ -396,7 +408,7 @@ parser.Program{ [TestFaultToleranceIncompleteSave3 - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SaveStatement{ Range: parser.Range{ diff --git a/internal/parser/__snapshots__/parser_test.snap b/internal/parser/__snapshots__/parser_test.snap index 81b55bdd..3de4f8d8 100755 --- a/internal/parser/__snapshots__/parser_test.snap +++ b/internal/parser/__snapshots__/parser_test.snap @@ -1,6 +1,6 @@ [TestPlainAddress - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -62,7 +62,7 @@ parser.Program{ [TestVariable - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -120,7 +120,7 @@ parser.Program{ [TestAllotment - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -219,7 +219,7 @@ parser.Program{ [TestAllotmentPerc - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -371,7 +371,7 @@ parser.Program{ [TestAllotmentPercFloating - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -460,7 +460,7 @@ parser.Program{ [TestAllotmentDest - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -561,7 +561,7 @@ parser.Program{ [TestCapped - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -649,7 +649,7 @@ parser.Program{ [TestCappedVariable - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -724,7 +724,7 @@ parser.Program{ [TestNested - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -842,7 +842,7 @@ parser.Program{ [TestMultipleSends - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -956,32 +956,47 @@ parser.Program{ } --- [TestEmptyVars - 1] -parser.Program{} +parser.Program{ + Vars: &parser.VarDeclarations{ + Range: parser.Range{ + Start: parser.Position{}, + End: parser.Position{Character:8, Line:0}, + }, + Declarations: nil, + }, + Statements: nil, +} --- [TestSingleVar - 1] parser.Program{ - Vars: { - { - Range: parser.Range{ - Start: parser.Position{Character:7, Line:0}, - End: parser.Position{Character:23, Line:0}, - }, - Name: &parser.Variable{ + Vars: &parser.VarDeclarations{ + Range: parser.Range{ + Start: parser.Position{}, + End: parser.Position{Character:25, Line:0}, + }, + Declarations: { + { Range: parser.Range{ - Start: parser.Position{Character:16, Line:0}, + Start: parser.Position{Character:7, Line:0}, End: parser.Position{Character:23, Line:0}, }, - Name: "my_var", - }, - Type: &parser.TypeDecl{ - Range: parser.Range{ - Start: parser.Position{Character:7, Line:0}, - End: parser.Position{Character:15, Line:0}, + Name: &parser.Variable{ + Range: parser.Range{ + Start: parser.Position{Character:16, Line:0}, + End: parser.Position{Character:23, Line:0}, + }, + Name: "my_var", + }, + Type: &parser.TypeDecl{ + Range: parser.Range{ + Start: parser.Position{Character:7, Line:0}, + End: parser.Position{Character:15, Line:0}, + }, + Name: "monetary", }, - Name: "monetary", + Origin: (*parser.FnCall)(nil), }, - Origin: (*parser.FnCall)(nil), }, }, Statements: nil, @@ -990,7 +1005,7 @@ parser.Program{ [TestVariableMonetary - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1039,7 +1054,7 @@ parser.Program{ [TestAllotmentDestRemaining - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1165,7 +1180,7 @@ parser.Program{ [TestAllotmentVariableSource - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1250,7 +1265,7 @@ parser.Program{ [TestOverdraftUnbounded - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1304,7 +1319,7 @@ parser.Program{ [TestOverdraftUnboundedVariable - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1356,7 +1371,7 @@ parser.Program{ [TestBoundedOverdraft - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1427,7 +1442,7 @@ parser.Program{ [TestFunctionCallNoArgs - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -1449,7 +1464,7 @@ parser.Program{ [TestFunctionCallOneArg - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -1481,7 +1496,7 @@ parser.Program{ [TestFunctionCallManyArgs - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -1566,54 +1581,60 @@ parser.Program{ [TestVarOrigin - 1] parser.Program{ - Vars: { - { - Range: parser.Range{ - Start: parser.Position{Character:1, Line:2}, - End: parser.Position{Character:49, Line:2}, - }, - Name: &parser.Variable{ - Range: parser.Range{ - Start: parser.Position{Character:10, Line:2}, - End: parser.Position{Character:17, Line:2}, - }, - Name: "my_var", - }, - Type: &parser.TypeDecl{ + Vars: &parser.VarDeclarations{ + Range: parser.Range{ + Start: parser.Position{Character:0, Line:1}, + End: parser.Position{Character:1, Line:3}, + }, + Declarations: { + { Range: parser.Range{ Start: parser.Position{Character:1, Line:2}, - End: parser.Position{Character:9, Line:2}, - }, - Name: "monetary", - }, - Origin: &parser.FnCall{ - Range: parser.Range{ - Start: parser.Position{Character:20, Line:2}, End: parser.Position{Character:49, Line:2}, }, - Caller: &parser.FnCallIdentifier{ + Name: &parser.Variable{ Range: parser.Range{ - Start: parser.Position{Character:20, Line:2}, - End: parser.Position{Character:29, Line:2}, + Start: parser.Position{Character:10, Line:2}, + End: parser.Position{Character:17, Line:2}, + }, + Name: "my_var", + }, + Type: &parser.TypeDecl{ + Range: parser.Range{ + Start: parser.Position{Character:1, Line:2}, + End: parser.Position{Character:9, Line:2}, }, - Name: "origin_fn", + Name: "monetary", }, - Args: { - &parser.AccountInterpLiteral{ + Origin: &parser.FnCall{ + Range: parser.Range{ + Start: parser.Position{Character:20, Line:2}, + End: parser.Position{Character:49, Line:2}, + }, + Caller: &parser.FnCallIdentifier{ Range: parser.Range{ - Start: parser.Position{Character:30, Line:2}, - End: parser.Position{Character:41, Line:2}, - }, - Parts: { - parser.AccountTextPart{Name:"my_account"}, + Start: parser.Position{Character:20, Line:2}, + End: parser.Position{Character:29, Line:2}, }, + Name: "origin_fn", }, - &parser.StringLiteral{ - Range: parser.Range{ - Start: parser.Position{Character:43, Line:2}, - End: parser.Position{Character:48, Line:2}, + Args: { + &parser.AccountInterpLiteral{ + Range: parser.Range{ + Start: parser.Position{Character:30, Line:2}, + End: parser.Position{Character:41, Line:2}, + }, + Parts: { + parser.AccountTextPart{Name:"my_account"}, + }, + }, + &parser.StringLiteral{ + Range: parser.Range{ + Start: parser.Position{Character:43, Line:2}, + End: parser.Position{Character:48, Line:2}, + }, + String: "str", }, - String: "str", }, }, }, @@ -1625,7 +1646,7 @@ parser.Program{ [TestInorderSource - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1706,7 +1727,7 @@ parser.Program{ [TestInorderDestination - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1823,7 +1844,7 @@ parser.Program{ [TestSendAll - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1872,7 +1893,7 @@ parser.Program{ [TestAllotmentDestKept - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -1966,7 +1987,7 @@ parser.Program{ [TestWhitespaceInRatio - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -2072,7 +2093,7 @@ mismatched input 'ee' expecting {'oneof', '(', '[', '{', PERCENTAGE_PORTION_LITE [TestNegativeNumberLit - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -2134,7 +2155,7 @@ parser.Program{ [TestSaveStatementSimple - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SaveStatement{ Range: parser.Range{ @@ -2183,7 +2204,7 @@ parser.Program{ [TestSaveAllStatement - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SaveStatement{ Range: parser.Range{ @@ -2219,7 +2240,7 @@ parser.Program{ [TestSaveStatementVar - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SaveStatement{ Range: parser.Range{ @@ -2253,7 +2274,7 @@ parser.Program{ [TestInfix - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -2375,7 +2396,7 @@ parser.Program{ [TestInfixPrec - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -2440,7 +2461,7 @@ parser.Program{ [TestNumberSyntaxUnderscore - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -2477,7 +2498,7 @@ parser.Program{ [TestParensInfixPrec - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -2542,7 +2563,7 @@ parser.Program{ [TestOneofSource - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -2610,7 +2631,7 @@ parser.Program{ [TestOneofDestination - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.SendStatement{ Range: parser.Range{ @@ -2714,7 +2735,7 @@ parser.Program{ [TestDivInfix - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -2765,7 +2786,7 @@ parser.Program{ [TestDivVariableDenominator - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -2816,7 +2837,7 @@ parser.Program{ [TestInfixSumDiv - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -2881,7 +2902,7 @@ parser.Program{ [TestStringTemplate - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ @@ -2918,7 +2939,7 @@ parser.Program{ [TestInterpAccount - 1] parser.Program{ - Vars: nil, + Vars: (*parser.VarDeclarations)(nil), Statements: { &parser.FnCall{ Range: parser.Range{ diff --git a/internal/parser/ast.go b/internal/parser/ast.go index 45de0fed..2825becf 100644 --- a/internal/parser/ast.go +++ b/internal/parser/ast.go @@ -309,7 +309,12 @@ type VarDeclaration struct { Origin *FnCall } +type VarDeclarations struct { + Range + Declarations []VarDeclaration +} + type Program struct { - Vars []VarDeclaration + Vars *VarDeclarations Statements []Statement } diff --git a/internal/parser/parser.go b/internal/parser/parser.go index 1a4ad25a..25382a73 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -75,19 +75,23 @@ func ParseErrorsToString(errors []ParserError, source string) string { return buf } -func parseVarsDeclaration(varsCtx antlrParser.IVarsDeclarationContext) []VarDeclaration { +func parseVarsDeclaration(varsCtx antlrParser.IVarsDeclarationContext) *VarDeclarations { + if varsCtx == nil { return nil } - var vars []VarDeclaration + varBlock := VarDeclarations{ + Range: ctxToRange(varsCtx), + } + for _, varDecl := range varsCtx.AllVarDeclaration() { decl := parseVarDeclaration(varDecl) if decl != nil { - vars = append(vars, *decl) + varBlock.Declarations = append(varBlock.Declarations, *decl) } } - return vars + return &varBlock } func parseProgram(programCtx antlrParser.IProgramContext) Program { diff --git a/numscript.go b/numscript.go index a56b02ed..98df032a 100644 --- a/numscript.go +++ b/numscript.go @@ -18,7 +18,11 @@ type ParseResult struct { func (p ParseResult) GetNeededVariables() map[string]string { m := make(map[string]string) - for _, varDecl := range p.parseResult.Value.Vars { + if p.parseResult.Value.Vars == nil { + return m + } + + for _, varDecl := range p.parseResult.Value.Vars.Declarations { if varDecl.Name == nil || varDecl.Origin != nil { continue } diff --git a/numscript_test.go b/numscript_test.go index 9536786f..29c63748 100644 --- a/numscript_test.go +++ b/numscript_test.go @@ -35,6 +35,28 @@ func TestGetVars(t *testing.T) { } +func TestGetVarsEmpty(t *testing.T) { + parseResult := numscript.Parse(` + vars {} +`) + + require.Empty(t, parseResult.GetParsingErrors(), "There should not be parsing errors") + require.Equal(t, + map[string]string{}, + parseResult.GetNeededVariables(), + ) +} + +func TestGetVarsNovars(t *testing.T) { + parseResult := numscript.Parse(``) + + require.Empty(t, parseResult.GetParsingErrors(), "There should not be parsing errors") + require.Equal(t, + map[string]string{}, + parseResult.GetNeededVariables(), + ) +} + func TestDoNotGetWorldBalance(t *testing.T) { parseResult := numscript.Parse(`send [COIN 100] ( source = @world