Skip to content

Commit

Permalink
Fix for issue 16 : SQL 'in' operator now accepts empty lists and inhi…
Browse files Browse the repository at this point in the history
…bits

"not in (empty)" tautology and returns an always false for "in (empty)"

  http://github.com/max-l/Squeryl/issues#issue/16
  • Loading branch information
max-l committed Sep 19, 2010
1 parent 05698c4 commit 4cd99be
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/main/scala/org/squeryl/dsl/FieldTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ trait NumericalExpression[A] extends TypedExpressionNode[A] {
def in[B <% NumericalExpression[_]](e: Query[B]) = new BinaryOperatorNodeLogicalBoolean(this, e.ast, "in")
def notIn[B <% NumericalExpression[_]](e: Query[B]) = new BinaryOperatorNodeLogicalBoolean(this, e.ast, "not in")

def in(l: ListNumerical) = new BinaryOperatorNodeLogicalBoolean(this, l, "in")
def notIn(l: ListNumerical) = new BinaryOperatorNodeLogicalBoolean(this, l, "not in")
def in(l: ListNumerical) = new InListExpression(this, l, true)
def notIn(l: ListNumerical) = new InListExpression(this, l, false)

def between[B,C](b: NumericalExpression[B], c: NumericalExpression[C]) = new BetweenExpression(this, b, c)

Expand Down Expand Up @@ -147,8 +147,8 @@ trait EnumExpression[A] extends NonNumericalExpression[A] {
trait StringExpression[A] extends NonNumericalExpression[A] {
outer =>

def in(e: ListString) = new BinaryOperatorNodeLogicalBoolean(this, e, "in")
def notIn(e: ListString) = new BinaryOperatorNodeLogicalBoolean(this, e, "not in")
def in(e: ListString) = new InListExpression(this, e, true)
def notIn(e: ListString) = new InListExpression(this, e, false)

//def between(lower: BaseScalarString, upper: BaseScalarString): LogicalBoolean = error("implement me") //new BinaryOperatorNode(this, lower, div) with LogicalBoolean
def like(e: StringExpression[_]) = new BinaryOperatorNodeLogicalBoolean(this, e, "like")
Expand All @@ -164,8 +164,8 @@ trait StringExpression[A] extends NonNumericalExpression[A] {

trait DateExpression[A] extends NonNumericalExpression[A] {

def in(e: ListDate) = new BinaryOperatorNodeLogicalBoolean(this, e, "in")
def notIn(e: ListDate) = new BinaryOperatorNodeLogicalBoolean(this, e, "not in")
def in(e: ListDate) = new InListExpression(this, e, true)
def notIn(e: ListDate) = new InListExpression(this, e, false)

def ~ = this
}
24 changes: 23 additions & 1 deletion src/main/scala/org/squeryl/dsl/ast/ExpressionNode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ trait ExpressionNode {


trait ListExpressionNode extends ExpressionNode {

def quotesElement = false

def isEmpty: Boolean
}

trait ListNumerical extends ListExpressionNode
Expand All @@ -105,6 +108,22 @@ trait ListString extends ListExpressionNode {

class EqualityExpression(override val left: TypedExpressionNode[_], override val right: TypedExpressionNode[_]) extends BinaryOperatorNodeLogicalBoolean(left, right, "=")

class InListExpression(left: ExpressionNode, right: ListExpressionNode, inclusion: Boolean) extends BinaryOperatorNodeLogicalBoolean(left, right, if(inclusion) "in" else "not in") {

override def inhibited =
if(right.isEmpty)
(! inclusion)
else
super.inhibited

override def doWrite(sw: StatementWriter) =
if(inclusion && right.isEmpty)
sw.write("(1 = 0)")
else
super.doWrite(sw)
}


class BinaryOperatorNodeLogicalBoolean(left: ExpressionNode, right: ExpressionNode, op: String)
extends BinaryOperatorNode(left,right, op) with LogicalBoolean {

Expand Down Expand Up @@ -324,7 +343,10 @@ class ConstantExpressionNode[T](val value: T) extends ExpressionNode {
}

class ConstantExpressionNodeList[T](val value: Traversable[T]) extends ExpressionNode with ListExpressionNode {


def isEmpty =
value == Nil

def doWrite(sw: StatementWriter) =
if(quotesElement)
sw.write(this.value.map(e=>"'" +e+"'").mkString("(",",",")"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ trait QueryExpressionElements extends ExpressionNode {

def whereClause: Option[ExpressionNode]

def hasUnInhibitedWhereClause =
whereClause != None &&
(! whereClause.get.inhibited) &&
(whereClause.get.children.filter(c => !c.inhibited) != Nil)

def havingClause: Option[ExpressionNode]

def groupByClause: Iterable[ExpressionNode]
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/org/squeryl/internals/DatabaseAdapter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ trait DatabaseAdapter {
}
}

if(qen.whereClause != None && qen.whereClause.get.children.filter(c => !c.inhibited) != Nil) {
if(qen.hasUnInhibitedWhereClause) {
sw.write("Where")
sw.nextLine
sw.writeIndented {
Expand Down
23 changes: 23 additions & 0 deletions src/test/scala/org/squeryl/tests/musicdb/MusicDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ class MusicDbTestRun extends QueryTester {
def working = {
import testInstance._

testInTautology

testNotInTautology

testDynamicWhereClause1

testEnums
Expand Down Expand Up @@ -762,6 +766,25 @@ class MusicDbTestRun extends QueryTester {
select(a)
)

def testInTautology = {

val q = artists.where(_.firstName in Nil).toList

assertEquals(Nil, q, 'testInTautology)

passed('testInTautology)
}

def testNotInTautology = {

val allArtists = artists.map(_.id).toSet

val q = artists.where(_.firstName notIn Nil).map(_.id).toSet

assertEquals(allArtists, q, 'testNotInTautology)

passed('testNotInTautology)
}

// //class EnumE[A <: Enumeration#Value](val a: A) {
// class EnumE[A](val a: A) {
Expand Down

0 comments on commit 4cd99be

Please sign in to comment.