diff --git a/README.md b/README.md index b6718bf..289d4ac 100755 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ When reading files the API accepts several options: * `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`. * `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec` or one of case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`). Defaults to no compression when a codec is not specified. * `nullValue`: specificy a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame +* `quoteMode`: when to quote fields (`ALL`, `MINIMAL` (default), `NON_NUMERIC`, `NONE`), see [Quote Modes](https://commons.apache.org/proper/commons-csv/apidocs/org/apache/commons/csv/QuoteMode.html) The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details. diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 6251b0d..4a01f01 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -15,7 +15,7 @@ */ package com.databricks.spark -import org.apache.commons.csv.CSVFormat +import org.apache.commons.csv.{CSVFormat, QuoteMode} import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark.sql.{DataFrame, SQLContext} @@ -121,11 +121,19 @@ package object csv { throw new Exception("Quotation cannot be more than one character.") } + val quoteModeString = parameters.getOrElse("quoteMode", "MINIMAL") + val quoteMode: QuoteMode = if (quoteModeString == null) { + null + } else { + QuoteMode.valueOf(quoteModeString.toUpperCase) + } + val nullValue = parameters.getOrElse("nullValue", "null") val csvFormat = defaultCsvFormat .withDelimiter(delimiterChar) .withQuote(quoteChar) + .withQuoteMode(quoteMode) .withEscape(escapeChar) .withSkipHeaderRecord(false) .withNullString(nullValue) @@ -141,6 +149,7 @@ package object csv { val csvFormat = defaultCsvFormat .withDelimiter(delimiterChar) .withQuote(quoteChar) + .withQuoteMode(quoteMode) .withEscape(escapeChar) .withSkipHeaderRecord(false) .withNullString(nullValue) diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 00eb846..f693fce 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -18,6 +18,7 @@ package com.databricks.spark.csv import java.io.File import java.nio.charset.UnsupportedCharsetException import java.sql.Timestamp +import scala.io.Source import com.databricks.spark.csv.util.ParseModes import org.apache.hadoop.io.compress.GzipCodec @@ -442,6 +443,98 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) } + test("DSL save with a quoteMode") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = sqlContext.csvFile(carsFile, parserLib = parserLib) + val delimiter = "," + var quote = "\"" + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", + "quote" -> quote, "delimiter" -> delimiter, "quoteMode" -> "ALL")) + + val carsCopy = sqlContext.csvFile(copyFilePath + "/") + for(file <- new File(copyFilePath + "/").listFiles) { + if (!(file.getName.startsWith("_") || file.getName.startsWith("."))) { + for(line <- Source.fromFile(file).getLines()) { + for(column <- line.split(delimiter)) { + assert(column.startsWith(quote)) + assert(column.endsWith(quote)) + } + } + } + } + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + + test("DSL save with non numeric quoteMode") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = sqlContext.csvFile(carsFile, parserLib = parserLib, inferSchema = true) + val delimiter = "," + var quote = "\"" + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", + "quote" -> quote, "delimiter" -> delimiter, "quoteMode" -> "NON_NUMERIC")) + + val carsCopy = sqlContext.csvFile(copyFilePath + "/") + for(file <- new File(copyFilePath + "/").listFiles) { + if (!(file.getName.startsWith("_") || file.getName.startsWith("."))) { + for((line, lineno) <- Source.fromFile(file).getLines().zipWithIndex) { + val columns = line.split(delimiter) + if (lineno == 0) { + assert(columns(0).startsWith(quote)) + assert(columns(0).endsWith(quote)) + assert(columns(1).startsWith(quote)) + assert(columns(1).endsWith(quote)) + } else { + assert(!columns(0).startsWith(quote)) + assert(!columns(0).endsWith(quote)) + assert(columns(1).startsWith(quote)) + assert(columns(1).endsWith(quote)) + } + } + } + } + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + + test("DSL save with null quoteMode") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = sqlContext.csvFile(carsFile, parserLib = parserLib) + val delimiter = "," + var quote = "\"" + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", + "quote" -> quote, "delimiter" -> delimiter, "quoteMode" -> null)) + + val carsCopy = sqlContext.csvFile(copyFilePath + "/") + for(file <- new File(copyFilePath + "/").listFiles) { + if (!(file.getName.startsWith("_") || file.getName.startsWith("."))) { + for(line <- Source.fromFile(file).getLines()) { + for(column <- line.split(delimiter)) { + assert(!column.startsWith(quote)) + assert(!column.endsWith(quote)) + } + } + } + } + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + test("DSL save with a compression codec") { // Create temp directory TestUtils.deleteRecursively(new File(tempEmptyDir))