From fab982b5bfee14c6768ef24a9763d61f26680884 Mon Sep 17 00:00:00 2001 From: Marcos Pereira Date: Tue, 24 Jan 2017 23:35:14 -0200 Subject: [PATCH] Removes global state from CSRF components (#6854) --- .../play/filters/csrf/AddCSRFTokenAction.java | 15 +- .../filters/csrf/RequireCSRFCheckAction.java | 22 +-- .../scala/play/filters/csrf/CSRFActions.scala | 131 ++++++++++-------- .../scala/play/filters/csrf/CSRFFilter.scala | 12 +- .../main/scala/play/filters/csrf/csrf.scala | 9 +- .../play/filters/cors/CORSWithCSRFSpec.scala | 8 +- 6 files changed, 114 insertions(+), 83 deletions(-) diff --git a/framework/src/play-filters-helpers/src/main/java/play/filters/csrf/AddCSRFTokenAction.java b/framework/src/play-filters-helpers/src/main/java/play/filters/csrf/AddCSRFTokenAction.java index 8cd4e69bd26..3dc3fd10c04 100644 --- a/framework/src/play-filters-helpers/src/main/java/play/filters/csrf/AddCSRFTokenAction.java +++ b/framework/src/play-filters-helpers/src/main/java/play/filters/csrf/AddCSRFTokenAction.java @@ -7,6 +7,7 @@ import javax.inject.Inject; +import play.api.http.SessionConfiguration; import play.api.libs.crypto.CSRFTokenSigner; import play.api.mvc.Session; import play.mvc.Action; @@ -19,25 +20,29 @@ public class AddCSRFTokenAction extends Action { private final CSRFConfig config; + private final SessionConfiguration sessionConfiguration; private final CSRF.TokenProvider tokenProvider; private final CSRFTokenSigner tokenSigner; @Inject - public AddCSRFTokenAction(CSRFConfig config, CSRF.TokenProvider tokenProvider, CSRFTokenSigner tokenSigner) { + public AddCSRFTokenAction(CSRFConfig config, SessionConfiguration sessionConfiguration, CSRF.TokenProvider tokenProvider, CSRFTokenSigner tokenSigner) { this.config = config; + this.sessionConfiguration = sessionConfiguration; this.tokenProvider = tokenProvider; this.tokenSigner = tokenSigner; } private final CSRF.Token$ Token = CSRF.Token$.MODULE$; - private final CSRFAction$ CSRFAction = CSRFAction$.MODULE$; @Override public CompletionStage call(Http.Context ctx) { + + CSRFActionHelper csrfActionHelper = new CSRFActionHelper(sessionConfiguration, config, tokenSigner); + play.api.mvc.Request request = - CSRFAction.tagRequestFromHeader(ctx.request()._underlyingRequest(), config, tokenSigner); + csrfActionHelper.tagRequestFromHeader(ctx.request()._underlyingRequest()); - if (CSRFAction.getTokenToValidate(request, config, tokenSigner).isEmpty()) { + if (csrfActionHelper.getTokenToValidate(request).isEmpty()) { // No token in header and we have to create one if not found, so create a new token String newToken = tokenProvider.generateToken(); @@ -46,7 +51,7 @@ public CompletionStage call(Http.Context ctx) { ctx.args.put(Token.NameRequestTag(), config.tokenName()); // Create a new Scala RequestHeader with the token - request = CSRFAction.tagRequest(request, new CSRF.Token(config.tokenName(), newToken)); + request = csrfActionHelper.tagRequest(request, new CSRF.Token(config.tokenName(), newToken)); // Also add it to the response if (config.cookieName().isDefined()) { diff --git a/framework/src/play-filters-helpers/src/main/java/play/filters/csrf/RequireCSRFCheckAction.java b/framework/src/play-filters-helpers/src/main/java/play/filters/csrf/RequireCSRFCheckAction.java index c0c0a726578..6cdfae18cdc 100644 --- a/framework/src/play-filters-helpers/src/main/java/play/filters/csrf/RequireCSRFCheckAction.java +++ b/framework/src/play-filters-helpers/src/main/java/play/filters/csrf/RequireCSRFCheckAction.java @@ -8,6 +8,7 @@ import javax.inject.Inject; +import play.api.http.SessionConfiguration; import play.api.libs.crypto.CSRFTokenSigner; import play.api.mvc.RequestHeader; import play.api.mvc.Session; @@ -20,34 +21,37 @@ public class RequireCSRFCheckAction extends Action { private final CSRFConfig config; + private final SessionConfiguration sessionConfiguration; private final CSRF.TokenProvider tokenProvider; - private final CSRFTokenSigner crypto; + private final CSRFTokenSigner tokenSigner; private final Injector injector; @Inject - public RequireCSRFCheckAction(CSRFConfig config, CSRF.TokenProvider tokenProvider, CSRFTokenSigner csrfTokenSigner, Injector injector) { + public RequireCSRFCheckAction(CSRFConfig config, SessionConfiguration sessionConfiguration, CSRF.TokenProvider tokenProvider, CSRFTokenSigner csrfTokenSigner, Injector injector) { this.config = config; + this.sessionConfiguration = sessionConfiguration; this.tokenProvider = tokenProvider; - this.crypto = csrfTokenSigner; + this.tokenSigner = csrfTokenSigner; this.injector = injector; } - private final CSRFAction$ CSRFAction = CSRFAction$.MODULE$; - @Override public CompletionStage call(Http.Context ctx) { - RequestHeader request = CSRFAction.tagRequestFromHeader(ctx._requestHeader(), config, crypto); + + CSRFActionHelper csrfActionHelper = new CSRFActionHelper(sessionConfiguration, config, tokenSigner); + + RequestHeader request = csrfActionHelper.tagRequestFromHeader(ctx._requestHeader()); // Check for bypass - if (!CSRFAction.requiresCsrfCheck(request, config)) { + if (!csrfActionHelper.requiresCsrfCheck(request)) { return delegate.call(ctx); } else { // Get token from cookie/session - Option headerToken = CSRFAction.getTokenToValidate(request, config, crypto); + Option headerToken = csrfActionHelper.getTokenToValidate(request); if (headerToken.isDefined()) { String tokenToCheck = null; // Get token from query string - Option queryStringToken = CSRFAction.getHeaderToken(request, config); + Option queryStringToken = csrfActionHelper.getHeaderToken(request); if (queryStringToken.isDefined()) { tokenToCheck = queryStringToken.get(); } else { diff --git a/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/CSRFActions.scala b/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/CSRFActions.scala index 49c4614a6cb..e10a0fb71aa 100644 --- a/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/CSRFActions.scala +++ b/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/CSRFActions.scala @@ -12,6 +12,7 @@ import akka.stream.scaladsl.{ Flow, Keep, Sink, Source } import akka.stream.stage.{ DetachedContext, DetachedStage } import akka.util.ByteString import play.api.http.HeaderNames._ +import play.api.http.SessionConfiguration import play.api.libs.crypto.CSRFTokenSigner import play.api.libs.streams.Accumulator import play.api.mvc._ @@ -35,16 +36,18 @@ class CSRFAction( config: CSRFConfig = CSRFConfig(), tokenSigner: CSRFTokenSigner, tokenProvider: TokenProvider, + sessionConfiguration: SessionConfiguration, errorHandler: => ErrorHandler = CSRF.DefaultErrorHandler)(implicit mat: Materializer) extends EssentialAction { - import CSRFAction._ import play.core.Execution.Implicits.trampoline + lazy val csrfActionHelper = new CSRFActionHelper(sessionConfiguration, config, tokenSigner) + private def checkFailed(req: RequestHeader, msg: String): Accumulator[ByteString, Result] = - Accumulator.done(clearTokenIfInvalid(req, config, errorHandler, msg)) + Accumulator.done(csrfActionHelper.clearTokenIfInvalid(req, errorHandler, msg)) def apply(untaggedRequest: RequestHeader) = { - val request = tagRequestFromHeader(untaggedRequest, config, tokenSigner) + val request = csrfActionHelper.tagRequestFromHeader(untaggedRequest) // this function exists purely to aid readability def continue = next(request) @@ -52,15 +55,15 @@ class CSRFAction( // Only filter unsafe methods and content types if (config.checkMethod(request.method) && config.checkContentType(request.contentType)) { - if (!requiresCsrfCheck(request, config)) { + if (!csrfActionHelper.requiresCsrfCheck(request)) { continue } else { // Only proceed with checks if there is an incoming token in the header, otherwise there's no point - getTokenToValidate(request, config, tokenSigner).map { headerToken => + csrfActionHelper.getTokenToValidate(request).map { headerToken => // First check if there's a token in the query string or header, if we find one, don't bother handling the body - getHeaderToken(request, config).map { queryStringToken => + csrfActionHelper.getHeaderToken(request).map { queryStringToken => if (tokenProvider.compareTokens(headerToken, queryStringToken)) { filterLogger.trace("[CSRF] Valid token found in query string") @@ -95,17 +98,17 @@ class CSRFAction( } } - } else if (getTokenToValidate(request, config, tokenSigner).isEmpty && config.createIfNotFound(request)) { + } else if (csrfActionHelper.getTokenToValidate(request).isEmpty && config.createIfNotFound(request)) { // No token in header and we have to create one if not found, so create a new token val newToken = tokenProvider.generateToken // The request - val requestWithNewToken = tagRequest(request, Token(config.tokenName, newToken)) + val requestWithNewToken = csrfActionHelper.tagRequest(request, Token(config.tokenName, newToken)) // Once done, add it to the result next(requestWithNewToken).map(result => - CSRFAction.addTokenToResponse(config, newToken, request, result)) + csrfActionHelper.addTokenToResponse(newToken, request, result)) } else { filterLogger.trace("[CSRF] No check necessary") @@ -150,7 +153,7 @@ class CSRFAction( ).mapFuture { validatedBodySource => action(request).run(validatedBodySource) }.recoverWith { - case NoTokenInBody => clearTokenIfInvalid(request, config, errorHandler, "No CSRF token found in body") + case NoTokenInBody => csrfActionHelper.clearTokenIfInvalid(request, errorHandler, "No CSRF token found in body") } } @@ -302,7 +305,7 @@ private class BodyHandler(config: CSRFConfig, checkBody: ByteString => Boolean) } } else { // CSRF check failed - ctx.fail(CSRFAction.NoTokenInBody) + ctx.fail(NoTokenInBody) } } else { // Buffer @@ -353,37 +356,41 @@ private class BodyHandler(config: CSRFConfig, checkBody: ByteString => Boolean) continue = true ctx.absorbTermination() } else { - ctx.fail(CSRFAction.NoTokenInBody) + ctx.fail(NoTokenInBody) } } } } -object CSRFAction { +private[csrf] object NoTokenInBody extends RuntimeException(null, null, false, false) - private[csrf] object NoTokenInBody extends RuntimeException(null, null, false, false) +private[csrf] class CSRFActionHelper( + sessionConfiguration: SessionConfiguration, + csrfConfig: CSRFConfig, + tokenSigner: CSRFTokenSigner +) { /** * Get the header token, that is, the token that should be validated. */ - private[csrf] def getTokenToValidate(request: RequestHeader, config: CSRFConfig, tokenSigner: CSRFTokenSigner) = { + private[csrf] def getTokenToValidate(request: RequestHeader) = { val tagToken = request.tags.get(Token.RequestTag) - val cookieToken = config.cookieName.flatMap(cookie => request.cookies.get(cookie).map(_.value)) - val sessionToken = request.session.get(config.tokenName) + val cookieToken = csrfConfig.cookieName.flatMap(cookie => request.cookies.get(cookie).map(_.value)) + val sessionToken = request.session.get(csrfConfig.tokenName) cookieToken orElse sessionToken orElse tagToken filter { token => // return None if the token is invalid - !config.signTokens || tokenSigner.extractSignedToken(token).isDefined + !csrfConfig.signTokens || tokenSigner.extractSignedToken(token).isDefined } } /** * Tag incoming requests with the token in the header */ - private[csrf] def tagRequestFromHeader(request: RequestHeader, config: CSRFConfig, tokenSigner: CSRFTokenSigner): RequestHeader = { - getTokenToValidate(request, config, tokenSigner).fold(request) { tokenValue => - val token = Token(config.tokenName, tokenValue) + private[csrf] def tagRequestFromHeader(request: RequestHeader): RequestHeader = { + getTokenToValidate(request).fold(request) { tokenValue => + val token = Token(csrfConfig.tokenName, tokenValue) val newReq = tagRequest(request, token) - if (config.signTokens) { + if (csrfConfig.signTokens) { // Extract the signed token, and then resign it. This makes the token random per request, preventing the BREACH // vulnerability val newTokenValue = tokenSigner.extractSignedToken(token.value).map(tokenSigner.signToken) @@ -394,8 +401,8 @@ object CSRFAction { } } - private[csrf] def tagRequestFromHeader[A](request: Request[A], config: CSRFConfig, tokenSigner: CSRFTokenSigner): Request[A] = { - Request(tagRequestFromHeader(request: RequestHeader, config, tokenSigner), request.body) + private[csrf] def tagRequestFromHeader[A](request: Request[A]): Request[A] = { + Request(tagRequestFromHeader(request: RequestHeader), request.body) } private[csrf] def tagRequest(request: RequestHeader, token: Token): RequestHeader = { @@ -409,37 +416,37 @@ object CSRFAction { Request(tagRequest(request: RequestHeader, token), request.body) } - private[csrf] def getHeaderToken(request: RequestHeader, config: CSRFConfig) = { - val queryStringToken = request.getQueryString(config.tokenName) - val headerToken = request.headers.get(config.headerName) + private[csrf] def getHeaderToken(request: RequestHeader) = { + val queryStringToken = request.getQueryString(csrfConfig.tokenName) + val headerToken = request.headers.get(csrfConfig.headerName) queryStringToken orElse headerToken } - private[csrf] def requiresCsrfCheck(request: RequestHeader, config: CSRFConfig): Boolean = { - if (config.bypassCorsTrustedOrigins && request.tags.contains(CORSFilter.RequestTag)) { + private[csrf] def requiresCsrfCheck(request: RequestHeader): Boolean = { + if (csrfConfig.bypassCorsTrustedOrigins && request.tags.contains(CORSFilter.RequestTag)) { filterLogger.trace("[CSRF] Bypassing check because CORSFilter request tag found") false } else { - config.shouldProtect(request) + csrfConfig.shouldProtect(request) } } - private[csrf] def addTokenToResponse(config: CSRFConfig, newToken: String, request: RequestHeader, result: Result) = { + private[csrf] def addTokenToResponse(newToken: String, request: RequestHeader, result: Result) = { if (isCached(result)) { filterLogger.trace("[CSRF] Not adding token to cached response") result } else { filterLogger.trace("[CSRF] Adding token to result: " + result) - config.cookieName.map { + csrfConfig.cookieName.map { // cookie name => - result.withCookies(Cookie(name, newToken, path = Session.path, domain = Session.domain, - secure = config.secureCookie, httpOnly = config.httpOnlyCookie)) + result.withCookies(Cookie(name, newToken, path = sessionConfiguration.path, domain = sessionConfiguration.domain, + secure = csrfConfig.secureCookie, httpOnly = csrfConfig.httpOnlyCookie)) } getOrElse { - val newSession = result.session(request) + (config.tokenName -> newToken) + val newSession = result.session(request) + (csrfConfig.tokenName -> newToken) result.withSession(newSession) } } @@ -449,18 +456,18 @@ object CSRFAction { private[csrf] def isCached(result: Result): Boolean = result.header.headers.get(CACHE_CONTROL).fold(false)(!_.contains("no-cache")) - private[csrf] def clearTokenIfInvalid(request: RequestHeader, config: CSRFConfig, errorHandler: ErrorHandler, msg: String): Future[Result] = { + private[csrf] def clearTokenIfInvalid(request: RequestHeader, errorHandler: ErrorHandler, msg: String): Future[Result] = { import play.core.Execution.Implicits.trampoline errorHandler.handle(request, msg) map { result => CSRF.getToken(request).fold( - config.cookieName.flatMap { cookie => + csrfConfig.cookieName.flatMap { cookie => request.cookies.get(cookie).map { token => result.discardingCookies( - DiscardingCookie(cookie, domain = Session.domain, path = Session.path, secure = config.secureCookie)) + DiscardingCookie(cookie, domain = sessionConfiguration.domain, path = sessionConfiguration.path, secure = csrfConfig.secureCookie)) } }.getOrElse { - result.withSession(result.session(request) - config.tokenName) + result.withSession(result.session(request) - csrfConfig.tokenName) } )(_ => result) } @@ -472,22 +479,27 @@ object CSRFAction { * * Apply this to all actions that require a CSRF check. */ -case class CSRFCheck @Inject() (config: CSRFConfig, tokenSigner: CSRFTokenSigner) { - - private class CSRFCheckAction[A](tokenProvider: TokenProvider, errorHandler: ErrorHandler, wrapped: Action[A]) extends Action[A] { +case class CSRFCheck @Inject() (config: CSRFConfig, tokenSigner: CSRFTokenSigner, sessionConfiguration: SessionConfiguration) { + + private class CSRFCheckAction[A]( + tokenProvider: TokenProvider, + errorHandler: ErrorHandler, + wrapped: Action[A], + csrfActionHelper: CSRFActionHelper + ) extends Action[A] { def parser = wrapped.parser def executionContext = wrapped.executionContext def apply(untaggedRequest: Request[A]) = { - val request = CSRFAction.tagRequestFromHeader(untaggedRequest, config, tokenSigner) + val request = csrfActionHelper.tagRequestFromHeader(untaggedRequest) // Maybe bypass - if (!CSRFAction.requiresCsrfCheck(request, config) || !config.checkContentType(request.contentType)) { + if (!csrfActionHelper.requiresCsrfCheck(request) || !config.checkContentType(request.contentType)) { wrapped(request) } else { // Get token from header - CSRFAction.getTokenToValidate(request, config, tokenSigner).flatMap { headerToken => + csrfActionHelper.getTokenToValidate(request).flatMap { headerToken => // Get token from query string - CSRFAction.getHeaderToken(request, config) + csrfActionHelper.getHeaderToken(request) // Or from body if not found .orElse({ val form = request.body match { @@ -504,7 +516,7 @@ case class CSRFCheck @Inject() (config: CSRFConfig, tokenSigner: CSRFTokenSigner case queryToken if tokenProvider.compareTokens(queryToken, headerToken) => wrapped(request) } }.getOrElse { - CSRFAction.clearTokenIfInvalid(request, config, errorHandler, "CSRF token check failed") + csrfActionHelper.clearTokenIfInvalid(request, errorHandler, "CSRF token check failed") } } } @@ -514,13 +526,13 @@ case class CSRFCheck @Inject() (config: CSRFConfig, tokenSigner: CSRFTokenSigner * Wrap an action in a CSRF check. */ def apply[A](action: Action[A], errorHandler: ErrorHandler): Action[A] = - new CSRFCheckAction(new TokenProviderProvider(config, tokenSigner).get, errorHandler, action) + new CSRFCheckAction(new TokenProviderProvider(config, tokenSigner).get, errorHandler, action, new CSRFActionHelper(sessionConfiguration, config, tokenSigner)) /** * Wrap an action in a CSRF check. */ def apply[A](action: Action[A]): Action[A] = - new CSRFCheckAction(new TokenProviderProvider(config, tokenSigner).get, CSRF.DefaultErrorHandler, action) + new CSRFCheckAction(new TokenProviderProvider(config, tokenSigner).get, CSRF.DefaultErrorHandler, action, new CSRFActionHelper(sessionConfiguration, config, tokenSigner)) } /** @@ -528,25 +540,30 @@ case class CSRFCheck @Inject() (config: CSRFConfig, tokenSigner: CSRFTokenSigner * * Apply this to all actions that render a form that contains a CSRF token. */ -case class CSRFAddToken @Inject() (config: CSRFConfig, crypto: CSRFTokenSigner) { - - private class CSRFAddTokenAction[A](config: CSRFConfig, tokenProvider: TokenProvider, wrapped: Action[A]) extends Action[A] { +case class CSRFAddToken @Inject() (config: CSRFConfig, crypto: CSRFTokenSigner, sessionConfiguration: SessionConfiguration) { + + private class CSRFAddTokenAction[A]( + config: CSRFConfig, + tokenProvider: TokenProvider, + wrapped: Action[A], + csrfActionHelper: CSRFActionHelper + ) extends Action[A] { def parser = wrapped.parser def executionContext = wrapped.executionContext def apply(untaggedRequest: Request[A]) = { - val request = CSRFAction.tagRequestFromHeader(untaggedRequest, config, crypto) + val request = csrfActionHelper.tagRequestFromHeader(untaggedRequest) - if (CSRFAction.getTokenToValidate(request, config, crypto).isEmpty) { + if (csrfActionHelper.getTokenToValidate(request).isEmpty) { // No token in header and we have to create one if not found, so create a new token val newToken = tokenProvider.generateToken // The request - val requestWithNewToken = CSRFAction.tagRequest(request, Token(config.tokenName, newToken)) + val requestWithNewToken = csrfActionHelper.tagRequest(request, Token(config.tokenName, newToken)) // Once done, add it to the result import play.core.Execution.Implicits.trampoline wrapped(requestWithNewToken).map(result => - CSRFAction.addTokenToResponse(config, newToken, request, result)) + csrfActionHelper.addTokenToResponse(newToken, request, result)) } else { wrapped(request) } @@ -557,5 +574,5 @@ case class CSRFAddToken @Inject() (config: CSRFConfig, crypto: CSRFTokenSigner) * Wrap an action in an action that ensures there is a CSRF token. */ def apply[A](action: Action[A]): Action[A] = - new CSRFAddTokenAction(config, new TokenProviderProvider(config, crypto).get, action) + new CSRFAddTokenAction(config, new TokenProviderProvider(config, crypto).get, action, new CSRFActionHelper(sessionConfiguration, config, crypto)) } diff --git a/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/CSRFFilter.scala b/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/CSRFFilter.scala index d43a5e94171..01ddd679167 100644 --- a/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/CSRFFilter.scala +++ b/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/CSRFFilter.scala @@ -6,6 +6,7 @@ package play.filters.csrf import javax.inject.{ Inject, Provider } import akka.stream.Materializer +import play.api.http.SessionConfiguration import play.api.libs.crypto.CSRFTokenSigner import play.api.mvc._ import play.core.j.JavaContextComponents @@ -26,19 +27,20 @@ import play.filters.csrf.CSRF._ class CSRFFilter( config: => CSRFConfig, tokenSigner: => CSRFTokenSigner, + sessionConfiguration: => SessionConfiguration, val tokenProvider: TokenProvider, val errorHandler: ErrorHandler = CSRF.DefaultErrorHandler)(implicit mat: Materializer) extends EssentialFilter { @Inject - def this(config: Provider[CSRFConfig], tokenSignerProvider: Provider[CSRFTokenSigner], tokenProvider: TokenProvider, errorHandler: ErrorHandler)(mat: Materializer) = { - this(config.get, tokenSignerProvider.get, tokenProvider, errorHandler)(mat) + def this(config: Provider[CSRFConfig], tokenSignerProvider: Provider[CSRFTokenSigner], sessionConfiguration: SessionConfiguration, tokenProvider: TokenProvider, errorHandler: ErrorHandler)(mat: Materializer) = { + this(config.get, tokenSignerProvider.get, sessionConfiguration, tokenProvider, errorHandler)(mat) } // Java constructor for manually constructing the filter - def this(config: CSRFConfig, tokenSigner: play.libs.crypto.CSRFTokenSigner, tokenProvider: TokenProvider, errorHandler: CSRFErrorHandler, contextComponents: JavaContextComponents)(mat: Materializer) = { - this(config, tokenSigner.asScala, tokenProvider, new JavaCSRFErrorHandlerAdapter(errorHandler, contextComponents))(mat) + def this(config: CSRFConfig, tokenSigner: play.libs.crypto.CSRFTokenSigner, sessionConfiguration: SessionConfiguration, tokenProvider: TokenProvider, errorHandler: CSRFErrorHandler, contextComponents: JavaContextComponents)(mat: Materializer) = { + this(config, tokenSigner.asScala, sessionConfiguration, tokenProvider, new JavaCSRFErrorHandlerAdapter(errorHandler, contextComponents))(mat) } - def apply(next: EssentialAction): EssentialAction = new CSRFAction(next, config, tokenSigner, tokenProvider, errorHandler) + def apply(next: EssentialAction): EssentialAction = new CSRFAction(next, config, tokenSigner, tokenProvider, sessionConfiguration, errorHandler) } diff --git a/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/csrf.scala b/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/csrf.scala index 60c1505454a..7de9f23f73a 100644 --- a/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/csrf.scala +++ b/framework/src/play-filters-helpers/src/main/scala/play/filters/csrf/csrf.scala @@ -9,7 +9,7 @@ import javax.inject.{ Inject, Provider, Singleton } import akka.stream.Materializer import com.typesafe.config.ConfigMemorySize import play.api._ -import play.api.http.HttpErrorHandler +import play.api.http.{ HttpConfiguration, HttpErrorHandler } import play.api.inject.{ Binding, Module } import play.api.libs.crypto.CSRFTokenSigner import play.api.mvc.Results._ @@ -282,13 +282,14 @@ trait CSRFComponents { def configuration: Configuration def csrfTokenSigner: CSRFTokenSigner def httpErrorHandler: HttpErrorHandler + def httpConfiguration: HttpConfiguration implicit def materializer: Materializer lazy val csrfConfig: CSRFConfig = CSRFConfig.fromConfiguration(configuration) lazy val csrfTokenProvider: CSRF.TokenProvider = new CSRF.TokenProviderProvider(csrfConfig, csrfTokenSigner).get lazy val csrfErrorHandler: CSRF.ErrorHandler = new CSRFHttpErrorHandler(httpErrorHandler) - lazy val csrfFilter: CSRFFilter = new CSRFFilter(csrfConfig, csrfTokenSigner, csrfTokenProvider, csrfErrorHandler) - lazy val csrfCheck: CSRFCheck = new CSRFCheck(csrfConfig, csrfTokenSigner) - lazy val csrfAddToken: CSRFAddToken = new CSRFAddToken(csrfConfig, csrfTokenSigner) + lazy val csrfFilter: CSRFFilter = new CSRFFilter(csrfConfig, csrfTokenSigner, httpConfiguration.session, csrfTokenProvider, csrfErrorHandler) + lazy val csrfCheck: CSRFCheck = CSRFCheck(csrfConfig, csrfTokenSigner, httpConfiguration.session) + lazy val csrfAddToken: CSRFAddToken = CSRFAddToken(csrfConfig, csrfTokenSigner, httpConfiguration.session) } diff --git a/framework/src/play-filters-helpers/src/test/scala/play/filters/cors/CORSWithCSRFSpec.scala b/framework/src/play-filters-helpers/src/test/scala/play/filters/cors/CORSWithCSRFSpec.scala index 73118f6f318..e8db6b98ad6 100644 --- a/framework/src/play-filters-helpers/src/test/scala/play/filters/cors/CORSWithCSRFSpec.scala +++ b/framework/src/play-filters-helpers/src/test/scala/play/filters/cors/CORSWithCSRFSpec.scala @@ -7,7 +7,7 @@ import java.time.{ Clock, Instant, ZoneId } import javax.inject.Inject import play.api.Application -import play.api.http.{ ContentTypes, HttpFilters, SecretConfiguration } +import play.api.http.{ ContentTypes, HttpFilters, SecretConfiguration, SessionConfiguration } import play.api.inject.bind import play.api.libs.crypto.{ DefaultCSRFTokenSigner, DefaultCookieSigner } import play.api.mvc.{ DefaultActionBuilder, Results } @@ -27,17 +27,19 @@ object CORSWithCSRFSpec { } class CORSWithCSRFRouter @Inject() (action: DefaultActionBuilder) extends Router { - val signer = { + private val signer = { val secretConfiguration = SecretConfiguration("0123456789abcdef", None) val clock = Clock.fixed(Instant.ofEpochMilli(0L), ZoneId.systemDefault) val signer = new DefaultCookieSigner(secretConfiguration) new DefaultCSRFTokenSigner(signer, clock) } + private val sessionConfiguration = SessionConfiguration() + override def routes = { case p"/error" => action { req => throw sys.error("error") } case _ => - val csrfCheck = new CSRFCheck(play.filters.csrf.CSRFConfig(), signer) + val csrfCheck = CSRFCheck(play.filters.csrf.CSRFConfig(), signer, sessionConfiguration) csrfCheck(action(Results.Ok), CSRF.DefaultErrorHandler) } override def withPrefix(prefix: String) = this