Skip to content
This repository has been archived by the owner on Sep 12, 2021. It is now read-only.

Commit

Permalink
Polish the social state handler implementation (#508)
Browse files Browse the repository at this point in the history
- Remove old state provider implementation
- Optimize code
- Add more tests
  • Loading branch information
akkie committed Apr 29, 2017
1 parent c04c1d1 commit f725261
Show file tree
Hide file tree
Showing 20 changed files with 539 additions and 1,247 deletions.
1 change: 0 additions & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Expand Up @@ -2,7 +2,6 @@


* [ ] Have you read [How to write the perfect pull request](https://github.com/blog/1943-how-to-write-the-perfect-pull-request)? * [ ] Have you read [How to write the perfect pull request](https://github.com/blog/1943-how-to-write-the-perfect-pull-request)?
* [ ] Have you read through the [contributor guidelines](https://github.com/mohiva/play-silhouette/blob/master/CONTRIBUTING.md)? * [ ] Have you read through the [contributor guidelines](https://github.com/mohiva/play-silhouette/blob/master/CONTRIBUTING.md)?
* [ ] Have you [squashed your commits](https://www.playframework.com/documentation/2.5.x/WorkingWithGit#Squashing-commits)?
* [ ] Have you added copyright headers to new files? * [ ] Have you added copyright headers to new files?
* [ ] Have you suggest documentation edits? * [ ] Have you suggest documentation edits?
* [ ] Have you added tests for any changed functionality? * [ ] Have you added tests for any changed functionality?
Expand Down
Expand Up @@ -23,6 +23,7 @@ import org.specs2.mock.Mockito
import org.specs2.specification.Scope import org.specs2.specification.Scope
import play.api.libs.concurrent.Execution.Implicits._ import play.api.libs.concurrent.Execution.Implicits._
import play.api.test.FakeRequest import play.api.test.FakeRequest
import test.SocialProviderSpec


import scala.concurrent.Future import scala.concurrent.Future
import scala.concurrent.duration._ import scala.concurrent.duration._
Expand Down Expand Up @@ -70,29 +71,25 @@ class CasProviderSpec extends SocialProviderSpec[CasInfo] with Mockito with Logg
"redirect to CAS server if service ticket is not present in request" in new Context { "redirect to CAS server if service ticket is not present in request" in new Context {
implicit val req = FakeRequest(GET, "/") implicit val req = FakeRequest(GET, "/")


result(provider.authenticate()) { result(provider.authenticate()) { result =>
case result => status(result) must equalTo(SEE_OTHER)
status(result) must equalTo(SEE_OTHER) redirectLocation(result) must beSome("https://cas-url/?service=https%3A%2F%2Fcas-redirect%2F")
redirectLocation(result) must beSome("https://cas-url/?service=https%3A%2F%2Fcas-redirect%2F")
} }
} }


"redirect to CAS server with the original requested URL if service ticket is not present in the request" in new Context { "redirect to CAS server with the original requested URL if service ticket is not present in the request" in new Context {
implicit val req = FakeRequest(GET, redirectURLWithOrigin) implicit val req = FakeRequest(GET, redirectURLWithOrigin)


result(provider.authenticate()) { result(provider.authenticate()) { result =>
case result => status(result) must equalTo(SEE_OTHER)
status(result) must equalTo(SEE_OTHER) redirectLocation(result) must beSome("https://cas-url/?service=https%3A%2F%2Fcas-redirect%2F")
redirectLocation(result) must beSome("https://cas-url/?service=https%3A%2F%2Fcas-redirect%2F")
} }
} }


"return a valid CASAuthInfo object if service ticket is present in request" in new Context { "return a valid CASAuthInfo object if service ticket is present in request" in new Context {
implicit val req = FakeRequest(GET, "/?ticket=%s".format(ticket)) implicit val req = FakeRequest(GET, "/?ticket=%s".format(ticket))


authInfo(provider.authenticate()) { authInfo(provider.authenticate())(authInfo => authInfo must be equalTo CasInfo(ticket))
case authInfo => authInfo must be equalTo CasInfo(ticket)
}
} }
} }


Expand All @@ -111,7 +108,7 @@ class CasProviderSpec extends SocialProviderSpec[CasInfo] with Mockito with Logg


await(futureProfile) must beLike[CommonSocialProfile] { await(futureProfile) must beLike[CommonSocialProfile] {
case profile => case profile =>
profile must be equalTo new CommonSocialProfile( profile must be equalTo CommonSocialProfile(
loginInfo = new LoginInfo(CasProvider.ID, userName), loginInfo = new LoginInfo(CasProvider.ID, userName),
firstName = Some(firstName), firstName = Some(firstName),
lastName = Some(lastName), lastName = Some(lastName),
Expand Down Expand Up @@ -145,7 +142,7 @@ class CasProviderSpec extends SocialProviderSpec[CasInfo] with Mockito with Logg


lazy val ticket = "ST-12345678" lazy val ticket = "ST-12345678"


lazy val casAuthInfo = new CasInfo(ticket) lazy val casAuthInfo = CasInfo(ticket)


lazy val principal = mock[AttributePrincipal].smart lazy val principal = mock[AttributePrincipal].smart


Expand Down
Expand Up @@ -32,12 +32,12 @@ import play.api.libs.json._
import play.api.libs.ws.WSResponse import play.api.libs.ws.WSResponse
import play.api.mvc._ import play.api.mvc._


import scala.concurrent.{ ExecutionContext, Future } import scala.concurrent.Future
import scala.reflect.ClassTag import scala.reflect.ClassTag
import scala.util.{ Failure, Success, Try } import scala.util.{ Failure, Success, Try }


/** /**
* The Oauth2 info. * The OAuth2 info.
* *
* @param accessToken The access token. * @param accessToken The access token.
* @param tokenType The token type. * @param tokenType The token type.
Expand Down Expand Up @@ -145,7 +145,7 @@ trait OAuth2Provider extends SocialStateProvider with OAuth2Constants with Logge
} }


/** /**
* Handles the Oauth2 flow. * Handles the OAuth2 flow.
* *
* The left flow is the authorization flow, which will be processed, if no `code` parameter exists * The left flow is the authorization flow, which will be processed, if no `code` parameter exists
* in the request. The right flow is the access token flow, which will be executed after a successful * in the request. The right flow is the access token flow, which will be executed after a successful
Expand Down Expand Up @@ -289,74 +289,6 @@ trait OAuth2Constants {
val AccessDenied = "access_denied" val AccessDenied = "access_denied"
} }


/**
* The OAuth2 state.
*
* This is to prevent the client for CSRF attacks as described in the OAuth2 RFC.
*
* @see https://tools.ietf.org/html/rfc6749#section-10.12
*/
trait OAuth2State {

/**
* Checks if the state is expired. This is an absolute timeout since the creation of
* the state.
*
* @return True if the state is expired, false otherwise.
*/
def isExpired: Boolean
}

/**
* Provides state for authentication providers.
*/
trait OAuth2StateProvider {

/**
* The type of the state implementation.
*/
type State <: OAuth2State

/**
* Builds the state.
*
* @param request The current request.
* @param ec The execution context to handle the asynchronous operations.
* @tparam B The type of the request body.
* @return The build state.
*/
def build[B](implicit request: ExtractableRequest[B], ec: ExecutionContext): Future[State]

/**
* Validates the provider and the client state.
*
* @param request The current request.
* @param ec The execution context to handle the asynchronous operations.
* @tparam B The type of the request body.
* @return The state on success, otherwise an failure.
*/
def validate[B](implicit request: ExtractableRequest[B], ec: ExecutionContext): Future[State]

/**
* Publishes the state to the client.
*
* @param result The result to send to the client.
* @param state The state to publish.
* @param request The current request.
* @tparam B The type of the request body.
* @return The result to send to the client.
*/
def publish[B](result: Result, state: State)(implicit request: ExtractableRequest[B]): Result

/**
* Returns a serialized value of the state.
*
* @param state The state to serialize.
* @return A serialized value of the state.
*/
def serialize(state: State): String
}

/** /**
* The OAuth2 settings. * The OAuth2 settings.
* *
Expand All @@ -383,4 +315,5 @@ case class OAuth2Settings(
scope: Option[String] = None, scope: Option[String] = None,
authorizationParams: Map[String, String] = Map.empty, authorizationParams: Map[String, String] = Map.empty,
accessTokenParams: Map[String, String] = Map.empty, accessTokenParams: Map[String, String] = Map.empty,
customProperties: Map[String, String] = Map.empty) customProperties: Map[String, String] = Map.empty
)
@@ -1,5 +1,5 @@
/** /**
* Copyright 2016 Mohiva Organisation (license at mohiva dot com) * Copyright 2017 Mohiva Organisation (license at mohiva dot com)
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,9 @@ package com.mohiva.play.silhouette.impl.providers


import com.mohiva.play.silhouette.api.AuthInfo import com.mohiva.play.silhouette.api.AuthInfo
import com.mohiva.play.silhouette.api.crypto.{ Base64, CookieSigner } import com.mohiva.play.silhouette.api.crypto.{ Base64, CookieSigner }
import com.mohiva.play.silhouette.api.exceptions.ProviderException
import com.mohiva.play.silhouette.api.util.ExtractableRequest import com.mohiva.play.silhouette.api.util.ExtractableRequest
import com.mohiva.play.silhouette.impl.providers.DefaultSocialStateHandler._
import com.mohiva.play.silhouette.impl.providers.SocialStateItem._ import com.mohiva.play.silhouette.impl.providers.SocialStateItem._
import play.api.libs.json.{ Format, JsValue, Json } import play.api.libs.json.{ Format, JsValue, Json }
import play.api.mvc.Result import play.api.mvc.Result
Expand Down Expand Up @@ -71,7 +73,7 @@ object SocialStateItem {
* *
* @return The serialized representation of the item. * @return The serialized representation of the item.
*/ */
override def toString = s"${Base64.encode(id)}-${Base64.encode(data)}" def asString = s"${Base64.encode(id)}-${Base64.encode(data)}"
} }


/** /**
Expand All @@ -93,7 +95,6 @@ object SocialStateItem {
} }
} }
} }

} }


/** /**
Expand All @@ -108,7 +109,7 @@ trait SocialStateProvider extends SocialProvider {
* sends to the browser (e.g.: in the case of OAuth where the user needs to be redirected to the service * sends to the browser (e.g.: in the case of OAuth where the user needs to be redirected to the service
* provider). * provider).
* *
* @param format The JSON format to the transform the user state into JSON. * @param format The JSON format to transform the user state into JSON.
* @param request The request. * @param request The request.
* @param classTag The class tag for the user state item. * @param classTag The class tag for the user state item.
* @tparam S The type of the user state item. * @tparam S The type of the user state item.
Expand Down Expand Up @@ -205,7 +206,8 @@ trait SocialStateHandler {
* *
* @param handlers The item handlers configured for this handler. * @param handlers The item handlers configured for this handler.
*/ */
class DefaultSocialStateHandler(val handlers: Set[SocialStateItemHandler], cookieSigner: CookieSigner) extends SocialStateHandler { class DefaultSocialStateHandler(val handlers: Set[SocialStateItemHandler], cookieSigner: CookieSigner)
extends SocialStateHandler {


/** /**
* The concrete instance of the state provider. * The concrete instance of the state provider.
Expand Down Expand Up @@ -259,20 +261,24 @@ class DefaultSocialStateHandler(val handlers: Set[SocialStateItemHandler], cooki
override def unserialize[B](state: String)( override def unserialize[B](state: String)(
implicit implicit
request: ExtractableRequest[B], request: ExtractableRequest[B],
ec: ExecutionContext): Future[SocialState] = { ec: ExecutionContext

): Future[SocialState] = {
Future.fromTry(cookieSigner.extract(state)).flatMap(state => state.split('.').toList match { Future.fromTry(cookieSigner.extract(state)).flatMap { state =>
case Nil => Future.successful(SocialState(Set())) state.split('.').toList match {
case items => case Nil | List("") =>
Future.sequence(items.map { Future.successful(SocialState(Set()))
case ItemStructure(item) => handlers.find(_.canHandle(item)) match { case items =>
case Some(handler) => handler.unserialize(item) Future.sequence {
case None => items.map {
throw new RuntimeException("None of the registered handlers can handle the given state item:" + item) case ItemStructure(item) => handlers.find(_.canHandle(item)) match {
} case Some(handler) => handler.unserialize(item)
case s => throw new RuntimeException("Cannot extract social state item from string: " + s) case None => throw new ProviderException(MissingItemHandlerError.format(item))
}).map(items => SocialState(items.toSet)) }
}) case item => throw new ProviderException(ItemExtractionError.format(item))
}
}.map(items => SocialState(items.toSet))
}
}
} }


/** /**
Expand All @@ -294,6 +300,19 @@ class DefaultSocialStateHandler(val handlers: Set[SocialStateItemHandler], cooki
} }
} }


/**
* The companion object for the [[DefaultSocialStateHandler]] class.
*/
object DefaultSocialStateHandler {

/**
* Some errors.
*/
val MissingItemHandlerError = "None of the registered handlers can handle the given state item: %s"
val ItemExtractionError = "Cannot extract social state item from string: %s"

}

/** /**
* Handles state for different purposes. * Handles state for different purposes.
*/ */
Expand Down

0 comments on commit f725261

Please sign in to comment.