diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala
deleted file mode 100644
index 4636c4600a01a..0000000000000
--- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-/**
- * This class exists to restrict the visibility of TaskContext setters.
- */
-private [spark] object TaskContextHelper {
-
- def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc)
-
- def unset(): Unit = TaskContext.unset()
-
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c82ae4baa3630..508fe7b3303ca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -50,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
* not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
* a small number of times before cancelling the whole stage.
*
+ * Here's a checklist to use when making or reviewing changes to this class:
+ *
+ * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to
+ * include the new structure. This will help to catch memory leaks.
*/
private[spark]
class DAGScheduler(
@@ -111,6 +115,8 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]
+ private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator
+
// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
@@ -128,8 +134,6 @@ class DAGScheduler(
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)
- private val outputCommitCoordinator = env.outputCommitCoordinator
-
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessLoop.post(BeginEvent(task, taskInfo))
@@ -641,13 +645,13 @@ class DAGScheduler(
val split = rdd.partitions(job.partitions(0))
val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
attemptNumber = 0, runningLocally = true)
- TaskContextHelper.setTaskContext(taskContext)
+ TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
- TaskContextHelper.unset()
+ TaskContext.unset()
}
} catch {
case e: Exception =>
@@ -710,9 +714,10 @@ class DAGScheduler(
// cancelling the stages because if the DAG scheduler is stopped, the entire application
// is in the process of getting stopped.
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
- runningStages.foreach { stage =>
- stage.latestInfo.stageFailed(stageFailedMessage)
- listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
+ // The `toArray` here is necessary so that we don't iterate over `runningStages` while
+ // mutating it.
+ runningStages.toArray.foreach { stage =>
+ markStageAsFinished(stage, Some(stageFailedMessage))
}
listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error)))
}
@@ -887,10 +892,9 @@ class DAGScheduler(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
} else {
- // Because we posted SparkListenerStageSubmitted earlier, we should post
- // SparkListenerStageCompleted here in case there are no tasks to run.
- outputCommitCoordinator.stageEnd(stage.id)
- listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
+ // Because we posted SparkListenerStageSubmitted earlier, we should mark
+ // the stage as completed here in case there are no tasks to run
+ markStageAsFinished(stage, None)
val debugString = stage match {
case stage: ShuffleMapStage =>
@@ -902,7 +906,6 @@ class DAGScheduler(
s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
}
logDebug(debugString)
- runningStages -= stage
}
}
@@ -968,22 +971,6 @@ class DAGScheduler(
}
val stage = stageIdToStage(task.stageId)
-
- def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
- val serviceTime = stage.latestInfo.submissionTime match {
- case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
- case _ => "Unknown"
- }
- if (errorMessage.isEmpty) {
- logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
- stage.latestInfo.completionTime = Some(clock.getTimeMillis())
- } else {
- stage.latestInfo.stageFailed(errorMessage.get)
- logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
- }
- listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
- runningStages -= stage
- }
event.reason match {
case Success =>
listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
@@ -1099,7 +1086,6 @@ class DAGScheduler(
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
- runningStages -= failedStage
}
if (disallowStageRetryForTest) {
@@ -1215,6 +1201,26 @@ class DAGScheduler(
submitWaitingStages()
}
+ /**
+ * Marks a stage as finished and removes it from the list of running stages.
+ */
+ private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
+ val serviceTime = stage.latestInfo.submissionTime match {
+ case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
+ case _ => "Unknown"
+ }
+ if (errorMessage.isEmpty) {
+ logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
+ stage.latestInfo.completionTime = Some(clock.getTimeMillis())
+ } else {
+ stage.latestInfo.stageFailed(errorMessage.get)
+ logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
+ }
+ outputCommitCoordinator.stageEnd(stage.id)
+ listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
+ runningStages -= stage
+ }
+
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
@@ -1264,8 +1270,7 @@ class DAGScheduler(
if (runningStages.contains(stage)) {
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
- stage.latestInfo.stageFailed(failureReason)
- listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
+ markStageAsFinished(stage, Some(failureReason))
} catch {
case e: UnsupportedOperationException =>
logInfo(s"Could not cancel tasks for stage $stageId", e)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
index 9e29fd13821dc..7c184b1dcb308 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -59,6 +59,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()
private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]]
+ /**
+ * Returns whether the OutputCommitCoordinator's internal data structures are all empty.
+ */
+ def isEmpty: Boolean = {
+ authorizedCommittersByStage.isEmpty
+ }
+
/**
* Called by tasks to ask whether they can commit their output to HDFS.
*
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 4d9f940813b8e..8b592867ee31d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
-import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext}
+import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
@@ -54,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
- TaskContextHelper.setTaskContext(context)
+ TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
if (_killed) {
@@ -64,7 +64,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
runTask(context)
} finally {
context.markTaskCompleted()
- TaskContextHelper.unset()
+ TaskContext.unset()
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index f57921b768310..30b6184c77839 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -242,14 +242,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
shuffleSpillCompress <- Set(true, false);
shuffleCompress <- Set(true, false)
) {
- val conf = new SparkConf()
+ val myConf = conf.clone()
.setAppName("test")
.setMaster("local")
.set("spark.shuffle.spill.compress", shuffleSpillCompress.toString)
.set("spark.shuffle.compress", shuffleCompress.toString)
.set("spark.shuffle.memoryFraction", "0.001")
resetSparkContext()
- sc = new SparkContext(conf)
+ sc = new SparkContext(myConf)
try {
sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect()
} catch {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 63360a0f189a3..eb759f0807a17 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -783,6 +783,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assert(scheduler.runningStages.isEmpty)
assert(scheduler.shuffleToMapStage.isEmpty)
assert(scheduler.waitingStages.isEmpty)
+ assert(scheduler.outputCommitCoordinator.isEmpty)
}
// Nothing in this test should break if the task info's fields are null, but
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index c08c76d226713..771a07183e26f 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -493,7 +493,7 @@ from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.sql import Row, SQLContext
sc = SparkContext(appName="SimpleTextClassificationPipeline")
-sqlCtx = SQLContext(sc)
+sqlContext = SQLContext(sc)
# Prepare training documents, which are labeled.
LabeledDocument = Row("id", "text", "label")
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 4441d6a000a02..663f656883721 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1642,7 +1642,7 @@ moved into the udf object in `SQLContext`.
{% highlight java %}
-sqlCtx.udf.register("strLen", (s: String) => s.length())
+sqlContext.udf.register("strLen", (s: String) => s.length())
{% endhighlight %}
@@ -1650,7 +1650,7 @@ sqlCtx.udf.register("strLen", (s: String) => s.length())
{% highlight java %}
-sqlCtx.udf().register("strLen", (String s) -> { s.length(); });
+sqlContext.udf().register("strLen", (String s) -> { s.length(); });
{% endhighlight %}
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 879a52cef8ff0..0c1f24761d0de 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -282,6 +282,10 @@ def parse_args():
parser.add_option(
"--vpc-id", default=None,
help="VPC to launch instances in")
+ parser.add_option(
+ "--private-ips", action="store_true", default=False,
+ help="Use private IPs for instances rather than public if VPC/subnet " +
+ "requires that.")
(opts, args) = parser.parse_args()
if len(args) != 2:
@@ -707,7 +711,7 @@ def get_instances(group_names):
# Deploy configuration files and run setup scripts on a newly launched
# or started EC2 cluster.
def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
- master = master_nodes[0].public_dns_name
+ master = get_dns_name(master_nodes[0], opts.private_ips)
if deploy_ssh_key:
print "Generating cluster's SSH key on master..."
key_setup = """
@@ -719,8 +723,9 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
print "Transferring cluster's SSH key to slaves..."
for slave in slave_nodes:
- print slave.public_dns_name
- ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)
+ slave_address = get_dns_name(slave, opts.private_ips)
+ print slave_address
+ ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar)
modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs',
'mapreduce', 'spark-standalone', 'tachyon']
@@ -809,7 +814,8 @@ def is_cluster_ssh_available(cluster_instances, opts):
Check if SSH is available on all the instances in a cluster.
"""
for i in cluster_instances:
- if not is_ssh_available(host=i.public_dns_name, opts=opts):
+ dns_name = get_dns_name(i, opts.private_ips)
+ if not is_ssh_available(host=dns_name, opts=opts):
return False
else:
return True
@@ -923,7 +929,7 @@ def get_num_disks(instance_type):
#
# root_dir should be an absolute path to the directory with the files we want to deploy.
def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
- active_master = master_nodes[0].public_dns_name
+ active_master = get_dns_name(master_nodes[0], opts.private_ips)
num_disks = get_num_disks(opts.instance_type)
hdfs_data_dirs = "/mnt/ephemeral-hdfs/data"
@@ -948,10 +954,12 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
print "Deploying Spark via git hash; Tachyon won't be set up"
modules = filter(lambda x: x != "tachyon", modules)
+ master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes]
+ slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes]
template_vars = {
- "master_list": '\n'.join([i.public_dns_name for i in master_nodes]),
+ "master_list": '\n'.join(master_addresses),
"active_master": active_master,
- "slave_list": '\n'.join([i.public_dns_name for i in slave_nodes]),
+ "slave_list": '\n'.join(slave_addresses),
"cluster_url": cluster_url,
"hdfs_data_dirs": hdfs_data_dirs,
"mapred_local_dirs": mapred_local_dirs,
@@ -1011,7 +1019,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
#
# root_dir should be an absolute path.
def deploy_user_files(root_dir, opts, master_nodes):
- active_master = master_nodes[0].public_dns_name
+ active_master = get_dns_name(master_nodes[0], opts.private_ips)
command = [
'rsync', '-rv',
'-e', stringify_command(ssh_command(opts)),
@@ -1122,6 +1130,20 @@ def get_partition(total, num_partitions, current_partitions):
return num_slaves_this_zone
+# Gets the IP address, taking into account the --private-ips flag
+def get_ip_address(instance, private_ips=False):
+ ip = instance.ip_address if not private_ips else \
+ instance.private_ip_address
+ return ip
+
+
+# Gets the DNS name, taking into account the --private-ips flag
+def get_dns_name(instance, private_ips=False):
+ dns = instance.public_dns_name if not private_ips else \
+ instance.private_ip_address
+ return dns
+
+
def real_main():
(opts, action, cluster_name) = parse_args()
@@ -1230,7 +1252,7 @@ def real_main():
if any(master_nodes + slave_nodes):
print "The following instances will be terminated:"
for inst in master_nodes + slave_nodes:
- print "> %s" % inst.public_dns_name
+ print "> %s" % get_dns_name(inst, opts.private_ips)
print "ALL DATA ON ALL NODES WILL BE LOST!!"
msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name)
@@ -1294,13 +1316,17 @@ def real_main():
elif action == "login":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
- master = master_nodes[0].public_dns_name
- print "Logging into master " + master + "..."
- proxy_opt = []
- if opts.proxy_port is not None:
- proxy_opt = ['-D', opts.proxy_port]
- subprocess.check_call(
- ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)])
+ if not master_nodes[0].public_dns_name and not opts.private_ips:
+ print "Master has no public DNS name. Maybe you meant to specify " \
+ "--private-ips?"
+ else:
+ master = get_dns_name(master_nodes[0], opts.private_ips)
+ print "Logging into master " + master + "..."
+ proxy_opt = []
+ if opts.proxy_port is not None:
+ proxy_opt = ['-D', opts.proxy_port]
+ subprocess.check_call(
+ ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)])
elif action == "reboot-slaves":
response = raw_input(
@@ -1318,7 +1344,11 @@ def real_main():
elif action == "get-master":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
- print master_nodes[0].public_dns_name
+ if not master_nodes[0].public_dns_name and not opts.private_ips:
+ print "Master has no public DNS name. Maybe you meant to specify " \
+ "--private-ips?"
+ else:
+ print get_dns_name(master_nodes[0], opts.private_ips)
elif action == "stop":
response = raw_input(
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index dee794840a3e1..8159ffbe2d269 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -55,7 +55,7 @@ public void setAge(int age) {
public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL");
JavaSparkContext ctx = new JavaSparkContext(sparkConf);
- SQLContext sqlCtx = new SQLContext(ctx);
+ SQLContext sqlContext = new SQLContext(ctx);
System.out.println("=== Data source: RDD ===");
// Load a text file and convert each line to a Java Bean.
@@ -74,11 +74,11 @@ public Person call(String line) {
});
// Apply a schema to an RDD of Java Beans and register it as a table.
- DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class);
+ DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
- DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
// The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
@@ -99,12 +99,12 @@ public String call(Row row) {
// Read in the parquet file created above.
// Parquet files are self-describing so the schema is preserved.
// The result of loading a parquet file is also a DataFrame.
- DataFrame parquetFile = sqlCtx.parquetFile("people.parquet");
+ DataFrame parquetFile = sqlContext.parquetFile("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
DataFrame teenagers2 =
- sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
+ sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
teenagerNames = teenagers2.toJavaRDD().map(new Function() {
@Override
public String call(Row row) {
@@ -120,7 +120,7 @@ public String call(Row row) {
// The path can be either a single text file or a directory storing text files.
String path = "examples/src/main/resources/people.json";
// Create a DataFrame from the file(s) pointed by path
- DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path);
+ DataFrame peopleFromJsonFile = sqlContext.jsonFile(path);
// Because the schema of a JSON dataset is automatically inferred, to write queries,
// it is better to take a look at what is the schema.
@@ -133,8 +133,8 @@ public String call(Row row) {
// Register this DataFrame as a table.
peopleFromJsonFile.registerTempTable("people");
- // SQL statements can be run by using the sql methods provided by sqlCtx.
- DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+ // SQL statements can be run by using the sql methods provided by sqlContext.
+ DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
// The results of SQL queries are DataFrame and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
@@ -151,7 +151,7 @@ public String call(Row row) {
List jsonData = Arrays.asList(
"{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}");
JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData);
- DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd());
+ DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd());
// Take a look at the schema of this new DataFrame.
peopleFromJsonRDD.printSchema();
@@ -164,7 +164,7 @@ public String call(Row row) {
peopleFromJsonRDD.registerTempTable("people2");
- DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2");
+ DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2");
List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() {
@Override
public String call(Row row) {
diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py
index d281f4fa44282..c73edb7fd6b20 100644
--- a/examples/src/main/python/ml/simple_text_classification_pipeline.py
+++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py
@@ -33,7 +33,7 @@
if __name__ == "__main__":
sc = SparkContext(appName="SimpleTextClassificationPipeline")
- sqlCtx = SQLContext(sc)
+ sqlContext = SQLContext(sc)
# Prepare training documents, which are labeled.
LabeledDocument = Row("id", "text", "label")
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
index b5a70db2b9a3c..fcbf56cbf0c52 100644
--- a/examples/src/main/python/mllib/dataset_example.py
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -44,19 +44,19 @@ def summarize(dataset):
print >> sys.stderr, "Usage: dataset_example.py "
exit(-1)
sc = SparkContext(appName="DatasetExample")
- sqlCtx = SQLContext(sc)
+ sqlContext = SQLContext(sc)
if len(sys.argv) == 2:
input = sys.argv[1]
else:
input = "data/mllib/sample_libsvm_data.txt"
points = MLUtils.loadLibSVMFile(sc, input)
- dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
+ dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache()
summarize(dataset0)
tempdir = tempfile.NamedTemporaryFile(delete=False).name
os.unlink(tempdir)
print "Save dataset as a Parquet file to %s." % tempdir
dataset0.saveAsParquetFile(tempdir)
print "Load it back and summarize it again."
- dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
+ dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache()
summarize(dataset1)
shutil.rmtree(tempdir)
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index e04d4088df7dc..2edea9b5b69ba 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -1,21 +1,20 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
*/
+
package org.apache.spark.streaming.flume
import java.net.InetSocketAddress
@@ -213,7 +212,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging
assert(counter === totalEventsPerChannel * channels.size)
}
- def assertChannelIsEmpty(channel: MemoryChannel) = {
+ def assertChannelIsEmpty(channel: MemoryChannel): Unit = {
val queueRemaining = channel.getClass.getDeclaredField("queueRemaining")
queueRemaining.setAccessible(true)
val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits")
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 51d273af8da84..39e6754c81dbf 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -151,7 +151,9 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
}
/** Class to create socket channel with compression */
- private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory {
+ private class CompressionChannelFactory(compressionLevel: Int)
+ extends NioClientSocketChannelFactory {
+
override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
val encoder = new ZlibEncoder(compressionLevel)
pipeline.addFirst("deflater", encoder)
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index 24d78ecb3a97d..a19a72c58a705 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -139,7 +139,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
msgTopic.publish(message)
} catch {
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
- Thread.sleep(50) // wait for Spark streaming to consume something from the message queue
+ // wait for Spark streaming to consume something from the message queue
+ Thread.sleep(50)
}
}
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 8d15150458d26..a570e4ed75fc3 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -38,12 +38,12 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val doubleRing = ring ++ ring
val graph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1)
assert(graph.edges.count() === doubleRing.size)
- assert(graph.edges.collect.forall(e => e.attr == 1))
+ assert(graph.edges.collect().forall(e => e.attr == 1))
// uniqueEdges option should uniquify edges and store duplicate count in edge attributes
val uniqueGraph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1, Some(RandomVertexCut))
assert(uniqueGraph.edges.count() === ring.size)
- assert(uniqueGraph.edges.collect.forall(e => e.attr == 2))
+ assert(uniqueGraph.edges.collect().forall(e => e.attr == 2))
}
}
@@ -64,7 +64,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
assert( graph.edges.count() === rawEdges.size )
// Vertices not explicitly provided but referenced by edges should be created automatically
assert( graph.vertices.count() === 100)
- graph.triplets.collect.map { et =>
+ graph.triplets.collect().map { et =>
assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr))
assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr))
}
@@ -75,15 +75,17 @@ class GraphSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val n = 5
val star = starGraph(sc, n)
- assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect.toSet ===
- (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet)
+ assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect().toSet
+ === (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet)
}
}
test("partitionBy") {
withSpark { sc =>
- def mkGraph(edges: List[(Long, Long)]) = Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0)
- def nonemptyParts(graph: Graph[Int, Int]) = {
+ def mkGraph(edges: List[(Long, Long)]): Graph[Int, Int] = {
+ Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0)
+ }
+ def nonemptyParts(graph: Graph[Int, Int]): RDD[List[Edge[Int]]] = {
graph.edges.partitionsRDD.mapPartitions { iter =>
Iterator(iter.next()._2.iterator.toList)
}.filter(_.nonEmpty)
@@ -102,7 +104,8 @@ class GraphSuite extends FunSuite with LocalSparkContext {
assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1)
// partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into
// the same partition
- assert(nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1)
+ assert(
+ nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1)
// partitionBy(EdgePartition2D) puts identical edges in the same partition
assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1)
@@ -140,10 +143,10 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val g = Graph(
sc.parallelize(List((0L, "a"), (1L, "b"), (2L, "c"))),
sc.parallelize(List(Edge(0L, 1L, 1), Edge(0L, 2L, 1)), 2))
- assert(g.triplets.collect.map(_.toTuple).toSet ===
+ assert(g.triplets.collect().map(_.toTuple).toSet ===
Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1)))
val gPart = g.partitionBy(EdgePartition2D)
- assert(gPart.triplets.collect.map(_.toTuple).toSet ===
+ assert(gPart.triplets.collect().map(_.toTuple).toSet ===
Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1)))
}
}
@@ -154,10 +157,10 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val star = starGraph(sc, n)
// mapVertices preserving type
val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2")
- assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet)
+ assert(mappedVAttrs.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet)
// mapVertices changing type
val mappedVAttrs2 = star.mapVertices((vid, attr) => attr.length)
- assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, 1)).toSet)
+ assert(mappedVAttrs2.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, 1)).toSet)
}
}
@@ -177,12 +180,12 @@ class GraphSuite extends FunSuite with LocalSparkContext {
// Trigger initial vertex replication
graph0.triplets.foreach(x => {})
// Change type of replicated vertices, but preserve erased type
- val graph1 = graph0.mapVertices {
- case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double)
+ val graph1 = graph0.mapVertices { case (vid, integerOpt) =>
+ integerOpt.map((x: java.lang.Integer) => x.toDouble: java.lang.Double)
}
// Access replicated vertices, exposing the erased type
val graph2 = graph1.mapTriplets(t => t.srcAttr.get)
- assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0))
+ assert(graph2.edges.map(_.attr).collect().toSet === Set[java.lang.Double](1.0, 2.0, 3.0))
}
}
@@ -202,7 +205,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val n = 5
val star = starGraph(sc, n)
- assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect.toSet ===
+ assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect().toSet ===
(1L to n).map(x => Edge(0, x, "vv")).toSet)
}
}
@@ -211,7 +214,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val n = 5
val star = starGraph(sc, n)
- assert(star.reverse.outDegrees.collect.toSet === (1 to n).map(x => (x: VertexId, 1)).toSet)
+ assert(star.reverse.outDegrees.collect().toSet === (1 to n).map(x => (x: VertexId, 1)).toSet)
}
}
@@ -221,7 +224,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0)))
val graph = Graph(vertices, edges).reverse
val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _)
- assert(result.collect.toSet === Set((1L, 2)))
+ assert(result.collect().toSet === Set((1L, 2)))
}
}
@@ -237,7 +240,8 @@ class GraphSuite extends FunSuite with LocalSparkContext {
assert(subgraph.vertices.collect().toSet === (0 to n by 2).map(x => (x, "v")).toSet)
// And 4 edges.
- assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet)
+ assert(subgraph.edges.map(_.copy()).collect().toSet ===
+ (2 to n by 2).map(x => Edge(0, x, 1)).toSet)
}
}
@@ -273,9 +277,9 @@ class GraphSuite extends FunSuite with LocalSparkContext {
sc.parallelize((1 to n).flatMap(x =>
List((0: VertexId, x: VertexId), (0: VertexId, x: VertexId))), 1), "v")
val star2 = doubleStar.groupEdges { (a, b) => a}
- assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) ===
- star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]))
- assert(star2.vertices.collect.toSet === star.vertices.collect.toSet)
+ assert(star2.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int]) ===
+ star.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int]))
+ assert(star2.vertices.collect().toSet === star.vertices.collect().toSet)
}
}
@@ -300,21 +304,23 @@ class GraphSuite extends FunSuite with LocalSparkContext {
throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId))
}
Iterator((et.srcId, 1))
- }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet
+ }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect().toSet
assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet)
// outerJoinVertices followed by mapReduceTriplets(activeSetOpt)
- val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x+1) % n: VertexId)), 3)
+ val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x + 1) % n: VertexId)), 3)
val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache()
val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache()
- val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) }
+ val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) =>
+ newOpt.getOrElse(old)
+ }
val numOddNeighbors = changedGraph.mapReduceTriplets(et => {
// Map function should only run on edges with source in the active set
if (et.srcId % 2 != 1) {
throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId))
}
Iterator((et.dstId, 1))
- }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet
+ }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect().toSet
assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet)
}
@@ -340,17 +346,18 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val n = 5
val reverseStar = starGraph(sc, n).reverse.cache()
// outerJoinVertices changing type
- val reverseStarDegrees =
- reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) }
+ val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) {
+ (vid, a, bOpt) => bOpt.getOrElse(0)
+ }
val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets(
et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)),
- (a: Int, b: Int) => a + b).collect.toSet
+ (a: Int, b: Int) => a + b).collect().toSet
assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0)))
// outerJoinVertices preserving type
val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString }
val newReverseStar =
reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") }
- assert(newReverseStar.vertices.map(_._2).collect.toSet ===
+ assert(newReverseStar.vertices.map(_._2).collect().toSet ===
(0 to n).map(x => "v%d".format(x)).toSet)
}
}
@@ -361,7 +368,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2)
val graph = Graph(verts, edges)
val triplets = graph.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr))
- .collect.toSet
+ .collect().toSet
assert(triplets ===
Set((1: VertexId, 2: VertexId, "a", "b"), (2: VertexId, 1: VertexId, "b", "a")))
}
@@ -417,7 +424,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val graph = Graph.fromEdgeTuples(edges, 1)
val neighborAttrSums = graph.mapReduceTriplets[Int](
et => Iterator((et.dstId, et.srcAttr)), _ + _)
- assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n)))
+ assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n)))
} finally {
sc.stop()
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
index a3e28efc75a98..d2ad9be555770 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
@@ -26,7 +26,7 @@ import org.apache.spark.SparkContext
*/
trait LocalSparkContext {
/** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */
- def withSpark[T](f: SparkContext => T) = {
+ def withSpark[T](f: SparkContext => T): T = {
val conf = new SparkConf()
GraphXUtils.registerKryoClasses(conf)
val sc = new SparkContext("local", "test", conf)
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
index c9443d11c76cf..d0a7198d691d7 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.storage.StorageLevel
class VertexRDDSuite extends FunSuite with LocalSparkContext {
- def vertices(sc: SparkContext, n: Int) = {
+ private def vertices(sc: SparkContext, n: Int) = {
VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5))
}
@@ -52,7 +52,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache()
val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache()
val vertexC = vertexA.minus(vertexB)
- assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet)
+ assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet)
}
}
@@ -62,7 +62,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val vertexB: RDD[(VertexId, Int)] =
sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache()
val vertexC = vertexA.minus(vertexB)
- assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet)
+ assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet)
}
}
@@ -72,7 +72,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1)))
assert(vertexA.partitions.size != vertexB.partitions.size)
val vertexC = vertexA.minus(vertexB)
- assert(vertexC.map(_._1).collect.toSet === (0 until 50).toSet)
+ assert(vertexC.map(_._1).collect().toSet === (0 until 50).toSet)
}
}
@@ -106,7 +106,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1)))
assert(vertexA.partitions.size != vertexB.partitions.size)
val vertexC = vertexA.diff(vertexB)
- assert(vertexC.map(_._1).collect.toSet === (8 until 16).toSet)
+ assert(vertexC.map(_._1).collect().toSet === (8 until 16).toSet)
}
}
@@ -116,11 +116,11 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val verts = vertices(sc, n).cache()
val evens = verts.filter(q => ((q._2 % 2) == 0)).cache()
// leftJoin with another VertexRDD
- assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet ===
+ assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet ===
(0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet)
// leftJoin with an RDD
val evensRDD = evens.map(identity)
- assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet ===
+ assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet ===
(0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet)
}
}
@@ -134,7 +134,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val vertexC = vertexA.leftJoin(vertexB) { (vid, old, newOpt) =>
old - newOpt.getOrElse(0)
}
- assert(vertexC.filter(v => v._2 != 0).map(_._1).collect.toSet == (1 to 99 by 2).toSet)
+ assert(vertexC.filter(v => v._2 != 0).map(_._1).collect().toSet == (1 to 99 by 2).toSet)
}
}
@@ -144,11 +144,11 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val verts = vertices(sc, n).cache()
val evens = verts.filter(q => ((q._2 % 2) == 0)).cache()
// innerJoin with another VertexRDD
- assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet ===
+ assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect().toSet ===
(0 to n by 2).map(x => (x.toLong, 0)).toSet)
// innerJoin with an RDD
val evensRDD = evens.map(identity)
- assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet ===
+ assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect().toSet ===
(0 to n by 2).map(x => (x.toLong, 0)).toSet) }
}
@@ -161,7 +161,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val vertexC = vertexA.innerJoin(vertexB) { (vid, old, newVal) =>
old - newVal
}
- assert(vertexC.filter(v => v._2 == 0).map(_._1).collect.toSet == (0 to 98 by 2).toSet)
+ assert(vertexC.filter(v => v._2 == 0).map(_._1).collect().toSet == (0 to 98 by 2).toSet)
}
}
@@ -171,7 +171,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val verts = vertices(sc, n)
val messageTargets = (0 to n) ++ (0 to n by 2)
val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1)))
- assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet ===
+ assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect().toSet ===
(0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet)
}
}
@@ -183,7 +183,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]]))
val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b)
// test merge function
- assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9)))
+ assert(rdd.collect().toSet == Set((0L, 0), (1L, 3), (2L, 9)))
}
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
index 3915be15b3434..4cc30a96408f8 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
@@ -32,7 +32,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
val ccGraph = gridGraph.connectedComponents()
- val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+ val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum()
assert(maxCCid === 0)
}
} // end of Grid connected components
@@ -42,7 +42,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse
val ccGraph = gridGraph.connectedComponents()
- val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+ val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum()
assert(maxCCid === 0)
}
} // end of Grid connected components
@@ -50,8 +50,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
test("Chain Connected Components") {
withSpark { sc =>
- val chain1 = (0 until 9).map(x => (x, x+1) )
- val chain2 = (10 until 20).map(x => (x, x+1) )
+ val chain1 = (0 until 9).map(x => (x, x + 1))
+ val chain2 = (10 until 20).map(x => (x, x + 1))
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0)
val ccGraph = twoChains.connectedComponents()
@@ -73,12 +73,12 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
test("Reverse Chain Connected Components") {
withSpark { sc =>
- val chain1 = (0 until 9).map(x => (x, x+1) )
- val chain2 = (10 until 20).map(x => (x, x+1) )
+ val chain1 = (0 until 9).map(x => (x, x + 1))
+ val chain2 = (10 until 20).map(x => (x, x + 1))
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse
val ccGraph = twoChains.connectedComponents()
- val vertices = ccGraph.vertices.collect
+ val vertices = ccGraph.vertices.collect()
for ( (id, cc) <- vertices ) {
if (id < 10) {
assert(cc === 0)
@@ -120,9 +120,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
// Build the initial Graph
val graph = Graph(users, relationships, defaultUser)
val ccGraph = graph.connectedComponents()
- val vertices = ccGraph.vertices.collect
+ val vertices = ccGraph.vertices.collect()
for ( (id, cc) <- vertices ) {
- assert(cc == 0)
+ assert(cc === 0)
}
}
} // end of toy connected components
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
index fc491ae327c2a..95804b07b1db0 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
@@ -19,15 +19,12 @@ package org.apache.spark.graphx.lib
import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
-import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
-import org.apache.spark.graphx.lib._
import org.apache.spark.graphx.util.GraphGenerators
-import org.apache.spark.rdd._
+
object GridPageRank {
- def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = {
+ def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double): Seq[(VertexId, Double)] = {
val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int])
val outDegree = Array.fill(nRows * nCols)(0)
// Convert row column address into vertex ids (row major order)
@@ -35,13 +32,13 @@ object GridPageRank {
// Make the grid graph
for (r <- 0 until nRows; c <- 0 until nCols) {
val ind = sub2ind(r,c)
- if (r+1 < nRows) {
+ if (r + 1 < nRows) {
outDegree(ind) += 1
- inNbrs(sub2ind(r+1,c)) += ind
+ inNbrs(sub2ind(r + 1,c)) += ind
}
- if (c+1 < nCols) {
+ if (c + 1 < nCols) {
outDegree(ind) += 1
- inNbrs(sub2ind(r,c+1)) += ind
+ inNbrs(sub2ind(r,c + 1)) += ind
}
}
// compute the pagerank
@@ -64,7 +61,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = {
a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) }
- .map { case (id, error) => error }.sum
+ .map { case (id, error) => error }.sum()
}
test("Star PageRank") {
@@ -80,12 +77,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
// Static PageRank should only take 2 iterations to converge
val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) =>
if (pr1 != pr2) 1 else 0
- }.map { case (vid, test) => test }.sum
+ }.map { case (vid, test) => test }.sum()
assert(notMatching === 0)
val staticErrors = staticRanks2.map { case (vid, pr) =>
- val correct = (vid > 0 && pr == resetProb) ||
- (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5)
+ val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) ))
+ val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5)
if (!correct) 1 else 0
}
assert(staticErrors.sum === 0)
@@ -95,8 +92,6 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
}
} // end of test Star PageRank
-
-
test("Grid PageRank") {
withSpark { sc =>
val rows = 10
@@ -109,18 +104,18 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache()
val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache()
- val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache()
+ val referenceRanks = VertexRDD(
+ sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache()
assert(compareRanks(staticRanks, referenceRanks) < errorTol)
assert(compareRanks(dynamicRanks, referenceRanks) < errorTol)
}
} // end of Grid PageRank
-
test("Chain PageRank") {
withSpark { sc =>
- val chain1 = (0 until 9).map(x => (x, x+1) )
- val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) }
+ val chain1 = (0 until 9).map(x => (x, x + 1))
+ val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) }
val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
val resetProb = 0.15
val tol = 0.0001
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
index df54aa37cad68..1f658c371ffcf 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
@@ -34,8 +34,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
val edges = sc.parallelize(Seq.empty[Edge[Int]])
val graph = Graph(vertices, edges)
val sccGraph = graph.stronglyConnectedComponents(5)
- for ((id, scc) <- sccGraph.vertices.collect) {
- assert(id == scc)
+ for ((id, scc) <- sccGraph.vertices.collect()) {
+ assert(id === scc)
}
}
}
@@ -45,8 +45,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7)))
val graph = Graph.fromEdgeTuples(rawEdges, -1)
val sccGraph = graph.stronglyConnectedComponents(20)
- for ((id, scc) <- sccGraph.vertices.collect) {
- assert(0L == scc)
+ for ((id, scc) <- sccGraph.vertices.collect()) {
+ assert(0L === scc)
}
}
}
@@ -60,13 +60,14 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
val rawEdges = sc.parallelize(edges)
val graph = Graph.fromEdgeTuples(rawEdges, -1)
val sccGraph = graph.stronglyConnectedComponents(20)
- for ((id, scc) <- sccGraph.vertices.collect) {
- if (id < 3)
- assert(0L == scc)
- else if (id < 6)
- assert(3L == scc)
- else
- assert(id == scc)
+ for ((id, scc) <- sccGraph.vertices.collect()) {
+ if (id < 3) {
+ assert(0L === scc)
+ } else if (id < 6) {
+ assert(3L === scc)
+ } else {
+ assert(id === scc)
+ }
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8299155117efc..cc8b0721cf2b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -31,13 +31,11 @@ import org.apache.spark.storage.StorageLevel
* Params for logistic regression.
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
- with HasRegParam with HasMaxIter with HasThreshold {
+ with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold {
setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5)
}
-
-
/**
* :: AlphaComponent ::
*
@@ -55,6 +53,9 @@ class LogisticRegression
/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
+ /** @group setParam */
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
@@ -67,7 +68,8 @@ class LogisticRegression
}
// Train model
- val lr = new LogisticRegressionWithLBFGS
+ val lr = new LogisticRegressionWithLBFGS()
+ .setIntercept(paramMap(fitIntercept))
lr.optimizer
.setRegParam(paramMap(regParam))
.setNumIterations(paramMap(maxIter))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala
index fb881160bf180..3f3519c35ceb9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala
@@ -43,7 +43,8 @@ private[shared] object SharedParamCodeGen {
ParamDesc[Double]("threshold", "threshold in prediction"),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[String]("outputCol", "output column name"),
- ParamDesc[Int]("checkpointInterval", "checkpoint interval"))
+ ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
+ ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 33a499103f1e8..42af80058f643 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -220,4 +220,23 @@ trait HasCheckpointInterval extends Params {
/** @group getParam */
final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param fitIntercept (default: true).
+ */
+@DeveloperApi
+trait HasFitIntercept extends Params {
+
+ /**
+ * Param for whether to fit an intercept term.
+ * @group param
+ */
+ final val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "whether to fit an intercept term")
+
+ setDefault(fitIntercept, true)
+
+ /** @group getParam */
+ final def getFitIntercept: Boolean = getOrDefault(fitIntercept)
+}
// scalastyle:on
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index b3d1bfcfbee0f..35d8c2e16c6cd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(lr.getPredictionCol == "prediction")
assert(lr.getRawPredictionCol == "rawPrediction")
assert(lr.getProbabilityCol == "probability")
+ assert(lr.getFitIntercept == true)
val model = lr.fit(dataset)
model.transform(dataset)
.select("label", "probability", "prediction", "rawPrediction")
@@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model.getPredictionCol == "prediction")
assert(model.getRawPredictionCol == "rawPrediction")
assert(model.getProbabilityCol == "probability")
+ assert(model.intercept !== 0.0)
+ }
+
+ test("logistic regression doesn't fit intercept when fitIntercept is off") {
+ val lr = new LogisticRegression
+ lr.setFitIntercept(false)
+ val model = lr.fit(dataset)
+ assert(model.intercept === 0.0)
}
test("logistic regression with setters") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 84ed5d23322e2..6109ed98323e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -30,7 +30,7 @@ class TestParams extends Params {
setDefault(maxIter -> 10)
- override def validate(paramMap: ParamMap) = {
+ override def validate(paramMap: ParamMap): Unit = {
val m = extractParamMap(paramMap)
require(m(maxIter) >= 0)
require(m.contains(inputCol))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index f9fe3e006ccb8..ea89b17b7c08f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -102,7 +102,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
def validateModelFit(
piData: Array[Double],
thetaData: Array[Array[Double]],
- model: NaiveBayesModel) = {
+ model: NaiveBayesModel): Unit = {
def closeFit(d1: Double, d2: Double, precision: Double): Boolean = {
(d1 - d2).abs <= precision
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index d50c43d439187..5683b55e8500a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.streaming.TestSuiteBase
class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
- override def maxWaitTimeMillis = 30000
+ override def maxWaitTimeMillis: Int = 30000
// Test if we can accurately learn B for Y = logistic(BX) on streaming data
test("parameter accuracy") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 7bf250eb5a383..0f2b26d462ad2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -199,9 +199,13 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
test("k-means|| initialization") {
case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] {
- @Override def compare(that: VectorWithCompare): Int = {
- if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) >
- that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1
+ override def compare(that: VectorWithCompare): Int = {
+ if (this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) >
+ that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) {
+ -1
+ } else {
+ 1
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 302d751eb8a94..15de10fd13a19 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering
import org.scalatest.FunSuite
-import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors}
+import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -141,7 +141,7 @@ private[clustering] object LDASuite {
(terms.toArray, termWeights.toArray)
}
- def tinyCorpus = Array(
+ def tinyCorpus: Array[(Long, Vector)] = Array(
Vectors.dense(1, 3, 0, 2, 8),
Vectors.dense(0, 2, 1, 0, 4),
Vectors.dense(2, 3, 12, 3, 1),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
index 850c9fce507cd..f90025d535e45 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom
class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
- override def maxWaitTimeMillis = 30000
+ override def maxWaitTimeMillis: Int = 30000
test("accuracy for single center and equivalence to grand average") {
// set parameters
@@ -59,7 +59,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
// estimated center from streaming should exactly match the arithmetic mean of all data points
// because the decay factor is set to 1.0
val grandMean =
- input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
+ input.flatten.map(x => x.toBreeze).reduce(_ + _) / (numBatches * numPoints).toDouble
assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
index 6395188a0842a..63f2ea916d457 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
@@ -181,7 +181,8 @@ class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializa
val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed)
testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1)
- val exponential = RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed)
+ val exponential =
+ RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed)
testGeneratedVectorRDD(exponential, rows, cols, parts, exponentialMean, exponentialMean, 0.1)
val gamma = RandomRDDs.gammaVectorRDD(sc, gammaShape, gammaScale, rows, cols, parts, seed)
@@ -197,7 +198,7 @@ private[random] class MockDistro extends RandomDataGenerator[Double] {
// This allows us to check that each partition has a different seed
override def nextValue(): Double = seed.toDouble
- override def setSeed(seed: Long) = this.seed = seed
+ override def setSeed(seed: Long): Unit = this.seed = seed
override def copy(): MockDistro = new MockDistro
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 8775c0ca9df84..b3798940ddc38 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -203,6 +203,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
* @param numProductBlocks number of product blocks to partition products into
* @param negativeFactors whether the generated user/product factors can have negative entries
*/
+ // scalastyle:off
def testALS(
users: Int,
products: Int,
@@ -216,6 +217,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
numUserBlocks: Int = -1,
numProductBlocks: Int = -1,
negativeFactors: Boolean = true) {
+ // scalastyle:on
+
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
features, samplingRate, implicitPrefs, negativeWeights, negativeFactors)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 43d61151e2471..d6c93cc0e49cd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -35,7 +35,7 @@ private object RidgeRegressionSuite {
class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
- def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
+ def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = {
predictions.zip(input).map { case (prediction, expected) =>
(prediction - expected.label) * (prediction - expected.label)
}.reduceLeft(_ + _) / predictions.size
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 24fd8df691817..26604dbe6c1ef 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.streaming.TestSuiteBase
class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
- override def maxWaitTimeMillis = 20000
+ override def maxWaitTimeMillis: Int = 20000
// Assert that two values are equal within tolerance epsilon
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
index e957fa5d25f4c..352193a67860c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
@@ -95,16 +95,16 @@ object TestingUtils {
/**
* Comparison using absolute tolerance.
*/
- def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison,
- x, eps, ABS_TOL_MSG)
+ def absTol(eps: Double): CompareDoubleRightSide =
+ CompareDoubleRightSide(AbsoluteErrorComparison, x, eps, ABS_TOL_MSG)
/**
* Comparison using relative tolerance.
*/
- def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison,
- x, eps, REL_TOL_MSG)
+ def relTol(eps: Double): CompareDoubleRightSide =
+ CompareDoubleRightSide(RelativeErrorComparison, x, eps, REL_TOL_MSG)
- override def toString = x.toString
+ override def toString: String = x.toString
}
case class CompareVectorRightSide(
@@ -166,7 +166,7 @@ object TestingUtils {
x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps)
}, x, eps, REL_TOL_MSG)
- override def toString = x.toString
+ override def toString: String = x.toString
}
case class CompareMatrixRightSide(
@@ -229,7 +229,7 @@ object TestingUtils {
x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps)
}, x, eps, REL_TOL_MSG)
- override def toString = x.toString
+ override def toString: String = x.toString
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
index b0ecb33c28483..59e6c778806f4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
@@ -88,16 +88,20 @@ class TestingUtilsSuite extends FunSuite {
assert(!(17.8 ~= 17.59 absTol 0.2))
// Comparisons of numbers very close to zero, and both side of zeros
- assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
- assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
-
- assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
- assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
+ assert(
+ Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
+ assert(
+ Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
+
+ assert(
+ -Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
+ assert(
+ Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue)
}
test("Comparing vectors using relative error.") {
- //Comparisons of two dense vectors
+ // Comparisons of two dense vectors
assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01)
assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01)
assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)
@@ -130,7 +134,7 @@ class TestingUtilsSuite extends FunSuite {
test("Comparing vectors using absolute error.") {
- //Comparisons of two dense vectors
+ // Comparisons of two dense vectors
assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~==
Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 0a16cbd8bff62..2a5e84a7dfdb4 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -29,11 +29,10 @@
def launch_gateway():
- SPARK_HOME = os.environ["SPARK_HOME"]
-
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
else:
+ SPARK_HOME = os.environ["SPARK_HOME"]
# Launch the Py4j gateway using Spark's run command so that we pick up the
# proper classpath and settings from spark-env.sh
on_windows = platform.system() == "Windows"
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 4ff7463498cce..7f42de531f3b4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -91,9 +91,9 @@ class LogisticRegressionModel(JavaModel):
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.feature tests")
- sqlCtx = SQLContext(sc)
+ sqlContext = SQLContext(sc)
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx
+ globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 433b4fb5d22bf..1cfcd019dfb18 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -117,9 +117,9 @@ def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.feature tests")
- sqlCtx = SQLContext(sc)
+ sqlContext = SQLContext(sc)
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx
+ globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 1a02fece9c5a5..81aa970a32f76 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -53,9 +53,9 @@
try:
# Try to access HiveConf, it will raise exception if Hive is not added
sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
- sqlCtx = HiveContext(sc)
+ sqlCtx = sqlContext = HiveContext(sc)
except py4j.protocol.Py4JError:
- sqlCtx = SQLContext(sc)
+ sqlCtx = sqlContext = SQLContext(sc)
print("""Welcome to
____ __
@@ -68,7 +68,7 @@
platform.python_version(),
platform.python_build()[0],
platform.python_build()[1]))
-print("SparkContext available as sc, %s available as sqlCtx." % sqlCtx.__class__.__name__)
+print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__)
if add_files is not None:
print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead")
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index c2d81ba804110..93e2d176a5b6f 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -37,12 +37,12 @@
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
-def _monkey_patch_RDD(sqlCtx):
+def _monkey_patch_RDD(sqlContext):
def toDF(self, schema=None, sampleRatio=None):
"""
Converts current :class:`RDD` into a :class:`DataFrame`
- This is a shorthand for ``sqlCtx.createDataFrame(rdd, schema, sampleRatio)``
+ This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)``
:param schema: a StructType or list of names of columns
:param samplingRatio: the sample ratio of rows used for inferring
@@ -51,7 +51,7 @@ def toDF(self, schema=None, sampleRatio=None):
>>> rdd.toDF().collect()
[Row(name=u'Alice', age=1)]
"""
- return sqlCtx.createDataFrame(self, schema, sampleRatio)
+ return sqlContext.createDataFrame(self, schema, sampleRatio)
RDD.toDF = toDF
@@ -75,13 +75,13 @@ def __init__(self, sparkContext, sqlContext=None):
"""Creates a new SQLContext.
>>> from datetime import datetime
- >>> sqlCtx = SQLContext(sc)
+ >>> sqlContext = SQLContext(sc)
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
>>> df = allTypes.toDF()
>>> df.registerTempTable("allTypes")
- >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
+ >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
@@ -133,18 +133,18 @@ def registerFunction(self, name, f, returnType=StringType()):
:param samplingRatio: lambda function
:param returnType: a :class:`DataType` object
- >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
- >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
+ >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
+ >>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
>>> from pyspark.sql.types import IntegerType
- >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
- >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+ >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
>>> from pyspark.sql.types import IntegerType
- >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
- >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+ >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
@@ -229,26 +229,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
:param samplingRatio: the sample ratio of rows used for inferring
>>> l = [('Alice', 1)]
- >>> sqlCtx.createDataFrame(l).collect()
+ >>> sqlContext.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
- >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
+ >>> sqlContext.createDataFrame(l, ['name', 'age']).collect()
[Row(name=u'Alice', age=1)]
>>> d = [{'name': 'Alice', 'age': 1}]
- >>> sqlCtx.createDataFrame(d).collect()
+ >>> sqlContext.createDataFrame(d).collect()
[Row(age=1, name=u'Alice')]
>>> rdd = sc.parallelize(l)
- >>> sqlCtx.createDataFrame(rdd).collect()
+ >>> sqlContext.createDataFrame(rdd).collect()
[Row(_1=u'Alice', _2=1)]
- >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
+ >>> df = sqlContext.createDataFrame(rdd, ['name', 'age'])
>>> df.collect()
[Row(name=u'Alice', age=1)]
>>> from pyspark.sql import Row
>>> Person = Row('name', 'age')
>>> person = rdd.map(lambda r: Person(*r))
- >>> df2 = sqlCtx.createDataFrame(person)
+ >>> df2 = sqlContext.createDataFrame(person)
>>> df2.collect()
[Row(name=u'Alice', age=1)]
@@ -256,11 +256,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
>>> schema = StructType([
... StructField("name", StringType(), True),
... StructField("age", IntegerType(), True)])
- >>> df3 = sqlCtx.createDataFrame(rdd, schema)
+ >>> df3 = sqlContext.createDataFrame(rdd, schema)
>>> df3.collect()
[Row(name=u'Alice', age=1)]
- >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
+ >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
[Row(name=u'Alice', age=1)]
"""
if isinstance(data, DataFrame):
@@ -316,7 +316,7 @@ def registerDataFrameAsTable(self, df, tableName):
Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`.
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
"""
if (df.__class__ is DataFrame):
self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName)
@@ -330,7 +330,7 @@ def parquetFile(self, *paths):
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
>>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> df2 = sqlContext.parquetFile(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
@@ -352,7 +352,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0):
>>> shutil.rmtree(jsonFile)
>>> with open(jsonFile, 'w') as f:
... f.writelines(jsonStrings)
- >>> df1 = sqlCtx.jsonFile(jsonFile)
+ >>> df1 = sqlContext.jsonFile(jsonFile)
>>> df1.printSchema()
root
|-- field1: long (nullable = true)
@@ -365,7 +365,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0):
... StructField("field2", StringType()),
... StructField("field3",
... StructType([StructField("field5", ArrayType(IntegerType()))]))])
- >>> df2 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> df2 = sqlContext.jsonFile(jsonFile, schema)
>>> df2.printSchema()
root
|-- field2: string (nullable = true)
@@ -386,11 +386,11 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
If the schema is provided, applies the given schema to this JSON dataset.
Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema.
- >>> df1 = sqlCtx.jsonRDD(json)
+ >>> df1 = sqlContext.jsonRDD(json)
>>> df1.first()
Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
- >>> df2 = sqlCtx.jsonRDD(json, df1.schema)
+ >>> df2 = sqlContext.jsonRDD(json, df1.schema)
>>> df2.first()
Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
@@ -400,7 +400,7 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
... StructField("field3",
... StructType([StructField("field5", ArrayType(IntegerType()))]))
... ])
- >>> df3 = sqlCtx.jsonRDD(json, schema)
+ >>> df3 = sqlContext.jsonRDD(json, schema)
>>> df3.first()
Row(field2=u'row1', field3=Row(field5=None))
"""
@@ -480,8 +480,8 @@ def createExternalTable(self, tableName, path=None, source=None,
def sql(self, sqlQuery):
"""Returns a :class:`DataFrame` representing the result of the given query.
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
+ >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
@@ -490,8 +490,8 @@ def sql(self, sqlQuery):
def table(self, tableName):
"""Returns the specified table as a :class:`DataFrame`.
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlCtx.table("table1")
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
+ >>> df2 = sqlContext.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
@@ -505,8 +505,8 @@ def tables(self, dbName=None):
The returned DataFrame has two columns: ``tableName`` and ``isTemporary``
(a column with :class:`BooleanType` indicating if a table is a temporary one or not).
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlCtx.tables()
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
+ >>> df2 = sqlContext.tables()
>>> df2.filter("tableName = 'table1'").first()
Row(tableName=u'table1', isTemporary=True)
"""
@@ -520,10 +520,10 @@ def tableNames(self, dbName=None):
If ``dbName`` is not specified, the current database will be used.
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> "table1" in sqlCtx.tableNames()
+ >>> sqlContext.registerDataFrameAsTable(df, "table1")
+ >>> "table1" in sqlContext.tableNames()
True
- >>> "table1" in sqlCtx.tableNames("db")
+ >>> "table1" in sqlContext.tableNames("db")
True
"""
if dbName is None:
@@ -578,11 +578,11 @@ def _get_hive_ctx(self):
class UDFRegistration(object):
"""Wrapper for user-defined function registration."""
- def __init__(self, sqlCtx):
- self.sqlCtx = sqlCtx
+ def __init__(self, sqlContext):
+ self.sqlContext = sqlContext
def register(self, name, f, returnType=StringType()):
- return self.sqlCtx.registerFunction(name, f, returnType)
+ return self.sqlContext.registerFunction(name, f, returnType)
register.__doc__ = SQLContext.registerFunction.__doc__
@@ -595,13 +595,12 @@ def _test():
globs = pyspark.sql.context.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+ globs['sqlContext'] = SQLContext(sc)
globs['rdd'] = rdd = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
- _monkey_patch_RDD(sqlCtx)
globs['df'] = rdd.toDF()
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index c30326ebd133e..ef91a9c4f522d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -110,7 +110,7 @@ def saveAsParquetFile(self, path):
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
>>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> df2 = sqlContext.parquetFile(parquetFile)
>>> sorted(df2.collect()) == sorted(df.collect())
True
"""
@@ -123,7 +123,7 @@ def registerTempTable(self, name):
that was used to create this :class:`DataFrame`.
>>> df.registerTempTable("people")
- >>> df2 = sqlCtx.sql("select * from people")
+ >>> df2 = sqlContext.sql("select * from people")
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
@@ -1180,7 +1180,7 @@ def _test():
globs = pyspark.sql.dataframe.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = SQLContext(sc)
+ globs['sqlContext'] = SQLContext(sc)
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 146ba6f3e0d98..daeb6916b58bc 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -161,7 +161,7 @@ def _test():
globs = pyspark.sql.functions.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = SQLContext(sc)
+ globs['sqlContext'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.functions, globs=globs,
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 45eb8b945dcb0..7e0124b13671b 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -434,7 +434,7 @@ def _parse_datatype_json_string(json_string):
>>> def check_datatype(datatype):
... pickled = pickle.loads(pickle.dumps(datatype))
... assert datatype == pickled
- ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... assert datatype == python_datatype
>>> for cls in _all_primitive_types.values():
@@ -1237,7 +1237,7 @@ def _test():
globs = pyspark.sql.types.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+ globs['sqlContext'] = SQLContext(sc)
globs['ExamplePoint'] = ExamplePoint
globs['ExamplePointUDT'] = ExamplePointUDT
(failure_count, test_count) = doctest.testmod(
diff --git a/sql/README.md b/sql/README.md
index fbb3200a3a4b4..237620e3fa808 100644
--- a/sql/README.md
+++ b/sql/README.md
@@ -56,6 +56,6 @@ res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,v
You can also build further queries on top of these `DataFrames` using the query DSL.
```
-scala> query.where('key > 30).select(avg('key)).collect()
+scala> query.where(query("key") > 30).select(avg(query("key"))).collect()
res3: Array[org.apache.spark.sql.Row] = Array([274.79025423728814])
```
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 119cb9c3a4400..b3aba4f68ddf9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -293,7 +293,7 @@ class Analyzer(
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldName) if child.resolved =>
- resolveGetField(child, fieldName)
+ q.resolveGetField(child, fieldName, resolver)
}
}
@@ -313,36 +313,6 @@ class Analyzer(
*/
protected def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
-
- /**
- * Returns the resolved `GetField`, and report error if no desired field or over one
- * desired fields are found.
- */
- protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
- def findField(fields: Array[StructField]): Int = {
- val checkField = (f: StructField) => resolver(f.name, fieldName)
- val ordinal = fields.indexWhere(checkField)
- if (ordinal == -1) {
- throw new AnalysisException(
- s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
- } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
- throw new AnalysisException(
- s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
- } else {
- ordinal
- }
- }
- expr.dataType match {
- case StructType(fields) =>
- val ordinal = findField(fields)
- StructGetField(expr, fields(ordinal), ordinal)
- case ArrayType(StructType(fields), containsNull) =>
- val ordinal = findField(fields)
- ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
- case otherType =>
- throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
- }
- }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 02f7c26a8ab6e..7967189cacb24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -150,7 +150,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}.toSeq
}
- def schema: StructType = StructType.fromAttributes(output)
+ lazy val schema: StructType = StructType.fromAttributes(output)
/** Returns the output schema in the tree format. */
def schemaString: String = schema.treeString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 2e9f3aa4ec4ad..d8f5858f5033e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -205,11 +205,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
try {
-
- // The foldLeft adds UnresolvedGetField for every remaining parts of the name,
- // and aliased it with the last part of the name.
- // For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
+ // The foldLeft adds GetFields for every remaining parts of the identifier,
+ // and aliases it with the last part of the identifier.
+ // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
+ // Then this will add GetField("c", GetField("b", a)), and alias
// the final expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver))
val aliasName = nestedFields.last
@@ -234,10 +233,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
/**
* Returns the resolved `GetField`, and report error if no desired field or over one
* desired fields are found.
- *
- * TODO: this code is duplicated from Analyzer and should be refactored to avoid this.
*/
- protected def resolveGetField(
+ def resolveGetField(
expr: Expression,
fieldName: String,
resolver: Resolver): Expression = {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index cf191715d29d6..87bc20f79c3cd 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -171,7 +171,9 @@ class BasicOperationsSuite extends TestSuiteBase {
test("flatMapValues") {
testOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
- (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)),
+ (s: DStream[String]) => {
+ s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10))
+ },
Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ),
true
)
@@ -474,7 +476,7 @@ class BasicOperationsSuite extends TestSuiteBase {
stream.foreachRDD(_ => {}) // Dummy output stream
ssc.start()
Thread.sleep(2000)
- def getInputFromSlice(fromMillis: Long, toMillis: Long) = {
+ def getInputFromSlice(fromMillis: Long, toMillis: Long): Set[Int] = {
stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 91a2b2bba461d..54c30440a6e8d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -43,7 +43,7 @@ class CheckpointSuite extends TestSuiteBase {
var ssc: StreamingContext = null
- override def batchDuration = Milliseconds(500)
+ override def batchDuration: Duration = Milliseconds(500)
override def beforeFunction() {
super.beforeFunction()
@@ -72,7 +72,7 @@ class CheckpointSuite extends TestSuiteBase {
val input = (1 to 10).map(_ => Seq("a")).toSeq
val operation = (st: DStream[String]) => {
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
- Some((values.sum + state.getOrElse(0)))
+ Some(values.sum + state.getOrElse(0))
}
st.map(x => (x, 1))
.updateStateByKey(updateFunc)
@@ -199,7 +199,12 @@ class CheckpointSuite extends TestSuiteBase {
testCheckpointedOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
(s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
- Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
+ Seq(
+ Seq(("a", 2), ("b", 1)),
+ Seq(("", 2)),
+ Seq(),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("", 2)), Seq() ),
3
)
}
@@ -212,7 +217,8 @@ class CheckpointSuite extends TestSuiteBase {
val n = 10
val w = 4
val input = (1 to n).map(_ => Seq("a")).toSeq
- val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4)))
+ val output = Seq(
+ Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4)))
val operation = (st: DStream[String]) => {
st.map(x => (x, 1))
.reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration)
@@ -236,7 +242,13 @@ class CheckpointSuite extends TestSuiteBase {
classOf[TextOutputFormat[Text, IntWritable]])
output
},
- Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()),
+ Seq(
+ Seq(("a", 2), ("b", 1)),
+ Seq(("", 2)),
+ Seq(),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("", 2)),
+ Seq()),
3
)
} finally {
@@ -259,7 +271,13 @@ class CheckpointSuite extends TestSuiteBase {
classOf[NewTextOutputFormat[Text, IntWritable]])
output
},
- Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()),
+ Seq(
+ Seq(("a", 2), ("b", 1)),
+ Seq(("", 2)),
+ Seq(),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("", 2)),
+ Seq()),
3
)
} finally {
@@ -298,7 +316,13 @@ class CheckpointSuite extends TestSuiteBase {
output
}
},
- Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()),
+ Seq(
+ Seq(("a", 2), ("b", 1)),
+ Seq(("", 2)),
+ Seq(),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("", 2)),
+ Seq()),
3
)
} finally {
@@ -533,7 +557,8 @@ class CheckpointSuite extends TestSuiteBase {
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
*/
- def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = {
+ def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] =
+ {
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
logInfo("Manual clock before advancing = " + clock.getTimeMillis())
for (i <- 1 to numBatches.toInt) {
@@ -543,7 +568,7 @@ class CheckpointSuite extends TestSuiteBase {
logInfo("Manual clock after advancing = " + clock.getTimeMillis())
Thread.sleep(batchDuration.milliseconds)
- val outputStream = ssc.graph.getOutputStreams.filter { dstream =>
+ val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
}.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
outputStream.output.map(_.flatten)
@@ -552,4 +577,4 @@ class CheckpointSuite extends TestSuiteBase {
private object CheckpointSuite extends Serializable {
var batchThreeShouldBlockIndefinitely: Boolean = true
-}
\ No newline at end of file
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
index 26435d8515815..0c4c06534a693 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
@@ -29,9 +29,9 @@ class FailureSuite extends TestSuiteBase with Logging {
val directory = Utils.createTempDir()
val numBatches = 30
- override def batchDuration = Milliseconds(1000)
+ override def batchDuration: Duration = Milliseconds(1000)
- override def useManualClock = false
+ override def useManualClock: Boolean = false
override def afterFunction() {
Utils.deleteRecursively(directory)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 7ed6320a3d0bc..e6ac4975c5e68 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -52,7 +52,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
val outputStream = new TestOutputStream(networkStream, outputBuffer)
- def output = outputBuffer.flatMap(x => x)
+ def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
outputStream.register()
ssc.start()
@@ -164,7 +164,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val countStream = networkStream.count
val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]]
val outputStream = new TestOutputStream(countStream, outputBuffer)
- def output = outputBuffer.flatMap(x => x)
+ def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x)
outputStream.register()
ssc.start()
@@ -196,7 +196,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val queueStream = ssc.queueStream(queue, oneAtATime = true)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
val outputStream = new TestOutputStream(queueStream, outputBuffer)
- def output = outputBuffer.filter(_.size > 0)
+ def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0)
outputStream.register()
ssc.start()
@@ -204,7 +204,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val input = Seq("1", "2", "3", "4", "5")
val expectedOutput = input.map(Seq(_))
- //Thread.sleep(1000)
+
val inputIterator = input.toIterator
for (i <- 0 until input.size) {
// Enqueue more than 1 item per tick but they should dequeue one at a time
@@ -239,7 +239,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val queueStream = ssc.queueStream(queue, oneAtATime = false)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
val outputStream = new TestOutputStream(queueStream, outputBuffer)
- def output = outputBuffer.filter(_.size > 0)
+ def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0)
outputStream.register()
ssc.start()
@@ -352,7 +352,8 @@ class TestServer(portToBind: Int = 0) extends Logging {
logInfo("New connection")
try {
clientSocket.setTcpNoDelay(true)
- val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream))
+ val outputStream = new BufferedWriter(
+ new OutputStreamWriter(clientSocket.getOutputStream))
while(clientSocket.isConnected) {
val msg = queue.poll(100, TimeUnit.MILLISECONDS)
@@ -384,7 +385,7 @@ class TestServer(portToBind: Int = 0) extends Logging {
def stop() { servingThread.interrupt() }
- def port = serverSocket.getLocalPort
+ def port: Int = serverSocket.getLocalPort
}
/** This is a receiver to test multiple threads inserting data using block generator */
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index ef4873de2f5a9..c090eaec2928d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -96,7 +96,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
testBlockStoring(handler) { case (data, blockIds, storeResults) =>
// Verify the data in block manager is correct
val storedData = blockIds.flatMap { blockId =>
- blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty)
+ blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty)
}.toList
storedData shouldEqual data
@@ -120,7 +120,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
testBlockStoring(handler) { case (data, blockIds, storeResults) =>
// Verify the data in block manager is correct
val storedData = blockIds.flatMap { blockId =>
- blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty)
+ blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty)
}.toList
storedData shouldEqual data
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index 42fad769f0c1a..b63b37d9f9cef 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -228,7 +228,8 @@ class ReceivedBlockTrackerSuite
* Get all the data written in the given write ahead log files. By default, it will read all
* files in the test log directory.
*/
- def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = {
+ def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles)
+ : Seq[ReceivedBlockTrackerLogEvent] = {
logFiles.flatMap {
file => new WriteAheadLogReader(file, hadoopConf).toSeq
}.map { byteBuffer =>
@@ -244,7 +245,8 @@ class ReceivedBlockTrackerSuite
}
/** Create batch allocation object from the given info */
- def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = {
+ def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo])
+ : BatchAllocationEvent = {
BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos))))
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index aa20ad0b5374e..10c35cba8dc53 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -308,7 +308,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
val errors = new ArrayBuffer[Throwable]
/** Check if all data structures are clean */
- def isAllEmpty = {
+ def isAllEmpty: Boolean = {
singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty &&
arrayBuffers.isEmpty && errors.isEmpty
}
@@ -320,24 +320,21 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
def pushBytes(
bytes: ByteBuffer,
optionalMetadata: Option[Any],
- optionalBlockId: Option[StreamBlockId]
- ) {
+ optionalBlockId: Option[StreamBlockId]) {
byteBuffers += bytes
}
def pushIterator(
iterator: Iterator[_],
optionalMetadata: Option[Any],
- optionalBlockId: Option[StreamBlockId]
- ) {
+ optionalBlockId: Option[StreamBlockId]) {
iterators += iterator
}
def pushArrayBuffer(
arrayBuffer: ArrayBuffer[_],
optionalMetadata: Option[Any],
- optionalBlockId: Option[StreamBlockId]
- ) {
+ optionalBlockId: Option[StreamBlockId]) {
arrayBuffers += arrayBuffer
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 2e5005ef6ff14..d1bbf39dc7897 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -213,7 +213,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
ssc = new StreamingContext(sc, Milliseconds(100))
var runningCount = 0
SlowTestReceiver.receivedAllRecords = false
- //Create test receiver that sleeps in onStop()
+ // Create test receiver that sleeps in onStop()
val totalNumRecords = 15
val recordsPerSecond = 1
val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond))
@@ -370,7 +370,8 @@ object TestReceiver {
}
/** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */
-class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging {
+class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int)
+ extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging {
var receivingThreadOption: Option[Thread] = None
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index f52562b0a0f73..852e8bb71d4f6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -38,8 +38,8 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
// To make sure that the processing start and end times in collected
// information are different for successive batches
- override def batchDuration = Milliseconds(100)
- override def actuallyWait = true
+ override def batchDuration: Duration = Milliseconds(100)
+ override def actuallyWait: Boolean = true
test("batch info reporting") {
val ssc = setupStreams(input, operation)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 3565d621e8a6c..c3cae8aeb6d15 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -53,8 +53,9 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]],
val selectedInput = if (index < input.size) input(index) else Seq[T]()
// lets us test cases where RDDs are not created
- if (selectedInput == null)
+ if (selectedInput == null) {
return None
+ }
val rdd = ssc.sc.makeRDD(selectedInput, numPartitions)
logInfo("Created RDD " + rdd.id + " with " + selectedInput)
@@ -104,7 +105,9 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T],
output.clear()
}
- def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten))
+ def toTestOutputStream: TestOutputStream[T] = {
+ new TestOutputStream[T](this.parent, this.output.map(_.flatten))
+ }
}
/**
@@ -148,34 +151,34 @@ class BatchCounter(ssc: StreamingContext) {
trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Name of the framework for Spark context
- def framework = this.getClass.getSimpleName
+ def framework: String = this.getClass.getSimpleName
// Master for Spark context
- def master = "local[2]"
+ def master: String = "local[2]"
// Batch duration
- def batchDuration = Seconds(1)
+ def batchDuration: Duration = Seconds(1)
// Directory where the checkpoint data will be saved
- lazy val checkpointDir = {
+ lazy val checkpointDir: String = {
val dir = Utils.createTempDir()
logDebug(s"checkpointDir: $dir")
dir.toString
}
// Number of partitions of the input parallel collections created for testing
- def numInputPartitions = 2
+ def numInputPartitions: Int = 2
// Maximum time to wait before the test times out
- def maxWaitTimeMillis = 10000
+ def maxWaitTimeMillis: Int = 10000
// Whether to use manual clock or not
- def useManualClock = true
+ def useManualClock: Boolean = true
// Whether to actually wait in real time before changing manual clock
- def actuallyWait = false
+ def actuallyWait: Boolean = false
- //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things.
+ // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things.
val conf = new SparkConf()
.setMaster(master)
.setAppName(framework)
@@ -346,7 +349,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Wait until expected number of output items have been generated
val startTime = System.currentTimeMillis()
- while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ while (output.size < numExpectedOutput &&
+ System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput)
ssc.awaitTerminationOrTimeout(50)
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
index 87a0395efbf2a..998426ebb82e5 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
@@ -32,7 +32,8 @@ import org.apache.spark._
/**
* Selenium tests for the Spark Web UI.
*/
-class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase {
+class UISeleniumSuite
+ extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase {
implicit var webDriver: WebDriver = _
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
index a5d2bb2fde16c..c39ad05f41520 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
@@ -22,9 +22,9 @@ import org.apache.spark.storage.StorageLevel
class WindowOperationsSuite extends TestSuiteBase {
- override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer
+ override def maxWaitTimeMillis: Int = 20000 // large window tests can sometimes take longer
- override def batchDuration = Seconds(1) // making sure its visible in this class
+ override def batchDuration: Duration = Seconds(1) // making sure its visible in this class
val largerSlideInput = Seq(
Seq(("a", 1)),
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
index 7a6a2f3e577dd..c3602a5b73732 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
@@ -28,10 +28,13 @@ import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBloc
import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter}
import org.apache.spark.util.Utils
-class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
+class WriteAheadLogBackedBlockRDDSuite
+ extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
+
val conf = new SparkConf()
.setMaster("local[2]")
.setAppName(this.getClass.getSimpleName)
+
val hadoopConf = new Configuration()
var sparkContext: SparkContext = null
@@ -86,7 +89,8 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w
* @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log
* @param testStoreInBM Test whether blocks read from log are stored back into block manager
*/
- private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) {
+ private def testRDD(
+ numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) {
val numBlocks = numPartitionsInBM + numPartitionsInWAL
val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50))
@@ -110,7 +114,7 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w
"Unexpected blocks in BlockManager"
)
- // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not
+ // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not
require(
segments.takeRight(numPartitionsInWAL).forall(s =>
new File(s.path.stripPrefix("file://")).exists()),
@@ -152,6 +156,6 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w
}
private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = {
- Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0))
+ Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0))
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala
index 4150b60635ed6..7865b06c2e3c2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala
@@ -90,7 +90,7 @@ class JobGeneratorSuite extends TestSuiteBase {
val receiverTracker = ssc.scheduler.receiverTracker
// Get the blocks belonging to a batch
- def getBlocksOfBatch(batchTime: Long) = {
+ def getBlocksOfBatch(batchTime: Long): Seq[ReceivedBlockInfo] = {
receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id)
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 8335659667f22..a3919c43b95b4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -291,7 +291,7 @@ object WriteAheadLogSuite {
manager
}
- /** Read data from a segments of a log file directly and return the list of byte buffers.*/
+ /** Read data from a segments of a log file directly and return the list of byte buffers. */
def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = {
segments.map { segment =>
val reader = HdfsUtils.getInputStream(segment.path, hadoopConf)
diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala
index d0bf328f2b74d..d66750463033a 100644
--- a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala
@@ -25,7 +25,8 @@ package org.apache.spark.streamingtest
*/
class ImplicitSuite {
- // We only want to test if `implict` works well with the compiler, so we don't need a real DStream.
+ // We only want to test if `implicit` works well with the compiler,
+ // so we don't need a real DStream.
def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null
def testToPairDStreamFunctions(): Unit = {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 79d55a09eb671..7219852c0a752 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -40,6 +40,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication}
import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException
import org.apache.hadoop.yarn.util.Records
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException}
@@ -561,7 +562,14 @@ private[spark] class Client(
var lastState: YarnApplicationState = null
while (true) {
Thread.sleep(interval)
- val report = getApplicationReport(appId)
+ val report: ApplicationReport =
+ try {
+ getApplicationReport(appId)
+ } catch {
+ case e: ApplicationNotFoundException =>
+ logError(s"Application $appId not found.")
+ return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED)
+ }
val state = report.getYarnApplicationState
if (logApplicationReport) {
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 8abdc26b43806..407dc1ac4d37d 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -34,7 +34,7 @@ private[spark] class YarnClientSchedulerBackend(
private var client: Client = null
private var appId: ApplicationId = null
- @volatile private var stopping: Boolean = false
+ private var monitorThread: Thread = null
/**
* Create a Yarn client to submit an application to the ResourceManager.
@@ -57,7 +57,8 @@ private[spark] class YarnClientSchedulerBackend(
client = new Client(args, conf)
appId = client.submitApplication()
waitForApplication()
- asyncMonitorApplication()
+ monitorThread = asyncMonitorApplication()
+ monitorThread.start()
}
/**
@@ -123,34 +124,19 @@ private[spark] class YarnClientSchedulerBackend(
* If the application has exited for any reason, stop the SparkContext.
* This assumes both `client` and `appId` have already been set.
*/
- private def asyncMonitorApplication(): Unit = {
+ private def asyncMonitorApplication(): Thread = {
assert(client != null && appId != null, "Application has not been submitted yet!")
val t = new Thread {
override def run() {
- while (!stopping) {
- var state: YarnApplicationState = null
- try {
- val report = client.getApplicationReport(appId)
- state = report.getYarnApplicationState()
- } catch {
- case e: ApplicationNotFoundException =>
- state = YarnApplicationState.KILLED
- }
- if (state == YarnApplicationState.FINISHED ||
- state == YarnApplicationState.KILLED ||
- state == YarnApplicationState.FAILED) {
- logError(s"Yarn application has already exited with state $state!")
- sc.stop()
- stopping = true
- }
- Thread.sleep(1000L)
- }
+ val (state, _) = client.monitorApplication(appId, logApplicationReport = false)
+ logError(s"Yarn application has already exited with state $state!")
+ sc.stop()
Thread.currentThread().interrupt()
}
}
t.setName("Yarn application state monitor")
t.setDaemon(true)
- t.start()
+ t
}
/**
@@ -158,7 +144,7 @@ private[spark] class YarnClientSchedulerBackend(
*/
override def stop() {
assert(client != null, "Attempted to stop this scheduler before starting it!")
- stopping = true
+ monitorThread.interrupt()
super.stop()
client.stop()
logInfo("Stopped")