Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added executor-base data source. #25

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ project/target/
.ensime*
*.swp
.idea
*.log
111 changes: 111 additions & 0 deletions src/main/scala/com/crealytics/spark/excel/Extractor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package com.crealytics.spark.excel

import java.io.InputStream
import java.math.BigDecimal
import java.sql.Timestamp
import java.text.SimpleDateFormat

import org.apache.poi.ss.usermodel.{WorkbookFactory, Row => SheetRow, _}
import org.apache.spark.sql.types._

import scala.collection.JavaConverters._

case class Extractor(useHeader: Boolean,
inputStream: InputStream,
sheetName: Option[String],
startColumn: Int = 0,
endColumn: Int = Int.MaxValue,
timestampFormat: Option[String] = None) {
private lazy val workbook = WorkbookFactory.create(inputStream)
private lazy val sheet = findSheet(workbook, sheetName)

private val timestampParser = timestampFormat.map(d => new SimpleDateFormat(d))

private def parseTimestamp(stringValue: String): Timestamp = {
timestampParser match {
case Some(parser) => new Timestamp(parser.parse(stringValue).getTime)
case None => Timestamp.valueOf(stringValue)
}
}

def firstRowWithData: Vector[Option[Cell]] = sheet.asScala
.find(_ != null)
.getOrElse(throw new RuntimeException(s"Sheet $sheet doesn't seem to contain any data"))
.eachCellIterator(startColumn, endColumn)
.to[Vector]

def lookups(requiredColumns: Array[String],
schema: StructType
): Vector[(SheetRow) => Any] = requiredColumns.map { c =>
val columnNameRegex = s"(.*?)(_color)?".r
val columnNameRegex(columnName, isColor) = c
val columnIndex = schema.indexWhere(_.name == columnName)

val cellExtractor: Cell => Any = if (isColor == null) {
castTo(_, schema(columnIndex).dataType)
} else {
_.getCellStyle.getFillForegroundColorColor match {
case null => ""
case c: org.apache.poi.xssf.usermodel.XSSFColor => c.getARGBHex
case uct => throw new RuntimeException(s"Unknown color type $uct: ${uct.getClass}")
}
}
{ row: SheetRow =>
val cell = row.getCell(columnIndex + startColumn)
if (cell == null) {
null
} else {
cellExtractor(cell)
}
}
}.to[Vector]

private def castTo(cell: Cell, castType: DataType): Any = {
if (cell.getCellTypeEnum == CellType.BLANK) {
return null
}
val dataFormatter = new DataFormatter()
lazy val stringValue = dataFormatter.formatCellValue(cell)
lazy val numericValue = cell.getNumericCellValue
lazy val bigDecimal = new BigDecimal(stringValue.replaceAll(",", ""))
castType match {
case _: ByteType => numericValue.toByte
case _: ShortType => numericValue.toShort
case _: IntegerType => numericValue.toInt
case _: LongType => numericValue.toLong
case _: FloatType => numericValue.toFloat
case _: DoubleType => numericValue
case _: BooleanType => cell.getBooleanCellValue
case _: DecimalType => bigDecimal
case _: TimestampType => parseTimestamp(stringValue)
case _: DateType => new java.sql.Date(DateUtil.getJavaDate(numericValue).getTime)
case _: StringType => stringValue
case t => throw new RuntimeException(s"Unsupported cast from $cell to $t")
}
}

private def findSheet(workBook: Workbook,
sheetName: Option[String]): Sheet = {
sheetName.map { sn =>
Option(workBook.getSheet(sn)).getOrElse(
throw new IllegalArgumentException(s"Unknown sheet $sn")
)
}.getOrElse(workBook.sheetIterator.next)
}

def dataRows: Iterator[SheetRow] = sheet.rowIterator.asScala.drop(if (useHeader) 1 else 0)

def extract(schema: StructType,
requiredColumns: Array[String]): Vector[Vector[Any]] = {
dataRows
.map(row => lookups(requiredColumns, schema).map(l => l(row)))
.toVector
}

def stringsAndCellTypes: Seq[Vector[Int]] = dataRows.map { row: SheetRow =>
row.eachCellIterator(startColumn, endColumn).map { cell =>
cell.fold(Cell.CELL_TYPE_BLANK)(_.getCellType)
}.toVector
}.toVector
}

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private[excel] object InferSchema {
header: Array[String]): StructType = {
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] = rowsRDD.aggregate(startType)(
inferRowType _,
inferRowType,
mergeRowTypes)

val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.crealytics.spark.excel.executor

import com.crealytics.spark.excel.utils.{ExcelCreatableRelationProvider, ParameterChecker}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType

class DefaultSource
extends ExcelCreatableRelationProvider
with DataSourceRegister {
override def shortName(): String = "large_excel"

override def createRelation(sqlContext: SQLContext,
parameters: Map[String, String],
userSchema: StructType
): ExcelExecutorRelation = ExcelExecutorRelation(
location = ParameterChecker.check(parameters, "path"),
sheetName = parameters.get("sheetName"),
useHeader = ParameterChecker.check(parameters, "useHeader").toBoolean,
treatEmptyValuesAsNulls = parameters.get("treatEmptyValuesAsNulls").fold(true)(_.toBoolean),
schema = userSchema,
startColumn = parameters.get("startColumn").fold(0)(_.toInt),
endColumn = parameters.get("endColumn").fold(Int.MaxValue)(_.toInt)
)(sqlContext)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.crealytics.spark.excel.executor

import com.crealytics.spark.excel.Extractor
import com.crealytics.spark.excel.utils.Loan
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SerializableWritable
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

case class ExcelExecutorRelation(
location: String,
sheetName: Option[String],
useHeader: Boolean,
treatEmptyValuesAsNulls: Boolean,
schema: StructType,
startColumn: Int = 0,
endColumn: Int = Int.MaxValue,
timestampFormat: Option[String] = None
)
(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with PrunedScan {

override def buildScan: RDD[Row] = buildScan(schema.map(_.name).toArray)

override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
val confBroadcast = sqlContext.sparkContext.broadcast(
new SerializableWritable(sqlContext.sparkContext.hadoopConfiguration)
)

sqlContext.sparkContext.parallelize(Seq(location), 1).flatMap { loc =>
val path = new Path(loc)
val rows = Loan.withCloseable(FileSystem.get(path.toUri, confBroadcast.value.value).open(path)) {
inputStream =>
Extractor(useHeader,
inputStream,
sheetName,
startColumn,
endColumn,
timestampFormat).extract(schema, requiredColumns)
}
rows.map(Row.fromSeq)
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.crealytics.spark.excel.utils

import com.crealytics.spark.excel.ExcelFileSaver
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

trait ExcelCreatableRelationProvider extends CreatableRelationProvider with
SchemaRelationProvider {
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
val path = ParameterChecker.check(parameters, "path")
val sheetName = parameters.getOrElse("sheetName", "Sheet1")
val useHeader = ParameterChecker.check(parameters, "useHeader").toBoolean
val dateFormat = parameters.getOrElse("dateFormat", ExcelFileSaver.DEFAULT_DATE_FORMAT)
val timestampFormat = parameters.getOrElse("timestampFormat", ExcelFileSaver.DEFAULT_TIMESTAMP_FORMAT)
val filesystemPath = new Path(path)
val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
val doSave = if (fs.exists(filesystemPath)) {
mode match {
case SaveMode.Append =>
sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}")
case SaveMode.Overwrite =>
fs.delete(filesystemPath, true)
true
case SaveMode.ErrorIfExists =>
sys.error(s"path $path already exists.")
case SaveMode.Ignore => false
}
} else {
true
}
if (doSave) {
// Only save data when the save mode is not ignore.
new ExcelFileSaver(fs).save(
filesystemPath,
data,
sheetName = sheetName,
useHeader = useHeader,
dateFormat = dateFormat,
timestampFormat = timestampFormat
)
}

createRelation(sqlContext, parameters, data.schema)
}
}
7 changes: 7 additions & 0 deletions src/main/scala/com/crealytics/spark/excel/utils/Loan.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.crealytics.spark.excel.utils

object Loan {
def withCloseable[R <: AutoCloseable, T](closeable: R)(f: R => T): T =
try f(closeable)
finally if (closeable != null) closeable.close()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.crealytics.spark.excel.utils

object ParameterChecker {
// Forces a Parameter to exist, otherwise an exception is thrown.
def check(map: Map[String, String], param: String): String = {
if (!map.contains(param)) {
throw new IllegalArgumentException(s"Parameter ${'"'}$param${'"'} is missing in options.")
} else {
map.apply(param)
}
}
}
28 changes: 28 additions & 0 deletions src/main/scala/com/crealytics/spark/excel/utils/RichRow.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.crealytics.spark.excel.utils

import org.apache.poi.ss.usermodel.{Cell, Row}
import org.apache.poi.ss.usermodel.Row.MissingCellPolicy

object RichRow {
implicit class RichRow(val row: Row) extends AnyVal {

def eachCellIterator(startColumn: Int, endColumn: Int): Iterator[Option[Cell]] = new Iterator[Option[Cell]] {
private val lastCellInclusive = row.getLastCellNum - 1
private val endCol = Math.min(endColumn, Math.max(startColumn, lastCellInclusive))
require(startColumn >= 0 && startColumn <= endCol)

private var nextCol = startColumn

override def hasNext: Boolean = nextCol <= endCol && nextCol <= lastCellInclusive

override def next(): Option[Cell] = {
val next = if (nextCol > endCol) throw new NoSuchElementException(s"column index = $nextCol")
else Option(row.getCell(nextCol, MissingCellPolicy.RETURN_NULL_AND_BLANK))
nextCol += 1
next
}
}

}

}