Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roundtrip null values of any type #147

Closed
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ mimaDefaultSettings ++ Seq(
ProblemFilters.excludePackage("com.databricks.spark.csv.CsvRelation"),
ProblemFilters.excludePackage("com.databricks.spark.csv.util.InferSchema"),
ProblemFilters.excludePackage("com.databricks.spark.sql.readers"),
ProblemFilters.excludePackage("com.databricks.spark.csv.util.TypeCast"),
// We allowed the private `CsvRelation` type to leak into the public method signature:
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"com.databricks.spark.csv.DefaultSource.createRelation")
Expand Down
8 changes: 8 additions & 0 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class CsvParser extends Serializable {
private var parseMode: String = ParseModes.DEFAULT
private var ignoreLeadingWhiteSpace: Boolean = false
private var ignoreTrailingWhiteSpace: Boolean = false
private var treatEmptyValuesAsNulls: Boolean = false
private var parserLib: String = ParserLibs.DEFAULT
private var charset: String = TextFile.DEFAULT_CHARSET.name()
private var inferSchema: Boolean = false
Expand Down Expand Up @@ -84,6 +85,11 @@ class CsvParser extends Serializable {
this
}

def withTreatEmptyValuesAsNulls(treatAsNull: Boolean): CsvParser = {
this.treatEmptyValuesAsNulls = treatAsNull
this
}

def withParserLib(parserLib: String): CsvParser = {
this.parserLib = parserLib
this
Expand Down Expand Up @@ -114,6 +120,7 @@ class CsvParser extends Serializable {
parserLib,
ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls,
schema,
inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
Expand All @@ -132,6 +139,7 @@ class CsvParser extends Serializable {
parserLib,
ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls,
schema,
inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ case class CsvRelation protected[spark] (
parserLib: String,
ignoreLeadingWhiteSpace: Boolean,
ignoreTrailingWhiteSpace: Boolean,
treatEmptyValuesAsNulls: Boolean,
userSchema: StructType = null,
inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with InsertableRelation {
Expand Down Expand Up @@ -113,7 +114,8 @@ case class CsvRelation protected[spark] (
index = 0
while (index < schemaFields.length) {
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
treatEmptyValuesAsNulls)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down
9 changes: 9 additions & 0 deletions src/main/scala/com/databricks/spark/csv/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ class DefaultSource
} else {
throw new Exception("Ignore white space flag can be true or false")
}
val treatEmptyValuesAsNulls = parameters.getOrElse("treatEmptyValuesAsNulls", "false")
val treatEmptyValuesAsNullsFlag = if (treatEmptyValuesAsNulls == "false") {
false
} else if (treatEmptyValuesAsNulls == "true") {
true
} else {
throw new Exception("Treat empty values as null flag can be true or false")
}

val charset = parameters.getOrElse("charset", TextFile.DEFAULT_CHARSET.name())
// TODO validate charset?
Expand All @@ -137,6 +145,7 @@ class DefaultSource
parserLib,
ignoreLeadingWhiteSpaceFlag,
ignoreTrailingWhiteSpaceFlag,
treatEmptyValuesAsNullsFlag,
schema,
inferSchemaFlag)(sqlContext)
}
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/com/databricks/spark/csv/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ package object csv {
parserLib = parserLib,
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls = false,
inferCsvSchema = inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}
Expand All @@ -76,6 +77,7 @@ package object csv {
parserLib = parserLib,
ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls = false,
inferCsvSchema = inferSchema)(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}
Expand Down Expand Up @@ -116,11 +118,13 @@ package object csv {
case None => None
}

val nullValue = parameters.getOrElse("nullValue", "null")

val csvFormatBase = CSVFormat.DEFAULT
.withDelimiter(delimiterChar)
.withEscape(escapeChar)
.withSkipHeaderRecord(false)
.withNullString("null")
.withNullString(nullValue)

val csvFormat = quoteChar match {
case Some(c) => csvFormatBase.withQuote(c)
Expand All @@ -139,7 +143,7 @@ package object csv {
.withDelimiter(delimiterChar)
.withEscape(escapeChar)
.withSkipHeaderRecord(false)
.withNullString("null")
.withNullString(nullValue)

val csvFormat = quoteChar match {
case Some(c) => csvFormatBase.withQuote(c)
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ object TypeCast {
* @param datum string value
* @param castType SparkSQL type
*/
private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = {
if (datum == "" && nullable && !castType.isInstanceOf[StringType]){
private[csv] def castTo(
datum: String,
castType: DataType,
nullable: Boolean = true,
treatEmptyValuesAsNulls: Boolean = false): Any = {
if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please elaborate why it is a good idea to have null for StringType? I am not against, it but I want to see a clear usecase in which empty string should be parsed as null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, just saw this. Parsing the empty strings as nulls is compatible with the default behavior on multiple hadoop based systems, including the one we use the most (Scalding). It also allows us to indicate missing values (a null for missing values is much clearer than "" and consistent with other datatypes).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Now let's get it passed style checker and I will merge this.

null
} else {
castType match {
Expand Down
24 changes: 24 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,30 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
}

test("DSL test roundtrip nulls") {
// Create temp directory
TestUtils.deleteRecursively(new File(tempEmptyDir))
new File(tempEmptyDir).mkdirs()
val copyFilePath = tempEmptyDir + "null-numbers.csv"
val agesSchema = StructType(List(StructField("name", StringType, true),
StructField("age", IntegerType, true)))

val agesRows = Seq(Row("alice", 35), Row("bob", null), Row(null, 24))
val agesRdd = sqlContext.sparkContext.parallelize(agesRows)
val agesDf = sqlContext.createDataFrame(agesRdd, agesSchema)

agesDf.saveAsCsvFile(copyFilePath, Map("header" -> "true", "nullValue" -> ""))

val agesCopy = new CsvParser()
.withSchema(agesSchema)
.withUseHeader(true)
.withTreatEmptyValuesAsNulls(true)
.withParserLib(parserLib)
.csvFile(sqlContext, copyFilePath)

assert(agesCopy.count == agesRows.size)
assert(agesCopy.collect.toSet == agesRows.toSet)
}

test("DSL test with alternative delimiter and quote") {
val results = new CsvParser()
Expand Down