Skip to content

Commit

Permalink
Merge pull request #7989 from boggle/3.0-type-checking
Browse files Browse the repository at this point in the history
Refactor type checking for functions and expressions
  • Loading branch information
Mats-SX committed Sep 21, 2016
2 parents debfe02 + 23ec213 commit 3b0bac5
Show file tree
Hide file tree
Showing 65 changed files with 484 additions and 323 deletions.
Expand Up @@ -121,20 +121,22 @@ case class Add(lhs: Expression, rhs: Expression)(val position: InputPosition)

case class UnaryAdd(rhs: Expression)(val position: InputPosition)
extends Expression with LeftUnaryOperatorExpression with PrefixFunctionTyping {
val signatures = Vector(
Signature(argumentTypes = Vector(CTInteger), outputType = CTInteger),
Signature(argumentTypes = Vector(CTFloat), outputType = CTFloat)

override val signatures = Vector(
ExpressionSignature(argumentTypes = Vector(CTInteger), outputType = CTInteger),
ExpressionSignature(argumentTypes = Vector(CTFloat), outputType = CTFloat)
)

override def canonicalOperatorSymbol = "+"
}

case class Subtract(lhs: Expression, rhs: Expression)(val position: InputPosition)
extends Expression with BinaryOperatorExpression with InfixFunctionTyping {
val signatures = Vector(
Signature(argumentTypes = Vector(CTInteger, CTInteger), outputType = CTInteger),
Signature(argumentTypes = Vector(CTInteger, CTFloat), outputType = CTFloat),
Signature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)

override val signatures = Vector(
ExpressionSignature(argumentTypes = Vector(CTInteger, CTInteger), outputType = CTInteger),
ExpressionSignature(argumentTypes = Vector(CTInteger, CTFloat), outputType = CTFloat),
ExpressionSignature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
)

override def semanticCheck(ctx: SemanticContext): SemanticCheck =
Expand All @@ -151,24 +153,26 @@ case class Subtract(lhs: Expression, rhs: Expression)(val position: InputPositio

case class UnarySubtract(rhs: Expression)(val position: InputPosition)
extends Expression with LeftUnaryOperatorExpression with PrefixFunctionTyping {
val signatures = Vector(
Signature(argumentTypes = Vector(CTInteger), outputType = CTInteger),
Signature(argumentTypes = Vector(CTFloat), outputType = CTFloat)

override val signatures = Vector(
ExpressionSignature(argumentTypes = Vector(CTInteger), outputType = CTInteger),
ExpressionSignature(argumentTypes = Vector(CTFloat), outputType = CTFloat)
)

override def canonicalOperatorSymbol = "-"
}

case class Multiply(lhs: Expression, rhs: Expression)(val position: InputPosition)
extends Expression with BinaryOperatorExpression with InfixFunctionTyping {

// 1 * 1 => 1
// 1 * 1.1 => 1.1
// 1.1 * 1 => 1.1
// 1.1 * 1.1 => 1.21
val signatures = Vector(
Signature(argumentTypes = Vector(CTInteger, CTInteger), outputType = CTInteger),
Signature(argumentTypes = Vector(CTInteger, CTFloat), outputType = CTFloat),
Signature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
override val signatures = Vector(
ExpressionSignature(argumentTypes = Vector(CTInteger, CTInteger), outputType = CTInteger),
ExpressionSignature(argumentTypes = Vector(CTInteger, CTFloat), outputType = CTFloat),
ExpressionSignature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
)

override def semanticCheck(ctx: SemanticContext): SemanticCheck =
Expand All @@ -185,42 +189,45 @@ case class Multiply(lhs: Expression, rhs: Expression)(val position: InputPositio

case class Divide(lhs: Expression, rhs: Expression)(val position: InputPosition)
extends Expression with BinaryOperatorExpression with InfixFunctionTyping {

// 1 / 1 => 1
// 1 / 1.1 => 0.909
// 1.1 / 1 => 1.1
// 1.1 / 1.1 => 1.0
val signatures = Vector(
Signature(argumentTypes = Vector(CTInteger, CTInteger), outputType = CTInteger),
Signature(argumentTypes = Vector(CTInteger, CTFloat), outputType = CTFloat),
Signature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
override val signatures = Vector(
ExpressionSignature(argumentTypes = Vector(CTInteger, CTInteger), outputType = CTInteger),
ExpressionSignature(argumentTypes = Vector(CTInteger, CTFloat), outputType = CTFloat),
ExpressionSignature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
)

override def canonicalOperatorSymbol = "/"
}

case class Modulo(lhs: Expression, rhs: Expression)(val position: InputPosition)
extends Expression with BinaryOperatorExpression with InfixFunctionTyping {

// 1 % 1 => 0
// 1 % 1.1 => 1.0
// 1.1 % 1 => 0.1
// 1.1 % 1.1 => 0.0
val signatures = Vector(
Signature(argumentTypes = Vector(CTInteger, CTInteger), outputType = CTInteger),
Signature(argumentTypes = Vector(CTInteger, CTFloat), outputType = CTFloat),
Signature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
override val signatures = Vector(
ExpressionSignature(argumentTypes = Vector(CTInteger, CTInteger), outputType = CTInteger),
ExpressionSignature(argumentTypes = Vector(CTInteger, CTFloat), outputType = CTFloat),
ExpressionSignature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
)

override def canonicalOperatorSymbol = "%"
}

case class Pow(lhs: Expression, rhs: Expression)(val position: InputPosition)
extends Expression with BinaryOperatorExpression with InfixFunctionTyping {

// 1 ^ 1 => 1.1
// 1 ^ 1.1 => 1.0
// 1.1 ^ 1 => 1.1
// 1.1 ^ 1.1 => 1.1105
val signatures = Vector(
Signature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
override val signatures = Vector(
ExpressionSignature(argumentTypes = Vector(CTFloat, CTFloat), outputType = CTFloat)
)

override def canonicalOperatorSymbol = "^"
Expand Down
Expand Up @@ -91,7 +91,7 @@ abstract class Expression extends ASTNode with ASTExpression with SemanticChecki
def dependencies: Set[Variable] =
this.treeFold(TreeAcc[Set[Variable]](Set.empty)) {
case scope: ScopeExpression => {
case acc =>
acc =>
val newAcc = acc.push(scope.variables)
(newAcc, Some((x) => x.pop))
}
Expand All @@ -107,14 +107,14 @@ abstract class Expression extends ASTNode with ASTExpression with SemanticChecki
def inputs: Seq[(Expression, Set[Variable])] =
this.treeFold(TreeAcc[Seq[(Expression, Set[Variable])]](Seq.empty)) {
case scope: ScopeExpression=> {
case acc =>
val newAcc = acc.push(scope.variables).map { case pairs => pairs :+ (scope -> acc.toSet) }
acc =>
val newAcc = acc.push(scope.variables).map(pairs => pairs :+ (scope -> acc.toSet))
(newAcc, Some((x) => x.pop))
}

case expr: Expression => {
case acc =>
val newAcc = acc.map { case pairs => pairs :+ (expr -> acc.toSet) }
acc =>
val newAcc = acc.map(pairs => pairs :+ (expr -> acc.toSet))
(newAcc, Some(identity))
}
}.data
Expand Down Expand Up @@ -148,49 +148,20 @@ abstract class Expression extends ASTNode with ASTExpression with SemanticChecki
}
}

trait SimpleTyping { self: Expression =>
trait SimpleTyping {
self: Expression =>

protected def possibleTypes: TypeSpec

def semanticCheck(ctx: SemanticContext): SemanticCheck = specifyType(possibleTypes)
}

trait FunctionTyping { self: Expression =>

case class Signature(argumentTypes: IndexedSeq[CypherType], outputType: CypherType)
trait FunctionTyping extends ExpressionCallTypeChecking {
self: Expression =>

def signatures: Seq[Signature]

def semanticCheck(ctx: ast.Expression.SemanticContext): SemanticCheck =
override def semanticCheck(ctx: ast.Expression.SemanticContext): SemanticCheck =
arguments.semanticCheck(ctx) chain
checkTypes

def checkTypes: SemanticCheck = s => {
val initSignatures = signatures.filter(_.argumentTypes.length == arguments.length)

val (remainingSignatures: Seq[Signature], result) = arguments.foldLeft((initSignatures, success(s))) {
case (accumulator@(Seq(), _), _) =>
accumulator
case ((possibilities, r1), arg) =>
val argTypes = possibilities.foldLeft(TypeSpec.none) { _ | _.argumentTypes.head.covariant }
val r2 = arg.expectType(argTypes)(r1.state)

val actualTypes = arg.types(r2.state)
val remainingPossibilities = possibilities.filter {
sig => actualTypes containsAny sig.argumentTypes.head.covariant
} map {
sig => sig.copy(argumentTypes = sig.argumentTypes.tail)
}
(remainingPossibilities, SemanticCheckResult(r2.state, r1.errors ++ r2.errors))
}

val outputType = remainingSignatures match {
case Seq() => TypeSpec.all
case _ => remainingSignatures.foldLeft(TypeSpec.none) { _ | _.outputType.invariant }
}
specifyType(outputType)(result.state) match {
case Left(err) => SemanticCheckResult(result.state, result.errors :+ err)
case Right(state) => SemanticCheckResult(state, result.errors)
}
}
typeChecker.checkTypes(self)
}

trait PrefixFunctionTyping extends FunctionTyping { self: Expression =>
Expand Down
@@ -0,0 +1,67 @@
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.frontend.v3_0.ast

import org.neo4j.cypher.internal.frontend.v3_0.symbols.{CypherType, TypeSpec}
import org.neo4j.cypher.internal.frontend.v3_0.{SemanticCheckResult, ast, _}

trait ExpressionCallTypeChecking {

def signatures: Seq[ExpressionSignature] = Seq.empty

protected final def signatureLengths = typeChecker.signatureLengths
protected final lazy val typeChecker: ExpressionCallTypeChecker = ExpressionCallTypeChecker(signatures)
}

case class ExpressionCallTypeChecker(signatures: Seq[ExpressionSignature]) {

val signatureLengths = signatures.map(_.argumentTypes.length)

def checkTypes(invocation: Expression): SemanticCheck = s => {
val initSignatures = signatures.filter(_.argumentTypes.length == invocation.arguments.length)

val (remainingSignatures: Seq[ExpressionSignature], result) =
invocation.arguments.foldLeft((initSignatures, SemanticCheckResult.success(s))) {
case (accumulator@(Seq(), _), _) =>
accumulator
case ((possibilities, r1), arg) =>
val argTypes = possibilities.foldLeft(TypeSpec.none) { _ | _.argumentTypes.head.covariant }
val r2 = arg.expectType(argTypes)(r1.state)

val actualTypes = arg.types(r2.state)
val remainingPossibilities = possibilities.filter {
sig => actualTypes containsAny sig.argumentTypes.head.covariant
} map {
sig => sig.copy(argumentTypes = sig.argumentTypes.tail)
}
(remainingPossibilities, SemanticCheckResult(r2.state, r1.errors ++ r2.errors))
}

val outputType = remainingSignatures match {
case Seq() => TypeSpec.all
case _ => remainingSignatures.foldLeft(TypeSpec.none) { _ | _.outputType.invariant }
}

invocation.specifyType(outputType)(result.state) match {
case Left(err) => SemanticCheckResult(result.state, result.errors :+ err)
case Right(state) => SemanticCheckResult(state, result.errors)
}
}
}
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.frontend.v3_0.ast

import org.neo4j.cypher.internal.frontend.v3_0.symbols.CypherType

case class ExpressionSignature(argumentTypes: IndexedSeq[CypherType], outputType: CypherType)
Expand Up @@ -136,49 +136,15 @@ abstract class Function extends SemanticChecking {
FunctionInvocation(asFunctionName, distinct = false, IndexedSeq(lhs, rhs))(position)
}

trait SimpleTypedFunction extends ExpressionCallTypeChecking {
self: Function =>

trait SimpleTypedFunction { self: Function =>
case class Signature(argumentTypes: IndexedSeq[CypherType], outputType: CypherType)

val signatures: Seq[Signature]

private lazy val signatureLengths = signatures.map(_.argumentTypes.length)

def semanticCheck(ctx: ast.Expression.SemanticContext, invocation: ast.FunctionInvocation): SemanticCheck =
checkMinArgs(invocation, signatureLengths.min) chain checkMaxArgs(invocation, signatureLengths.max) chain
checkTypes(invocation)

private def checkTypes(invocation: ast.FunctionInvocation): SemanticCheck = s => {
val initSignatures = signatures.filter(_.argumentTypes.length == invocation.arguments.length)

val (remainingSignatures: Seq[Signature], result) = invocation.arguments.foldLeft((initSignatures, SemanticCheckResult.success(s))) {
case (accumulator@(Seq(), _), _) =>
accumulator
case ((possibilities, r1), arg) =>
val argTypes = possibilities.foldLeft(TypeSpec.none) { _ | _.argumentTypes.head.covariant }
val r2 = arg.expectType(argTypes)(r1.state)

val actualTypes = arg.types(r2.state)
val remainingPossibilities = possibilities.filter {
sig => actualTypes containsAny sig.argumentTypes.head.covariant
} map {
sig => sig.copy(argumentTypes = sig.argumentTypes.tail)
}
(remainingPossibilities, SemanticCheckResult(r2.state, r1.errors ++ r2.errors))
}

val outputType = remainingSignatures match {
case Seq() => TypeSpec.all
case _ => remainingSignatures.foldLeft(TypeSpec.none) { _ | _.outputType.invariant }
}
invocation.specifyType(outputType)(result.state) match {
case Left(err) => SemanticCheckResult(result.state, result.errors :+ err)
case Right(state) => SemanticCheckResult(state, result.errors)
}
}
override def semanticCheck(ctx: ast.Expression.SemanticContext, invocation: ast.FunctionInvocation): SemanticCheck =
checkMinArgs(invocation, signatureLengths.min) chain
checkMaxArgs(invocation, signatureLengths.max) chain
typeChecker.checkTypes(invocation)
}


abstract class AggregatingFunction extends Function {
override def semanticCheckHook(ctx: ast.Expression.SemanticContext, invocation: ast.FunctionInvocation): SemanticCheck =
when(ctx == ast.Expression.SemanticContext.Simple) {
Expand Down

0 comments on commit 3b0bac5

Please sign in to comment.