Skip to content

Commit

Permalink
FIX-5017 Add authorization metrics (#5074)
Browse files Browse the repository at this point in the history
Motivation:

Fixing an open improvement issue.
#5017

Modifications:

- Modified the `AuthService.java` file to use `MoreMeters.newTimer`.
- Updated the tests. I think some are still remaining to be fixed, so
opening the PR as draft.
  - Any feedback is welcome as it's my first time! 😄  

Result:

- Closes #5017. (If this resolves the issue.)
- Describe the consequences that a user will face after this PR is
merged.

<!--
Visit this URL to learn more about how to write a pull request
description:

https://armeria.dev/community/developer-guide#how-to-write-pull-request-description
-->
The metrics would look like:
```
# HELP armeria_server_auth_seconds  
# TYPE armeria_server_auth_seconds summary
armeria_server_auth_seconds_count 1.0
armeria_server_auth_seconds_sum 0.018437417
# HELP armeria_server_auth_seconds_max  
# TYPE armeria_server_auth_seconds_max gauge
armeria_server_auth_seconds_max 0.018437417
```

---------

Co-authored-by: jrhee17 <guins_j@guins.org>
Co-authored-by: songmw725 <songmw725@gmail.com>
  • Loading branch information
3 people committed Mar 28, 2024
1 parent 1be0dd1 commit eb9470c
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static java.util.Objects.requireNonNull;

import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import org.slf4j.Logger;
Expand All @@ -28,11 +29,17 @@
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.metric.MeterIdPrefix;
import com.linecorp.armeria.common.metric.MoreMeters;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.ServiceConfig;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.SimpleDecoratingHttpService;

import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;

/**
* Decorates an {@link HttpService} to provide HTTP authorization functionality.
*
Expand Down Expand Up @@ -75,33 +82,52 @@ public static AuthServiceBuilder builder() {
private final Authorizer<HttpRequest> authorizer;
private final AuthSuccessHandler defaultSuccessHandler;
private final AuthFailureHandler defaultFailureHandler;
@Nullable
private Timer successTimer;
@Nullable
private Timer failureTimer;
private final MeterIdPrefix meterIdPrefix;

AuthService(HttpService delegate, Authorizer<HttpRequest> authorizer,
AuthSuccessHandler defaultSuccessHandler, AuthFailureHandler defaultFailureHandler) {
AuthSuccessHandler defaultSuccessHandler, AuthFailureHandler defaultFailureHandler,
MeterIdPrefix meterIdPrefix) {
super(delegate);
this.authorizer = authorizer;
this.defaultSuccessHandler = defaultSuccessHandler;
this.defaultFailureHandler = defaultFailureHandler;
this.meterIdPrefix = meterIdPrefix;
}

@Override
public void serviceAdded(ServiceConfig cfg) throws Exception {
super.serviceAdded(cfg);
final MeterRegistry meterRegistry = cfg.server().meterRegistry();
successTimer = MoreMeters.newTimer(meterRegistry, meterIdPrefix.name(),
meterIdPrefix.tags("result", "success"));
failureTimer = MoreMeters.newTimer(meterRegistry, meterIdPrefix.name(),
meterIdPrefix.tags("result", "failure"));
}

@Override
public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception {
final long startNanos = System.nanoTime();

return HttpResponse.of(AuthorizerUtil.authorizeAndSupplyHandlers(authorizer, ctx, req)
.handleAsync((result, cause) -> {
try {
final HttpService delegate = (HttpService) unwrap();
if (cause == null) {
if (result != null) {
if (!result.isAuthorized()) {
return handleFailure(delegate, result.failureHandler(), ctx, req, null);
return handleFailure(delegate, result.failureHandler(), ctx, req, null, startNanos);
}
return handleSuccess(delegate, result.successHandler(), ctx, req);
return handleSuccess(delegate, result.successHandler(), ctx, req, startNanos);
}
cause = AuthorizerUtil.newNullResultException(authorizer);
}

return handleFailure(delegate, result != null ? result.failureHandler() : null,
ctx, req, cause);
ctx, req, cause, startNanos);
} catch (Exception e) {
return Exceptions.throwUnsafely(e);
}
Expand All @@ -110,8 +136,10 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc

private HttpResponse handleSuccess(HttpService delegate,
@Nullable AuthSuccessHandler authorizerSuccessHandler,
ServiceRequestContext ctx, HttpRequest req)
throws Exception {
ServiceRequestContext ctx, HttpRequest req,
long startNanos) throws Exception {
assert successTimer != null;
successTimer.record(System.nanoTime() - startNanos, TimeUnit.NANOSECONDS);
final AuthSuccessHandler handler = authorizerSuccessHandler == null ? defaultSuccessHandler
: authorizerSuccessHandler;
return handler.authSucceeded(delegate, ctx, req);
Expand All @@ -120,7 +148,9 @@ private HttpResponse handleSuccess(HttpService delegate,
private HttpResponse handleFailure(HttpService delegate,
@Nullable AuthFailureHandler authorizerFailureHandler,
ServiceRequestContext ctx, HttpRequest req,
@Nullable Throwable cause) throws Exception {
@Nullable Throwable cause, long startNanos) throws Exception {
assert failureTimer != null;
failureTimer.record(System.nanoTime() - startNanos, TimeUnit.NANOSECONDS);
final AuthFailureHandler handler = authorizerFailureHandler == null ? defaultFailureHandler
: authorizerFailureHandler;
if (cause != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.linecorp.armeria.common.auth.BasicToken;
import com.linecorp.armeria.common.auth.OAuth1aToken;
import com.linecorp.armeria.common.auth.OAuth2Token;
import com.linecorp.armeria.common.metric.MeterIdPrefix;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.Service;

Expand All @@ -45,6 +46,7 @@ public final class AuthServiceBuilder {
}
return HttpResponse.of(HttpStatus.UNAUTHORIZED);
};
private MeterIdPrefix meterIdPrefix = new MeterIdPrefix("armeria.server.auth");

/**
* Creates a new instance.
Expand Down Expand Up @@ -153,17 +155,41 @@ public AuthServiceBuilder onFailure(AuthFailureHandler failureHandler) {
return this;
}

/**
* Sets the {@link MeterIdPrefix} pattern to which metrics will be collected.
* By default, {@code armeria.server.auth} will be used as the metric name.
* <table>
* <caption>Metrics that will be generated by this class</caption>
* <tr>
* <th>metric name</th>
* <th>description</th>
* </tr>
* <tr>
* <td>{@code <name>#count{result="success"}}</td>
* <td>The number of successful authentication requests.</td>
* </tr>
* <tr>
* <td>{@code <name>#count{result="failure"}}</td>
* <td>The number of failed authentication requests.</td>
* </tr>
* </table>
*/
public AuthServiceBuilder meterIdPrefix(MeterIdPrefix meterIdPrefix) {
this.meterIdPrefix = requireNonNull(meterIdPrefix, "meterIdPrefix");
return this;
}

/**
* Returns a newly-created {@link AuthService} based on the {@link Authorizer}s added to this builder.
*/
public AuthService build(HttpService delegate) {
return new AuthService(requireNonNull(delegate, "delegate"), authorizer(),
successHandler, failureHandler);
successHandler, failureHandler, meterIdPrefix);
}

private AuthService build(HttpService delegate, Authorizer<HttpRequest> authorizer) {
return new AuthService(requireNonNull(delegate, "delegate"), authorizer,
successHandler, failureHandler);
successHandler, failureHandler, meterIdPrefix);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.linecorp.armeria.common.HttpHeaderNames.AUTHORIZATION;
import static com.linecorp.armeria.common.util.UnmodifiableFuture.completedFuture;
import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
Expand All @@ -42,16 +43,17 @@

import com.linecorp.armeria.client.BlockingWebClient;
import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.auth.AuthToken;
import com.linecorp.armeria.common.auth.BasicToken;
import com.linecorp.armeria.common.auth.OAuth1aToken;
import com.linecorp.armeria.common.auth.OAuth2Token;
import com.linecorp.armeria.common.metric.MoreMeters;
import com.linecorp.armeria.common.util.UnmodifiableFuture;
import com.linecorp.armeria.internal.testing.AnticipatedException;
import com.linecorp.armeria.server.AbstractHttpService;
Expand All @@ -61,6 +63,7 @@
import com.linecorp.armeria.server.logging.LoggingService;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import io.netty.util.AsciiString;

class AuthServiceTest {
Expand All @@ -75,13 +78,16 @@ String accessToken() {

private static final Function<HttpHeaders, InsecureToken> INSECURE_TOKEN_EXTRACTOR =
headers -> new InsecureToken();

private static final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry();
private static final AsciiString CUSTOM_TOKEN_HEADER = HttpHeaderNames.of("X-Custom-Authorization");

private static final AtomicReference<Throwable> peeledException = new AtomicReference<>();

@RegisterExtension
static final ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) throws Exception {
sb.meterRegistry(meterRegistry);
final HttpService ok = new AbstractHttpService() {
@Override
protected HttpResponse doGet(ServiceRequestContext ctx, HttpRequest req) {
Expand Down Expand Up @@ -207,6 +213,17 @@ protected HttpResponse doGet(ServiceRequestContext ctx, HttpRequest req) {
})
.newDecorator())
.decorate(LoggingService.newDecorator()));

sb.service("/peeled_exception", AuthService.builder()
.add((ctx, data) -> {
return UnmodifiableFuture.exceptionallyCompletedFuture(
new AnticipatedException());
})
.onFailure((delegate, ctx, req, cause) -> {
peeledException.set(cause);
return HttpResponse.of(HttpStatus.FORBIDDEN);
}).build((ctx, req) -> HttpResponse.of("OK"))
);
}
};

Expand Down Expand Up @@ -293,9 +310,9 @@ void testOAuth1a() throws Exception {
@Test
void testOAuth2() throws Exception {
final BlockingWebClient webClient = WebClient.builder(server.httpUri())
.auth(AuthToken.ofOAuth2("dummy_oauth2_token"))
.build()
.blocking();
.auth(AuthToken.ofOAuth2("dummy_oauth2_token"))
.build()
.blocking();
assertThat(webClient.get("/oauth2").status()).isEqualTo(HttpStatus.OK);
try (CloseableHttpClient hc = HttpClients.createMinimal()) {
try (CloseableHttpResponse res = hc.execute(
Expand Down Expand Up @@ -419,21 +436,21 @@ void testOnFailureException() throws Exception {

@Test
void shouldPeelRedundantAuthorizerExceptions() throws Exception {
final AtomicReference<Throwable> causeRef = new AtomicReference<>();
final AuthService service =
AuthService.builder()
.add((ctx, data) -> {
return UnmodifiableFuture.exceptionallyCompletedFuture(
new AnticipatedException());
})
.onFailure((delegate, ctx, req, cause) -> {
causeRef.set(cause);
return HttpResponse.of(HttpStatus.FORBIDDEN);
}).build((ctx, req) -> HttpResponse.of("OK"));
final ServiceRequestContext ctx = ServiceRequestContext.of(HttpRequest.of(HttpMethod.GET, "/"));
final HttpResponse response = service.serve(ctx, ctx.request());
assertThat(response.aggregate().join().status()).isEqualTo(HttpStatus.FORBIDDEN);
assertThat(causeRef.get()).isInstanceOf(AnticipatedException.class);
assertThat(server.blockingWebClient().get("/peeled_exception").status())
.isEqualTo(HttpStatus.FORBIDDEN);
assertThat(peeledException.get()).isInstanceOf(AnticipatedException.class);
}

@Test
void shouldRecordMetrics() {
final double before = MoreMeters.measureAll(meterRegistry)
.getOrDefault("armeria.server.auth#count", 0.0);
final AggregatedHttpResponse res = server.blockingWebClient(cb -> cb.auth(AuthToken.ofBasic("brown",
"cony")))
.get("/basic");
assertThat(res.status().code()).isEqualTo(200);
await().untilAsserted(() -> assertThat(MoreMeters.measureAll(meterRegistry))
.containsEntry("armeria.server.auth#count{result=success}", before + 1));
}

private static HttpUriRequestBase getRequest(String path, String authorization) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ public void serviceAdded(ServiceConfig cfg) throws Exception {
.config()
.dependencyInjector();
decorated = handlerRegistry.applyDecorators(delegate, dependencyInjector);
for (HttpService decorator : decorated.values()) {
decorator.serviceAdded(cfg);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ public class GrpcHttpJsonTranscodingServiceAnnotatedAuthServiceTest {
static final ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) throws Exception {
final GrpcService grpcService = GrpcService.builder()
final GrpcService grpcService = GrpcService
.builder()
.addService(new AuthenticatedHttpJsonTranscodingTestService())
.enableHttpJsonTranscoding(true)
.build();
Expand All @@ -86,21 +87,21 @@ protected void configure(ServerBuilder sb) throws Exception {

@Test
void testAuthenticatedRpcMethod() throws Exception {
final Transcoding.GetMessageRequestV1 requestMessage = Transcoding.GetMessageRequestV1.newBuilder()
.setName("messages/1").build();
final Transcoding.GetMessageRequestV1 requestMessage =
Transcoding.GetMessageRequestV1.newBuilder().setName("messages/1").build();
final Throwable exception = assertThrows(Throwable.class,
() -> grpcClient.getMessageV1(requestMessage).getText());
() -> grpcClient.getMessageV1(requestMessage).getText());
assertThat(exception).isInstanceOf(StatusRuntimeException.class);
assertThat(((StatusRuntimeException) exception).getStatus().getCode())
.isEqualTo(Status.UNAUTHENTICATED.getCode());

final Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of(TEST_CREDENTIAL_KEY, Metadata.ASCII_STRING_MARSHALLER),
"some-credential-string");
"some-credential-string");
final Transcoding.Message result =
grpcClient.withInterceptors(
MetadataUtils.newAttachHeadersInterceptor(metadata)
).getMessageV1(requestMessage);
MetadataUtils.newAttachHeadersInterceptor(metadata)
).getMessageV1(requestMessage);
assertThat(result.getText()).isEqualTo("messages/1");
}

Expand All @@ -110,11 +111,11 @@ void testAuthenticatedHttpJsonTranscoding() throws Exception {
assertThat(failResponse.status()).isEqualTo(HttpStatus.UNAUTHORIZED);

final JsonNode root = webClient.prepare()
.get("/v1/messages/1")
.header(TEST_CREDENTIAL_KEY, "some-credential-string")
.asJson(JsonNode.class)
.execute()
.content();
.get("/v1/messages/1")
.header(TEST_CREDENTIAL_KEY, "some-credential-string")
.asJson(JsonNode.class)
.execute()
.content();
assertThat(root.get("text").asText()).isEqualTo("messages/1");
}

Expand Down

0 comments on commit eb9470c

Please sign in to comment.