From 1e28af8b7af9d51ef85946ff7c6e65440a8bc81f Mon Sep 17 00:00:00 2001 From: ascandone Date: Tue, 16 Sep 2025 16:14:38 +0200 Subject: [PATCH 01/10] union find draft --- internal/analysis/union_find.go | 54 +++++++++++++++++ internal/analysis/union_find_test.go | 87 ++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 internal/analysis/union_find.go create mode 100644 internal/analysis/union_find_test.go diff --git a/internal/analysis/union_find.go b/internal/analysis/union_find.go new file mode 100644 index 00000000..8db6ce70 --- /dev/null +++ b/internal/analysis/union_find.go @@ -0,0 +1,54 @@ +package analysis + +type Type interface { + Resolve() Type +} + +var _ Type = (*TVar)(nil) +var _ Type = (*Asset)(nil) + +// Impls + +func (t *TVar) Resolve() Type { + if t.resolution == nil { + return t + } + + resolved := t.resolution + + // TODO path compression + return resolved.Resolve() +} + +type TVar struct { + resolution Type +} + +type Asset string + +func (a *Asset) Resolve() Type { + return a +} + +func Unify(t1 Type, t2 Type) (ok bool) { + t1 = t1.Resolve() + t2 = t2.Resolve() + + switch t1 := t1.(type) { + case *Asset: + switch t2 := t2.(type) { + case *Asset: + return string(*t1) == string(*t2) + + case *TVar: + return Unify(t2, t1) + } + + case *TVar: + // t1 is a tvar, so we can always unify it with t2 + t1.resolution = t2 + return true + } + + return false +} diff --git a/internal/analysis/union_find_test.go b/internal/analysis/union_find_test.go new file mode 100644 index 00000000..6261808a --- /dev/null +++ b/internal/analysis/union_find_test.go @@ -0,0 +1,87 @@ +package analysis_test + +import ( + "testing" + + "github.com/formancehq/numscript/internal/analysis" + "github.com/stretchr/testify/require" +) + +func TestResolveConcrete(t *testing.T) { + t1 := analysis.Asset("USD") + out := t1.Resolve() + require.Equal(t, &t1, out) +} + +func TestUnifyConcreteWhenNotSame(t *testing.T) { + t1 := analysis.Asset("USD") + t2 := analysis.Asset("EUR") + ok := analysis.Unify(&t1, &t2) + require.False(t, ok) +} + +func TestUnifyConcreteWhenSame(t *testing.T) { + t1 := analysis.Asset("USD") + t2 := analysis.Asset("USD") + ok := analysis.Unify(&t1, &t2) + require.True(t, ok) +} + +func TestResolveUnbound(t *testing.T) { + t1 := &analysis.TVar{} + require.Same(t, t1.Resolve(), t1) +} + +func TestUnifyVarWithConcrete(t *testing.T) { + t1 := &analysis.TVar{} + t2 := analysis.Asset("USD") + + ok := analysis.Unify(t1, &t2) + require.True(t, ok) + + require.Same(t, t1.Resolve(), &t2) +} + +func TestUnifyTransitive(t *testing.T) { + t1 := &analysis.TVar{} + t2 := &analysis.TVar{} + t3 := &analysis.TVar{} + + // t1->t2->t3 + + ok := analysis.Unify(t1, t2) + require.True(t, ok) + + ok = analysis.Unify(t1, t3) + require.True(t, ok) + + t4 := analysis.Asset("USD") + ok = analysis.Unify(t1, &t4) + require.True(t, ok) + + require.Same(t, t1.Resolve(), &t4) + require.Same(t, t2.Resolve(), &t4) + require.Same(t, t3.Resolve(), &t4) +} + +func TestUnifyTransitiveInverse(t *testing.T) { + t1 := &analysis.TVar{} + t2 := &analysis.TVar{} + t3 := &analysis.TVar{} + + // t1->t2->t3 + + ok := analysis.Unify(t1, t2) + require.True(t, ok) + + ok = analysis.Unify(t1, t3) + require.True(t, ok) + + t4 := analysis.Asset("USD") + ok = analysis.Unify(t3, &t4) + require.True(t, ok) + + require.Same(t, t1.Resolve(), &t4) + require.Same(t, t2.Resolve(), &t4) + require.Same(t, t3.Resolve(), &t4) +} From 31abef73025a5f6e7a7562267be3e7c7b9468c57 Mon Sep 17 00:00:00 2001 From: ascandone Date: Tue, 16 Sep 2025 16:18:49 +0200 Subject: [PATCH 02/10] perf opt --- internal/analysis/union_find.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/internal/analysis/union_find.go b/internal/analysis/union_find.go index 8db6ce70..57c4f239 100644 --- a/internal/analysis/union_find.go +++ b/internal/analysis/union_find.go @@ -14,10 +14,12 @@ func (t *TVar) Resolve() Type { return t } - resolved := t.resolution + resolved := t.resolution.Resolve() - // TODO path compression - return resolved.Resolve() + // This bit doesn't change the behaviour but + t.resolution = resolved + + return resolved } type TVar struct { From 4fa77ac9f617e456bf2213bbe6c93d85ed6eb0f4 Mon Sep 17 00:00:00 2001 From: ascandone Date: Fri, 19 Sep 2025 12:24:36 +0200 Subject: [PATCH 03/10] fix cycles in unify --- internal/analysis/union_find.go | 27 +++++++++++++++++++++++++++ internal/analysis/union_find_test.go | 8 ++++++++ 2 files changed, 35 insertions(+) diff --git a/internal/analysis/union_find.go b/internal/analysis/union_find.go index 57c4f239..5e95c648 100644 --- a/internal/analysis/union_find.go +++ b/internal/analysis/union_find.go @@ -1,5 +1,11 @@ package analysis +import ( + "fmt" + + "github.com/formancehq/numscript/internal/utils" +) + type Type interface { Resolve() Type } @@ -15,6 +21,9 @@ func (t *TVar) Resolve() Type { } resolved := t.resolution.Resolve() + if resolved == t { + return t + } // This bit doesn't change the behaviour but t.resolution = resolved @@ -47,6 +56,11 @@ func Unify(t1 Type, t2 Type) (ok bool) { } case *TVar: + // We must avoid cycles when unifying a var with itself + if t1 == t2 { + return true + } + // t1 is a tvar, so we can always unify it with t2 t1.resolution = t2 return true @@ -54,3 +68,16 @@ func Unify(t1 Type, t2 Type) (ok bool) { return false } + +func TypeToString(r Type) string { + r = r.Resolve() + switch r := r.(type) { + case *TVar: + return fmt.Sprintf("'%p", r) + + case *Asset: + return string(*r) + } + + return utils.NonExhaustiveMatchPanic[string](r) +} diff --git a/internal/analysis/union_find_test.go b/internal/analysis/union_find_test.go index 6261808a..2fc33346 100644 --- a/internal/analysis/union_find_test.go +++ b/internal/analysis/union_find_test.go @@ -27,6 +27,14 @@ func TestUnifyConcreteWhenSame(t *testing.T) { require.True(t, ok) } +func TestUnifyItselfIsNoop(t *testing.T) { + t1 := &analysis.TVar{} + ok := analysis.Unify(t1, t1) + require.True(t, ok) + + require.Same(t, t1.Resolve(), t1) +} + func TestResolveUnbound(t *testing.T) { t1 := &analysis.TVar{} require.Same(t, t1.Resolve(), t1) From a779e27be55d2c9efede4f36dd698af1181f2202 Mon Sep 17 00:00:00 2001 From: ascandone Date: Fri, 19 Sep 2025 12:58:23 +0200 Subject: [PATCH 04/10] improve inference --- internal/analysis/check.go | 68 ++++++++++++++++++++++++++-- internal/analysis/check_test.go | 56 +++++++++++++++++++++++ internal/analysis/diagnostic_kind.go | 13 ++++++ internal/analysis/union_find.go | 12 ++--- internal/analysis/union_find_test.go | 16 +++---- 5 files changed, 148 insertions(+), 17 deletions(-) diff --git a/internal/analysis/check.go b/internal/analysis/check.go index df7d570f..998c0768 100644 --- a/internal/analysis/check.go +++ b/internal/analysis/check.go @@ -123,7 +123,6 @@ type Diagnostic struct { } type CheckResult struct { - version parser.Version nextDiagnosticId int32 unboundedAccountInSend parser.ValueExpr emptiedAccount map[string]struct{} @@ -134,6 +133,50 @@ type CheckResult struct { fnCallResolution map[*parser.FnCallIdentifier]FnCallResolution Diagnostics []Diagnostic Program parser.Program + + stmtType Type + ExprTypes map[parser.ValueExpr]Type + VarTypes map[parser.VarDeclaration]Type +} + +func (r *CheckResult) getExprType(expr parser.ValueExpr) Type { + exprType, ok := r.ExprTypes[expr] + if !ok { + t := TVar{} + r.ExprTypes[expr] = &t + return &t + } + return exprType +} + +func (r *CheckResult) getVarDeclType(decl parser.VarDeclaration) Type { + exprType, ok := r.VarTypes[decl] + if !ok { + t := TVar{} + r.VarTypes[decl] = &t + return &t + } + return exprType +} + +func (r *CheckResult) unifyNodeWith(expr parser.ValueExpr, t Type) { + exprT := r.getExprType(expr) + r.unify(expr.GetRange(), exprT, t) +} + +func (r *CheckResult) unify(rng parser.Range, t1 Type, t2 Type) { + ok := Unify(t1, t2) + if ok { + return + } + + r.Diagnostics = append(r.Diagnostics, Diagnostic{ + Range: rng, + Kind: &AssetMismatch{ + Expected: TypeToString(t1), + Got: TypeToString(t2), + }, + }) } func (r CheckResult) GetErrorsCount() int { @@ -174,13 +217,15 @@ func (r CheckResult) ResolveBuiltinFn(v *parser.FnCallIdentifier) FnCallResoluti func newCheckResult(program parser.Program) CheckResult { return CheckResult{ - version: program.GetVersion(), + Program: program, + emptiedAccount: make(map[string]struct{}), declaredVars: make(map[string]parser.VarDeclaration), unusedVars: make(map[string]parser.Range), varResolution: make(map[*parser.Variable]parser.VarDeclaration), fnCallResolution: make(map[*parser.FnCallIdentifier]FnCallResolution), - Program: program, + ExprTypes: make(map[parser.ValueExpr]Type), + VarTypes: make(map[parser.VarDeclaration]Type), } } @@ -214,6 +259,7 @@ func (res *CheckResult) check() { func (res *CheckResult) checkStatement(statement parser.Statement) { res.emptiedAccount = make(map[string]struct{}) + res.stmtType = &TVar{} switch statement := statement.(type) { case *parser.SaveStatement: @@ -364,6 +410,7 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin case *parser.Variable: if varDeclaration, ok := res.declaredVars[lit.Name]; ok { res.varResolution[lit] = varDeclaration + res.unifyNodeWith(lit, res.getVarDeclType(varDeclaration)) } else { res.pushDiagnostic(lit.Range, UnboundVariable{Name: lit.Name, Type: typeHint}) } @@ -378,6 +425,11 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin case *parser.MonetaryLiteral: res.checkExpression(lit.Asset, TypeAsset) res.checkExpression(lit.Amount, TypeNumber) + /* + we unify $mon and $asset in: + `let $mon := [$asset 42]` + */ + res.unifyNodeWith(lit, res.getExprType(lit.Asset)) return TypeMonetary case *parser.BinaryInfix: @@ -415,6 +467,8 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin case *parser.PercentageLiteral: return TypePortion case *parser.AssetLiteral: + t := TAsset(lit.Asset) + res.unifyNodeWith(lit, &t) return TypeAsset case *parser.NumberLiteral: return TypeNumber @@ -461,6 +515,12 @@ func (res *CheckResult) checkSentValue(sentValue parser.SentValue) { res.checkExpression(sentValue.Asset, TypeAsset) case *parser.SentValueLiteral: res.checkExpression(sentValue.Monetary, TypeMonetary) + + res.unifyNodeWith(sentValue.Monetary, res.stmtType) + res.unifyNodeWith( + sentValue.Monetary, + res.stmtType, + ) } } @@ -538,6 +598,8 @@ func (res *CheckResult) checkSource(source parser.Source) { case *parser.SourceCapped: onExit := res.enterCappedSource() + res.unifyNodeWith(source.Cap, res.stmtType) + res.checkExpression(source.Cap, TypeMonetary) res.checkSource(source.From) diff --git a/internal/analysis/check_test.go b/internal/analysis/check_test.go index 5a5bf31c..6675fc8f 100644 --- a/internal/analysis/check_test.go +++ b/internal/analysis/check_test.go @@ -968,3 +968,59 @@ func TestInorderRedundantWhenEmptyColored(t *testing.T) { checkSource(input), ) } + +func TestCheckAssetMismatch(t *testing.T) { + + t.Parallel() + + input := ` + + send [USD 100] ( + source = max [EUR 10] from @a + destination = @dest +)` + + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "[EUR 10]", 0), + Kind: &analysis.AssetMismatch{Expected: "USD", Got: "EUR"}, + }, + }, + checkSource(input), + ) +} + +func TestCheckAssetMismatchInVar(t *testing.T) { + + t.Parallel() + + input := ` + +vars { + monetary $mon +} + +send [EUR 0] ( + source = max $mon from @a + destination = @b +) + +send [USD 0] ( + source = max $mon from @a + destination = @b +) + +` + + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$mon", 2), + // TODO shoulnd't the error be the other way around? + Kind: &analysis.AssetMismatch{Expected: "USD", Got: "EUR"}, + }, + }, + checkSource(input), + ) +} diff --git a/internal/analysis/diagnostic_kind.go b/internal/analysis/diagnostic_kind.go index 605042e6..8b2a376b 100644 --- a/internal/analysis/diagnostic_kind.go +++ b/internal/analysis/diagnostic_kind.go @@ -127,6 +127,19 @@ func (TypeMismatch) Severity() Severity { return ErrorSeverity } +type AssetMismatch struct { + Expected string + Got string +} + +func (e AssetMismatch) Message() string { + return fmt.Sprintf("Asset mismatch (expected '%s', got '%s' instead)", e.Expected, e.Got) +} + +func (AssetMismatch) Severity() Severity { + return ErrorSeverity +} + type RemainingIsNotLast struct{} func (e RemainingIsNotLast) Message() string { diff --git a/internal/analysis/union_find.go b/internal/analysis/union_find.go index 5e95c648..1752ac28 100644 --- a/internal/analysis/union_find.go +++ b/internal/analysis/union_find.go @@ -11,7 +11,7 @@ type Type interface { } var _ Type = (*TVar)(nil) -var _ Type = (*Asset)(nil) +var _ Type = (*TAsset)(nil) // Impls @@ -35,9 +35,9 @@ type TVar struct { resolution Type } -type Asset string +type TAsset string -func (a *Asset) Resolve() Type { +func (a *TAsset) Resolve() Type { return a } @@ -46,9 +46,9 @@ func Unify(t1 Type, t2 Type) (ok bool) { t2 = t2.Resolve() switch t1 := t1.(type) { - case *Asset: + case *TAsset: switch t2 := t2.(type) { - case *Asset: + case *TAsset: return string(*t1) == string(*t2) case *TVar: @@ -75,7 +75,7 @@ func TypeToString(r Type) string { case *TVar: return fmt.Sprintf("'%p", r) - case *Asset: + case *TAsset: return string(*r) } diff --git a/internal/analysis/union_find_test.go b/internal/analysis/union_find_test.go index 2fc33346..91a580a9 100644 --- a/internal/analysis/union_find_test.go +++ b/internal/analysis/union_find_test.go @@ -8,21 +8,21 @@ import ( ) func TestResolveConcrete(t *testing.T) { - t1 := analysis.Asset("USD") + t1 := analysis.TAsset("USD") out := t1.Resolve() require.Equal(t, &t1, out) } func TestUnifyConcreteWhenNotSame(t *testing.T) { - t1 := analysis.Asset("USD") - t2 := analysis.Asset("EUR") + t1 := analysis.TAsset("USD") + t2 := analysis.TAsset("EUR") ok := analysis.Unify(&t1, &t2) require.False(t, ok) } func TestUnifyConcreteWhenSame(t *testing.T) { - t1 := analysis.Asset("USD") - t2 := analysis.Asset("USD") + t1 := analysis.TAsset("USD") + t2 := analysis.TAsset("USD") ok := analysis.Unify(&t1, &t2) require.True(t, ok) } @@ -42,7 +42,7 @@ func TestResolveUnbound(t *testing.T) { func TestUnifyVarWithConcrete(t *testing.T) { t1 := &analysis.TVar{} - t2 := analysis.Asset("USD") + t2 := analysis.TAsset("USD") ok := analysis.Unify(t1, &t2) require.True(t, ok) @@ -63,7 +63,7 @@ func TestUnifyTransitive(t *testing.T) { ok = analysis.Unify(t1, t3) require.True(t, ok) - t4 := analysis.Asset("USD") + t4 := analysis.TAsset("USD") ok = analysis.Unify(t1, &t4) require.True(t, ok) @@ -85,7 +85,7 @@ func TestUnifyTransitiveInverse(t *testing.T) { ok = analysis.Unify(t1, t3) require.True(t, ok) - t4 := analysis.Asset("USD") + t4 := analysis.TAsset("USD") ok = analysis.Unify(t3, &t4) require.True(t, ok) From e8f5a7623422360cec20a182116abf2f0dcdc7fc Mon Sep 17 00:00:00 2001 From: ascandone Date: Fri, 19 Sep 2025 13:44:01 +0200 Subject: [PATCH 05/10] inference of balance() constraints --- internal/analysis/check.go | 14 +++++++++++--- internal/analysis/check_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/internal/analysis/check.go b/internal/analysis/check.go index 998c0768..655a967d 100644 --- a/internal/analysis/check.go +++ b/internal/analysis/check.go @@ -242,6 +242,7 @@ func (res *CheckResult) check() { if varDecl.Origin != nil { res.checkExpression(*varDecl.Origin, varDecl.Type.Name) + res.unifyNodeWith(*varDecl.Origin, res.getVarDeclType(varDecl)) } } } @@ -351,6 +352,13 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) { type_ := sig[index] res.checkExpression(arg, type_) } + + if fnCall.Caller.Name == FnVarOriginBalance { + // we run unify(, ) in: + // := balance(@acc, ) + assetArg := validArgs[1] + res.unifyNodeWith(fnCall, res.getExprType(assetArg)) + } } else { for _, arg := range validArgs { res.checkExpression(arg, TypeAny) @@ -382,7 +390,7 @@ func (res *CheckResult) checkDuplicateVars(variableName parser.Variable, decl pa } } -func (res *CheckResult) checkFnCall(fnCall parser.FnCall) string { +func (res *CheckResult) checkFnCall(fnCall *parser.FnCall) string { returnType := TypeAny if resolution, ok := Builtins[fnCall.Caller.Name]; ok { @@ -395,7 +403,7 @@ func (res *CheckResult) checkFnCall(fnCall parser.FnCall) string { } // this must come after resolution - res.checkFnCallArity(&fnCall) + res.checkFnCallArity(fnCall) return returnType } @@ -476,7 +484,7 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin return TypeString case *parser.FnCall: - return res.checkFnCall(*lit) + return res.checkFnCall(lit) default: return TypeAny diff --git a/internal/analysis/check_test.go b/internal/analysis/check_test.go index 6675fc8f..f0d49911 100644 --- a/internal/analysis/check_test.go +++ b/internal/analysis/check_test.go @@ -1024,3 +1024,28 @@ send [USD 0] ( checkSource(input), ) } + +func TestCheckBalanceAssetConstraint(t *testing.T) { + t.Parallel() + + input := ` +vars { + monetary $mon = balance(@acc, USD/2) +} + +send [USD 42] ( + source = max $mon from @a + destination = @b +) +` + + require.Equal(t, + []analysis.Diagnostic{ + { + Range: parser.RangeOfIndexed(input, "$mon", 1), + Kind: &analysis.AssetMismatch{Expected: "USD", Got: "USD/2"}, + }, + }, + checkSource(input), + ) +} From 16357fafe096590c0f2b03a8c8437216132be888 Mon Sep 17 00:00:00 2001 From: ascandone Date: Fri, 19 Sep 2025 14:11:12 +0200 Subject: [PATCH 06/10] inference of infix --- internal/analysis/check.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/analysis/check.go b/internal/analysis/check.go index 655a967d..47cdcbcc 100644 --- a/internal/analysis/check.go +++ b/internal/analysis/check.go @@ -441,6 +441,8 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin return TypeMonetary case *parser.BinaryInfix: + res.unifyNodeWith(lit.Left, res.getExprType(lit.Right)) + switch lit.Operator { case parser.InfixOperatorPlus: return res.checkInfixOverload(lit, []string{TypeNumber, TypeMonetary}) From ddf82c691be1ac91dd64d0a7e779739eba867bcf Mon Sep 17 00:00:00 2001 From: ascandone Date: Wed, 24 Sep 2025 12:40:26 +0200 Subject: [PATCH 07/10] more inference --- internal/analysis/check.go | 20 ++++++++++---------- internal/analysis/check_test.go | 24 ++++++++++++++++++++++++ internal/analysis/document_symbols.go | 2 +- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/internal/analysis/check.go b/internal/analysis/check.go index 47cdcbcc..01a5121e 100644 --- a/internal/analysis/check.go +++ b/internal/analysis/check.go @@ -127,7 +127,7 @@ type CheckResult struct { unboundedAccountInSend parser.ValueExpr emptiedAccount map[string]struct{} unboundedSend bool - declaredVars map[string]parser.VarDeclaration + DeclaredVars map[string]parser.VarDeclaration unusedVars map[string]parser.Range varResolution map[*parser.Variable]parser.VarDeclaration fnCallResolution map[*parser.FnCallIdentifier]FnCallResolution @@ -220,7 +220,7 @@ func newCheckResult(program parser.Program) CheckResult { Program: program, emptiedAccount: make(map[string]struct{}), - declaredVars: make(map[string]parser.VarDeclaration), + DeclaredVars: make(map[string]parser.VarDeclaration), unusedVars: make(map[string]parser.Range), varResolution: make(map[*parser.Variable]parser.VarDeclaration), fnCallResolution: make(map[*parser.FnCallIdentifier]FnCallResolution), @@ -382,10 +382,10 @@ func (res *CheckResult) checkVarType(typeDecl parser.TypeDecl) { func (res *CheckResult) checkDuplicateVars(variableName parser.Variable, decl parser.VarDeclaration) { // check there aren't duplicate variables - if _, ok := res.declaredVars[variableName.Name]; ok { + if _, ok := res.DeclaredVars[variableName.Name]; ok { res.pushDiagnostic(variableName.Range, DuplicateVariable{Name: variableName.Name}) } else { - res.declaredVars[variableName.Name] = decl + res.DeclaredVars[variableName.Name] = decl res.unusedVars[variableName.Name] = variableName.Range } } @@ -416,7 +416,7 @@ func (res *CheckResult) checkExpression(lit parser.ValueExpr, requiredType strin func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) string { switch lit := lit.(type) { case *parser.Variable: - if varDeclaration, ok := res.declaredVars[lit.Name]; ok { + if varDeclaration, ok := res.DeclaredVars[lit.Name]; ok { res.varResolution[lit] = varDeclaration res.unifyNodeWith(lit, res.getVarDeclType(varDeclaration)) } else { @@ -523,14 +523,10 @@ func (res *CheckResult) checkSentValue(sentValue parser.SentValue) { switch sentValue := sentValue.(type) { case *parser.SentValueAll: res.checkExpression(sentValue.Asset, TypeAsset) + res.unifyNodeWith(sentValue.Asset, res.stmtType) case *parser.SentValueLiteral: res.checkExpression(sentValue.Monetary, TypeMonetary) - res.unifyNodeWith(sentValue.Monetary, res.stmtType) - res.unifyNodeWith( - sentValue.Monetary, - res.stmtType, - ) } } @@ -591,6 +587,7 @@ func (res *CheckResult) checkSource(source parser.Source) { res.checkExpression(source.Color, TypeString) if source.Bounded != nil { res.checkExpression(*source.Bounded, TypeMonetary) + res.unifyNodeWith(*source.Bounded, res.stmtType) } case *parser.SourceInorder: @@ -611,6 +608,7 @@ func (res *CheckResult) checkSource(source parser.Source) { res.unifyNodeWith(source.Cap, res.stmtType) res.checkExpression(source.Cap, TypeMonetary) + res.unifyNodeWith(source.Cap, res.stmtType) res.checkSource(source.From) onExit() @@ -752,6 +750,7 @@ func (res *CheckResult) checkDestination(destination parser.Destination) { case *parser.DestinationInorder: for _, clause := range destination.Clauses { res.checkExpression(clause.Cap, TypeMonetary) + res.unifyNodeWith(clause.Cap, res.stmtType) res.checkKeptOrDestination(clause.To) } res.checkKeptOrDestination(destination.Remaining) @@ -761,6 +760,7 @@ func (res *CheckResult) checkDestination(destination parser.Destination) { for _, clause := range destination.Clauses { res.checkExpression(clause.Cap, TypeMonetary) + res.unifyNodeWith(clause.Cap, res.stmtType) res.checkKeptOrDestination(clause.To) } res.checkKeptOrDestination(destination.Remaining) diff --git a/internal/analysis/check_test.go b/internal/analysis/check_test.go index f0d49911..256a3480 100644 --- a/internal/analysis/check_test.go +++ b/internal/analysis/check_test.go @@ -1049,3 +1049,27 @@ send [USD 42] ( checkSource(input), ) } + +func TestInferVars(t *testing.T) { + t.Parallel() + + input := ` +vars { + monetary $mon1 + monetary $mon2 +} + +send $mon1 ( + source = @a allowing overdraft up to $mon2 + destination = @b +) +` + + res := analysis.CheckSource(input) + + t1 := res.VarTypes[res.DeclaredVars["mon1"]] + + t2 := res.VarTypes[res.DeclaredVars["mon2"]] + + require.Same(t, t1.Resolve(), t2.Resolve()) +} diff --git a/internal/analysis/document_symbols.go b/internal/analysis/document_symbols.go index 2b60a863..9b210efe 100644 --- a/internal/analysis/document_symbols.go +++ b/internal/analysis/document_symbols.go @@ -25,7 +25,7 @@ type DocumentSymbol struct { // results are sorted by start position func (r *CheckResult) GetSymbols() []DocumentSymbol { var symbols []DocumentSymbol - for k, v := range r.declaredVars { + for k, v := range r.DeclaredVars { symbols = append(symbols, DocumentSymbol{ Name: k, Kind: DocumentSymbolVariable, From 822694d07204f21eceef48b9d8ea2d92da61acd5 Mon Sep 17 00:00:00 2001 From: ascandone Date: Wed, 24 Sep 2025 12:52:09 +0200 Subject: [PATCH 08/10] more inference --- internal/analysis/check.go | 9 ++++++--- internal/analysis/check_test.go | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/internal/analysis/check.go b/internal/analysis/check.go index 01a5121e..2fd5957e 100644 --- a/internal/analysis/check.go +++ b/internal/analysis/check.go @@ -353,11 +353,14 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) { res.checkExpression(arg, type_) } - if fnCall.Caller.Name == FnVarOriginBalance { + switch fnCall.Caller.Name { + case FnVarOriginBalance, FnVarOriginOverdraft: // we run unify(, ) in: // := balance(@acc, ) - assetArg := validArgs[1] - res.unifyNodeWith(fnCall, res.getExprType(assetArg)) + res.unifyNodeWith(fnCall, res.getExprType(validArgs[1])) + + case FnVarOriginGetAsset: + res.unifyNodeWith(fnCall, res.getExprType(validArgs[0])) } } else { for _, arg := range validArgs { diff --git a/internal/analysis/check_test.go b/internal/analysis/check_test.go index 256a3480..0b499f37 100644 --- a/internal/analysis/check_test.go +++ b/internal/analysis/check_test.go @@ -1073,3 +1073,21 @@ send $mon1 ( require.Same(t, t1.Resolve(), t2.Resolve()) } + +func TestInferGetAsset(t *testing.T) { + t.Parallel() + + input := ` +vars { + asset $ass = get_asset([USD/2 100]) +} +` + + res := analysis.CheckSource(input) + + v := res.DeclaredVars["ass"] + t1 := res.VarTypes[v] + + expected := analysis.TAsset("USD/2") + require.Equal(t, &expected, t1.Resolve()) +} From 1bff3bdf451279dba49c72085a030bd3941f5008 Mon Sep 17 00:00:00 2001 From: ascandone Date: Thu, 2 Oct 2025 15:48:57 +0200 Subject: [PATCH 09/10] chore: fix comment --- internal/analysis/union_find.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/analysis/union_find.go b/internal/analysis/union_find.go index 1752ac28..4fd7086f 100644 --- a/internal/analysis/union_find.go +++ b/internal/analysis/union_find.go @@ -25,7 +25,8 @@ func (t *TVar) Resolve() Type { return t } - // This bit doesn't change the behaviour but + // This bit doesn't change the behaviour but allows to return the path right away + // the next time we call Resolve() t.resolution = resolved return resolved From 50fad85906183a5ee22f2b4cf7bdb816ffdef3ce Mon Sep 17 00:00:00 2001 From: ascandone Date: Thu, 2 Oct 2025 15:50:39 +0200 Subject: [PATCH 10/10] fix: remove unnecessary call to unify() --- internal/analysis/check.go | 2 -- internal/analysis/check_test.go | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/internal/analysis/check.go b/internal/analysis/check.go index 2fd5957e..9c755800 100644 --- a/internal/analysis/check.go +++ b/internal/analysis/check.go @@ -609,9 +609,7 @@ func (res *CheckResult) checkSource(source parser.Source) { onExit := res.enterCappedSource() res.unifyNodeWith(source.Cap, res.stmtType) - res.checkExpression(source.Cap, TypeMonetary) - res.unifyNodeWith(source.Cap, res.stmtType) res.checkSource(source.From) onExit() diff --git a/internal/analysis/check_test.go b/internal/analysis/check_test.go index 0b499f37..7cb96553 100644 --- a/internal/analysis/check_test.go +++ b/internal/analysis/check_test.go @@ -1017,8 +1017,7 @@ send [USD 0] ( []analysis.Diagnostic{ { Range: parser.RangeOfIndexed(input, "$mon", 2), - // TODO shoulnd't the error be the other way around? - Kind: &analysis.AssetMismatch{Expected: "USD", Got: "EUR"}, + Kind: &analysis.AssetMismatch{Expected: "USD", Got: "EUR"}, }, }, checkSource(input),