diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt b/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt index 2045624092..7186a28e0b 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt @@ -54,13 +54,13 @@ public class Auth private constructor( val candidateProviders = HashSet(plugin.providers) while (call.response.status == HttpStatusCode.Unauthorized) { - val headerValue = call.response.headers[HttpHeaders.WWWAuthenticate] + val headerValues = call.response.headers.getAll(HttpHeaders.WWWAuthenticate) + val authHeaders = headerValues?.map { parseAuthorizationHeaders(it) }?.flatten() ?: emptyList() - val authHeader = headerValue?.let { parseAuthorizationHeader(headerValue) } - val provider = when { - authHeader == null && candidateProviders.size == 1 -> candidateProviders.first() - authHeader == null -> return@intercept call - else -> candidateProviders.find { it.isApplicable(authHeader) } ?: return@intercept call + val (provider, authHeader) = when { + authHeaders.isEmpty() && candidateProviders.size == 1 -> candidateProviders.first() to null + authHeaders.isEmpty() -> return@intercept call + else -> findProviderAndHeader(candidateProviders, authHeaders) ?: return@intercept call } if (!provider.refreshToken(call.response)) return@intercept call @@ -76,6 +76,21 @@ public class Auth private constructor( return@intercept call } } + + private fun findProviderAndHeader( + providers: Collection, + authHeaders: List + ): Pair? { + authHeaders.forEach { header -> + providers.forEach { provider -> + if (provider.isApplicable(header)) { + return provider to header + } + } + } + + return null + } } } diff --git a/ktor-http/api/ktor-http.api b/ktor-http/api/ktor-http.api index c6bf689756..998d5f8675 100644 --- a/ktor-http/api/ktor-http.api +++ b/ktor-http/api/ktor-http.api @@ -1114,6 +1114,7 @@ public final class io/ktor/http/auth/HttpAuthHeader$Single : io/ktor/http/auth/H public final class io/ktor/http/auth/HttpAuthHeaderKt { public static final fun parseAuthorizationHeader (Ljava/lang/String;)Lio/ktor/http/auth/HttpAuthHeader; + public static final fun parseAuthorizationHeaders (Ljava/lang/String;)Ljava/util/List; } public final class io/ktor/http/content/ByteArrayContent : io/ktor/http/content/OutgoingContent$ByteArrayContent { diff --git a/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt b/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt index 68d8138a7e..1a3a386676 100644 --- a/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt +++ b/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt @@ -8,7 +8,6 @@ import io.ktor.http.* import io.ktor.http.parsing.* import io.ktor.util.* import io.ktor.utils.io.charsets.* -import kotlin.native.concurrent.* private val TOKEN_EXTRA = setOf('!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~') private val TOKEN68_EXTRA = setOf('-', '.', '_', '~', '+', '/') @@ -19,12 +18,14 @@ private val escapeRegex: Regex = "\\\\.".toRegex() * Parses an authorization header [headerValue] into a [HttpAuthHeader]. * @return [HttpAuthHeader] or `null` if argument string is blank. * @throws [ParseException] on invalid header + * + * @see [parseAuthorizationHeaders] */ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? { var index = 0 index = headerValue.skipSpaces(index) - var tokenStartIndex = index + val tokenStartIndex = index while (index < headerValue.length && headerValue[index].isToken()) { index++ } @@ -32,7 +33,6 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? { // Auth scheme val authScheme = headerValue.substring(tokenStartIndex until index) index = headerValue.skipSpaces(index) - tokenStartIndex = index if (authScheme.isBlank()) { return null @@ -42,28 +42,88 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? { return HttpAuthHeader.Parameterized(authScheme, emptyList()) } - val token68 = matchToken68(headerValue, index) + val (indexAfterToken68, token68) = matchToken68(headerValue, index) if (token68 != null) { - return HttpAuthHeader.Single(authScheme, token68) + return checkSingleHeader(indexAfterToken68, HttpAuthHeader.Single(authScheme, token68)) } - val parameters = matchParameters(headerValue, tokenStartIndex) - return HttpAuthHeader.Parameterized(authScheme, parameters) + val (endIndex, parameters) = matchParameters(headerValue, index) + return checkSingleHeader(endIndex, HttpAuthHeader.Parameterized(authScheme, parameters)) +} + +private fun checkSingleHeader(endIndex: Int, header: HttpAuthHeader): HttpAuthHeader { + return if (endIndex == -1) header else + throw ParseException("Function parseAuthorizationHeader can parse only one header") } -private fun matchParameters(headerValue: String, startIndex: Int): Map { +/** + * Parses an authorization header [headerValue] into a list of [HttpAuthHeader]. + * @return a list of [HttpAuthHeader] + * @throws [ParseException] on invalid header + */ +public fun parseAuthorizationHeaders(headerValue: String): List { + var index = 0 + val headers = mutableListOf() + while (index != -1) { + val (nextIndex, header) = parseAuthorizationHeader(headerValue, index) + headers.add(header) + index = nextIndex + } + return headers +} + +private fun parseAuthorizationHeader( + headerValue: String, + startIndex: Int, +): Pair { + var index = headerValue.skipSpaces(startIndex) + + // Auth scheme + val schemeStartIndex = index + while (index < headerValue.length && headerValue[index].isToken()) { + index++ + } + val authScheme = headerValue.substring(schemeStartIndex until index) + + if (authScheme.isBlank()) { + throw ParseException("Invalid authScheme value: it should be token, can't be blank") + } + + val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index) + if (isEndOfChallenge) { + return endChallengeIndex to HttpAuthHeader.Parameterized(authScheme, emptyList()) + } + + val (nextIndex, token68) = matchToken68(headerValue, endChallengeIndex) + if (token68 != null) { + return nextIndex to HttpAuthHeader.Single(authScheme, token68) + } + + val (nextIndexChallenge, parameters) = matchParameters(headerValue, index) + return nextIndexChallenge to HttpAuthHeader.Parameterized(authScheme, parameters) +} + +private fun matchParameters(headerValue: String, startIndex: Int): Pair> { val result = mutableMapOf() var index = startIndex while (index > 0 && index < headerValue.length) { - index = matchParameter(headerValue, index, result) - index = headerValue.skipDelimiter(index, ',') + val (nextIndex, wasParameter) = matchParameter(headerValue, index, result) + if (wasParameter) { + index = headerValue.skipDelimiter(nextIndex, ',') + } else { + return nextIndex to result + } } - return result + return index to result } -private fun matchParameter(headerValue: String, startIndex: Int, parameters: MutableMap): Int { +private fun matchParameter( + headerValue: String, + startIndex: Int, + parameters: MutableMap +): Pair { val keyStart = headerValue.skipSpaces(startIndex) var index = keyStart @@ -71,15 +131,15 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut while (index < headerValue.length && headerValue[index].isToken()) { index++ } - val key = headerValue.substring(keyStart until index) - // Take '=' + // Check if new challenge index = headerValue.skipSpaces(index) - if (index >= headerValue.length || headerValue[index] != '=') { - throw ParseException("Expected `=` after parameter key '$key': $headerValue") + if (index == headerValue.length || headerValue[index] != '=') { + return keyStart to false } + // Take '=' index++ index = headerValue.skipSpaces(index) @@ -113,10 +173,10 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut parameters[key] = if (quoted) value.unescaped() else value if (quoted) index++ - return index + return index to true } -private fun matchToken68(headerValue: String, startIndex: Int): String? { +private fun matchToken68(headerValue: String, startIndex: Int): Pair { var index = startIndex while (index < headerValue.length && headerValue[index].isToken68()) { @@ -127,12 +187,14 @@ private fun matchToken68(headerValue: String, startIndex: Int): String? { index++ } - val onlySpaceRemaining = (index until headerValue.length).all { headerValue[it] == ' ' } - if (onlySpaceRemaining) { - return headerValue.substring(startIndex until index) - } + val token68 = headerValue.substring(startIndex until index) - return null + val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index) + return if (isEndOfChallenge) { + endChallengeIndex to token68 + } else { + startIndex to null + } } /** @@ -355,13 +417,11 @@ private fun String.unescaped() = replace(escapeRegex) { it.value.takeLast(1) } private fun String.skipDelimiter(startIndex: Int, delimiter: Char): Int { var index = skipSpaces(startIndex) - while (index < length && this[index] != delimiter) { - index++ - } - if (index == length) return -1 - index++ + if (this[index] != delimiter) + throw ParseException("Expected delimiter $delimiter at position $index, but found ${this[index]}") + index++ return skipSpaces(index) } @@ -374,6 +434,14 @@ private fun String.skipSpaces(startIndex: Int): Int { return index } +private fun String.isEndOfChallenge(startIndex: Int): Pair { + val index = skipSpaces(startIndex) + if (index == length) return -1 to true + if (this[index] == ',') return index + 1 to true + + return index to false +} + private fun Char.isToken68(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN68_EXTRA private fun Char.isToken(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN_EXTRA diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth/jvmAndNix/test/io/ktor/tests/auth/AuthorizeHeaderParserTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth/jvmAndNix/test/io/ktor/tests/auth/AuthorizeHeaderParserTest.kt index b5bce68873..0519f3dd0f 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth/jvmAndNix/test/io/ktor/tests/auth/AuthorizeHeaderParserTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth/jvmAndNix/test/io/ktor/tests/auth/AuthorizeHeaderParserTest.kt @@ -9,19 +9,23 @@ import kotlin.random.* import kotlin.test.* class AuthorizeHeaderParserTest { - @Test fun empty() { + @Test + fun empty() { testParserParameterized("Basic", emptyMap(), "Basic") } - @Test fun emptyWithTrailingSpaces() { + @Test + fun emptyWithTrailingSpaces() { testParserParameterized("Basic", emptyMap(), "Basic ") } - @Test fun singleSimple() { + @Test + fun singleSimple() { testParserSingle("Basic", "abc==", "Basic abc==") } - @Test fun testParameterizedSimple() { + @Test + fun testParameterizedSimple() { testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1") testParserParameterized("Basic", mapOf("a" to "1"), "Basic a =1") testParserParameterized("Basic", mapOf("a" to "1"), "Basic a = 1") @@ -30,7 +34,8 @@ class AuthorizeHeaderParserTest { testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1 ") } - @Test fun testParameterizedSimpleTwoParams() { + @Test + fun testParameterizedSimpleTwoParams() { testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1, b=2") testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1,b=2") testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 ,b=2") @@ -38,19 +43,53 @@ class AuthorizeHeaderParserTest { testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 , b=2 ") } - @Test fun testParameterizedQuoted() { + @Test + fun testParameterizedQuoted() { testParserParameterized("Basic", mapOf("a" to "1 2"), "Basic a=\"1 2\"") } - @Test fun testParameterizedQuotedEscaped() { + @Test + fun testParameterizedQuotedEscaped() { testParserParameterized("Basic", mapOf("a" to "1 \" 2"), "Basic a=\"1 \\\" 2\"") testParserParameterized("Basic", mapOf("a" to "1 A 2"), "Basic a=\"1 \\A 2\"") } - @Test fun testParameterizedQuotedEscapedInTheMiddle() { + @Test + fun testParameterizedQuotedEscapedInTheMiddle() { testParserParameterized("Basic", mapOf("a" to "1 \" 2", "b" to "2"), "Basic a=\"1 \\\" 2\", b= 2") } + @Test + fun testMultipleChallengesParameters() { + val expected = listOf( + HttpAuthHeader.Parameterized("Digest", emptyMap()), + HttpAuthHeader.Parameterized("Bearer", mapOf("1" to "2", "3" to "4")), + HttpAuthHeader.Parameterized("Basic", emptyMap()), + ) + testParserMultipleChallenges(expected, "Digest, Bearer 1 = 2, 3=4, Basic ") + } + + @Test + fun testMultipleChallengesSingle() { + val expected = listOf( + HttpAuthHeader.Single("Bearer", "abc=="), + HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")), + HttpAuthHeader.Single("Basic", "def==="), + HttpAuthHeader.Parameterized("Digest", emptyMap()) + ) + testParserMultipleChallenges(expected, "Bearer abc==, Bearer abc=def, Basic def===, Digest") + } + + @Test + fun testMultipleChallengesAllHeaders() { + val expected = listOf( + HttpAuthHeader.Parameterized("Basic", emptyMap()), + HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")), + HttpAuthHeader.Single("Digest", "abc==") + ) + testParserMultipleChallenges(expected, "Basic, Bearer abc=def,Digest abc==") + } + private fun testParserSingle(scheme: String, value: String, headerValue: String) { val actual = parseAuthorizationHeader(headerValue)!! @@ -75,11 +114,31 @@ class AuthorizeHeaderParserTest { } } + private fun testParserMultipleChallenges(expected: List, headerValue: String) { + val actual = parseAuthorizationHeaders(headerValue) + + assertEquals(expected.size, actual.size) + (expected zip actual).forEach { (expectedHeader, actualHeader) -> + if (expectedHeader is HttpAuthHeader.Single) { + assertIs(actualHeader) + + assertEquals(expectedHeader.blob, actualHeader.blob) + } + if (expectedHeader is HttpAuthHeader.Parameterized) { + assertIs(actualHeader) + assertEquals( + expectedHeader.parameters.associateBy({ it.name }, { it.value }), + actualHeader.parameters.associateBy({ it.name }, { it.value }) + ) + } + } + } + private fun Random.nextString( length: Int, possible: Iterable = ('a'..'z') + ('A'..'Z') + ('0'..'9') ) = possible.toList().let { possibleElements -> - (0..length - 1).map { nextFrom(possibleElements) }.joinToString("") + (0 until length).map { nextFrom(possibleElements) }.joinToString("") } private fun Random.nextString(length: Int, possible: String) = nextString(length, possible.toList())