Skip to content

Commit

Permalink
Pass around AWSCredentialProviders instead of AWSCredentials
Browse files Browse the repository at this point in the history
This patch refactors the library internals to pass around `AWSCredentialProvider` instances instead of `AWSCredentials`, helping to avoid issues where temporary credentials are obtained at the start of a read or write operation and then expire when they are re-used later in the operation.

This fixes #200.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #284 from JoshRosen/credential-expiry.
  • Loading branch information
JoshRosen committed Oct 19, 2016
1 parent 51c29e6 commit bdf4462
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 30 deletions.
Expand Up @@ -18,7 +18,7 @@ package com.databricks.spark.redshift

import java.net.URI

import com.amazonaws.auth.{BasicAWSCredentials, AWSCredentials, AWSSessionCredentials, DefaultAWSCredentialsProviderChain}
import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, AWSSessionCredentials, BasicAWSCredentials, DefaultAWSCredentialsProviderChain}
import org.apache.hadoop.conf.Configuration

import com.databricks.spark.redshift.Parameters.MergedParameters
Expand All @@ -44,11 +44,20 @@ private[redshift] object AWSCredentialsUtils {
})
}

def load(params: MergedParameters, hadoopConfiguration: Configuration): AWSCredentials = {
def staticCredentialsProvider(credentials: AWSCredentials): AWSCredentialsProvider = {
new AWSCredentialsProvider {
override def getCredentials: AWSCredentials = credentials
override def refresh(): Unit = {}
}
}

def load(params: MergedParameters, hadoopConfiguration: Configuration): AWSCredentialsProvider = {
params.temporaryAWSCredentials.getOrElse(loadFromURI(params.rootTempDir, hadoopConfiguration))
}

private def loadFromURI(tempPath: String, hadoopConfiguration: Configuration): AWSCredentials = {
private def loadFromURI(
tempPath: String,
hadoopConfiguration: Configuration): AWSCredentialsProvider = {
// scalastyle:off
// A good reference on Hadoop's configuration loading / precedence is
// https://github.com/apache/hadoop/blob/trunk/hadoop-tools/hadoop-aws/src/site/markdown/tools/hadoop-aws/index.md
Expand All @@ -63,7 +72,7 @@ private[redshift] object AWSCredentialsUtils {
Option(uri.getUserInfo).flatMap { userInfo =>
if (userInfo.contains(":")) {
val Array(accessKey, secretKey) = userInfo.split(":")
Some(new BasicAWSCredentials(accessKey, secretKey))
Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)))
} else {
None
}
Expand All @@ -75,13 +84,13 @@ private[redshift] object AWSCredentialsUtils {
val accessKey = hadoopConfiguration.get(s"fs.$uriScheme.$accessKeyConfig", null)
val secretKey = hadoopConfiguration.get(s"fs.$uriScheme.$secretKeyConfig", null)
if (accessKey != null && secretKey != null) {
Some(new BasicAWSCredentials(accessKey, secretKey))
Some(staticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)))
} else {
None
}
}.getOrElse {
// Finally, fall back on the instance profile provider
new DefaultAWSCredentialsProviderChain().getCredentials
new DefaultAWSCredentialsProviderChain()
}
case other =>
throw new IllegalArgumentException(s"Unrecognized scheme $other; expected s3, s3n, or s3a")
Expand Down
Expand Up @@ -16,7 +16,7 @@

package com.databricks.spark.redshift

import com.amazonaws.auth.AWSCredentials
import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.s3.AmazonS3Client
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType
Expand All @@ -26,7 +26,9 @@ import org.slf4j.LoggerFactory
/**
* Redshift Source implementation for Spark SQL
*/
class DefaultSource(jdbcWrapper: JDBCWrapper, s3ClientFactory: AWSCredentials => AmazonS3Client)
class DefaultSource(
jdbcWrapper: JDBCWrapper,
s3ClientFactory: AWSCredentialsProvider => AmazonS3Client)
extends RelationProvider
with SchemaRelationProvider
with CreatableRelationProvider {
Expand Down
9 changes: 6 additions & 3 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Expand Up @@ -16,7 +16,7 @@

package com.databricks.spark.redshift

import com.amazonaws.auth.{AWSCredentials, BasicSessionCredentials}
import com.amazonaws.auth.{AWSCredentialsProvider, BasicSessionCredentials}

/**
* All user-specifiable parameters for spark-redshift, along with their validation rules and
Expand Down Expand Up @@ -235,12 +235,15 @@ private[redshift] object Parameters {
* the user when Hadoop is configured to authenticate to S3 via IAM roles assigned to EC2
* instances.
*/
def temporaryAWSCredentials: Option[AWSCredentials] = {
def temporaryAWSCredentials: Option[AWSCredentialsProvider] = {
for (
accessKey <- parameters.get("temporary_aws_access_key_id");
secretAccessKey <- parameters.get("temporary_aws_secret_access_key");
sessionToken <- parameters.get("temporary_aws_session_token")
) yield new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken)
) yield {
AWSCredentialsUtils.staticCredentialsProvider(
new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken))
}
}
}
}
Expand Up @@ -22,7 +22,7 @@ import java.net.URI

import scala.collection.JavaConverters._

import com.amazonaws.auth.AWSCredentials
import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.s3.AmazonS3Client
import com.eclipsesource.json.Json
import org.apache.spark.rdd.RDD
Expand All @@ -38,7 +38,7 @@ import com.databricks.spark.redshift.Parameters.MergedParameters
*/
private[redshift] case class RedshiftRelation(
jdbcWrapper: JDBCWrapper,
s3ClientFactory: AWSCredentials => AmazonS3Client,
s3ClientFactory: AWSCredentialsProvider => AmazonS3Client,
params: MergedParameters,
userSchema: Option[StructType])
(@transient val sqlContext: SQLContext)
Expand Down Expand Up @@ -111,7 +111,7 @@ private[redshift] case class RedshiftRelation(
} else {
// Unload data from Redshift into a temporary directory in S3:
val tempDir = params.createPerQueryTempDir()
val unloadSql = buildUnloadStmt(requiredColumns, filters, tempDir)
val unloadSql = buildUnloadStmt(requiredColumns, filters, tempDir, creds)
log.info(unloadSql)
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials)
try {
Expand Down Expand Up @@ -162,13 +162,14 @@ private[redshift] case class RedshiftRelation(
private def buildUnloadStmt(
requiredColumns: Array[String],
filters: Array[Filter],
tempDir: String): String = {
tempDir: String,
creds: AWSCredentialsProvider): String = {
assert(!requiredColumns.isEmpty)
// Always quote column names:
val columnList = requiredColumns.map(col => s""""$col"""").mkString(", ")
val whereClause = FilterPushdown.buildWhereClause(schema, filters)
val creds = AWSCredentialsUtils.load(params, sqlContext.sparkContext.hadoopConfiguration)
val credsString: String = AWSCredentialsUtils.getRedshiftCredentialsString(params, creds)
val credsString: String =
AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials)
val query = {
// Since the query passed to UNLOAD will be enclosed in single quotes, we need to escape
// any backslashes and single quotes that appear in the query itself
Expand Down
Expand Up @@ -19,7 +19,7 @@ package com.databricks.spark.redshift
import java.net.URI
import java.sql.{Connection, Date, SQLException, Timestamp}

import com.amazonaws.auth.AWSCredentials
import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.s3.AmazonS3Client
import org.apache.hadoop.fs.{FileSystem, Path}

Expand Down Expand Up @@ -59,7 +59,7 @@ import org.apache.spark.sql.types._
*/
private[redshift] class RedshiftWriter(
jdbcWrapper: JDBCWrapper,
s3ClientFactory: AWSCredentials => AmazonS3Client) {
s3ClientFactory: AWSCredentialsProvider => AmazonS3Client) {

private val log = LoggerFactory.getLogger(getClass)

Expand Down Expand Up @@ -89,9 +89,10 @@ private[redshift] class RedshiftWriter(
private def copySql(
sqlContext: SQLContext,
params: MergedParameters,
creds: AWSCredentials,
creds: AWSCredentialsProvider,
manifestUrl: String): String = {
val credsString: String = AWSCredentialsUtils.getRedshiftCredentialsString(params, creds)
val credsString: String =
AWSCredentialsUtils.getRedshiftCredentialsString(params, creds.getCredentials)
val fixedUrl = Utils.fixS3Url(manifestUrl)
s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
s"AVRO 'auto' manifest ${params.extraCopyOptions}"
Expand All @@ -116,7 +117,7 @@ private[redshift] class RedshiftWriter(
conn: Connection,
data: DataFrame,
params: MergedParameters,
creds: AWSCredentials,
creds: AWSCredentialsProvider,
manifestUrl: Option[String]): Unit = {

// If the table doesn't exist, we need to create it first, using JDBC to infer column types
Expand Down Expand Up @@ -334,7 +335,7 @@ private[redshift] class RedshiftWriter(
"https://github.com/databricks/spark-redshift/pull/157")
}

val creds: AWSCredentials =
val creds: AWSCredentialsProvider =
AWSCredentialsUtils.load(params, sqlContext.sparkContext.hadoopConfiguration)

Utils.assertThatFileSystemIsNotS3BlockFileSystem(
Expand Down
Expand Up @@ -65,7 +65,7 @@ class AWSCredentialsUtilsSuite extends FunSuite {
"temporary_aws_session_token" -> "token"
))

val creds = AWSCredentialsUtils.load(params, conf)
val creds = AWSCredentialsUtils.load(params, conf).getCredentials
assert(creds.isInstanceOf[AWSSessionCredentials])
assert(creds.getAWSAccessKeyId === "key_id")
assert(creds.getAWSSecretKey === "secret")
Expand All @@ -78,13 +78,13 @@ class AWSCredentialsUtilsSuite extends FunSuite {
conf.set("fs.s3.awsSecretAccessKey", "CONFKEY")

{
val creds = AWSCredentialsUtils.load("s3://URIID:URIKEY@bucket/path", conf)
val creds = AWSCredentialsUtils.load("s3://URIID:URIKEY@bucket/path", conf).getCredentials
assert(creds.getAWSAccessKeyId === "URIID")
assert(creds.getAWSSecretKey === "URIKEY")
}

{
val creds = AWSCredentialsUtils.load("s3://bucket/path", conf)
val creds = AWSCredentialsUtils.load("s3://bucket/path", conf).getCredentials
assert(creds.getAWSAccessKeyId === "CONFID")
assert(creds.getAWSSecretKey === "CONFKEY")
}
Expand All @@ -97,13 +97,13 @@ class AWSCredentialsUtilsSuite extends FunSuite {
conf.set("fs.s3n.awsSecretAccessKey", "CONFKEY")

{
val creds = AWSCredentialsUtils.load("s3n://URIID:URIKEY@bucket/path", conf)
val creds = AWSCredentialsUtils.load("s3n://URIID:URIKEY@bucket/path", conf).getCredentials
assert(creds.getAWSAccessKeyId === "URIID")
assert(creds.getAWSSecretKey === "URIKEY")
}

{
val creds = AWSCredentialsUtils.load("s3n://bucket/path", conf)
val creds = AWSCredentialsUtils.load("s3n://bucket/path", conf).getCredentials
assert(creds.getAWSAccessKeyId === "CONFID")
assert(creds.getAWSSecretKey === "CONFKEY")
}
Expand All @@ -116,13 +116,13 @@ class AWSCredentialsUtilsSuite extends FunSuite {
conf.set("fs.s3a.secret.key", "CONFKEY")

{
val creds = AWSCredentialsUtils.load("s3a://URIID:URIKEY@bucket/path", conf)
val creds = AWSCredentialsUtils.load("s3a://URIID:URIKEY@bucket/path", conf).getCredentials
assert(creds.getAWSAccessKeyId === "URIID")
assert(creds.getAWSSecretKey === "URIKEY")
}

{
val creds = AWSCredentialsUtils.load("s3a://bucket/path", conf)
val creds = AWSCredentialsUtils.load("s3a://bucket/path", conf).getCredentials
assert(creds.getAWSAccessKeyId === "CONFID")
assert(creds.getAWSSecretKey === "CONFKEY")
}
Expand Down

0 comments on commit bdf4462

Please sign in to comment.