Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into run-tests-python-mo…
Browse files Browse the repository at this point in the history
…dules
  • Loading branch information
JoshRosen committed Jun 25, 2015
2 parents 4f8902c + 7bac2fe commit d33e525
Show file tree
Hide file tree
Showing 101 changed files with 3,544 additions and 1,252 deletions.
26 changes: 19 additions & 7 deletions R/pkg/R/client.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,36 @@ connectBackend <- function(hostname, port, timeout = 6000) {
con
}

launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) {
determineSparkSubmitBin <- function() {
if (.Platform$OS.type == "unix") {
sparkSubmitBinName = "spark-submit"
} else {
sparkSubmitBinName = "spark-submit.cmd"
}
sparkSubmitBinName
}

generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) {
if (jars != "") {
jars <- paste("--jars", jars)
}

if (packages != "") {
packages <- paste("--packages", packages)
}

combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ")
combinedArgs
}

launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) {
sparkSubmitBin <- determineSparkSubmitBin()
if (sparkHome != "") {
sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName)
} else {
sparkSubmitBin <- sparkSubmitBinName
}

if (jars != "") {
jars <- paste("--jars", jars)
}

combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ")
combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages)
cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n")
invisible(system2(sparkSubmitBin, combinedArgs, wait = F))
}
16 changes: 16 additions & 0 deletions R/pkg/R/column.R
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,22 @@ setMethod("cast",
}
})

#' Match a column with given values.
#'
#' @rdname column
#' @return a matched values as a result of comparing with given values.
#' \dontrun{
#' filter(df, "age in (10, 30)")
#' where(df, df$age %in% c(10, 30))
#' }
setMethod("%in%",
signature(x = "Column"),
function(x, table) {
table <- listToSeq(as.list(table))
jc <- callJMethod(x@jc, "in", table)
return(column(jc))
})

#' Approx Count Distinct
#'
#' @rdname column
Expand Down
7 changes: 5 additions & 2 deletions R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ sparkR.stop <- function() {
#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors.
#' @param sparkJars Character string vector of jar files to pass to the worker nodes.
#' @param sparkRLibDir The path where R is installed on the worker nodes.
#' @param sparkPackages Character string vector of packages from spark-packages.org
#' @export
#' @examples
#'\dontrun{
Expand All @@ -100,7 +101,8 @@ sparkR.init <- function(
sparkEnvir = list(),
sparkExecutorEnv = list(),
sparkJars = "",
sparkRLibDir = "") {
sparkRLibDir = "",
sparkPackages = "") {

if (exists(".sparkRjsc", envir = .sparkREnv)) {
cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")
Expand Down Expand Up @@ -129,7 +131,8 @@ sparkR.init <- function(
args = path,
sparkHome = sparkHome,
jars = jars,
sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"))
sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"),
sparkPackages = sparkPackages)
# wait atmost 100 seconds for JVM to launch
wait <- 0.1
for (i in 1:25) {
Expand Down
16 changes: 15 additions & 1 deletion R/pkg/inst/profile/shell.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,21 @@
sc <- SparkR::sparkR.init()
assign("sc", sc, envir=.GlobalEnv)
sqlContext <- SparkR::sparkRSQL.init(sc)
sparkVer <- SparkR:::callJMethod(sc, "version")
assign("sqlContext", sqlContext, envir=.GlobalEnv)
cat("\n Welcome to SparkR!")
cat("\n Welcome to")
cat("\n")
cat(" ____ __", "\n")
cat(" / __/__ ___ _____/ /__", "\n")
cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n")
cat(" /___/ .__/\\_,_/_/ /_/\\_\\")
if (nchar(sparkVer) == 0) {
cat("\n")
} else {
cat(" version ", sparkVer, "\n")
}
cat(" /_/", "\n")
cat("\n")

cat("\n Spark context is available as sc, SQL context is available as sqlContext\n")
}
32 changes: 32 additions & 0 deletions R/pkg/inst/tests/test_client.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

context("functions in client.R")

test_that("adding spark-testing-base as a package works", {
args <- generateSparkSubmitArgs("", "", "", "",
"holdenk:spark-testing-base:1.3.0_0.0.5")
expect_equal(gsub("[[:space:]]", "", args),
gsub("[[:space:]]", "",
"--packages holdenk:spark-testing-base:1.3.0_0.0.5"))
})

test_that("no package specified doesn't add packages flag", {
args <- generateSparkSubmitArgs("", "", "", "", "")
expect_equal(gsub("[[:space:]]", "", args),
"")
})
10 changes: 10 additions & 0 deletions R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,16 @@ test_that("filter() on a DataFrame", {
filtered2 <- where(df, df$name != "Michael")
expect_true(count(filtered2) == 2)
expect_true(collect(filtered2)$age[2] == 19)

# test suites for %in%
filtered3 <- filter(df, "age in (19)")
expect_equal(count(filtered3), 1)
filtered4 <- filter(df, "age in (19, 30)")
expect_equal(count(filtered4), 2)
filtered5 <- where(df, df$age %in% c(19))
expect_equal(count(filtered5), 1)
filtered6 <- where(df, df$age %in% c(19, 30))
expect_equal(count(filtered6), 2)
})

test_that("join() on a DataFrame", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {

@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
// Keep track of success so we know if we ecountered an exception
// We do this rather than a standard try/catch/re-throw to handle
// generic throwables.
boolean success = false;
try {
while (records.hasNext()) {
Expand All @@ -147,8 +150,19 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
closeAndWriteOutput();
success = true;
} finally {
if (!success) {
sorter.cleanupAfterError();
if (sorter != null) {
try {
sorter.cleanupAfterError();
} catch (Exception e) {
// Only throw this error if we won't be masking another
// error.
if (success) {
throw e;
} else {
logger.error("In addition to a failure during writing, we failed during " +
"cleanup.", e);
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,29 @@

package org.apache.spark.shuffle.hash

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.util.{Failure, Success, Try}
import java.io.InputStream

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.{Failure, Success}

import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
ShuffleBlockId}

private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
def fetchBlockStreams(
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
: Iterator[(BlockId, InputStream)] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager

val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))

Expand All @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}

def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
blocksByAddress,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

// Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
blockFetcherItr.map { blockPair =>
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Success(block) => {
block.asInstanceOf[Iterator[T]]
case Success(inputStream) => {
(blockId, inputStream)
}
case Failure(e) => {
blockId match {
Expand All @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
}
}
}

val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
val itr = blockFetcherItr.flatMap(unpackBlock)

val completionIter = CompletionIterator[T, Iterator[T]](itr, {
context.taskMetrics.updateShuffleReadMetrics()
})

new InterruptibleIterator[T](context, completionIter) {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
override def next(): T = {
readMetrics.incRecordsRead(1)
delegate.next()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext)
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
Expand All @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)

// Wrap the streams for compression based on configuration
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}

val ser = Serializer.getSerializer(dep.serializer)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val serializerInstance = ser.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())

// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

// Sort the output if there is a sort ordering defined.
Expand Down
Loading

0 comments on commit d33e525

Please sign in to comment.