Skip to content

Commit

Permalink
[SPARK-19058][SQL] fix partition related behaviors with DataFrameWrit…
Browse files Browse the repository at this point in the history
…er.saveAsTable

## What changes were proposed in this pull request?

When we append data to a partitioned table with `DataFrameWriter.saveAsTable`, there are 2 issues:
1. doesn't work when the partition has custom location.
2. will recover all partitions

This PR fixes them by moving the special partition handling code from `DataSourceAnalysis` to `InsertIntoHadoopFsRelationCommand`, so that the `DataFrameWriter.saveAsTable` code path can also benefit from it.

## How was this patch tested?

newly added regression tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes apache#16460 from cloud-fan/append.
  • Loading branch information
cloud-fan authored and cmonkey committed Jan 9, 2017
1 parent 7bc8fc6 commit 8be30f0
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 112 deletions.
Expand Up @@ -393,7 +393,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

// Drop the existing table
catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false)
createTable(tableIdent)
createTable(tableIdentWithDB)
// Refresh the cache of the table in the catalog.
catalog.refreshTable(tableIdentWithDB)

case _ => createTable(tableIdent)
}
Expand Down
Expand Up @@ -138,7 +138,7 @@ case class CreateDataSourceTableAsSelectCommand(
val tableIdentWithDB = table.identifier.copy(database = Some(db))
val tableName = tableIdentWithDB.unquotedString

val result = if (sessionState.catalog.tableExists(tableIdentWithDB)) {
if (sessionState.catalog.tableExists(tableIdentWithDB)) {
assert(mode != SaveMode.Overwrite,
s"Expect the table $tableName has been dropped when the save mode is Overwrite")

Expand All @@ -150,35 +150,34 @@ case class CreateDataSourceTableAsSelectCommand(
return Seq.empty
}

saveDataIntoTable(sparkSession, table, table.storage.locationUri, query, mode)
saveDataIntoTable(
sparkSession, table, table.storage.locationUri, query, mode, tableExists = true)
} else {
val tableLocation = if (table.tableType == CatalogTableType.MANAGED) {
Some(sessionState.catalog.defaultTablePath(table.identifier))
} else {
table.storage.locationUri
}
val result = saveDataIntoTable(sparkSession, table, tableLocation, query, mode)
val result = saveDataIntoTable(
sparkSession, table, tableLocation, query, mode, tableExists = false)
val newTable = table.copy(
storage = table.storage.copy(locationUri = tableLocation),
// We will use the schema of resolved.relation as the schema of the table (instead of
// the schema of df). It is important since the nullability may be changed by the relation
// provider (for example, see org.apache.spark.sql.parquet.DefaultSource).
schema = result.schema)
sessionState.catalog.createTable(newTable, ignoreIfExists = false)
result
}

result match {
case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty &&
sparkSession.sqlContext.conf.manageFilesourcePartitions =>
// Need to recover partitions into the metastore so our saved data is visible.
sparkSession.sessionState.executePlan(
AlterTableRecoverPartitionsCommand(table.identifier)).toRdd
case _ =>
result match {
case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty &&
sparkSession.sqlContext.conf.manageFilesourcePartitions =>
// Need to recover partitions into the metastore so our saved data is visible.
sparkSession.sessionState.executePlan(
AlterTableRecoverPartitionsCommand(table.identifier)).toRdd
case _ =>
}
}

// Refresh the cache of the table in the catalog.
sessionState.catalog.refreshTable(tableIdentWithDB)
Seq.empty[Row]
}

Expand All @@ -187,7 +186,8 @@ case class CreateDataSourceTableAsSelectCommand(
table: CatalogTable,
tableLocation: Option[String],
data: LogicalPlan,
mode: SaveMode): BaseRelation = {
mode: SaveMode,
tableExists: Boolean): BaseRelation = {
// Create the relation based on the input logical plan: `data`.
val pathOption = tableLocation.map("path" -> _)
val dataSource = DataSource(
Expand All @@ -196,7 +196,7 @@ case class CreateDataSourceTableAsSelectCommand(
partitionColumns = table.partitionColumnNames,
bucketSpec = table.bucketSpec,
options = table.storage.properties ++ pathOption,
catalogTable = Some(table))
catalogTable = if (tableExists) Some(table) else None)

try {
dataSource.write(mode, Dataset.ofRows(session, query))
Expand Down
Expand Up @@ -473,22 +473,26 @@ case class DataSource(
s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]")
}.asInstanceOf[Attribute]
}
val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
case LogicalRelation(t: HadoopFsRelation, _, _) => t.location
}.head
}
// For partitioned relation r, r.schema's column ordering can be different from the column
// ordering of data.logicalPlan (partition columns are all moved after data column). This
// will be adjusted within InsertIntoHadoopFsRelation.
val plan =
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitions = Map.empty,
customPartitionLocations = Map.empty,
partitionColumns = columns,
bucketSpec = bucketSpec,
fileFormat = format,
refreshFunction = _ => Unit, // No existing table needs to be refreshed.
options = options,
query = data.logicalPlan,
mode = mode,
catalogTable = catalogTable)
catalogTable = catalogTable,
fileIndex = fileIndex)
sparkSession.sessionState.executePlan(plan).toRdd
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
Expand Down
Expand Up @@ -197,91 +197,19 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {

val partitionSchema = actualQuery.resolve(
t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
val partitionsTrackedByCatalog =
t.sparkSession.sessionState.conf.manageFilesourcePartitions &&
l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
l.catalogTable.get.tracksPartitionsInCatalog

var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty

val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get }

// When partitions are tracked by the catalog, compute all custom partition locations that
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions(
l.catalogTable.get.identifier, Some(staticPartitions))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations = getCustomPartitionLocations(
t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions)
}

// Callback for updating metastore partition metadata after the insertion job completes.
// TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand
def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
if (partitionsTrackedByCatalog) {
val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions
if (newPartitions.nonEmpty) {
AlterTableAddPartitionCommand(
l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
ifNotExists = true).run(t.sparkSession)
}
if (overwrite) {
val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
if (deletedPartitions.nonEmpty) {
AlterTableDropPartitionCommand(
l.catalogTable.get.identifier, deletedPartitions.toSeq,
ifExists = true, purge = false,
retainData = true /* already deleted */).run(t.sparkSession)
}
}
}
t.location.refresh()
}

val insertCmd = InsertIntoHadoopFsRelationCommand(
InsertIntoHadoopFsRelationCommand(
outputPath,
staticPartitions,
customPartitionLocations,
partitionSchema,
t.bucketSpec,
t.fileFormat,
refreshPartitionsCallback,
t.options,
actualQuery,
mode,
table)

insertCmd
}

/**
* Given a set of input partitions, returns those that have locations that differ from the
* Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by
* the user.
*
* @return a mapping from partition specs to their custom locations
*/
private def getCustomPartitionLocations(
spark: SparkSession,
table: CatalogTable,
basePath: Path,
partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
val hadoopConf = spark.sessionState.newHadoopConf
val fs = basePath.getFileSystem(hadoopConf)
val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory)
partitions.flatMap { p =>
val defaultLocation = qualifiedBasePath.suffix(
"/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString
val catalogLocation = new Path(p.location).makeQualified(
fs.getUri, fs.getWorkingDirectory).toString
if (catalogLocation != defaultLocation) {
Some(p.spec -> catalogLocation)
} else {
None
}
}.toMap
table,
Some(t.location))
}
}

Expand Down
Expand Up @@ -106,7 +106,7 @@ object FileSourceStrategy extends Strategy with Logging {
val outputAttributes = readDataColumns ++ partitionColumns

val scan =
new FileSourceScanExec(
FileSourceScanExec(
fsRelation,
outputAttributes,
outputSchema,
Expand Down
Expand Up @@ -23,11 +23,11 @@ import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.command._

/**
* A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
Expand All @@ -37,23 +37,18 @@ import org.apache.spark.sql.execution.command.RunnableCommand
* overwrites: when the spec is empty, all partitions are overwritten.
* When it covers a prefix of the partition keys, only partitions matching
* the prefix are overwritten.
* @param customPartitionLocations mapping of partition specs to their custom locations. The
* caller should guarantee that exactly those table partitions
* falling under the specified static partition keys are contained
* in this map, and that no other partitions are.
*/
case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
staticPartitions: TablePartitionSpec,
customPartitionLocations: Map[TablePartitionSpec, String],
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
refreshFunction: Seq[TablePartitionSpec] => Unit,
options: Map[String, String],
@transient query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable])
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex])
extends RunnableCommand {

import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
Expand All @@ -74,12 +69,30 @@ case class InsertIntoHadoopFsRelationCommand(
val fs = outputPath.getFileSystem(hadoopConf)
val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)

val partitionsTrackedByCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions &&
catalogTable.isDefined &&
catalogTable.get.partitionColumnNames.nonEmpty &&
catalogTable.get.tracksPartitionsInCatalog

var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty

// When partitions are tracked by the catalog, compute all custom partition locations that
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
val matchingPartitions = sparkSession.sessionState.catalog.listPartitions(
catalogTable.get.identifier, Some(staticPartitions))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations = getCustomPartitionLocations(
fs, catalogTable.get, qualifiedOutputPath, matchingPartitions)
}

val pathExists = fs.exists(qualifiedOutputPath)
val doInsertion = (mode, pathExists) match {
case (SaveMode.ErrorIfExists, true) =>
throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
case (SaveMode.Overwrite, true) =>
deleteMatchingPartitions(fs, qualifiedOutputPath)
deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations)
true
case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
true
Expand All @@ -98,6 +111,27 @@ case class InsertIntoHadoopFsRelationCommand(
outputPath = outputPath.toString,
isAppend = isAppend)

// Callback for updating metastore partition metadata after the insertion job completes.
def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
if (partitionsTrackedByCatalog) {
val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions
if (newPartitions.nonEmpty) {
AlterTableAddPartitionCommand(
catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
ifNotExists = true).run(sparkSession)
}
if (mode == SaveMode.Overwrite) {
val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
if (deletedPartitions.nonEmpty) {
AlterTableDropPartitionCommand(
catalogTable.get.identifier, deletedPartitions.toSeq,
ifExists = true, purge = false,
retainData = true /* already deleted */).run(sparkSession)
}
}
}
}

FileFormatWriter.write(
sparkSession = sparkSession,
queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
Expand All @@ -108,8 +142,10 @@ case class InsertIntoHadoopFsRelationCommand(
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
refreshFunction = refreshFunction,
refreshFunction = refreshPartitionsCallback,
options = options)

fileIndex.foreach(_.refresh())
} else {
logInfo("Skipping insertion into a relation that already exists.")
}
Expand All @@ -121,7 +157,10 @@ case class InsertIntoHadoopFsRelationCommand(
* Deletes all partition files that match the specified static prefix. Partitions with custom
* locations are also cleared based on the custom locations map given to this class.
*/
private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = {
private def deleteMatchingPartitions(
fs: FileSystem,
qualifiedOutputPath: Path,
customPartitionLocations: Map[TablePartitionSpec, String]): Unit = {
val staticPartitionPrefix = if (staticPartitions.nonEmpty) {
"/" + partitionColumns.flatMap { p =>
staticPartitions.get(p.name) match {
Expand Down Expand Up @@ -152,4 +191,29 @@ case class InsertIntoHadoopFsRelationCommand(
}
}
}

/**
* Given a set of input partitions, returns those that have locations that differ from the
* Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by
* the user.
*
* @return a mapping from partition specs to their custom locations
*/
private def getCustomPartitionLocations(
fs: FileSystem,
table: CatalogTable,
qualifiedOutputPath: Path,
partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
partitions.flatMap { p =>
val defaultLocation = qualifiedOutputPath.suffix(
"/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString
val catalogLocation = new Path(p.location).makeQualified(
fs.getUri, fs.getWorkingDirectory).toString
if (catalogLocation != defaultLocation) {
Some(p.spec -> catalogLocation)
} else {
None
}
}.toMap
}
}

0 comments on commit 8be30f0

Please sign in to comment.