From b4782881c41c1235e79e15c990f401c42eb2bfce Mon Sep 17 00:00:00 2001 From: Martin Mauch Date: Mon, 11 Sep 2017 00:13:36 +0200 Subject: [PATCH 1/2] gitignore log files --- .gitignore | 1 + .../crealytics/spark/excel/Extractor.scala | 113 ++++++++++++++++++ .../crealytics/spark/excel/InferSchema.scala | 2 +- .../spark/excel/executor/DefaultSource.scala | 25 ++++ .../executor/ExcelExecutorRelation.scala | 45 +++++++ .../ExcelCreatableRelationProvider.scala | 48 ++++++++ .../crealytics/spark/excel/utils/Loan.scala | 7 ++ .../spark/excel/utils/ParameterChecker.scala | 12 ++ .../spark/excel/utils/RichRow.scala | 28 +++++ .../spark/excel/IntegrationSuite.scala | 58 +++++---- .../crealytics/spark/excel/RichRowSuite.scala | 2 +- 11 files changed, 315 insertions(+), 26 deletions(-) create mode 100644 src/main/scala/com/crealytics/spark/excel/Extractor.scala create mode 100644 src/main/scala/com/crealytics/spark/excel/executor/DefaultSource.scala create mode 100644 src/main/scala/com/crealytics/spark/excel/executor/ExcelExecutorRelation.scala create mode 100644 src/main/scala/com/crealytics/spark/excel/utils/ExcelCreatableRelationProvider.scala create mode 100644 src/main/scala/com/crealytics/spark/excel/utils/Loan.scala create mode 100644 src/main/scala/com/crealytics/spark/excel/utils/ParameterChecker.scala create mode 100644 src/main/scala/com/crealytics/spark/excel/utils/RichRow.scala diff --git a/.gitignore b/.gitignore index 726c3d2a..e40f3988 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ project/target/ .ensime* *.swp .idea +*.log diff --git a/src/main/scala/com/crealytics/spark/excel/Extractor.scala b/src/main/scala/com/crealytics/spark/excel/Extractor.scala new file mode 100644 index 00000000..eaae2fac --- /dev/null +++ b/src/main/scala/com/crealytics/spark/excel/Extractor.scala @@ -0,0 +1,113 @@ +package com.crealytics.spark.excel + +import java.io.InputStream +import java.math.BigDecimal +import java.sql.Timestamp +import java.text.SimpleDateFormat + +import org.apache.poi.ss.usermodel.{WorkbookFactory, Row => SheetRow, _} +import org.apache.spark.sql.types._ + +import scala.collection.JavaConverters._ + +case class Extractor(useHeader: Boolean, + inputStream: InputStream, + sheetName: Option[String], + startColumn: Int = 0, + endColumn: Int = Int.MaxValue, + timestampFormat: Option[String] = None) { + private lazy val workbook = WorkbookFactory.create(inputStream) + private lazy val sheet = findSheet(workbook, sheetName) + + import com.crealytics.spark.excel.utils.RichRow._ + + private val timestampParser = timestampFormat.map(d => new SimpleDateFormat(d)) + + private def parseTimestamp(stringValue: String): Timestamp = { + timestampParser match { + case Some(parser) => new Timestamp(parser.parse(stringValue).getTime) + case None => Timestamp.valueOf(stringValue) + } + } + + def firstRowWithData: Vector[Option[Cell]] = sheet.asScala + .find(_ != null) + .getOrElse(throw new RuntimeException(s"Sheet $sheet doesn't seem to contain any data")) + .eachCellIterator(startColumn, endColumn) + .to[Vector] + + def lookups(requiredColumns: Array[String], + schema: StructType + ): Vector[(SheetRow) => Any] = requiredColumns.map { c => + val columnNameRegex = s"(.*?)(_color)?".r + val columnNameRegex(columnName, isColor) = c + val columnIndex = schema.indexWhere(_.name == columnName) + + val cellExtractor: Cell => Any = if (isColor == null) { + castTo(_, schema(columnIndex).dataType) + } else { + _.getCellStyle.getFillForegroundColorColor match { + case null => "" + case c: org.apache.poi.xssf.usermodel.XSSFColor => c.getARGBHex + case uct => throw new RuntimeException(s"Unknown color type $uct: ${uct.getClass}") + } + } + { row: SheetRow => + val cell = row.getCell(columnIndex + startColumn) + if (cell == null) { + null + } else { + cellExtractor(cell) + } + } + }.to[Vector] + + private def castTo(cell: Cell, castType: DataType): Any = { + if (cell.getCellTypeEnum == CellType.BLANK) { + return null + } + val dataFormatter = new DataFormatter() + lazy val stringValue = dataFormatter.formatCellValue(cell) + lazy val numericValue = cell.getNumericCellValue + lazy val bigDecimal = new BigDecimal(stringValue.replaceAll(",", "")) + castType match { + case _: ByteType => numericValue.toByte + case _: ShortType => numericValue.toShort + case _: IntegerType => numericValue.toInt + case _: LongType => numericValue.toLong + case _: FloatType => numericValue.toFloat + case _: DoubleType => numericValue + case _: BooleanType => cell.getBooleanCellValue + case _: DecimalType => bigDecimal + case _: TimestampType => parseTimestamp(stringValue) + case _: DateType => new java.sql.Date(DateUtil.getJavaDate(numericValue).getTime) + case _: StringType => stringValue + case t => throw new RuntimeException(s"Unsupported cast from $cell to $t") + } + } + + private def findSheet(workBook: Workbook, + sheetName: Option[String]): Sheet = { + sheetName.map { sn => + Option(workBook.getSheet(sn)).getOrElse( + throw new IllegalArgumentException(s"Unknown sheet $sn") + ) + }.getOrElse(workBook.sheetIterator.next) + } + + def dataRows: Iterator[SheetRow] = sheet.rowIterator.asScala.drop(if (useHeader) 1 else 0) + + def extract(schema: StructType, + requiredColumns: Array[String]): Vector[Vector[Any]] = { + dataRows + .map(row => lookups(requiredColumns, schema).map(l => l(row))) + .toVector + } + + def stringsAndCellTypes: Seq[Vector[Int]] = dataRows.map { row: SheetRow => + row.eachCellIterator(startColumn, endColumn).map { cell => + cell.fold(Cell.CELL_TYPE_BLANK)(_.getCellType) + }.toVector + }.toVector +} + diff --git a/src/main/scala/com/crealytics/spark/excel/InferSchema.scala b/src/main/scala/com/crealytics/spark/excel/InferSchema.scala index 324e9198..d4c1444f 100644 --- a/src/main/scala/com/crealytics/spark/excel/InferSchema.scala +++ b/src/main/scala/com/crealytics/spark/excel/InferSchema.scala @@ -33,7 +33,7 @@ private[excel] object InferSchema { header: Array[String]): StructType = { val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = rowsRDD.aggregate(startType)( - inferRowType _, + inferRowType, mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => diff --git a/src/main/scala/com/crealytics/spark/excel/executor/DefaultSource.scala b/src/main/scala/com/crealytics/spark/excel/executor/DefaultSource.scala new file mode 100644 index 00000000..04ce784d --- /dev/null +++ b/src/main/scala/com/crealytics/spark/excel/executor/DefaultSource.scala @@ -0,0 +1,25 @@ +package com.crealytics.spark.excel.executor + +import com.crealytics.spark.excel.utils.{ExcelCreatableRelationProvider, ParameterChecker} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType + +class DefaultSource + extends ExcelCreatableRelationProvider + with DataSourceRegister { + override def shortName(): String = "large_excel" + + override def createRelation(sqlContext: SQLContext, + parameters: Map[String, String], + userSchema: StructType + ): ExcelExecutorRelation = ExcelExecutorRelation( + location = ParameterChecker.check(parameters, "path"), + sheetName = parameters.get("sheetName"), + useHeader = ParameterChecker.check(parameters, "useHeader").toBoolean, + treatEmptyValuesAsNulls = parameters.get("treatEmptyValuesAsNulls").fold(true)(_.toBoolean), + schema = userSchema, + startColumn = parameters.get("startColumn").fold(0)(_.toInt), + endColumn = parameters.get("endColumn").fold(Int.MaxValue)(_.toInt) + )(sqlContext) +} diff --git a/src/main/scala/com/crealytics/spark/excel/executor/ExcelExecutorRelation.scala b/src/main/scala/com/crealytics/spark/excel/executor/ExcelExecutorRelation.scala new file mode 100644 index 00000000..30010c5f --- /dev/null +++ b/src/main/scala/com/crealytics/spark/excel/executor/ExcelExecutorRelation.scala @@ -0,0 +1,45 @@ +package com.crealytics.spark.excel.executor + +import com.crealytics.spark.excel.Extractor +import com.crealytics.spark.excel.utils.Loan +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SerializableWritable +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +case class ExcelExecutorRelation( + location: String, + sheetName: Option[String], + useHeader: Boolean, + treatEmptyValuesAsNulls: Boolean, + schema: StructType, + startColumn: Int = 0, + endColumn: Int = Int.MaxValue + ) + (@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan with PrunedScan { + + override def buildScan: RDD[Row] = buildScan(schema.map(_.name).toArray) + + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { + val confBroadcast = sqlContext.sparkContext.broadcast( + new SerializableWritable(sqlContext.sparkContext.hadoopConfiguration) + ) + + sqlContext.sparkContext.parallelize(Seq(location), 1).flatMap { loc => + val path = new Path(loc) + val rows = Loan.withCloseable(FileSystem.get(path.toUri, confBroadcast.value.value).open(path)) { + inputStream => + Extractor(useHeader, + inputStream, + sheetName, + startColumn, + endColumn).extract(schema, requiredColumns) + } + rows.map(Row.fromSeq) + } + } +} + diff --git a/src/main/scala/com/crealytics/spark/excel/utils/ExcelCreatableRelationProvider.scala b/src/main/scala/com/crealytics/spark/excel/utils/ExcelCreatableRelationProvider.scala new file mode 100644 index 00000000..b48200a6 --- /dev/null +++ b/src/main/scala/com/crealytics/spark/excel/utils/ExcelCreatableRelationProvider.scala @@ -0,0 +1,48 @@ +package com.crealytics.spark.excel.utils + +import com.crealytics.spark.excel.ExcelFileSaver +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, SchemaRelationProvider} +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} + +trait ExcelCreatableRelationProvider extends CreatableRelationProvider with + SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val path = ParameterChecker.check(parameters, "path") + val sheetName = parameters.getOrElse("sheetName", "Sheet1") + val useHeader = ParameterChecker.check(parameters, "useHeader").toBoolean + val timestampFormat = parameters.getOrElse("timestampFormat", ExcelFileSaver.DEFAULT_TIMESTAMP_FORMAT) + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val doSave = if (fs.exists(filesystemPath)) { + mode match { + case SaveMode.Append => + sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") + case SaveMode.Overwrite => + fs.delete(filesystemPath, true) + true + case SaveMode.ErrorIfExists => + sys.error(s"path $path already exists.") + case SaveMode.Ignore => false + } + } else { + true + } + if (doSave) { + // Only save data when the save mode is not ignore. + new ExcelFileSaver(fs).save( + filesystemPath, + data, + sheetName = sheetName, + useHeader = useHeader, + timestampFormat = timestampFormat + ) + } + + createRelation(sqlContext, parameters, data.schema) + } +} diff --git a/src/main/scala/com/crealytics/spark/excel/utils/Loan.scala b/src/main/scala/com/crealytics/spark/excel/utils/Loan.scala new file mode 100644 index 00000000..27022900 --- /dev/null +++ b/src/main/scala/com/crealytics/spark/excel/utils/Loan.scala @@ -0,0 +1,7 @@ +package com.crealytics.spark.excel.utils + +object Loan { + def withCloseable[R <: AutoCloseable, T](closeable: R)(f: R => T): T = + try f(closeable) + finally if (closeable != null) closeable.close() +} diff --git a/src/main/scala/com/crealytics/spark/excel/utils/ParameterChecker.scala b/src/main/scala/com/crealytics/spark/excel/utils/ParameterChecker.scala new file mode 100644 index 00000000..99d7a1c2 --- /dev/null +++ b/src/main/scala/com/crealytics/spark/excel/utils/ParameterChecker.scala @@ -0,0 +1,12 @@ +package com.crealytics.spark.excel.utils + +object ParameterChecker { + // Forces a Parameter to exist, otherwise an exception is thrown. + def check(map: Map[String, String], param: String): String = { + if (!map.contains(param)) { + throw new IllegalArgumentException(s"Parameter ${'"'}$param${'"'} is missing in options.") + } else { + map.apply(param) + } + } +} diff --git a/src/main/scala/com/crealytics/spark/excel/utils/RichRow.scala b/src/main/scala/com/crealytics/spark/excel/utils/RichRow.scala new file mode 100644 index 00000000..394bc25d --- /dev/null +++ b/src/main/scala/com/crealytics/spark/excel/utils/RichRow.scala @@ -0,0 +1,28 @@ +package com.crealytics.spark.excel.utils + +import org.apache.poi.ss.usermodel.{Cell, Row} +import org.apache.poi.ss.usermodel.Row.MissingCellPolicy + +object RichRow { + implicit class RichRow(val row: Row) extends AnyVal { + + def eachCellIterator(startColumn: Int, endColumn: Int): Iterator[Option[Cell]] = new Iterator[Option[Cell]] { + private val lastCellInclusive = row.getLastCellNum - 1 + private val endCol = Math.min(endColumn, Math.max(startColumn, lastCellInclusive)) + require(startColumn >= 0 && startColumn <= endCol) + + private var nextCol = startColumn + + override def hasNext: Boolean = nextCol <= endCol && nextCol <= lastCellInclusive + + override def next(): Option[Cell] = { + val next = if (nextCol > endCol) throw new NoSuchElementException(s"column index = $nextCol") + else Option(row.getCell(nextCol, MissingCellPolicy.RETURN_NULL_AND_BLANK)) + nextCol += 1 + next + } + } + + } + +} diff --git a/src/test/scala/com/crealytics/spark/excel/IntegrationSuite.scala b/src/test/scala/com/crealytics/spark/excel/IntegrationSuite.scala index 0cc83395..20d90870 100644 --- a/src/test/scala/com/crealytics/spark/excel/IntegrationSuite.scala +++ b/src/test/scala/com/crealytics/spark/excel/IntegrationSuite.scala @@ -1,7 +1,7 @@ package com.crealytics.spark.excel import java.io.File -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalacheck.{Arbitrary, Gen, Shrink} import Arbitrary.{arbLong => _, arbString => _, _} @@ -18,26 +18,26 @@ import org.apache.spark.sql.functions.lit object IntegrationSuite { case class ExampleData( - aBoolean: Boolean, - aByte: Byte, - aShort: Short, - anInt: Int, - aLong: Long, - aDouble: Double, - aString: String, - aTimestamp: java.sql.Timestamp, - aDate: java.sql.Date - ) - - val exampleDataSchema = ScalaReflection.schemaFor[ExampleData].dataType.asInstanceOf[StructType] - - implicit val arbitraryDateFourDigits = Arbitrary[java.sql.Date]( + aBoolean: Boolean, + aByte: Byte, + aShort: Short, + anInt: Int, + aLong: Long, + aDouble: Double, + aString: String, + aTimestamp: java.sql.Timestamp, + aDate: java.sql.Date + ) + + val exampleDataSchema: StructType = ScalaReflection.schemaFor[ExampleData].dataType.asInstanceOf[StructType] + + implicit val arbitraryDateFourDigits: Arbitrary[Date] = Arbitrary[java.sql.Date]( Gen .chooseNum[Long](0L, (new java.util.Date).getTime + 1000000) .map(new java.sql.Date(_)) ) - implicit val arbitraryTimestamp = Arbitrary[java.sql.Timestamp]( + implicit val arbitraryTimestamp: Arbitrary[Timestamp] = Arbitrary[java.sql.Timestamp]( Gen .chooseNum[Long](0L, (new java.util.Date).getTime + 1000000) .map(new java.sql.Timestamp(_)) @@ -48,7 +48,9 @@ object IntegrationSuite { // We're restricting our tests to Int-sized Longs in order not to fail // because of this issue. implicit val arbitraryLongWithLosslessDoubleConvertability: Arbitrary[Long] = - Arbitrary[Long] { arbitrary[Int].map(_.toLong) } + Arbitrary[Long] { + arbitrary[Int].map(_.toLong) + } implicit val arbitraryStringWithoutUnicodeCharacters: Arbitrary[String] = Arbitrary[String](Gen.alphaNumStr) @@ -58,26 +60,26 @@ object IntegrationSuite { } class IntegrationSuite extends FunSuite with PropertyChecks with DataFrameSuiteBase { + import IntegrationSuite._ import spark.implicits._ implicit def shrinkOnlyNumberOfRows[A]: Shrink[List[A]] = Shrink.shrinkContainer[List, A] - val PackageName = "com.crealytics.spark.excel" val sheetName = "test sheet" - def writeThenRead(df: DataFrame): DataFrame = { + def writeThenRead(df: DataFrame, packageName: String): DataFrame = { val fileName = File.createTempFile("spark_excel_test_", ".xlsx").getAbsolutePath df.write - .format(PackageName) + .format(packageName) .option("sheetName", sheetName) .option("useHeader", "true") .mode("overwrite") .save(fileName) - spark.read.format(PackageName) + spark.read.format(packageName) .option("sheetName", sheetName) .option("useHeader", "true") .option("treatEmptyValuesAsNulls", "true") @@ -91,7 +93,15 @@ class IntegrationSuite extends FunSuite with PropertyChecks with DataFrameSuiteB forAll(rowsGen, MinSuccessful(20)) { rows => val expected = spark.createDataset(rows).toDF - assertDataFrameEquals(expected, writeThenRead(expected)) + assertDataFrameEquals(expected, writeThenRead(expected, "com.crealytics.spark.excel")) + } + } + + test("parses known datatypes correctly using executor-based") { + forAll(rowsGen, MinSuccessful(20)) { rows => + val expected = spark.createDataset(rows).toDF + + assertDataFrameEquals(expected, writeThenRead(expected, "com.crealytics.spark.excel.executor")) } } @@ -106,9 +116,9 @@ class IntegrationSuite extends FunSuite with PropertyChecks with DataFrameSuiteB // Generate the same DataFrame but with empty strings val expectedWithEmptyStr = expected.withColumn("aString", lit("": String)) // Set the schema so that aString is nullable - expectedWithEmptyStr.schema.fields.update(6, StructField("aString", DataTypes.StringType, true)) + expectedWithEmptyStr.schema.fields.update(6, StructField("aString", DataTypes.StringType, nullable = true)) - assertDataFrameEquals(expectedWithEmptyStr, writeThenRead(expectedWithNull)) + assertDataFrameEquals(expectedWithEmptyStr, writeThenRead(expectedWithNull, "com.crealytics.spark.excel")) } } } diff --git a/src/test/scala/com/crealytics/spark/excel/RichRowSuite.scala b/src/test/scala/com/crealytics/spark/excel/RichRowSuite.scala index 6df3442c..80baa0b4 100644 --- a/src/test/scala/com/crealytics/spark/excel/RichRowSuite.scala +++ b/src/test/scala/com/crealytics/spark/excel/RichRowSuite.scala @@ -26,7 +26,7 @@ trait RowGenerator extends MockFactory { } class RichRowSuite extends FunSuite with PropertyChecks with RowGenerator { - + import com.crealytics.spark.excel.RichRow._ test("Invalid cell range should throw an error") { forAll(rowGen) { g => (g.start > g.end) ==> Try { From 63cc995b47d97e17266ac309f29c6616a57f1e4b Mon Sep 17 00:00:00 2001 From: dragan Date: Mon, 11 Sep 2017 19:19:31 +0200 Subject: [PATCH 2/2] added executor-base data source. --- src/main/scala/com/crealytics/spark/excel/Extractor.scala | 2 -- .../spark/excel/executor/ExcelExecutorRelation.scala | 6 ++++-- .../spark/excel/utils/ExcelCreatableRelationProvider.scala | 2 ++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/main/scala/com/crealytics/spark/excel/Extractor.scala b/src/main/scala/com/crealytics/spark/excel/Extractor.scala index eaae2fac..a937ecf0 100644 --- a/src/main/scala/com/crealytics/spark/excel/Extractor.scala +++ b/src/main/scala/com/crealytics/spark/excel/Extractor.scala @@ -19,8 +19,6 @@ case class Extractor(useHeader: Boolean, private lazy val workbook = WorkbookFactory.create(inputStream) private lazy val sheet = findSheet(workbook, sheetName) - import com.crealytics.spark.excel.utils.RichRow._ - private val timestampParser = timestampFormat.map(d => new SimpleDateFormat(d)) private def parseTimestamp(stringValue: String): Timestamp = { diff --git a/src/main/scala/com/crealytics/spark/excel/executor/ExcelExecutorRelation.scala b/src/main/scala/com/crealytics/spark/excel/executor/ExcelExecutorRelation.scala index 30010c5f..6d7b4f43 100644 --- a/src/main/scala/com/crealytics/spark/excel/executor/ExcelExecutorRelation.scala +++ b/src/main/scala/com/crealytics/spark/excel/executor/ExcelExecutorRelation.scala @@ -16,7 +16,8 @@ case class ExcelExecutorRelation( treatEmptyValuesAsNulls: Boolean, schema: StructType, startColumn: Int = 0, - endColumn: Int = Int.MaxValue + endColumn: Int = Int.MaxValue, + timestampFormat: Option[String] = None ) (@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with PrunedScan { @@ -36,7 +37,8 @@ case class ExcelExecutorRelation( inputStream, sheetName, startColumn, - endColumn).extract(schema, requiredColumns) + endColumn, + timestampFormat).extract(schema, requiredColumns) } rows.map(Row.fromSeq) } diff --git a/src/main/scala/com/crealytics/spark/excel/utils/ExcelCreatableRelationProvider.scala b/src/main/scala/com/crealytics/spark/excel/utils/ExcelCreatableRelationProvider.scala index b48200a6..a95f8ca4 100644 --- a/src/main/scala/com/crealytics/spark/excel/utils/ExcelCreatableRelationProvider.scala +++ b/src/main/scala/com/crealytics/spark/excel/utils/ExcelCreatableRelationProvider.scala @@ -15,6 +15,7 @@ trait ExcelCreatableRelationProvider extends CreatableRelationProvider with val path = ParameterChecker.check(parameters, "path") val sheetName = parameters.getOrElse("sheetName", "Sheet1") val useHeader = ParameterChecker.check(parameters, "useHeader").toBoolean + val dateFormat = parameters.getOrElse("dateFormat", ExcelFileSaver.DEFAULT_DATE_FORMAT) val timestampFormat = parameters.getOrElse("timestampFormat", ExcelFileSaver.DEFAULT_TIMESTAMP_FORMAT) val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -39,6 +40,7 @@ trait ExcelCreatableRelationProvider extends CreatableRelationProvider with data, sheetName = sheetName, useHeader = useHeader, + dateFormat = dateFormat, timestampFormat = timestampFormat ) }