From d3a0c1bb1943735d331df8404f9a464456ea9e64 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 13 May 2015 22:27:37 -0700 Subject: [PATCH] add range api() --- .../scala/org/apache/spark/SparkContext.scala | 58 +++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 31 ++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 29 ++++++++++ 3 files changed, 118 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b59f562d05ead..5ae6c45679365 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -689,6 +689,64 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } + /** + * Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by + * `step` every element. + * + * @note if we need to cache this RDD, we should make sure each partition contains no more than + * 2 billion element. + * + * @param start the start value. + * @param end the end value. + * @param step the + * @param numSlices the partition number of the new RDD. + * @return + */ + def range( + start: Long, + end: Long, + step: Long = 1, + numSlices: Int = defaultParallelism): RDD[Long] = withScope { + assertNotStopped() + if (step == 0) { + // when step is 0, range will run infinite + throw new IllegalArgumentException("`step` cannot be 0") + } + val length = + if ((end - start) % step == 0) { + (end - start) / step + } else { + (end - start) / step + 1 + } + parallelize(0 to numSlices).mapPartitions(iter => { + val i = iter.next() + val partitionStart = (i * length) / numSlices * step + start + val partitionEnd = ((i + 1) * length) / numSlices * step + start + + new Iterator[Long] { + var number: Long = _ + initialize() + + override def hasNext = + if (step > 0) { + number < partitionEnd + } else { + number > partitionEnd + } + + override def next() = { + val ret = number + number += step + ret + } + + private def initialize() = { + number = partitionStart + } + } + }) + } + /** Distribute a local Scala collection to form an RDD. * * This method is identical to `parallelize`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 521f3dc821795..d15a01dff7d8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1125,6 +1125,37 @@ class SQLContext(@transient val sparkContext: SparkContext) catalog.unregisterTable(Seq(tableName)) } + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end`(exclusive) with step value 1. + * + * @since 1.4.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long, step: Long = 1): DataFrame = { + createDataFrame( + sparkContext.range(start, end, step).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) + } + + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end`(exclusive) with step value 1, with partition numbers + * specified. + * + * @since 1.4.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { + createDataFrame( + sparkContext.range(start, end, step, numPartitions).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) + } + /** * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1d5f6b3aad6fd..d82e138e991d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -532,4 +532,33 @@ class DataFrameSuite extends QueryTest { val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] assert(!p.child.isInstanceOf[Project]) } + + test("range api") { + TestSQLContext.range(0, 10).registerTempTable("rangeTable1") + val res1 = TestSQLContext.sql("select id from rangeTable1") + assert(res1.count == 10) + assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + TestSQLContext.range(3, 15, 3).registerTempTable("rangeTable2") + val res2 = TestSQLContext.sql("select id from rangeTable2") + assert(res2.count == 4) + assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + TestSQLContext.range(1, -2).registerTempTable("rangeTable3") + val res3 = TestSQLContext.sql("select id from rangeTable3") + assert(res3.count == 0) + // start is positive, end is negative, step is negative + TestSQLContext.range(1, -2, -2).registerTempTable("rangeTable4") + val res4 = TestSQLContext.sql("select id from rangeTable4") + assert(res4.count == 2) + assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) + // start, end, step are negative + TestSQLContext.range(-3, -8, -2).registerTempTable("rangeTable5") + val res5 = TestSQLContext.sql("select id from rangeTable5") + assert(res5.count == 3) + assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) + // start, end are negative, step is positive + TestSQLContext.range(-8, -4, 2).registerTempTable("rangeTable6") + val res6 = TestSQLContext.sql("select id from rangeTable6") + assert(res6.count == 2) + assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) + } }