Skip to content

Commit

Permalink
Merge pull request #15 from marmbrus/orderedRow
Browse files Browse the repository at this point in the history
Create OrderedRow class to allow ordering to be used by multiple operators.
  • Loading branch information
marmbrus committed Jan 23, 2014
2 parents ca2ff68 + 5ab18be commit 053a371
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 71 deletions.
6 changes: 6 additions & 0 deletions src/main/scala/catalyst/dsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ package object dsl {
def filter(dynamicUdf: (DynamicRow) => Boolean) =
Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(plan.output))), plan)

def sample(
fraction: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt) =
Sample(fraction, withReplacement, seed, plan)

def analyze = analysis.SimpleAnalyzer(plan)
}
}
79 changes: 18 additions & 61 deletions src/main/scala/catalyst/execution/basicOperators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import types._

import org.apache.spark.SparkContext._


case class Project(projectList: Seq[NamedExpression], child: SharkPlan) extends UnaryNode {
def output = projectList.map(_.toAttribute)

Expand All @@ -23,6 +22,15 @@ case class Filter(condition: Expression, child: SharkPlan) extends UnaryNode {
}
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SharkPlan)
extends UnaryNode {

def output = child.output

// TODO: How to pick seed?
def execute() = child.execute().sample(withReplacement, fraction, seed)
}

case class Union(left: SharkPlan, right: SharkPlan)(@transient sc: SharkContext)
extends BinaryNode {
// TODO: attributes output by union should be distinct for nullability purposes
Expand All @@ -46,72 +54,21 @@ case class StopAfter(limit: Int, child: SharkPlan)(@transient sc: SharkContext)
}

case class Sort(sortExprs: Seq[SortOrder], child: SharkPlan) extends UnaryNode {
val numPartitions = 1 // TODO: Set with input cardinality
val numPartitions = 8 // TODO: Set with input cardinality

private final val directions = sortExprs.map(_.direction).toIndexedSeq
private final val dataTypes = sortExprs.map(_.dataType).toIndexedSeq

private class SortKey(val keyValues: IndexedSeq[Any]) extends Ordered[SortKey] with Serializable {
def compare(other: SortKey): Int = {
var i = 0
while (i < keyValues.size) {
val left = keyValues(i)
val right = other.keyValues(i)
val curDirection = directions(i)
val curDataType = dataTypes(i)

logger.debug(s"Comparing $left, $right as $curDataType order $curDirection")
// TODO: Use numeric here too?
val comparison =
if (left == null && right == null) {
0
} else if (left == null) {
if (curDirection == Ascending) -1 else 1
} else if (right == null) {
if (curDirection == Ascending) 1 else -1
} else if (curDataType == IntegerType) {
if (curDirection == Ascending) {
left.asInstanceOf[Int] compare right.asInstanceOf[Int]
} else {
right.asInstanceOf[Int] compare left.asInstanceOf[Int]
}
} else if (curDataType == DoubleType) {
if (curDirection == Ascending) {
left.asInstanceOf[Double] compare right.asInstanceOf[Double]
} else {
right.asInstanceOf[Double] compare left.asInstanceOf[Double]
}
} else if (curDataType == LongType) {
if (curDirection == Ascending) {
left.asInstanceOf[Long] compare right.asInstanceOf[Long]
} else {
right.asInstanceOf[Long] compare left.asInstanceOf[Long]
}
} else if (curDataType == StringType) {
if (curDirection == Ascending) {
left.asInstanceOf[String] compare right.asInstanceOf[String]
} else {
right.asInstanceOf[String] compare left.asInstanceOf[String]
}
} else {
sys.error(s"Comparison not yet implemented for: $curDataType")
}

if (comparison != 0) return comparison
i += 1
}
return 0
}
}

// TODO: Don't include redundant expressions in both sortKey and row.
def execute() = attachTree(this, "sort") {
child.execute().map { row =>
val input = Vector(row)
val sortKey = new SortKey(sortExprs.map(s => Evaluate(s.child, input)).toIndexedSeq)

(sortKey, row)
}.sortByKey(ascending = true, numPartitions).map(_._2)
import scala.math.Ordering.Implicits._
implicit val ordering = new RowOrdering(sortExprs)

// TODO: Allow spark to take the ordering as an argument, also avoid needless pair creation.
child.execute()
.mapPartitions(iter => iter.map(row => (row, null)))
.sortByKey(ascending = true, numPartitions)
.map(_._1)
}

def output = child.output
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/catalyst/execution/planningStrategies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ trait PlanningStrategies {
execution.Filter(condition, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(group, agg, planLater(child))(sc) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
execution.LocalRelation(output, data.map(_.productIterator.toVector))(sc) :: Nil
case logical.StopAfter(limit, child) =>
Expand Down
30 changes: 29 additions & 1 deletion src/main/scala/catalyst/expressions/Row.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package catalyst
package expressions

import types._

/**
* Represents one row of output from a relational operator. Allows both generic access by ordinal,
* which will incur boxing overhead for primitives, as well as native primitive access.
Expand Down Expand Up @@ -76,4 +78,30 @@ class GenericRow(input: Seq[Any]) extends Row {
if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
values(i).asInstanceOf[Byte]
}
}
}

class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
def compare(a: Row, b: Row): Int = {
ordering.foreach { order =>
val left = Evaluate(order.child, Vector(a))
val right = Evaluate(order.child, Vector(b))

if (left == null && right == null) {
// Both null, continue looking.
} else if (left == null) {
return if (order.direction == Ascending) -1 else 1
} else if (right == null) {
return if (order.direction == Ascending) 1 else -1
} else {
val comparison = order.dataType match {
case n: NativeType if order.direction == Ascending =>
n.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
case n: NativeType if order.direction == Descending =>
n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
}
if (comparison != 0) return comparison
}
}
return 0
}
}
2 changes: 1 addition & 1 deletion src/main/scala/catalyst/frontend/Hive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ object HiveQl {
case Token("TOK_TABLEBUCKETSAMPLE",
Token(numerator, Nil) ::
Token(denominator, Nil) :: Nil) =>
Sample(numerator.toDouble / denominator.toDouble, relation)
Sample(numerator.toDouble / denominator.toDouble, false, (math.random * 1000).toInt, relation)
}.getOrElse(relation)

case Token("TOK_UNIQUEJOIN", joinArgs) =>
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/catalyst/plans/logical/basicOperators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
def references = Set.empty
}

case class Sample(percentage: Double, child: LogicalPlan) extends UnaryNode {
case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan)
extends UnaryNode {

def output = child.output
def references = Set.empty
}
Expand Down
17 changes: 15 additions & 2 deletions src/main/scala/catalyst/types/dataTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,21 @@ abstract class DataType {

case object NullType extends DataType

abstract class NativeType extends DataType { type JvmType }
abstract class NativeType extends DataType {
type JvmType
val ordering: Ordering[JvmType]
}

case object StringType extends NativeType {
type JvmType = String
val ordering = implicitly[Ordering[JvmType]]
}
case object BinaryType extends NativeType {
case object BinaryType extends DataType {
type JvmType = Array[Byte]
}
case object BooleanType extends NativeType {
type JvmType = Boolean
val ordering = implicitly[Ordering[JvmType]]
}

abstract class NumericType extends NativeType {
Expand Down Expand Up @@ -49,24 +55,28 @@ case object LongType extends IntegralType {
type JvmType = Long
val numeric = implicitly[Numeric[Long]]
val integral = implicitly[Integral[Long]]
val ordering = implicitly[Ordering[JvmType]]
}

case object IntegerType extends IntegralType {
type JvmType = Int
val numeric = implicitly[Numeric[Int]]
val integral = implicitly[Integral[Int]]
val ordering = implicitly[Ordering[JvmType]]
}

case object ShortType extends IntegralType {
type JvmType = Short
val numeric = implicitly[Numeric[Short]]
val integral = implicitly[Integral[Short]]
val ordering = implicitly[Ordering[JvmType]]
}

case object ByteType extends IntegralType {
type JvmType = Byte
val numeric = implicitly[Numeric[Byte]]
val integral = implicitly[Integral[Byte]]
val ordering = implicitly[Ordering[JvmType]]
}

/** Matcher for any expressions that evaluate to [[FractionalType]]s */
Expand All @@ -84,18 +94,21 @@ case object DecimalType extends FractionalType {
type JvmType = BigDecimal
val numeric = implicitly[Numeric[BigDecimal]]
val fractional = implicitly[Fractional[BigDecimal]]
val ordering = implicitly[Ordering[JvmType]]
}

case object DoubleType extends FractionalType {
type JvmType = Double
val numeric = implicitly[Numeric[Double]]
val fractional = implicitly[Fractional[Double]]
val ordering = implicitly[Ordering[JvmType]]
}

case object FloatType extends FractionalType {
type JvmType = Float
val numeric = implicitly[Numeric[Float]]
val fractional = implicitly[Fractional[Float]]
val ordering = implicitly[Ordering[JvmType]]
}

case class ArrayType(elementType: DataType) extends DataType
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/catalyst/execution/DslQueryTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class DslQueryTests extends FunSuite with BeforeAndAfterAll {
}

test("random sample") {
testData.where(Rand > 0.5).orderBy(Rand.asc).toRdd.collect()
testData.sample(0.5).toRdd.collect()
}

test("sorting") {
Expand Down
4 changes: 0 additions & 4 deletions src/test/scala/catalyst/execution/HiveQueryTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ class HiveQueryTests extends HiveComparisonTest {
createQueryTest("string literal",
"SELECT 'test' FROM src")

test("Run random sample") { // Since this is non-deterministic we just check to make sure it runs for now.
"SELECT key, value FROM src WHERE RAND() > 0.5 ORDER BY RAND()".q.stringResult()
}

createQueryTest("Escape sequences",
"""SELECT key, '\\\t\\' FROM src WHERE key = 86""")

Expand Down

0 comments on commit 053a371

Please sign in to comment.