Skip to content

Commit

Permalink
[Spark] Read side changes for v2 checkpoints
Browse files Browse the repository at this point in the history
This PR adds read side changes for v2 checkpoints.

Closes #2056

GitOrigin-RevId: 3673bb576aed5e1b572f2dfc4b69e829ae9555a6
  • Loading branch information
prakharjain09 authored and vkorukanti committed Sep 25, 2023
1 parent 4622db6 commit 6859c86
Show file tree
Hide file tree
Showing 11 changed files with 803 additions and 24 deletions.

Large diffs are not rendered by default.

69 changes: 66 additions & 3 deletions spark/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala
Expand Up @@ -54,17 +54,28 @@ import org.apache.spark.util.Utils
/**
* A class to help with comparing checkpoints with each other, where we may have had concurrent
* writers that checkpoint with different number of parts.
* The `numParts` field will be present only for multipart checkpoints (represented by
* Format.WITH_PARTS).
* The `fileName` field is present only for V2 Checkpoints (represented by Format.V2)
* These additional fields are used as a tie breaker when comparing multiple checkpoint
* instance of same Format for the same `version`.
*/
case class CheckpointInstance(
version: Long,
format: CheckpointInstance.Format,
fileName: Option[String] = None,
numParts: Option[Int] = None) extends Ordered[CheckpointInstance] {

// Assert that numParts are present when checkpoint format is Format.WITH_PARTS.
// For other formats, numParts must be None.
require((format == CheckpointInstance.Format.WITH_PARTS) == numParts.isDefined,
s"numParts ($numParts) must be present for checkpoint format" +
s" ${CheckpointInstance.Format.WITH_PARTS.name}")
// Assert that filePath is present only when checkpoint format is Format.V2.
// For other formats, filePath must be None.
require((format == CheckpointInstance.Format.V2) == fileName.isDefined,
s"fileName ($fileName) must be present for checkpoint format" +
s" ${CheckpointInstance.Format.V2.name}")

/**
* Returns a [[CheckpointProvider]] which can tell the files corresponding to this
Expand All @@ -81,7 +92,26 @@ case class CheckpointInstance(
val lastCheckpointInfo = lastCheckpointInfoHint.filter(cm => CheckpointInstance(cm) == this)
val cpFiles = filterFiles(deltaLog, filesForCheckpointConstruction)
format match {
case CheckpointInstance.Format.WITH_PARTS | CheckpointInstance.Format.SINGLE =>
// Treat single file checkpoints also as V2 Checkpoints because we don't know if it is
// actually a V2 checkpoint until we read it.
case CheckpointInstance.Format.V2 | CheckpointInstance.Format.SINGLE =>
assert(cpFiles.size == 1)
val fileStatus = cpFiles.head
if (format == CheckpointInstance.Format.V2) {
val hadoopConf = deltaLog.newDeltaHadoopConf()
UninitializedV2CheckpointProvider(
version,
fileStatus,
logPath,
hadoopConf,
deltaLog.options,
deltaLog.store,
lastCheckpointInfo)
} else {
UninitializedV1OrV2ParquetCheckpointProvider(
version, fileStatus, logPath, lastCheckpointInfo)
}
case CheckpointInstance.Format.WITH_PARTS =>
PreloadedCheckpointProvider(cpFiles, lastCheckpointInfo)
case CheckpointInstance.Format.SENTINEL =>
throw DeltaErrors.assertionFailedError(
Expand All @@ -93,6 +123,23 @@ case class CheckpointInstance(
filesForCheckpointConstruction: Seq[FileStatus]) : Seq[FileStatus] = {
val logPath = deltaLog.logPath
format match {
// Treat Single File checkpoints also as V2 Checkpoints because we don't know if it is
// actually a V2 checkpoint until we read it.
case format if format.usesSidecars =>
val checkpointFileName = format match {
case CheckpointInstance.Format.V2 => fileName.get
case CheckpointInstance.Format.SINGLE => checkpointFileSingular(logPath, version).getName
case other =>
throw new IllegalStateException(s"Unknown checkpoint format $other supporting sidecars")
}
val fileStatus = filesForCheckpointConstruction
.find(_.getPath.getName == checkpointFileName)
.getOrElse {
throw new IllegalStateException("Failed in getting the file information for:\n" +
fileName.get + "\namong\n" +
filesForCheckpointConstruction.map(_.getPath.getName).mkString(" -", "\n -", ""))
}
Seq(fileStatus)
case CheckpointInstance.Format.WITH_PARTS | CheckpointInstance.Format.SINGLE =>
val filePaths = if (format == CheckpointInstance.Format.WITH_PARTS) {
checkpointFileWithParts(logPath, version, numParts.get).toSet
Expand All @@ -119,28 +166,35 @@ case class CheckpointInstance(
* Single part checkpoint.
* 3. For Multi-part [[CheckpointInstance]]s corresponding to same version, the one with more
* parts is greater than the one with less parts.
* 4. For V2 Checkpoints corresponding to same version, we use the fileName as tie breaker.
*/
override def compare(other: CheckpointInstance): Int = {
(version, format, numParts) compare (other.version, other.format, other.numParts)
(version, format, numParts, fileName) compare
(other.version, other.format, other.numParts, other.fileName)
}
}

object CheckpointInstance {
sealed abstract class Format(val ordinal: Int, val name: String) extends Ordered[Format] {
override def compare(other: Format): Int = ordinal compare other.ordinal
def usesSidecars: Boolean = this.isInstanceOf[FormatUsesSidecars]
}
trait FormatUsesSidecars

object Format {
def unapply(name: String): Option[Format] = name match {
case SINGLE.name => Some(SINGLE)
case WITH_PARTS.name => Some(WITH_PARTS)
case V2.name => Some(V2)
case _ => None
}

/** single-file checkpoint format */
object SINGLE extends Format(0, "SINGLE")
object SINGLE extends Format(0, "SINGLE") with FormatUsesSidecars
/** multi-file checkpoint format */
object WITH_PARTS extends Format(1, "WITH_PARTS")
/** V2 Checkpoint format */
object V2 extends Format(2, "V2") with FormatUsesSidecars
/** Sentinel, for internal use only */
object SENTINEL extends Format(Int.MaxValue, "SENTINEL")
}
Expand All @@ -149,7 +203,14 @@ object CheckpointInstance {
// Three formats to worry about:
// * <version>.checkpoint.parquet
// * <version>.checkpoint.<i>.<n>.parquet
// * <version>.checkpoint.<u>.parquet where u is a unique string
path.getName.split("\\.") match {
case Array(v, "checkpoint", uniqueStr, format) if Seq("json", "parquet").contains(format) =>
CheckpointInstance(
version = v.toLong,
format = Format.V2,
numParts = None,
fileName = Some(path.getName))
case Array(v, "checkpoint", "parquet") =>
CheckpointInstance(v.toLong, Format.SINGLE, numParts = None)
case Array(v, "checkpoint", _, n, "parquet") =>
Expand Down Expand Up @@ -384,6 +445,8 @@ trait Checkpoints extends DeltaLogging {
case CheckpointInstance.Format.WITH_PARTS =>
assert(ci.numParts.nonEmpty, "Multi-Part Checkpoint must have non empty numParts")
matchingCheckpointInstances.length == ci.numParts.get
case CheckpointInstance.Format.V2 =>
matchingCheckpointInstances.length == 1
case CheckpointInstance.Format.SENTINEL =>
false
}
Expand Down
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.schema.{SchemaMergingUtils, SchemaUtils}
import org.apache.spark.sql.delta.sources._
import org.apache.spark.sql.delta.storage.LogStoreProvider
import org.apache.spark.sql.delta.util.FileNames
import com.google.common.cache.{CacheBuilder, RemovalNotification}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
Expand Down Expand Up @@ -86,6 +87,14 @@ class DeltaLog private(
import org.apache.spark.sql.delta.files.TahoeFileIndex
import org.apache.spark.sql.delta.util.FileNames._

/**
* Path to sidecar directory.
* This is intentionally kept `lazy val` as otherwise any other constructor codepaths in DeltaLog
* (e.g. SnapshotManagement etc) will see it as null as they are executed before this line is
* called.
*/
lazy val sidecarDirPath: Path = FileNames.sidecarDirPath(logPath)


protected def spark = SparkSession.active

Expand Down
14 changes: 13 additions & 1 deletion spark/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala
Expand Up @@ -207,12 +207,22 @@ class Snapshot(
* Pulls the protocol and metadata of the table from the files that are used to compute the
* Snapshot directly--without triggering a full state reconstruction. This is important, because
* state reconstruction depends on protocol and metadata for correctness.
*
* Also this method should only access methods defined in [[UninitializedCheckpointProvider]]
* which are not present in [[CheckpointProvider]]. This is because initialization of
* [[Snapshot.checkpointProvider]] depends on [[Snapshot.protocolAndMetadataReconstruction()]]
* and so if [[Snapshot.protocolAndMetadataReconstruction()]] starts depending on
* [[Snapshot.checkpointProvider]] then there will be cyclic dependency.
*/
protected def protocolAndMetadataReconstruction(): Array[(Protocol, Metadata)] = {
import implicits._

val schemaToUse = Action.logSchema(Set("protocol", "metaData"))
fileIndices.map(deltaLog.loadIndex(_, schemaToUse))
val checkpointOpt = checkpointProvider.topLevelFileIndex.map { index =>
deltaLog.loadIndex(index, schemaToUse)
.withColumn(COMMIT_VERSION_COLUMN, lit(checkpointProvider.version))
}
(checkpointOpt ++ deltaFileIndexOpt.map(deltaLog.loadIndex(_, schemaToUse)).toSeq)
.reduceOption(_.union(_)).getOrElse(emptyDF)
.select("protocol", "metaData", COMMIT_VERSION_COLUMN)
.where("protocol.minReaderVersion is not null or metaData.id is not null")
Expand Down Expand Up @@ -368,6 +378,8 @@ class Snapshot(
/** The [[CheckpointProvider]] for the underlying checkpoint */
lazy val checkpointProvider: CheckpointProvider = logSegment.checkpointProvider match {
case cp: CheckpointProvider => cp
case uninitializedProvider: UninitializedCheckpointProvider =>
CheckpointProvider(spark, this, checksumOpt, uninitializedProvider)
case o => throw new IllegalStateException(s"Unknown checkpoint provider: ${o.getClass.getName}")
}

Expand Down
Expand Up @@ -813,6 +813,14 @@ trait SnapshotManagement { self: DeltaLog =>
}

object SnapshotManagement {
// A thread pool for reading checkpoint files and collecting checkpoint v2 actions like
// checkpointMetadata, sidecarFiles.
private[delta] lazy val checkpointV2ThreadPool = {
val numThreads = SparkSession.active.sessionState.conf.getConf(
DeltaSQLConf.CHECKPOINT_V2_DRIVER_THREADPOOL_PARALLELISM)
DeltaThreadPool("checkpointV2-threadpool", numThreads)
}

protected[delta] lazy val deltaLogAsyncUpdateThreadPool = {
val tpe = ThreadUtils.newDaemonCachedThreadPool("delta-state-update", 8)
new DeltaThreadPool(tpe)
Expand Down
Expand Up @@ -59,6 +59,7 @@ class InMemoryLogReplay(
domainMetadatas.remove(a.domain)
case a: DomainMetadata if !a.removed =>
domainMetadatas(a.domain) = a
case _: CheckpointOnlyAction => // Ignore this while doing LogReplay
case a: Metadata =>
currentMetaData = a
case a: Protocol =>
Expand Down
Expand Up @@ -576,6 +576,14 @@ trait DeltaSQLConfBase {
// Checkpoint V2 Specific Configs
////////////////////////////////////

val CHECKPOINT_V2_DRIVER_THREADPOOL_PARALLELISM =
buildStaticConf("checkpointV2.threadpool.size")
.doc("The size of the threadpool for fetching CheckpointMetadata and SidecarFiles from a" +
" checkpoint.")
.internal()
.intConf
.createWithDefault(32)

val CHECKPOINT_V2_TOP_LEVEL_FILE_FORMAT =
buildConf("checkpointV2.topLevelFileFormat")
.internal()
Expand Down
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.delta.util
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.delta.{DeltaHistory, DeltaHistoryManager, SerializableFileStatus, SnapshotState}
import org.apache.spark.sql.delta.actions.{AddFile, Metadata, Protocol, RemoveFile, SingleAction}
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.commands.convert.ConvertTargetFile
import org.apache.spark.sql.delta.sources.IndexedFile

Expand Down Expand Up @@ -77,6 +77,10 @@ private[delta] trait DeltaEncoders {
private lazy val _pmvEncoder = new DeltaEncoder[(Protocol, Metadata, Long)]
implicit def pmvEncoder: Encoder[(Protocol, Metadata, Long)] = _pmvEncoder.get

private lazy val _v2CheckpointActionsEncoder = new DeltaEncoder[(CheckpointMetadata, SidecarFile)]
implicit def v2CheckpointActionsEncoder: Encoder[(CheckpointMetadata, SidecarFile)] =
_v2CheckpointActionsEncoder.get

private lazy val _serializableFileStatusEncoder = new DeltaEncoder[SerializableFileStatus]
implicit def serializableFileStatusEncoder: Encoder[SerializableFileStatus] =
_serializableFileStatusEncoder.get
Expand Down

0 comments on commit 6859c86

Please sign in to comment.