Skip to content
Permalink
Browse files

Test OAuth login

  • Loading branch information...
soywiz committed Jun 21, 2018
1 parent 1e011e4 commit 56119d2879d9300cf51d66ea7114ff815f7db752
@@ -15,6 +15,8 @@ mainClassName = "io.ktor.server.netty.DevelopmentEngine"
sourceSets {
main.kotlin.srcDirs = [ 'src' ]
main.resources.srcDirs = [ 'resources' ]
test.kotlin.srcDirs = [ 'test' ]
test.resources.srcDirs = [ 'testresources' ]
}

repositories {
@@ -32,6 +34,8 @@ dependencies {
compile "io.ktor:ktor-auth-jwt:$ktor_version"
compile "io.ktor:ktor-client-apache:$ktor_version"
compile "ch.qos.logback:logback-classic:$logback_version"

testCompile "io.ktor:ktor-server-tests:$ktor_version"
}

kotlin.experimental.coroutines = 'enable'
@@ -85,18 +85,24 @@ val loginProviders = listOf(
private val exec = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 4)

fun Application.OAuthLoginApplication() {
OAuthLoginApplicationWithDeps(
oauthHttpClient = HttpClient(Apache).apply {
environment.monitor.subscribe(ApplicationStopping) {
close()
}
}
)
}

fun Application.OAuthLoginApplicationWithDeps(oauthHttpClient: HttpClient) {
val authOauthForLogin = "authOauthForLogin"

install(DefaultHeaders)
install(CallLogging)
install(Locations)
install(Authentication) {
oauth(authOauthForLogin) {
client = HttpClient(Apache).apply {
environment.monitor.subscribe(ApplicationStopping) {
close()
}
}
client = oauthHttpClient
providerLookup = {
loginProviders[application.locations.resolve<login>(login::class, this).type]
}
@@ -0,0 +1,70 @@
package io.ktor.samples.auth

import io.ktor.client.*
import io.ktor.content.*
import io.ktor.http.*
import io.ktor.server.testing.*
import org.junit.Test
import kotlin.test.*

class OAuthTest {
@Test
fun testOAuthLogin() {
withTestApplication {
val testClientFactory = TestHttpClientFactory()
application.OAuthLoginApplicationWithDeps(
oauthHttpClient = HttpClient(testClientFactory)
)

lateinit var state: String

fun String.maskState() = Regex("state=(\\w+)").replace(this, "state=****")

handleRequest(HttpMethod.Get, "/login/google") {
addHeader("Host", "127.0.0.1")
}.let { call ->
val location = call.response.headers["Location"] ?: ""
assertEquals(
"https://accounts.google.com/o/oauth2/auth?client_id=***.apps.googleusercontent.com&redirect_uri=http%3A%2F%2F127.0.0.1%2Flogin%2Fgoogle&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fplus.login&state=****&response_type=code",
location.maskState()
)
val stateInfo = Regex("state=(\\w+)").find(location)
state = stateInfo!!.groupValues[1]
}

testClientFactory.addResponse("https://www.googleapis.com/oauth2/v3/token") { request ->
val textContent = request.content as TextContent
assertEquals(ContentType.Application.FormUrlEncoded, textContent.contentType)
assertEquals(
"client_id=***.apps.googleusercontent.com&client_secret=***&grant_type=authorization_code&state=****&code=mycode&redirect_uri=http%3A%2F%2F127.0.0.1%2Flogin%2Fgoogle",
textContent.text.maskState()
)
TextContent(
"""{
"access_token": "myaccesstoken",
"token_type": "mytokentype",
"expires_in": 3600,
"refresh_token": "myrefreshtoken"
}""".trimIndent(), contentType = ContentType.Application.Json
)
}

handleRequest(HttpMethod.Get, "/login/google?state=$state&code=mycode") {
addHeader("Host", "127.0.0.1")
}.let { call ->
assertEquals("""
<!DOCTYPE html>
<html>
<head>
<title>Logged in</title>
</head>
<body>
<h1>You are logged in</h1>
<p>Your token is OAuth2(accessToken=myaccesstoken, tokenType=mytokentype, expiresIn=3600, refreshToken=myrefreshtoken, extraParameters=Parameters [access_token=[myaccesstoken], refresh_token=[myrefreshtoken], token_type=[mytokentype], expires_in=[3600]])</p>
</body>
</html>
""".trimIndent(), call.response.content?.trim())
}
}
}
}
@@ -0,0 +1,85 @@
package io.ktor.samples.auth

import io.ktor.client.call.*
import io.ktor.client.engine.*
import io.ktor.client.request.*
import io.ktor.client.response.*
import io.ktor.content.*
import io.ktor.http.*
import io.ktor.util.*
import kotlinx.coroutines.experimental.*
import kotlinx.coroutines.experimental.io.*
import java.util.*
import kotlin.coroutines.experimental.*

class TestHttpClientFactory : HttpClientEngineFactory<TestHttpClientFactory.Config> {
fun addResponse(url: String, headers: HeadersBuilder.() -> Unit = {}, response: (HttpRequest) -> OutgoingContent) {
config.responses[url] = FakeResponse(
response
) { respBody ->
appendAll(respBody.headers)
headers()
}
}

data class FakeResponse(val body: (HttpRequest) -> OutgoingContent, val headers: HeadersBuilder.(OutgoingContent) -> Unit)

class Config : HttpClientEngineConfig() {
val responses = LinkedHashMap<String, FakeResponse>()
}

val config = Config()

override fun create(block: TestHttpClientFactory.Config.() -> Unit): HttpClientEngine {
val config = config.apply(block)
return Engine(config)
}

class Engine(val config: TestHttpClientFactory.Config) : HttpClientEngine {
override val dispatcher: CoroutineDispatcher = DefaultDispatcher

override fun close() = Unit

override suspend fun execute(call: HttpClientCall, data: HttpRequestData): HttpEngineCall {
val context = coroutineContext
val url = data.url.fullUrl
val response = config.responses[url] ?: error("Can't find response for $url")

val request = object : HttpRequest {
override val attributes: Attributes = Attributes().apply { data.attributes(this) }
override val call: HttpClientCall = call
override val content: OutgoingContent = data.body as OutgoingContent
override val executionContext: Job = Job()
override val headers: Headers = data.headers
override val method: HttpMethod = data.method
override val url: Url = data.url
}
val body = response.body(request)

return HttpEngineCall(
request,
object : HttpResponse {
override val call: HttpClientCall = call
override val content: ByteReadChannel = writer(context) {
when (body) {
is OutgoingContent.NoContent -> Unit
is OutgoingContent.ByteArrayContent -> channel.writeFully(body.bytes())
is OutgoingContent.ReadChannelContent -> body.readFrom().copyAndClose(channel)
is OutgoingContent.WriteChannelContent -> body.writeTo(channel)
}
}.channel
override val executionContext: Job = Job()
override val headers: Headers = HeadersBuilder().apply { response.headers(this, body) }.build()
override val requestTime: Date = Date()
override val responseTime: Date = Date()
override val status: HttpStatusCode = body.status ?: HttpStatusCode.OK
override val version: HttpProtocolVersion = HttpProtocolVersion.HTTP_1_1
override fun close() = Unit
}
)
}
}
}

private val Url.hostWithPortIfRequired: String get() = if (port == protocol.defaultPort) host else hostWithPort
private val Url.fullUrl: String get() = "${protocol.name}://$hostWithPortIfRequired$fullPath"

0 comments on commit 56119d2

Please sign in to comment.
You can’t perform that action at this time.