diff --git a/internal/analysis/check.go b/internal/analysis/check.go index df7d570f..9c755800 100644 --- a/internal/analysis/check.go +++ b/internal/analysis/check.go @@ -123,17 +123,60 @@ type Diagnostic struct { } type CheckResult struct { - version parser.Version nextDiagnosticId int32 unboundedAccountInSend parser.ValueExpr emptiedAccount map[string]struct{} unboundedSend bool - declaredVars map[string]parser.VarDeclaration + DeclaredVars map[string]parser.VarDeclaration unusedVars map[string]parser.Range varResolution map[*parser.Variable]parser.VarDeclaration fnCallResolution map[*parser.FnCallIdentifier]FnCallResolution Diagnostics []Diagnostic Program parser.Program + + stmtType Type + ExprTypes map[parser.ValueExpr]Type + VarTypes map[parser.VarDeclaration]Type +} + +func (r *CheckResult) getExprType(expr parser.ValueExpr) Type { + exprType, ok := r.ExprTypes[expr] + if !ok { + t := TVar{} + r.ExprTypes[expr] = &t + return &t + } + return exprType +} + +func (r *CheckResult) getVarDeclType(decl parser.VarDeclaration) Type { + exprType, ok := r.VarTypes[decl] + if !ok { + t := TVar{} + r.VarTypes[decl] = &t + return &t + } + return exprType +} + +func (r *CheckResult) unifyNodeWith(expr parser.ValueExpr, t Type) { + exprT := r.getExprType(expr) + r.unify(expr.GetRange(), exprT, t) +} + +func (r *CheckResult) unify(rng parser.Range, t1 Type, t2 Type) { + ok := Unify(t1, t2) + if ok { + return + } + + r.Diagnostics = append(r.Diagnostics, Diagnostic{ + Range: rng, + Kind: &AssetMismatch{ + Expected: TypeToString(t1), + Got: TypeToString(t2), + }, + }) } func (r CheckResult) GetErrorsCount() int { @@ -174,13 +217,15 @@ func (r CheckResult) ResolveBuiltinFn(v *parser.FnCallIdentifier) FnCallResoluti func newCheckResult(program parser.Program) CheckResult { return CheckResult{ - version: program.GetVersion(), + Program: program, + emptiedAccount: make(map[string]struct{}), - declaredVars: make(map[string]parser.VarDeclaration), + DeclaredVars: make(map[string]parser.VarDeclaration), unusedVars: make(map[string]parser.Range), varResolution: make(map[*parser.Variable]parser.VarDeclaration), fnCallResolution: make(map[*parser.FnCallIdentifier]FnCallResolution), - Program: program, + ExprTypes: make(map[parser.ValueExpr]Type), + VarTypes: make(map[parser.VarDeclaration]Type), } } @@ -197,6 +242,7 @@ func (res *CheckResult) check() { if varDecl.Origin != nil { res.checkExpression(*varDecl.Origin, varDecl.Type.Name) + res.unifyNodeWith(*varDecl.Origin, res.getVarDeclType(varDecl)) } } } @@ -214,6 +260,7 @@ func (res *CheckResult) check() { func (res *CheckResult) checkStatement(statement parser.Statement) { res.emptiedAccount = make(map[string]struct{}) + res.stmtType = &TVar{} switch statement := statement.(type) { case *parser.SaveStatement: @@ -305,6 +352,16 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) { type_ := sig[index] res.checkExpression(arg, type_) } + + switch fnCall.Caller.Name { + case FnVarOriginBalance, FnVarOriginOverdraft: + // we run unify(, ) in: + // := balance(@acc, ) + res.unifyNodeWith(fnCall, res.getExprType(validArgs[1])) + + case FnVarOriginGetAsset: + res.unifyNodeWith(fnCall, res.getExprType(validArgs[0])) + } } else { for _, arg := range validArgs { res.checkExpression(arg, TypeAny) @@ -328,15 +385,15 @@ func (res *CheckResult) checkVarType(typeDecl parser.TypeDecl) { func (res *CheckResult) checkDuplicateVars(variableName parser.Variable, decl parser.VarDeclaration) { // check there aren't duplicate variables - if _, ok := res.declaredVars[variableName.Name]; ok { + if _, ok := res.DeclaredVars[variableName.Name]; ok { res.pushDiagnostic(variableName.Range, DuplicateVariable{Name: variableName.Name}) } else { - res.declaredVars[variableName.Name] = decl + res.DeclaredVars[variableName.Name] = decl res.unusedVars[variableName.Name] = variableName.Range } } -func (res *CheckResult) checkFnCall(fnCall parser.FnCall) string { +func (res *CheckResult) checkFnCall(fnCall *parser.FnCall) string { returnType := TypeAny if resolution, ok := Builtins[fnCall.Caller.Name]; ok { @@ -349,7 +406,7 @@ func (res *CheckResult) checkFnCall(fnCall parser.FnCall) string { } // this must come after resolution - res.checkFnCallArity(&fnCall) + res.checkFnCallArity(fnCall) return returnType } @@ -362,8 +419,9 @@ func (res *CheckResult) checkExpression(lit parser.ValueExpr, requiredType strin func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) string { switch lit := lit.(type) { case *parser.Variable: - if varDeclaration, ok := res.declaredVars[lit.Name]; ok { + if varDeclaration, ok := res.DeclaredVars[lit.Name]; ok { res.varResolution[lit] = varDeclaration + res.unifyNodeWith(lit, res.getVarDeclType(varDeclaration)) } else { res.pushDiagnostic(lit.Range, UnboundVariable{Name: lit.Name, Type: typeHint}) } @@ -378,9 +436,16 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin case *parser.MonetaryLiteral: res.checkExpression(lit.Asset, TypeAsset) res.checkExpression(lit.Amount, TypeNumber) + /* + we unify $mon and $asset in: + `let $mon := [$asset 42]` + */ + res.unifyNodeWith(lit, res.getExprType(lit.Asset)) return TypeMonetary case *parser.BinaryInfix: + res.unifyNodeWith(lit.Left, res.getExprType(lit.Right)) + switch lit.Operator { case parser.InfixOperatorPlus: return res.checkInfixOverload(lit, []string{TypeNumber, TypeMonetary}) @@ -415,6 +480,8 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin case *parser.PercentageLiteral: return TypePortion case *parser.AssetLiteral: + t := TAsset(lit.Asset) + res.unifyNodeWith(lit, &t) return TypeAsset case *parser.NumberLiteral: return TypeNumber @@ -422,7 +489,7 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin return TypeString case *parser.FnCall: - return res.checkFnCall(*lit) + return res.checkFnCall(lit) default: return TypeAny @@ -459,8 +526,10 @@ func (res *CheckResult) checkSentValue(sentValue parser.SentValue) { switch sentValue := sentValue.(type) { case *parser.SentValueAll: res.checkExpression(sentValue.Asset, TypeAsset) + res.unifyNodeWith(sentValue.Asset, res.stmtType) case *parser.SentValueLiteral: res.checkExpression(sentValue.Monetary, TypeMonetary) + res.unifyNodeWith(sentValue.Monetary, res.stmtType) } } @@ -521,6 +590,7 @@ func (res *CheckResult) checkSource(source parser.Source) { res.checkExpression(source.Color, TypeString) if source.Bounded != nil { res.checkExpression(*source.Bounded, TypeMonetary) + res.unifyNodeWith(*source.Bounded, res.stmtType) } case *parser.SourceInorder: @@ -538,6 +608,7 @@ func (res *CheckResult) checkSource(source parser.Source) { case *parser.SourceCapped: onExit := res.enterCappedSource() + res.unifyNodeWith(source.Cap, res.stmtType) res.checkExpression(source.Cap, TypeMonetary) res.checkSource(source.From) @@ -680,6 +751,7 @@ func (res *CheckResult) checkDestination(destination parser.Destination) { case *parser.DestinationInorder: for _, clause := range destination.Clauses { res.checkExpression(clause.Cap, TypeMonetary) + res.unifyNodeWith(clause.Cap, res.stmtType) res.checkKeptOrDestination(clause.To) } res.checkKeptOrDestination(destination.Remaining) @@ -689,6 +761,7 @@ func (res *CheckResult) checkDestination(destination parser.Destination) { for _, clause := range destination.Clauses { res.checkExpression(clause.Cap, TypeMonetary) + res.unifyNodeWith(clause.Cap, res.stmtType) res.checkKeptOrDestination(clause.To) } res.checkKeptOrDestination(destination.Remaining) diff --git a/internal/analysis/check_test.go b/internal/analysis/check_test.go index 5a5bf31c..7cb96553 100644 --- a/internal/analysis/check_test.go +++ b/internal/analysis/check_test.go @@ -968,3 +968,125 @@ func TestInorderRedundantWhenEmptyColored(t *testing.T) { checkSource(input), ) } + +func TestCheckAssetMismatch(t *testing.T) { + + t.Parallel() + + input := ` + + send [USD 100] ( + source = max [EUR 10] from @a + destination = @dest +)` + + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "[EUR 10]", 0), + Kind: &analysis.AssetMismatch{Expected: "USD", Got: "EUR"}, + }, + }, + checkSource(input), + ) +} + +func TestCheckAssetMismatchInVar(t *testing.T) { + + t.Parallel() + + input := ` + +vars { + monetary $mon +} + +send [EUR 0] ( + source = max $mon from @a + destination = @b +) + +send [USD 0] ( + source = max $mon from @a + destination = @b +) + +` + + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$mon", 2), + Kind: &analysis.AssetMismatch{Expected: "USD", Got: "EUR"}, + }, + }, + checkSource(input), + ) +} + +func TestCheckBalanceAssetConstraint(t *testing.T) { + t.Parallel() + + input := ` +vars { + monetary $mon = balance(@acc, USD/2) +} + +send [USD 42] ( + source = max $mon from @a + destination = @b +) +` + + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$mon", 1), + Kind: &analysis.AssetMismatch{Expected: "USD", Got: "USD/2"}, + }, + }, + checkSource(input), + ) +} + +func TestInferVars(t *testing.T) { + t.Parallel() + + input := ` +vars { + monetary $mon1 + monetary $mon2 +} + +send $mon1 ( + source = @a allowing overdraft up to $mon2 + destination = @b +) +` + + res := analysis.CheckSource(input) + + t1 := res.VarTypes[res.DeclaredVars["mon1"]] + + t2 := res.VarTypes[res.DeclaredVars["mon2"]] + + require.Same(t, t1.Resolve(), t2.Resolve()) +} + +func TestInferGetAsset(t *testing.T) { + t.Parallel() + + input := ` +vars { + asset $ass = get_asset([USD/2 100]) +} +` + + res := analysis.CheckSource(input) + + v := res.DeclaredVars["ass"] + t1 := res.VarTypes[v] + + expected := analysis.TAsset("USD/2") + require.Equal(t, &expected, t1.Resolve()) +} diff --git a/internal/analysis/diagnostic_kind.go b/internal/analysis/diagnostic_kind.go index 605042e6..8b2a376b 100644 --- a/internal/analysis/diagnostic_kind.go +++ b/internal/analysis/diagnostic_kind.go @@ -127,6 +127,19 @@ func (TypeMismatch) Severity() Severity { return ErrorSeverity } +type AssetMismatch struct { + Expected string + Got string +} + +func (e AssetMismatch) Message() string { + return fmt.Sprintf("Asset mismatch (expected '%s', got '%s' instead)", e.Expected, e.Got) +} + +func (AssetMismatch) Severity() Severity { + return ErrorSeverity +} + type RemainingIsNotLast struct{} func (e RemainingIsNotLast) Message() string { diff --git a/internal/analysis/document_symbols.go b/internal/analysis/document_symbols.go index 2b60a863..9b210efe 100644 --- a/internal/analysis/document_symbols.go +++ b/internal/analysis/document_symbols.go @@ -25,7 +25,7 @@ type DocumentSymbol struct { // results are sorted by start position func (r *CheckResult) GetSymbols() []DocumentSymbol { var symbols []DocumentSymbol - for k, v := range r.declaredVars { + for k, v := range r.DeclaredVars { symbols = append(symbols, DocumentSymbol{ Name: k, Kind: DocumentSymbolVariable, diff --git a/internal/analysis/union_find.go b/internal/analysis/union_find.go new file mode 100644 index 00000000..4fd7086f --- /dev/null +++ b/internal/analysis/union_find.go @@ -0,0 +1,84 @@ +package analysis + +import ( + "fmt" + + "github.com/formancehq/numscript/internal/utils" +) + +type Type interface { + Resolve() Type +} + +var _ Type = (*TVar)(nil) +var _ Type = (*TAsset)(nil) + +// Impls + +func (t *TVar) Resolve() Type { + if t.resolution == nil { + return t + } + + resolved := t.resolution.Resolve() + if resolved == t { + return t + } + + // This bit doesn't change the behaviour but allows to return the path right away + // the next time we call Resolve() + t.resolution = resolved + + return resolved +} + +type TVar struct { + resolution Type +} + +type TAsset string + +func (a *TAsset) Resolve() Type { + return a +} + +func Unify(t1 Type, t2 Type) (ok bool) { + t1 = t1.Resolve() + t2 = t2.Resolve() + + switch t1 := t1.(type) { + case *TAsset: + switch t2 := t2.(type) { + case *TAsset: + return string(*t1) == string(*t2) + + case *TVar: + return Unify(t2, t1) + } + + case *TVar: + // We must avoid cycles when unifying a var with itself + if t1 == t2 { + return true + } + + // t1 is a tvar, so we can always unify it with t2 + t1.resolution = t2 + return true + } + + return false +} + +func TypeToString(r Type) string { + r = r.Resolve() + switch r := r.(type) { + case *TVar: + return fmt.Sprintf("'%p", r) + + case *TAsset: + return string(*r) + } + + return utils.NonExhaustiveMatchPanic[string](r) +} diff --git a/internal/analysis/union_find_test.go b/internal/analysis/union_find_test.go new file mode 100644 index 00000000..91a580a9 --- /dev/null +++ b/internal/analysis/union_find_test.go @@ -0,0 +1,95 @@ +package analysis_test + +import ( + "testing" + + "github.com/formancehq/numscript/internal/analysis" + "github.com/stretchr/testify/require" +) + +func TestResolveConcrete(t *testing.T) { + t1 := analysis.TAsset("USD") + out := t1.Resolve() + require.Equal(t, &t1, out) +} + +func TestUnifyConcreteWhenNotSame(t *testing.T) { + t1 := analysis.TAsset("USD") + t2 := analysis.TAsset("EUR") + ok := analysis.Unify(&t1, &t2) + require.False(t, ok) +} + +func TestUnifyConcreteWhenSame(t *testing.T) { + t1 := analysis.TAsset("USD") + t2 := analysis.TAsset("USD") + ok := analysis.Unify(&t1, &t2) + require.True(t, ok) +} + +func TestUnifyItselfIsNoop(t *testing.T) { + t1 := &analysis.TVar{} + ok := analysis.Unify(t1, t1) + require.True(t, ok) + + require.Same(t, t1.Resolve(), t1) +} + +func TestResolveUnbound(t *testing.T) { + t1 := &analysis.TVar{} + require.Same(t, t1.Resolve(), t1) +} + +func TestUnifyVarWithConcrete(t *testing.T) { + t1 := &analysis.TVar{} + t2 := analysis.TAsset("USD") + + ok := analysis.Unify(t1, &t2) + require.True(t, ok) + + require.Same(t, t1.Resolve(), &t2) +} + +func TestUnifyTransitive(t *testing.T) { + t1 := &analysis.TVar{} + t2 := &analysis.TVar{} + t3 := &analysis.TVar{} + + // t1->t2->t3 + + ok := analysis.Unify(t1, t2) + require.True(t, ok) + + ok = analysis.Unify(t1, t3) + require.True(t, ok) + + t4 := analysis.TAsset("USD") + ok = analysis.Unify(t1, &t4) + require.True(t, ok) + + require.Same(t, t1.Resolve(), &t4) + require.Same(t, t2.Resolve(), &t4) + require.Same(t, t3.Resolve(), &t4) +} + +func TestUnifyTransitiveInverse(t *testing.T) { + t1 := &analysis.TVar{} + t2 := &analysis.TVar{} + t3 := &analysis.TVar{} + + // t1->t2->t3 + + ok := analysis.Unify(t1, t2) + require.True(t, ok) + + ok = analysis.Unify(t1, t3) + require.True(t, ok) + + t4 := analysis.TAsset("USD") + ok = analysis.Unify(t3, &t4) + require.True(t, ok) + + require.Same(t, t1.Resolve(), &t4) + require.Same(t, t2.Resolve(), &t4) + require.Same(t, t3.Resolve(), &t4) +}