Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 84 additions & 11 deletions internal/analysis/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,60 @@ type Diagnostic struct {
}

type CheckResult struct {
version parser.Version
nextDiagnosticId int32
unboundedAccountInSend parser.ValueExpr
emptiedAccount map[string]struct{}
unboundedSend bool
declaredVars map[string]parser.VarDeclaration
DeclaredVars map[string]parser.VarDeclaration
unusedVars map[string]parser.Range
varResolution map[*parser.Variable]parser.VarDeclaration
fnCallResolution map[*parser.FnCallIdentifier]FnCallResolution
Diagnostics []Diagnostic
Program parser.Program

stmtType Type
ExprTypes map[parser.ValueExpr]Type
VarTypes map[parser.VarDeclaration]Type
}

func (r *CheckResult) getExprType(expr parser.ValueExpr) Type {
exprType, ok := r.ExprTypes[expr]
if !ok {
t := TVar{}
r.ExprTypes[expr] = &t
return &t
}
return exprType
}

func (r *CheckResult) getVarDeclType(decl parser.VarDeclaration) Type {
exprType, ok := r.VarTypes[decl]
if !ok {
t := TVar{}
r.VarTypes[decl] = &t
return &t
}
return exprType
}

func (r *CheckResult) unifyNodeWith(expr parser.ValueExpr, t Type) {
exprT := r.getExprType(expr)
r.unify(expr.GetRange(), exprT, t)
}

func (r *CheckResult) unify(rng parser.Range, t1 Type, t2 Type) {
ok := Unify(t1, t2)
if ok {
return
}

r.Diagnostics = append(r.Diagnostics, Diagnostic{
Range: rng,
Kind: &AssetMismatch{
Expected: TypeToString(t1),
Got: TypeToString(t2),
},
})
}

func (r CheckResult) GetErrorsCount() int {
Expand Down Expand Up @@ -174,13 +217,15 @@ func (r CheckResult) ResolveBuiltinFn(v *parser.FnCallIdentifier) FnCallResoluti

func newCheckResult(program parser.Program) CheckResult {
return CheckResult{
version: program.GetVersion(),
Program: program,

emptiedAccount: make(map[string]struct{}),
declaredVars: make(map[string]parser.VarDeclaration),
DeclaredVars: make(map[string]parser.VarDeclaration),
unusedVars: make(map[string]parser.Range),
varResolution: make(map[*parser.Variable]parser.VarDeclaration),
fnCallResolution: make(map[*parser.FnCallIdentifier]FnCallResolution),
Program: program,
ExprTypes: make(map[parser.ValueExpr]Type),
VarTypes: make(map[parser.VarDeclaration]Type),
}
}

Expand All @@ -197,6 +242,7 @@ func (res *CheckResult) check() {

if varDecl.Origin != nil {
res.checkExpression(*varDecl.Origin, varDecl.Type.Name)
res.unifyNodeWith(*varDecl.Origin, res.getVarDeclType(varDecl))
}
Comment on lines 243 to 246
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard against nil Type when Origin is present

If a variable has an origin expression but no explicit type, dereferencing varDecl.Type panics. Fall back to TypeAny when Type is nil.

Apply this diff:

-			if varDecl.Origin != nil {
-				res.checkExpression(*varDecl.Origin, varDecl.Type.Name)
-				res.unifyNodeWith(*varDecl.Origin, res.getVarDeclType(varDecl))
-			}
+			if varDecl.Origin != nil {
+				if varDecl.Type != nil {
+					res.checkExpression(*varDecl.Origin, varDecl.Type.Name)
+				} else {
+					res.checkExpression(*varDecl.Origin, TypeAny)
+				}
+				res.unifyNodeWith(*varDecl.Origin, res.getVarDeclType(varDecl))
+			}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if varDecl.Origin != nil {
res.checkExpression(*varDecl.Origin, varDecl.Type.Name)
res.unifyNodeWith(*varDecl.Origin, res.getVarDeclType(varDecl))
}
if varDecl.Origin != nil {
if varDecl.Type != nil {
res.checkExpression(*varDecl.Origin, varDecl.Type.Name)
} else {
res.checkExpression(*varDecl.Origin, TypeAny)
}
res.unifyNodeWith(*varDecl.Origin, res.getVarDeclType(varDecl))
}
🤖 Prompt for AI Agents
In internal/analysis/check.go around lines 243 to 246, guard against a nil
varDecl.Type when varDecl.Origin is present by using TypeAny as a fallback:
compute a local typ := varDecl.Type and if typ == nil set typ = TypeAny, then
call res.checkExpression(*varDecl.Origin, typ.Name) and
res.unifyNodeWith(*varDecl.Origin, res.getVarDeclType(varDecl)) (or pass typ
into getVarDeclType if appropriate) so no dereference of nil occurs.

}
}
Expand All @@ -214,6 +260,7 @@ func (res *CheckResult) check() {

func (res *CheckResult) checkStatement(statement parser.Statement) {
res.emptiedAccount = make(map[string]struct{})
res.stmtType = &TVar{}

switch statement := statement.(type) {
case *parser.SaveStatement:
Expand Down Expand Up @@ -305,6 +352,16 @@ func (res *CheckResult) checkFnCallArity(fnCall *parser.FnCall) {
type_ := sig[index]
res.checkExpression(arg, type_)
}

switch fnCall.Caller.Name {
case FnVarOriginBalance, FnVarOriginOverdraft:
// we run unify(<expr>, <asset>) in:
// <expr> := balance(@acc, <asset>)
res.unifyNodeWith(fnCall, res.getExprType(validArgs[1]))

case FnVarOriginGetAsset:
res.unifyNodeWith(fnCall, res.getExprType(validArgs[0]))
}
} else {
for _, arg := range validArgs {
res.checkExpression(arg, TypeAny)
Expand All @@ -328,15 +385,15 @@ func (res *CheckResult) checkVarType(typeDecl parser.TypeDecl) {

func (res *CheckResult) checkDuplicateVars(variableName parser.Variable, decl parser.VarDeclaration) {
// check there aren't duplicate variables
if _, ok := res.declaredVars[variableName.Name]; ok {
if _, ok := res.DeclaredVars[variableName.Name]; ok {
res.pushDiagnostic(variableName.Range, DuplicateVariable{Name: variableName.Name})
} else {
res.declaredVars[variableName.Name] = decl
res.DeclaredVars[variableName.Name] = decl
res.unusedVars[variableName.Name] = variableName.Range
}
}

func (res *CheckResult) checkFnCall(fnCall parser.FnCall) string {
func (res *CheckResult) checkFnCall(fnCall *parser.FnCall) string {
returnType := TypeAny

if resolution, ok := Builtins[fnCall.Caller.Name]; ok {
Expand All @@ -349,7 +406,7 @@ func (res *CheckResult) checkFnCall(fnCall parser.FnCall) string {
}

// this must come after resolution
res.checkFnCallArity(&fnCall)
res.checkFnCallArity(fnCall)

return returnType
}
Expand All @@ -362,8 +419,9 @@ func (res *CheckResult) checkExpression(lit parser.ValueExpr, requiredType strin
func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) string {
switch lit := lit.(type) {
case *parser.Variable:
if varDeclaration, ok := res.declaredVars[lit.Name]; ok {
if varDeclaration, ok := res.DeclaredVars[lit.Name]; ok {
res.varResolution[lit] = varDeclaration
res.unifyNodeWith(lit, res.getVarDeclType(varDeclaration))
} else {
res.pushDiagnostic(lit.Range, UnboundVariable{Name: lit.Name, Type: typeHint})
}
Expand All @@ -378,9 +436,16 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin
case *parser.MonetaryLiteral:
res.checkExpression(lit.Asset, TypeAsset)
res.checkExpression(lit.Amount, TypeNumber)
/*
we unify $mon and $asset in:
`let $mon := [$asset 42]`
*/
res.unifyNodeWith(lit, res.getExprType(lit.Asset))
return TypeMonetary

case *parser.BinaryInfix:
res.unifyNodeWith(lit.Left, res.getExprType(lit.Right))

switch lit.Operator {
case parser.InfixOperatorPlus:
return res.checkInfixOverload(lit, []string{TypeNumber, TypeMonetary})
Expand Down Expand Up @@ -415,14 +480,16 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin
case *parser.PercentageLiteral:
return TypePortion
case *parser.AssetLiteral:
t := TAsset(lit.Asset)
res.unifyNodeWith(lit, &t)
return TypeAsset
case *parser.NumberLiteral:
return TypeNumber
case *parser.StringLiteral:
return TypeString

case *parser.FnCall:
return res.checkFnCall(*lit)
return res.checkFnCall(lit)

default:
return TypeAny
Expand Down Expand Up @@ -459,8 +526,10 @@ func (res *CheckResult) checkSentValue(sentValue parser.SentValue) {
switch sentValue := sentValue.(type) {
case *parser.SentValueAll:
res.checkExpression(sentValue.Asset, TypeAsset)
res.unifyNodeWith(sentValue.Asset, res.stmtType)
case *parser.SentValueLiteral:
res.checkExpression(sentValue.Monetary, TypeMonetary)
res.unifyNodeWith(sentValue.Monetary, res.stmtType)
}
}

Expand Down Expand Up @@ -521,6 +590,7 @@ func (res *CheckResult) checkSource(source parser.Source) {
res.checkExpression(source.Color, TypeString)
if source.Bounded != nil {
res.checkExpression(*source.Bounded, TypeMonetary)
res.unifyNodeWith(*source.Bounded, res.stmtType)
}

case *parser.SourceInorder:
Expand All @@ -538,6 +608,7 @@ func (res *CheckResult) checkSource(source parser.Source) {
case *parser.SourceCapped:
onExit := res.enterCappedSource()

res.unifyNodeWith(source.Cap, res.stmtType)
res.checkExpression(source.Cap, TypeMonetary)
res.checkSource(source.From)

Expand Down Expand Up @@ -680,6 +751,7 @@ func (res *CheckResult) checkDestination(destination parser.Destination) {
case *parser.DestinationInorder:
for _, clause := range destination.Clauses {
res.checkExpression(clause.Cap, TypeMonetary)
res.unifyNodeWith(clause.Cap, res.stmtType)
res.checkKeptOrDestination(clause.To)
}
res.checkKeptOrDestination(destination.Remaining)
Expand All @@ -689,6 +761,7 @@ func (res *CheckResult) checkDestination(destination parser.Destination) {

for _, clause := range destination.Clauses {
res.checkExpression(clause.Cap, TypeMonetary)
res.unifyNodeWith(clause.Cap, res.stmtType)
res.checkKeptOrDestination(clause.To)
}
res.checkKeptOrDestination(destination.Remaining)
Expand Down
122 changes: 122 additions & 0 deletions internal/analysis/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -968,3 +968,125 @@ func TestInorderRedundantWhenEmptyColored(t *testing.T) {
checkSource(input),
)
}

func TestCheckAssetMismatch(t *testing.T) {

t.Parallel()

input := `

send [USD 100] (
source = max [EUR 10] from @a
destination = @dest
)`

require.Equal(t,
[]analysis.Diagnostic{
{
Range: parser.RangeOfIndexed(input, "[EUR 10]", 0),
Kind: &analysis.AssetMismatch{Expected: "USD", Got: "EUR"},
},
},
checkSource(input),
)
}

func TestCheckAssetMismatchInVar(t *testing.T) {

t.Parallel()

input := `

vars {
monetary $mon
}

send [EUR 0] (
source = max $mon from @a
destination = @b
)

send [USD 0] (
source = max $mon from @a
destination = @b
)

`

require.Equal(t,
[]analysis.Diagnostic{
{
Range: parser.RangeOfIndexed(input, "$mon", 2),
Kind: &analysis.AssetMismatch{Expected: "USD", Got: "EUR"},
},
},
checkSource(input),
)
}

func TestCheckBalanceAssetConstraint(t *testing.T) {
t.Parallel()

input := `
vars {
monetary $mon = balance(@acc, USD/2)
}

send [USD 42] (
source = max $mon from @a
destination = @b
)
`

require.Equal(t,
[]analysis.Diagnostic{
{
Range: parser.RangeOfIndexed(input, "$mon", 1),
Kind: &analysis.AssetMismatch{Expected: "USD", Got: "USD/2"},
},
},
checkSource(input),
)
}

func TestInferVars(t *testing.T) {
t.Parallel()

input := `
vars {
monetary $mon1
monetary $mon2
}

send $mon1 (
source = @a allowing overdraft up to $mon2
destination = @b
)
`

res := analysis.CheckSource(input)

t1 := res.VarTypes[res.DeclaredVars["mon1"]]

t2 := res.VarTypes[res.DeclaredVars["mon2"]]

require.Same(t, t1.Resolve(), t2.Resolve())
}

func TestInferGetAsset(t *testing.T) {
t.Parallel()

input := `
vars {
asset $ass = get_asset([USD/2 100])
}
`

res := analysis.CheckSource(input)

v := res.DeclaredVars["ass"]
t1 := res.VarTypes[v]

expected := analysis.TAsset("USD/2")
require.Equal(t, &expected, t1.Resolve())
}
13 changes: 13 additions & 0 deletions internal/analysis/diagnostic_kind.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ func (TypeMismatch) Severity() Severity {
return ErrorSeverity
}

type AssetMismatch struct {
Expected string
Got string
}

func (e AssetMismatch) Message() string {
return fmt.Sprintf("Asset mismatch (expected '%s', got '%s' instead)", e.Expected, e.Got)
}

func (AssetMismatch) Severity() Severity {
return ErrorSeverity
}

type RemainingIsNotLast struct{}

func (e RemainingIsNotLast) Message() string {
Expand Down
2 changes: 1 addition & 1 deletion internal/analysis/document_symbols.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type DocumentSymbol struct {
// results are sorted by start position
func (r *CheckResult) GetSymbols() []DocumentSymbol {
var symbols []DocumentSymbol
for k, v := range r.declaredVars {
for k, v := range r.DeclaredVars {
symbols = append(symbols, DocumentSymbol{
Name: k,
Kind: DocumentSymbolVariable,
Expand Down
Loading
Loading