Skip to content

Commit

Permalink
added support for unnamed returns
Browse files Browse the repository at this point in the history
  • Loading branch information
brennanjl committed May 21, 2024
1 parent b993e6a commit 5590381
Show file tree
Hide file tree
Showing 17 changed files with 1,760 additions and 1,447 deletions.
35 changes: 29 additions & 6 deletions core/types/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,11 @@ type Attribute struct {
func (a *Attribute) Clean(col *Column) error {
switch a.Type {
case MIN, MAX:
if !col.Type.Equals(IntType) {
if !col.Type.EqualsStrict(IntType) {
return fmt.Errorf("attribute %s is only valid for int columns", a.Type)
}
case MIN_LENGTH, MAX_LENGTH:
if !col.Type.Equals(TextType) {
if !col.Type.EqualsStrict(TextType) {
return fmt.Errorf("attribute %s is only valid for text columns", a.Type)
}
}
Expand Down Expand Up @@ -1208,10 +1208,11 @@ func (c *DataType) Copy() *DataType {
}
}

// Equals returns true if the type is equal to the other type.
// If either type is Unknown, it will return true.
func (c *DataType) Equals(other *DataType) bool {
// if unknown, return true
// EqualsStrict returns true if the type is equal to the other type.
// The types must be exactly the same, including metadata.
func (c *DataType) EqualsStrict(other *DataType) bool {
// if unknown, return true. unknown is a special case used
// internally when type checking is disabled.
if c.Name == unknownStr || other.Name == unknownStr {
return true
}
Expand All @@ -1227,6 +1228,15 @@ func (c *DataType) Equals(other *DataType) bool {
return strings.EqualFold(c.Name, other.Name)
}

// Equals returns true if the type is equal to the other type, or if either type is null.
func (c *DataType) Equals(other *DataType) bool {
if c.Name == nullStr || other.Name == nullStr {
return true
}

return c.EqualsStrict(other)
}

func (c *DataType) IsNumeric() bool {
return c.Name == intStr || c.Name == DecimalStr || c.Name == uint256Str || c.Name == unknownStr
}
Expand Down Expand Up @@ -1263,6 +1273,19 @@ var (
}
)

// ArrayType creates an array type of the given type.
// It panics if the type is already an array.
func ArrayType(t *DataType) *DataType {
if t.IsArray {
panic("cannot create an array of an array")
}
return &DataType{
Name: t.Name,
IsArray: true,
Metadata: t.Metadata,
}
}

const (
textStr = "text"
intStr = "int"
Expand Down
10 changes: 5 additions & 5 deletions internal/engine/execution/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,15 @@ func (p *preparedProcedure) coerceInputs(inputs []any) ([]any, error) {
panic("passed array to coerceScalar")
}

if typ.Equals(types.IntType) {
if typ.EqualsStrict(types.IntType) {
return conv.Int(val)
} else if typ.Equals(types.TextType) {
} else if typ.EqualsStrict(types.TextType) {
return conv.String(val)
} else if typ.Equals(types.BoolType) {
} else if typ.EqualsStrict(types.BoolType) {
return conv.Bool(val)
} else if typ.Equals(types.BlobType) {
} else if typ.EqualsStrict(types.BlobType) {
return conv.Blob(val)
} else if typ.Equals(types.UUIDType) {
} else if typ.EqualsStrict(types.UUIDType) {
return conv.UUID(val)
}

Expand Down
4 changes: 0 additions & 4 deletions parse/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ package parse

import "github.com/kwilteam/kwil-db/core/types"

// TODO: this does not actually do what we think it does.
// Recursive calls in the embedded sqlAnalyzer will call local sqlAnalyzer methods.
// This means that users can call unallowed expressions in actions.
// actionAnalyzer analyzes actions.
type actionAnalyzer struct {
sqlAnalyzer
// Mutative is true if the action mutates state.
Expand Down
25 changes: 19 additions & 6 deletions parse/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,13 @@ type sqlAnalyzer struct {
sqlResult sqlAnalyzeResult
}

// reset resets the sqlAnalyzer.
func (s *sqlAnalyzer) reset() {
// we don't need to touch the block context, since it does not change here.
s.sqlCtx = newSQLContext()
s.sqlResult = sqlAnalyzeResult{}
}

type sqlAnalyzeResult struct {
Mutative bool
}
Expand Down Expand Up @@ -1133,9 +1140,15 @@ func (s *sqlAnalyzer) VisitExpressionCase(p0 *ExpressionCase) any {
return s.expressionTypeErr(w[1])
}

// if return type is not set, set it to the first then
if returnType == nil {
returnType = then
}
// if the return type is of type null, we should keep trying
// to reset until we get a non-null type
if returnType.EqualsStrict(types.NullType) {
returnType = then
}

if !then.Equals(returnType) {
return s.typeErr(w[1], then, returnType)
Expand Down Expand Up @@ -1763,11 +1776,11 @@ func (s *sqlAnalyzer) VisitResultColumnExpression(p0 *ResultColumnExpression) an
// is a column.
if attr.Name == "" {
col, ok := p0.Expression.(*ExpressionColumn)
if !ok {
s.errs.AddErr(p0, ErrUnnamedResultColumn, "results must either be column references or have an alias")
// if returning a column and not aliased, we give it the column name.
// otherwise, we simply leave the name blank. It will not be referenceable
if ok {
attr.Name = col.Column
}

attr.Name = col.Column
}

return []*Attribute{attr}
Expand Down Expand Up @@ -2155,7 +2168,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtCall(p0 *ProcedureStmtCall) any {
returns1, ok := p0.Call.Accept(p).(*types.DataType)
if ok {
// if it returns null, then we do not need to assign it to a variable.
if !returns1.Equals(types.NullType) {
if !returns1.EqualsStrict(types.NullType) {
callReturns = append(callReturns, returns1)
}
} else {
Expand All @@ -2178,7 +2191,7 @@ func (p *procedureAnalyzer) VisitProcedureStmtCall(p0 *ProcedureStmtCall) any {
// we do not have to capture all return values, but we need to ensure
// we do not have more receivers than return values.
if len(p0.Receivers) != len(callReturns) {
p.errs.AddErr(p0, ErrResultShape, `function/procedure "%s" returns %d value(s), statement has %d receiver(s)`, p0.Call.FunctionName(), len(callReturns), len(p0.Receivers))
p.errs.AddErr(p0, ErrResultShape, `function/procedure "%s" returns %d value(s), statement expects %d value(s)`, p0.Call.FunctionName(), len(callReturns), len(p0.Receivers))
return nil
}

Expand Down
15 changes: 9 additions & 6 deletions parse/antlr.go
Original file line number Diff line number Diff line change
Expand Up @@ -1397,13 +1397,10 @@ func (s *schemaVisitor) VisitCase_expr(ctx *gen.Case_exprContext) any {
e.Case = ctx.GetCase_clause().Accept(s).(Expression)
}

for i := range ctx.AllWHEN() {
when := ctx.WHEN(i).Accept(s).(Expression)
then := ctx.THEN(i).Accept(s).(Expression)
for i := range ctx.AllWhen_then_clause() {
wt := ctx.AllWhen_then_clause()[i].Accept(s).([2]Expression)

e.WhenThen = append(e.WhenThen, [2]Expression{
when, then,
})
e.WhenThen = append(e.WhenThen, wt)
}

if ctx.GetElse_clause() != nil {
Expand All @@ -1414,6 +1411,12 @@ func (s *schemaVisitor) VisitCase_expr(ctx *gen.Case_exprContext) any {
return e
}

func (s *schemaVisitor) VisitWhen_then_clause(ctx *gen.When_then_clauseContext) any {
when := ctx.Sql_expr(0).Accept(s).(Expression)
then := ctx.Sql_expr(1).Accept(s).(Expression)
return [2]Expression{when, then}
}

func (s *schemaVisitor) VisitIn_sql_expr(ctx *gen.In_sql_exprContext) any {
e := &ExpressionIn{
Expression: ctx.Sql_expr().Accept(s).(Expression),
Expand Down
18 changes: 9 additions & 9 deletions parse/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var (
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].Equals(types.IntType) {
if !args[0].EqualsStrict(types.IntType) {
return nil, wrapErrArgumentType(types.IntType, args[0])
}

Expand All @@ -38,7 +38,7 @@ var (
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].Equals(types.TextType) {
if !args[0].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[0])
}

Expand All @@ -61,7 +61,7 @@ var (
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].Equals(types.TextType) {
if !args[0].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[0])
}

Expand All @@ -84,7 +84,7 @@ var (
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].Equals(types.TextType) {
if !args[0].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[0])
}

Expand All @@ -107,7 +107,7 @@ var (
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].Equals(types.TextType) {
if !args[0].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[0])
}

Expand All @@ -130,7 +130,7 @@ var (
return nil, fmt.Errorf("invalid number of arguments: expected at least 1, got %d", len(args))
}

if !args[0].Equals(types.TextType) {
if !args[0].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[0])
}

Expand All @@ -154,11 +154,11 @@ var (
return nil, wrapErrArgumentNumber(2, len(args))
}

if !args[0].Equals(types.UUIDType) {
if !args[0].EqualsStrict(types.UUIDType) {
return nil, wrapErrArgumentType(types.UUIDType, args[0])
}

if !args[1].Equals(types.TextType) {
if !args[1].EqualsStrict(types.TextType) {
return nil, wrapErrArgumentType(types.TextType, args[1])
}

Expand Down Expand Up @@ -320,7 +320,7 @@ var (
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].Equals(types.IntType) {
if !args[0].EqualsStrict(types.IntType) {
return nil, wrapErrArgumentType(types.IntType, args[0])
}

Expand Down
Loading

0 comments on commit 5590381

Please sign in to comment.