Skip to content

Commit

Permalink
add range api()
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed May 14, 2015
1 parent d3db2fd commit d3a0c1b
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
58 changes: 58 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,64 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}

/**
* Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by
* `step` every element.
*
* @note if we need to cache this RDD, we should make sure each partition contains no more than
* 2 billion element.
*
* @param start the start value.
* @param end the end value.
* @param step the
* @param numSlices the partition number of the new RDD.
* @return
*/
def range(
start: Long,
end: Long,
step: Long = 1,
numSlices: Int = defaultParallelism): RDD[Long] = withScope {
assertNotStopped()
if (step == 0) {
// when step is 0, range will run infinite
throw new IllegalArgumentException("`step` cannot be 0")
}
val length =
if ((end - start) % step == 0) {
(end - start) / step
} else {
(end - start) / step + 1
}
parallelize(0 to numSlices).mapPartitions(iter => {
val i = iter.next()
val partitionStart = (i * length) / numSlices * step + start
val partitionEnd = ((i + 1) * length) / numSlices * step + start

new Iterator[Long] {
var number: Long = _
initialize()

override def hasNext =
if (step > 0) {
number < partitionEnd
} else {
number > partitionEnd
}

override def next() = {
val ret = number
number += step
ret
}

private def initialize() = {
number = partitionStart
}
}
})
}

/** Distribute a local Scala collection to form an RDD.
*
* This method is identical to `parallelize`.
Expand Down
31 changes: 31 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,37 @@ class SQLContext(@transient val sparkContext: SparkContext)
catalog.unregisterTable(Seq(tableName))
}

/**
* :: Experimental ::
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
* in an range from `start` to `end`(exclusive) with step value 1.
*
* @since 1.4.0
* @group dataframe
*/
@Experimental
def range(start: Long, end: Long, step: Long = 1): DataFrame = {
createDataFrame(
sparkContext.range(start, end, step).map(Row(_)),
StructType(StructField("id", LongType, nullable = false) :: Nil))
}

/**
* :: Experimental ::
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
* in an range from `start` to `end`(exclusive) with step value 1, with partition numbers
* specified.
*
* @since 1.4.0
* @group dataframe
*/
@Experimental
def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
createDataFrame(
sparkContext.range(start, end, step, numPartitions).map(Row(_)),
StructType(StructField("id", LongType, nullable = false) :: Nil))
}

/**
* Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is
* used for SQL parsing can be configured with 'spark.sql.dialect'.
Expand Down
29 changes: 29 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -532,4 +532,33 @@ class DataFrameSuite extends QueryTest {
val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project]
assert(!p.child.isInstanceOf[Project])
}

test("range api") {
TestSQLContext.range(0, 10).registerTempTable("rangeTable1")
val res1 = TestSQLContext.sql("select id from rangeTable1")
assert(res1.count == 10)
assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
TestSQLContext.range(3, 15, 3).registerTempTable("rangeTable2")
val res2 = TestSQLContext.sql("select id from rangeTable2")
assert(res2.count == 4)
assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
TestSQLContext.range(1, -2).registerTempTable("rangeTable3")
val res3 = TestSQLContext.sql("select id from rangeTable3")
assert(res3.count == 0)
// start is positive, end is negative, step is negative
TestSQLContext.range(1, -2, -2).registerTempTable("rangeTable4")
val res4 = TestSQLContext.sql("select id from rangeTable4")
assert(res4.count == 2)
assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
// start, end, step are negative
TestSQLContext.range(-3, -8, -2).registerTempTable("rangeTable5")
val res5 = TestSQLContext.sql("select id from rangeTable5")
assert(res5.count == 3)
assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
// start, end are negative, step is positive
TestSQLContext.range(-8, -4, 2).registerTempTable("rangeTable6")
val res6 = TestSQLContext.sql("select id from rangeTable6")
assert(res6.count == 2)
assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
}
}

0 comments on commit d3a0c1b

Please sign in to comment.