From 6ee2da308ddf4259717eda98ea7de850cb015625 Mon Sep 17 00:00:00 2001 From: Vedran Pavic Date: Thu, 19 Dec 2019 21:46:08 +0100 Subject: [PATCH] Improve error handling --- .../BearerAuthenticationHandler.java | 2 +- .../vpavic/bearerauth/BearerTokenError.java | 22 +++++---- .../bearerauth/BearerTokenException.java | 13 +++++ .../bearerauth/WwwAuthenticateBuilder.java | 48 +++++++++++++++++++ .../WwwAuthenticateBuilderTests.java | 5 ++ .../ServletBearerAuthenticationFilter.java | 18 ++++++- .../WebFluxBearerAuthenticationFilter.java | 14 +++++- 7 files changed, 109 insertions(+), 13 deletions(-) create mode 100644 bearerauth-core/src/main/java/io/github/vpavic/bearerauth/WwwAuthenticateBuilder.java create mode 100644 bearerauth-core/src/test/java/io/github/vpavic/bearerauth/WwwAuthenticateBuilderTests.java diff --git a/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerAuthenticationHandler.java b/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerAuthenticationHandler.java index 487a297..eea0938 100644 --- a/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerAuthenticationHandler.java +++ b/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerAuthenticationHandler.java @@ -44,7 +44,7 @@ public CompletionStage handle(HttpExchange httpExchange) { BearerToken bearerToken = this.bearerTokenExtractor.apply(httpExchange); if (bearerToken == null) { CompletableFuture result = new CompletableFuture<>(); - result.completeExceptionally(new BearerTokenException(BearerTokenError.INVALID_REQUEST)); + result.completeExceptionally(new BearerTokenException()); return result; } return this.authorizationContextResolver.apply(bearerToken).handle((authorizationContext, throwable) -> { diff --git a/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerTokenError.java b/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerTokenError.java index f303f58..2b29051 100644 --- a/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerTokenError.java +++ b/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerTokenError.java @@ -4,28 +4,32 @@ public class BearerTokenError { - public static final BearerTokenError INVALID_REQUEST = new BearerTokenError(400, "invalid_request"); + public static final BearerTokenError INVALID_REQUEST = BearerTokenError.of("invalid_request", 400); - public static final BearerTokenError INVALID_TOKEN = new BearerTokenError(401, "invalid_token"); + public static final BearerTokenError INVALID_TOKEN = BearerTokenError.of("invalid_token", 401); - public static final BearerTokenError INSUFFICIENT_SCOPE = new BearerTokenError(403, "insufficient_scope"); - - private final int status; + public static final BearerTokenError INSUFFICIENT_SCOPE = BearerTokenError.of("insufficient_scope", 403); private final String code; - public BearerTokenError(int status, String code) { + private final int status; + + private BearerTokenError(String code, int status) { Objects.requireNonNull(code, "code must not be null"); - this.status = status; this.code = code; + this.status = status; } - public int getStatus() { - return this.status; + public static BearerTokenError of(String code, int status) { + return new BearerTokenError(code, status); } public String getCode() { return this.code; } + public int getStatus() { + return this.status; + } + } diff --git a/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerTokenException.java b/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerTokenException.java index c557e4b..bf281d8 100644 --- a/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerTokenException.java +++ b/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/BearerTokenException.java @@ -6,6 +6,15 @@ public class BearerTokenException extends RuntimeException { private final BearerTokenError error; + public BearerTokenException() { + this.error = null; + } + + public BearerTokenException(String message) { + super(message); + this.error = null; + } + public BearerTokenException(BearerTokenError error, String message) { super(message); Objects.requireNonNull(error, "error must not be null"); @@ -16,6 +25,10 @@ public BearerTokenException(BearerTokenError error) { this(error, error.getCode()); } + public int getStatus() { + return (this.error != null) ? this.error.getStatus() : 401; + } + public BearerTokenError getError() { return this.error; } diff --git a/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/WwwAuthenticateBuilder.java b/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/WwwAuthenticateBuilder.java new file mode 100644 index 0000000..5a3eb3a --- /dev/null +++ b/bearerauth-core/src/main/java/io/github/vpavic/bearerauth/WwwAuthenticateBuilder.java @@ -0,0 +1,48 @@ +package io.github.vpavic.bearerauth; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public final class WwwAuthenticateBuilder { + + private final BearerTokenException bearerTokenException; + + private String realm; + + private WwwAuthenticateBuilder(BearerTokenException bearerTokenException) { + Objects.requireNonNull(bearerTokenException, "bearerTokenException must not be null"); + this.bearerTokenException = bearerTokenException; + } + + public static WwwAuthenticateBuilder from(BearerTokenException bearerTokenException) { + return new WwwAuthenticateBuilder(bearerTokenException); + } + + public WwwAuthenticateBuilder withRealm(String realm) { + Objects.requireNonNull(realm, "realm must not be null"); + this.realm = realm; + return this; + } + + public String build() { + String wwwAuthenticate = "Bearer"; + List attributes = new ArrayList<>(); + if (this.realm != null) { + attributes.add(buildAttribute("realm", this.realm)); + } + BearerTokenError error = this.bearerTokenException.getError(); + if (error != null) { + attributes.add(buildAttribute("error", error.getCode())); + } + if (!attributes.isEmpty()) { + wwwAuthenticate += " " + String.join(", ", attributes); + } + return wwwAuthenticate; + } + + private static String buildAttribute(String name, String value) { + return name + "=\"" + value + "\""; + } + +} diff --git a/bearerauth-core/src/test/java/io/github/vpavic/bearerauth/WwwAuthenticateBuilderTests.java b/bearerauth-core/src/test/java/io/github/vpavic/bearerauth/WwwAuthenticateBuilderTests.java new file mode 100644 index 0000000..5c3210e --- /dev/null +++ b/bearerauth-core/src/test/java/io/github/vpavic/bearerauth/WwwAuthenticateBuilderTests.java @@ -0,0 +1,5 @@ +package io.github.vpavic.bearerauth; + +class WwwAuthenticateBuilderTests { + +} diff --git a/samples/sample-spring-servlet/src/main/java/sample/ServletBearerAuthenticationFilter.java b/samples/sample-spring-servlet/src/main/java/sample/ServletBearerAuthenticationFilter.java index fb0931c..05de044 100644 --- a/samples/sample-spring-servlet/src/main/java/sample/ServletBearerAuthenticationFilter.java +++ b/samples/sample-spring-servlet/src/main/java/sample/ServletBearerAuthenticationFilter.java @@ -3,8 +3,10 @@ import io.github.vpavic.bearerauth.AuthorizationContext; import io.github.vpavic.bearerauth.BearerAuthenticationHandler; import io.github.vpavic.bearerauth.BearerToken; +import io.github.vpavic.bearerauth.BearerTokenException; import io.github.vpavic.bearerauth.HttpExchange; import io.github.vpavic.bearerauth.MapAuthorizationContextResolver; +import io.github.vpavic.bearerauth.WwwAuthenticateBuilder; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -39,11 +41,23 @@ protected void doFilter(HttpServletRequest request, HttpServletResponse response throws IOException, ServletException { try { this.bearerAuthenticationHandler.handle(new ServletHttpExchange(request)).toCompletableFuture().get(); + chain.doFilter(request, response); } - catch (ExecutionException | InterruptedException ex) { + catch (ExecutionException ex) { + Throwable cause = ex.getCause(); + if (cause instanceof BearerTokenException) { + BearerTokenException bearerTokenException = (BearerTokenException) cause; + String wwwAuthenticate = WwwAuthenticateBuilder.from(bearerTokenException).build(); + response.addHeader("WWW-Authenticate", wwwAuthenticate); + response.sendError(bearerTokenException.getStatus()); + } + else { + throw new ServletException(ex); + } + } + catch (InterruptedException ex) { throw new ServletException(ex); } - chain.doFilter(request, response); } private static class ServletHttpExchange implements HttpExchange { diff --git a/samples/sample-spring-webflux/src/main/java/sample/WebFluxBearerAuthenticationFilter.java b/samples/sample-spring-webflux/src/main/java/sample/WebFluxBearerAuthenticationFilter.java index f62be38..58321b3 100644 --- a/samples/sample-spring-webflux/src/main/java/sample/WebFluxBearerAuthenticationFilter.java +++ b/samples/sample-spring-webflux/src/main/java/sample/WebFluxBearerAuthenticationFilter.java @@ -3,8 +3,13 @@ import io.github.vpavic.bearerauth.AuthorizationContext; import io.github.vpavic.bearerauth.BearerAuthenticationHandler; import io.github.vpavic.bearerauth.BearerToken; +import io.github.vpavic.bearerauth.BearerTokenException; import io.github.vpavic.bearerauth.HttpExchange; import io.github.vpavic.bearerauth.MapAuthorizationContextResolver; +import io.github.vpavic.bearerauth.WwwAuthenticateBuilder; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.stereotype.Component; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; @@ -34,7 +39,14 @@ public WebFluxBearerAuthenticationFilter() { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return Mono.fromCompletionStage(this.bearerAuthenticationHandler.handle(new WebFluxHttpExchange(exchange))) - .then(chain.filter(exchange)); + .then(chain.filter(exchange)) + .onErrorResume(BearerTokenException.class, ex -> { + String wwwAuthenticate = WwwAuthenticateBuilder.from(ex).build(); + ServerHttpResponse response = exchange.getResponse(); + response.getHeaders().set(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticate); + response.setStatusCode(HttpStatus.resolve(ex.getStatus())); + return Mono.empty(); + }); } private static class WebFluxHttpExchange implements HttpExchange {