Skip to content

Commit

Permalink
Make SignatureCalculator return a new Request, close #4477
Browse files Browse the repository at this point in the history
Motivation:

Some signing algorithms, such as OAuth1, can modify the url or the body.Motivation:
  • Loading branch information
slandelle committed Nov 13, 2023
1 parent bf62d32 commit ee259be
Show file tree
Hide file tree
Showing 21 changed files with 84 additions and 45 deletions.
Expand Up @@ -26,7 +26,7 @@
import io.netty.handler.codec.http.cookie.Cookie;
import java.net.InetAddress;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;

public class Request {

Expand All @@ -43,7 +43,7 @@ public class Request {
private final InetAddress localIpV6Address;
private final Realm realm;
private final ProxyServer proxyServer;
private final Consumer<Request> signatureCalculator;
private final Function<Request, Request> signatureCalculator;
private final InetAddressNameResolver nameResolver;
private final boolean http2Enabled;
private final Http2PriorKnowledge http2PriorKnowledge;
Expand All @@ -63,7 +63,7 @@ public Request(
InetAddress localIpV6Address,
Realm realm,
ProxyServer proxyServer,
Consumer<Request> signatureCalculator,
Function<Request, Request> signatureCalculator,
InetAddressNameResolver nameResolver,
boolean http2Enabled,
Http2PriorKnowledge http2PriorKnowledge,
Expand All @@ -88,6 +88,28 @@ public Request(
this.wsSubprotocol = wsSubprotocol;
}

public Request copyWithCopiedHeaders() {
return new Request(
this.name,
this.method,
this.uri,
this.headers.copy(),
this.cookies,
this.body,
this.requestTimeout,
this.virtualHost,
this.autoOrigin,
this.localIpV4Address,
this.localIpV6Address,
this.realm,
this.proxyServer,
this.signatureCalculator,
this.nameResolver,
this.http2Enabled,
this.http2PriorKnowledge,
this.wsSubprotocol);
}

public Request copyWithHttp2PriorKnowledge(Http2PriorKnowledge http2PriorKnowledge) {
return new Request(
this.name,
Expand Down Expand Up @@ -162,7 +184,7 @@ public ProxyServer getProxyServer() {
return proxyServer;
}

public Consumer<Request> getSignatureCalculator() {
public Function<Request, Request> getSignatureCalculator() {
return signatureCalculator;
}

Expand Down
Expand Up @@ -40,7 +40,7 @@
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;

public class RequestBuilder {

Expand All @@ -60,7 +60,7 @@ public class RequestBuilder {
private InetAddress localIpV6Address;
private Realm realm;
private ProxyServer proxyServer;
private Consumer<Request> signatureCalculator;
private Function<Request, Request> signatureCalculator;
private boolean http2Enabled;
private Http2PriorKnowledge http2PriorKnowledge = Http2PriorKnowledge.HTTP1_ONLY;
private String wsSubprotocol;
Expand Down Expand Up @@ -133,7 +133,7 @@ public RequestBuilder setProxyServer(ProxyServer proxyServer) {
return this;
}

public RequestBuilder setSignatureCalculator(Consumer<Request> signatureCalculator) {
public RequestBuilder setSignatureCalculator(Function<Request, Request> signatureCalculator) {
this.signatureCalculator = signatureCalculator;
return this;
}
Expand Down
Expand Up @@ -29,7 +29,7 @@
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.*;
import java.util.function.Consumer;
import java.util.function.Function;

public class WritableRequestBuilder {

Expand Down Expand Up @@ -88,11 +88,18 @@ private static WritableRequest buildRequestWithChunkedBody(

public static WritableRequest buildRequest(Request request, ByteBufAllocator alloc, boolean http2)
throws Exception {
return buildRequest0(signRequest(request), alloc, http2);
}

Consumer<Request> signatureCalculator = request.getSignatureCalculator();
if (signatureCalculator != null) {
signatureCalculator.accept(request);
}
private static Request signRequest(Request request) {
Function<Request, Request> signatureCalculator = request.getSignatureCalculator();
return signatureCalculator != null
? signatureCalculator.apply(request.copyWithCopiedHeaders())
: request;
}

private static WritableRequest buildRequest0(
Request request, ByteBufAllocator alloc, boolean http2) throws Exception {

HttpMethod method = request.getMethod();
String url = requestUrl(request.getUri(), request.getProxyServer(), http2);
Expand Down
Expand Up @@ -29,9 +29,9 @@
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;

public class OAuthSignatureCalculator implements Consumer<Request> {
public class OAuthSignatureCalculator implements Function<Request, Request> {

private static final ThreadLocal<OAuthSignatureCalculatorInstance> INSTANCES =
ThreadLocal.withInitial(
Expand All @@ -52,7 +52,7 @@ public OAuthSignatureCalculator(ConsumerKey consumerAuth, RequestToken requestTo
}

@Override
public void accept(Request request) {
public Request apply(Request request) {

RequestBody body = request.getBody();
List<Param> formParams =
Expand All @@ -68,6 +68,7 @@ public void accept(Request request) {
consumerAuth, requestToken, request.getMethod(), request.getUri(), formParams);

request.getHeaders().set(AUTHORIZATION, authorization);
return request;

} catch (InvalidKeyException e) {
throw new IllegalArgumentException("Failed to compute OAuth signature", e);
Expand Down
Expand Up @@ -396,7 +396,7 @@ void testSignatureGenerationWithAsteriskInPath()
}

@Test
void testPercentEncodeKeyValues() throws Exception {
void testPercentEncodeKeyValues() {
// see https://github.com/AsyncHttpClient/async-http-client/issues/1415
String keyValue = "\u3b05\u000c\u375b";

Expand All @@ -413,6 +413,6 @@ void testPercentEncodeKeyValues() throws Exception {
null)
.build();

calc.accept(request);
calc.apply(request);
}
}
Expand Up @@ -35,9 +35,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.net.ssl.KeyManagerFactory;
Expand Down Expand Up @@ -700,8 +698,8 @@ public HttpProtocolBuilder disableUrlEncoding() {
* @return a new HttpProtocolBuilder instance
*/
@NonNull
public HttpProtocolBuilder sign(@NonNull Consumer<Request> calculator) {
return sign((request, session) -> calculator.accept(request));
public HttpProtocolBuilder sign(@NonNull Function<Request, Request> calculator) {
return sign((request, session) -> calculator.apply(request));
}

/**
Expand All @@ -712,7 +710,7 @@ public HttpProtocolBuilder sign(@NonNull Consumer<Request> calculator) {
* @return a new HttpProtocolBuilder instance
*/
@NonNull
public HttpProtocolBuilder sign(@NonNull BiConsumer<Request, Session> calculator) {
public HttpProtocolBuilder sign(@NonNull BiFunction<Request, Session, Request> calculator) {
return new HttpProtocolBuilder(wrapped.sign(SignatureCalculators.toScala(calculator)));
}

Expand Down
Expand Up @@ -26,8 +26,7 @@
import io.gatling.javaapi.http.internal.SignatureCalculators;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.BiFunction;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -517,8 +516,8 @@ public T proxy(@NonNull Proxy proxy) {
* @return a new DSL instance
*/
@NonNull
public T sign(@NonNull Consumer<Request> calculator) {
return sign((request, session) -> calculator.accept(request));
public T sign(@NonNull Function<Request, Request> calculator) {
return sign((request, session) -> calculator.apply(request));
}

/**
Expand All @@ -529,7 +528,7 @@ public T sign(@NonNull Consumer<Request> calculator) {
* @return a new DSL instance
*/
@NonNull
public T sign(@NonNull BiConsumer<Request, Session> calculator) {
public T sign(@NonNull BiFunction<Request, Session, Request> calculator) {
return make(wrapped -> wrapped.sign(SignatureCalculators.toScala(calculator)));
}

Expand Down
Expand Up @@ -16,14 +16,14 @@

package io.gatling.javaapi.http.internal

import java.util.function.BiConsumer
import java.util.function.BiFunction

import io.gatling.commons.validation._
import io.gatling.core.session.{ Session => ScalaSession }
import io.gatling.http.client.Request
import io.gatling.javaapi.core.Session

object SignatureCalculators {
def toScala(calculator: BiConsumer[Request, Session]): (Request, ScalaSession) => Validation[_] =
(request, session) => calculator.accept(request, new Session(session)).success
def toScala(calculator: BiFunction[Request, Session, Request]): (Request, ScalaSession) => Validation[Request] =
(request, session) => calculator.apply(request, new Session(session)).success
}
Expand Up @@ -95,8 +95,8 @@ public class HttpJavaCompileTest extends Simulation {
.silentResources()
.silentUri("regex")
.disableUrlEncoding()
.sign(request -> {})
.sign((request, session) -> {})
.sign(request -> request)
.sign((request, session) -> request)
.signWithOAuth1("consumerKey", "clientSharedSecret", "token", "tokenSecret")
.signWithOAuth1(
session -> "consumerKey",
Expand Down Expand Up @@ -267,8 +267,8 @@ public class HttpJavaCompileTest extends Simulation {
.virtualHost("virtualHost")
.virtualHost(session -> "virtualHost")
.disableUrlEncoding()
.sign(request -> {})
.sign((request, session) -> {})
.sign(request -> request)
.sign((request, session) -> request)
.signWithOAuth1("consumerKey", "clientSharedSecret", "token", "tokenSecret")
.signWithOAuth1(
session -> "consumerKey",
Expand Down
Expand Up @@ -173,7 +173,7 @@ final case class HttpProtocolRequestPart(
disableUrlEncoding: Boolean,
silentUri: Option[Pattern],
silentResources: Boolean,
signatureCalculator: Option[(Request, Session) => Validation[_]]
signatureCalculator: Option[(Request, Session) => Validation[Request]]
)

final case class HttpProtocolResponsePart(
Expand Down
Expand Up @@ -132,7 +132,7 @@ final case class HttpProtocolBuilder(protocol: HttpProtocol, useOpenSsl: Boolean
def silentResources: HttpProtocolBuilder = this.modify(_.protocol.requestPart.silentResources).setTo(true)
def silentUri(pattern: String): HttpProtocolBuilder = this.modify(_.protocol.requestPart.silentUri).setTo(Some(pattern.r.pattern))
def disableUrlEncoding: HttpProtocolBuilder = this.modify(_.protocol.requestPart.disableUrlEncoding).setTo(true)
def sign(calculator: (Request, Session) => Validation[_]): HttpProtocolBuilder =
def sign(calculator: (Request, Session) => Validation[Request]): HttpProtocolBuilder =
this.modify(_.protocol.requestPart.signatureCalculator).setTo(Some(calculator))
def signWithOAuth1(
consumerKey: Expression[String],
Expand Down
Expand Up @@ -62,7 +62,7 @@ final case class CommonAttributes(
realm: Option[Expression[Realm]],
virtualHost: Option[Expression[String]],
proxy: Option[ProxyServer],
signatureCalculator: Option[(Request, Session) => Validation[_]],
signatureCalculator: Option[(Request, Session) => Validation[Request]],
ignoreProtocolHeaders: Boolean
)

Expand Down Expand Up @@ -97,14 +97,14 @@ object RequestBuilder {
clientSharedSecret: Expression[String],
token: Expression[String],
tokenSecret: Expression[String]
): (Request, Session) => Validation[_] =
): (Request, Session) => Validation[Request] =
(request, session) =>
for {
ck <- consumerKey(session)
css <- clientSharedSecret(session)
tk <- token(session)
tks <- tokenSecret(session)
} yield new OAuthSignatureCalculator(new ConsumerKey(ck, css), new RequestToken(tk, tks)).accept(request)
} yield new OAuthSignatureCalculator(new ConsumerKey(ck, css), new RequestToken(tk, tks)).apply(request)
}

abstract class RequestBuilder[B <: RequestBuilder[B]] {
Expand Down Expand Up @@ -181,7 +181,7 @@ abstract class RequestBuilder[B <: RequestBuilder[B]] {

def proxy(httpProxy: Proxy): B = newInstance(modify(commonAttributes)(_.proxy).setTo(Some(httpProxy.proxyServer)))

def sign(calculator: (Request, Session) => Validation[_]): B = newInstance(modify(commonAttributes)(_.signatureCalculator).setTo(Some(calculator)))
def sign(calculator: (Request, Session) => Validation[Request]): B = newInstance(modify(commonAttributes)(_.signatureCalculator).setTo(Some(calculator)))

def signWithOAuth1(consumerKey: Expression[String], clientSharedSecret: Expression[String], token: Expression[String], tokenSecret: Expression[String]): B =
sign(RequestBuilder.oauth1SignatureCalculator(consumerKey, clientSharedSecret, token, tokenSecret))
Expand Down
Expand Up @@ -201,15 +201,15 @@ abstract class RequestExpressionBuilder(
}
}

private val maybeSignatureCalculator: Option[(Request, Session) => Validation[_]] =
private val maybeSignatureCalculator: Option[(Request, Session) => Validation[Request]] =
commonAttributes.signatureCalculator.orElse(httpProtocol.requestPart.signatureCalculator)
private def configureSignatureCalculator(session: Session, requestBuilder: ClientRequestBuilder): Unit =
maybeSignatureCalculator match {
case Some(signatureCalculator) =>
requestBuilder.setSignatureCalculator { request =>
signatureCalculator(request, session) match {
case Failure(message) => throw new IllegalArgumentException(s"Failed to compute signature: $message")
case _ =>
case Success(signed) => signed
}
}
case _ =>
Expand Down
Expand Up @@ -309,6 +309,7 @@ class HttpCompileTest extends Simulation {
val rawSignature = mac.doFinal(request.getUri.getQuery.getBytes("UTF-8"))
val authorization = Base64.getEncoder.encodeToString(rawSignature)
request.getHeaders.add("Authorization", authorization)
request
}
)
// proxy
Expand Down
Expand Up @@ -69,7 +69,10 @@ class HttpRequestBuilderSpec extends BaseSpec with ValidationValues with EmptySe
}

"signature calculator" should "work when passed as a SignatureCalculator instance" in {
httpRequestDef(_.sign((request, _) => request.getHeaders.add("X-Token", "foo")))
httpRequestDef(_.sign { (request, _) =>
request.getHeaders.add("X-Token", "foo")
request
})
.build(sessionBase)
.map { httpRequest =>
val writableRequest = WritableRequestBuilder.buildRequest(httpRequest.clientRequest, null, false)
Expand Down
Expand Up @@ -48,6 +48,8 @@ class DynatraceSampleJava {
"x-dynaTrace",
"VU=" + VU + ";SI=" + SI + ";TSN=" + TSN + ";LSN=" + LSN + ";LTN=" + LTN + ";PC=" + PC
);

return request;
}
);
//#dynatrace
Expand Down
Expand Up @@ -35,6 +35,8 @@ val httpProtocol = http

request.headers["x-dynaTrace"] =
"VU=$VU;SI=$SI;TSN=$TSN;LSN=$LSN;LTN=$LTN;PC=$PC"

request
}
//#dynatrace
}
Expand Up @@ -52,6 +52,7 @@ val httpProtocol = http
"x-dynaTrace",
s"VU=$VU;SI=$SI;TSN=$TSN;LSN=$LSN;LTN=$LTN;PC=$PC"
)
request
}
//#dynatrace
}

0 comments on commit ee259be

Please sign in to comment.