Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh token branch-0.7 fix #391

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions server/src/main/protobuf/protocol.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ message QueryTableRequest {
repeated string predicateHints = 1;
optional string jsonPredicateHints = 6;
optional int64 limitHint = 2;
// Whether or not to return a refresh token in the response. Only used in latest snapshot query
// AND first page query. For long running queries, delta sharing spark may make additional request
// to refresh pre-signed urls, and there might be table changes between the initial request and
// the refresh request. The refresh token will contain version information to make sure that
// the refresh request returns the same set of files.
optional bool includeRefreshToken = 8;
// The refresh token used to refresh pre-signed urls. Only used in latest snapshot query AND
// first page query.
optional string refreshToken = 9;

// Only one of the three parameters can be supported in a single query.
// If none of them is specified, the query is for the latest version.
Expand Down Expand Up @@ -98,3 +107,13 @@ message PageToken {
optional string share = 2;
optional string schema = 3;
}

// Define a special class to generate the refresh token for latest snapshot query.
message RefreshToken {
// Id of the table being queried.
optional string id = 1;
// Only used in queryTable at snapshot, refers to the version being queried.
optional int64 version = 2;
// The expiration timestamp of the refresh token in milliseconds.
optional int64 expiration_timestamp = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ class DeltaSharingService(serverConfig: ServerConfig) {
version = None,
timestamp = None,
startingVersion = None,
endingVersion = None
endingVersion = None,
includeRefreshToken = false,
refreshToken = None
)
streamingOutput(Some(v), actions)
}
Expand All @@ -295,8 +297,8 @@ class DeltaSharingService(serverConfig: ServerConfig) {
@Param("schema") schema: String,
@Param("table") table: String,
request: QueryTableRequest): HttpResponse = processRequest {
val numVersionParams = Seq(request.version, request.timestamp, request.startingVersion)
.filter(_.isDefined).size
val numVersionParams =
Seq(request.version, request.timestamp, request.startingVersion).count(_.isDefined)
if (numVersionParams > 1) {
throw new DeltaSharingIllegalArgumentException(ErrorStrings.multipleParametersSetErrorMsg(
Seq("version", "timestamp", "startingVersion"))
Expand All @@ -308,6 +310,16 @@ class DeltaSharingService(serverConfig: ServerConfig) {
if (request.startingVersion.isDefined && request.startingVersion.get < 0) {
throw new DeltaSharingIllegalArgumentException("startingVersion cannot be negative.")
}
if (numVersionParams > 0 && request.includeRefreshToken.contains(true)) {
throw new DeltaSharingIllegalArgumentException(
"includeRefreshToken cannot be used when querying a specific version."
)
}
if (numVersionParams > 0 && request.refreshToken.isDefined) {
throw new DeltaSharingIllegalArgumentException(
"refreshToken cannot be used when querying a specific version."
)
}

val start = System.currentTimeMillis
val tableConfig = sharedTableManager.getTable(share, schema, table)
Expand Down Expand Up @@ -339,7 +351,9 @@ class DeltaSharingService(serverConfig: ServerConfig) {
request.version,
request.timestamp,
request.startingVersion,
request.endingVersion
request.endingVersion,
request.includeRefreshToken.getOrElse(false),
request.refreshToken
)
if (version < tableConfig.startVersion) {
throw new DeltaSharingIllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ case class ServerConfig(
// Whether to evaluate user provided `jsonPredicateHints`
@BeanProperty var evaluateJsonPredicateHints: Boolean,
// The timeout of an incoming web request in seconds. Set to 0 for no timeout
@BeanProperty var requestTimeoutSeconds: Long
@BeanProperty var requestTimeoutSeconds: Long,
// The TTL of the refresh token generated in queryTable API (in milliseconds).
@BeanProperty var refreshTokenTtlMs: Int
) extends ConfigItem {
import ServerConfig._

Expand All @@ -76,7 +78,8 @@ case class ServerConfig(
stalenessAcceptable = false,
evaluatePredicateHints = false,
evaluateJsonPredicateHints = false,
requestTimeoutSeconds = 30
requestTimeoutSeconds = 30,
refreshTokenTtlMs = 3600000 // 1 hour
)
}

Expand Down
11 changes: 10 additions & 1 deletion server/src/main/scala/io/delta/sharing/server/model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ case class SingleAction(
cdf: AddCDCFile = null,
remove: RemoveFile = null,
metaData: Metadata = null,
protocol: Protocol = null) {
protocol: Protocol = null,
endStreamAction: EndStreamAction = null) {

def unwrap: Action = {
if (file != null) {
Expand All @@ -40,6 +41,8 @@ case class SingleAction(
metaData
} else if (protocol != null) {
protocol
} else if (endStreamAction != null) {
endStreamAction
} else {
null
}
Expand Down Expand Up @@ -128,6 +131,12 @@ case class RemoveFile(
override def wrap: SingleAction = SingleAction(remove = this)
}

case class EndStreamAction(
refreshToken: String
) extends Action {
override def wrap: SingleAction = SingleAction(endStreamAction = this)
}

object Action {
// The maximum version of the protocol that this version of Delta Standalone understands.
val maxReaderVersion = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.standalone.internal

import java.net.URI
import java.nio.charset.StandardCharsets.UTF_8
import java.util.Base64
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
Expand All @@ -38,6 +39,7 @@ import org.apache.hadoop.fs.azurebfs.AzureBlobFileSystem
import org.apache.hadoop.fs.s3a.S3AFileSystem
import org.apache.spark.sql.types.{DataType, MetadataBuilder, StructType}
import scala.collection.mutable.ListBuffer
import scala.util.control.NonFatal

import io.delta.sharing.server.{
model,
Expand All @@ -51,6 +53,7 @@ import io.delta.sharing.server.{
WasbFileSigner
}
import io.delta.sharing.server.config.{ServerConfig, TableConfig}
import io.delta.sharing.server.protocol.RefreshToken

/**
* A class to load Delta tables from `TableConfig`. It also caches the loaded tables internally
Expand All @@ -72,7 +75,8 @@ class DeltaSharedTableLoader(serverConfig: ServerConfig) {
tableConfig,
serverConfig.preSignedUrlTimeoutSeconds,
serverConfig.evaluatePredicateHints,
serverConfig.evaluateJsonPredicateHints)
serverConfig.evaluateJsonPredicateHints,
serverConfig.refreshTokenTtlMs)
})
if (!serverConfig.stalenessAcceptable) {
deltaSharedTable.update()
Expand All @@ -93,7 +97,8 @@ class DeltaSharedTable(
tableConfig: TableConfig,
preSignedUrlTimeoutSeconds: Long,
evaluatePredicateHints: Boolean,
evaluateJsonPredicateHints: Boolean) {
evaluateJsonPredicateHints: Boolean,
refreshTokenTtlMs: Int) {

private val conf = withClassLoader {
new Configuration()
Expand Down Expand Up @@ -188,16 +193,21 @@ class DeltaSharedTable(
version: Option[Long],
timestamp: Option[String],
startingVersion: Option[Long],
endingVersion: Option[Long]
endingVersion: Option[Long],
includeRefreshToken: Boolean,
refreshToken: Option[String]
): (Long, Seq[model.SingleAction]) = withClassLoader {
// TODO Support `limitHint`
if (Seq(version, timestamp, startingVersion).filter(_.isDefined).size >= 2) {
throw new DeltaSharingIllegalArgumentException(
ErrorStrings.multipleParametersSetErrorMsg(Seq("version", "timestamp", "startingVersion"))
)
}
val snapshot = if (version.orElse(startingVersion).isDefined) {
deltaLog.getSnapshotForVersionAsOf(version.orElse(startingVersion).get)
// Validate refreshToken if it's specified
val refreshTokenOpt = refreshToken.map(decodeAndValidateRefreshToken)
val specifiedVersion = version.orElse(startingVersion).orElse(refreshTokenOpt.map(_.getVersion))
val snapshot = if (specifiedVersion.isDefined) {
deltaLog.getSnapshotForVersionAsOf(specifiedVersion.get)
} else if (timestamp.isDefined) {
val ts = DeltaSharingHistoryManager.getTimestamp("timestamp", timestamp.get)
try {
Expand Down Expand Up @@ -268,7 +278,7 @@ class DeltaSharedTable(
} else {
filteredFiles
}
filteredFiles.map { addFile =>
val signedFiles = filteredFiles.map { addFile =>
val cloudPath = absolutePath(deltaLog.dataPath, addFile.path)
val signedUrl = fileSigner.sign(cloudPath)
val modelAddFile = model.AddFile(
Expand All @@ -283,6 +293,22 @@ class DeltaSharedTable(
)
modelAddFile.wrap
}
signedFiles ++ {
// For backwards compatibility, return an `endStreamAction` object only when
// `includeRefreshToken` is true
if (includeRefreshToken) {
val refreshTokenStr = encodeRefreshToken(
RefreshToken(
id = Some(tableConfig.id),
version = Some(snapshot.version),
expirationTimestamp = Some(System.currentTimeMillis() + refreshTokenTtlMs)
)
)
Seq(model.EndStreamAction(refreshTokenStr).wrap)
} else {
Nil
}
}
} else {
Nil
}
Expand Down Expand Up @@ -512,4 +538,30 @@ class DeltaSharedTable(
new Path(path, p)
}
}

private def decodeAndValidateRefreshToken(tokenStr: String): RefreshToken = {
val token = try {
RefreshToken.parseFrom(Base64.getUrlDecoder.decode(tokenStr))
} catch {
case NonFatal(_) =>
throw new DeltaSharingIllegalArgumentException(
s"Error decoding refresh token: $tokenStr."
)
}
if (token.getExpirationTimestamp < System.currentTimeMillis()) {
throw new DeltaSharingIllegalArgumentException(
"The refresh token has expired. Please restart the query."
)
}
if (token.getId != tableConfig.id) {
throw new DeltaSharingIllegalArgumentException(
"The table specified in the refresh token does not match the table being queried."
)
}
token
}

private def encodeRefreshToken(token: RefreshToken): String = {
Base64.getUrlEncoder.encodeToString(token.toByteArray)
}
}
2 changes: 2 additions & 0 deletions server/src/universal/conf/delta-sharing-server.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,5 @@ stalenessAcceptable: false
evaluatePredicateHints: false
# Whether to evaluate user provided `jsonPredicateHints`
evaluateJsonPredicateHints: false
# The TTL of the refresh token generated in queryTable API (in milliseconds).
refreshTokenTtlMs: 3600000
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ private[sharing] trait DeltaSharingClient {
limit: Option[Long],
versionAsOf: Option[Long],
timestampAsOf: Option[String],
jsonPredicateHints: Option[String]): DeltaTableFiles
jsonPredicateHints: Option[String],
refreshToken: Option[String]): DeltaTableFiles

def getFiles(table: Table, startingVersion: Long, endingVersion: Option[Long]): DeltaTableFiles

Expand All @@ -81,7 +82,9 @@ private[sharing] case class QueryTableRequest(
timestamp: Option[String],
startingVersion: Option[Long],
endingVersion: Option[Long],
jsonPredicateHints: Option[String]
jsonPredicateHints: Option[String],
includeRefreshToken: Option[Boolean],
refreshToken: Option[String]
)

private[sharing] case class ListSharesResponse(
Expand All @@ -99,7 +102,8 @@ private[spark] class DeltaSharingRestClient(
numRetries: Int = 10,
maxRetryDuration: Long = Long.MaxValue,
sslTrustAll: Boolean = false,
forStreaming: Boolean = false) extends DeltaSharingClient {
forStreaming: Boolean = false
) extends DeltaSharingClient with Logging {

@volatile private var created = false

Expand Down Expand Up @@ -228,7 +232,10 @@ private[spark] class DeltaSharingRestClient(
limit: Option[Long],
versionAsOf: Option[Long],
timestampAsOf: Option[String],
jsonPredicateHints: Option[String]): DeltaTableFiles = {
jsonPredicateHints: Option[String],
refreshToken: Option[String]): DeltaTableFiles = {
// Retrieve refresh token when querying the latest snapshot.
val includeRefreshToken = versionAsOf.isEmpty && timestampAsOf.isEmpty
val encodedShareName = URLEncoder.encode(table.share, "UTF-8")
val encodedSchemaName = URLEncoder.encode(table.schema, "UTF-8")
val encodedTableName = URLEncoder.encode(table.name, "UTF-8")
Expand All @@ -243,15 +250,26 @@ private[spark] class DeltaSharingRestClient(
timestampAsOf,
None,
None,
jsonPredicateHints
jsonPredicateHints,
Some(includeRefreshToken),
refreshToken
)
)
val (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines)
val refreshTokenOpt = endStreamAction.flatMap { e =>
Option(e.refreshToken).flatMap { token =>
if (token.isEmpty) None else Some(token)
}
}
if (includeRefreshToken && refreshTokenOpt.isEmpty) {
logWarning("includeRefreshToken=true but refresh token is not returned.")
}
require(versionAsOf.isEmpty || versionAsOf.get == version)
val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
val protocol = JsonUtils.fromJson[SingleAction](filteredLines(0)).protocol
checkProtocol(protocol)
val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData
val files = lines.drop(2).map(line => JsonUtils.fromJson[SingleAction](line).file)
DeltaTableFiles(version, protocol, metadata, files)
val metadata = JsonUtils.fromJson[SingleAction](filteredLines(1)).metaData
val files = filteredLines.drop(2).map(line => JsonUtils.fromJson[SingleAction](line).file)
DeltaTableFiles(version, protocol, metadata, files, refreshToken = refreshTokenOpt)
}

override def getFiles(
Expand All @@ -265,7 +283,19 @@ private[spark] class DeltaSharingRestClient(
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/query")
val (version, lines) = getNDJson(
target, QueryTableRequest(Nil, None, None, None, Some(startingVersion), endingVersion, None))
target,
QueryTableRequest(
/* predicateHint */ Nil,
/* limitHint */ None,
/* version */ None,
/* timestamp */ None,
Some(startingVersion),
endingVersion,
/* jsonPredicateHints */ None,
/* includeRefreshToken */ None,
/* refreshToken */ None
)
)
val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
checkProtocol(protocol)
val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData
Expand Down Expand Up @@ -326,6 +356,17 @@ private[spark] class DeltaSharingRestClient(
)
}

// Check the last line and extract EndStreamAction if there is one.
private def maybeExtractEndStreamAction(
lines: Seq[String]): (Seq[String], Option[EndStreamAction]) = {
val endStreamAction = JsonUtils.fromJson[SingleAction](lines.last).endStreamAction
if (endStreamAction == null) {
(lines, None)
} else {
(lines.init, Some(endStreamAction))
}
}

private def getEncodedCDFParams(
cdfOptions: Map[String, String],
includeHistoricalMetadata: Boolean): String = {
Expand Down