Skip to content

Commit

Permalink
Change originRegex into (origin) -> originRegex.matcher(origin).match…
Browse files Browse the repository at this point in the history
…es()
  • Loading branch information
seonWKim committed Jun 23, 2023
1 parent 2948ead commit 815cc51
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ abstract class AbstractCorsPolicyBuilder {

private Set<String> origins = Collections.emptySet();
@Nullable
private Pattern originRegex;
@Nullable
private Predicate<String> originPredicate;
private final List<Route> routes = new ArrayList<>();
private boolean credentialsAllowed;
Expand Down Expand Up @@ -92,7 +90,7 @@ abstract class AbstractCorsPolicyBuilder {

AbstractCorsPolicyBuilder(Pattern originRegex) {
requireNonNull(originRegex, "originRegex");
this.originRegex = originRegex;
this.originPredicate = (origin) -> originRegex.matcher(origin).matches();
}

final void setConfig(CorsDecorator corsDecorator) {
Expand Down Expand Up @@ -442,7 +440,7 @@ public AbstractCorsPolicyBuilder disablePreflightResponseHeaders() {
* Returns a newly-created {@link CorsPolicy} based on the properties of this builder.
*/
CorsPolicy build() {
return new CorsPolicy(origins, originPredicate, originRegex, routes, credentialsAllowed, maxAge,
return new CorsPolicy(origins, originPredicate, routes, credentialsAllowed, maxAge,
nullOriginAllowed, exposedHeaders, allowAllRequestHeaders, allowedRequestHeaders,
allowedRequestMethods, preflightResponseHeadersDisabled,
preflightResponseHeaders);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,9 @@ public CorsPolicy getPolicy(@Nullable String origin, RoutingContext routingConte
if (isNullOrigin && policy.isNullOriginAllowed() &&
isPathMatched(policy, routingContext)) {
return policy;
} else if (!isNullOrigin && policy.origins().contains(lowerCaseOrigin) &&
isPathMatched(policy, routingContext)) {
return policy;
}

if (policy.originPredicate() != null) {
if (policy.originPredicate().test(origin) && isPathMatched(policy, routingContext)) {
return policy;
}
}

if (policy.originRegex() != null) {
if (policy.originRegex().matcher(origin).matches() && isPathMatched(policy, routingContext)) {
} else if (!isNullOrigin && isPathMatched(policy, routingContext)) {
if (policy.origins().contains(lowerCaseOrigin) ||
(policy.originPredicate() != null && policy.originPredicate().test(origin))) {
return policy;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ public static CorsPolicyBuilder builderForOriginRegex(Pattern regex) {
private final Set<String> origins;
@Nullable
private final Predicate<String> originPredicate;
@Nullable
private final Pattern originRegex;
private final List<Route> routes;
private final boolean credentialsAllowed;
private final boolean nullOriginAllowed;
Expand All @@ -117,15 +115,14 @@ public static CorsPolicyBuilder builderForOriginRegex(Pattern regex) {
private final String joinedAllowedRequestMethods;
private final Map<AsciiString, Supplier<?>> preflightResponseHeaders;

CorsPolicy(Set<String> origins, @Nullable Predicate<String> originPredicate, @Nullable Pattern originRegex,
CorsPolicy(Set<String> origins, @Nullable Predicate<String> originPredicate,
List<Route> routes, boolean credentialsAllowed, long maxAge,
boolean nullOriginAllowed, Set<AsciiString> exposedHeaders,
boolean allowAllRequestHeaders, Set<AsciiString> allowedRequestHeaders,
EnumSet<HttpMethod> allowedRequestMethods, boolean preflightResponseHeadersDisabled,
Map<AsciiString, Supplier<?>> preflightResponseHeaders) {
this.origins = ImmutableSet.copyOf(origins);
this.originPredicate = originPredicate;
this.originRegex = originRegex;
this.routes = ImmutableList.copyOf(routes);
this.credentialsAllowed = credentialsAllowed;
this.maxAge = maxAge;
Expand Down Expand Up @@ -174,14 +171,6 @@ public Predicate<String> originPredicate() {
return originPredicate;
}

/**
* Returns the regular expression to match origins.
*/
@Nullable
public Pattern originRegex() {
return originRegex;
}

/**
* Returns the list of {@link Route}s that this policy is supposed to be applied to.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,14 @@ public ChainedCorsPolicyBuilder andForOrigin(String origin) {
return andForOrigins(ImmutableList.of(origin));
}

/**
* Creates a new builder instance for a new {@link CorsPolicy}.
* @return {@link ChainedCorsPolicyBuilder} to support method chaining.
*/
public ChainedCorsPolicyBuilder andForOrigin(Predicate<String> originPredicate) {
return new ChainedCorsPolicyBuilder(this, originPredicate);
}

/**
* Creates a new builder instance for a new {@link CorsPolicy}.
* @return {@link ChainedCorsPolicyBuilder} to support method chaining.
Expand All @@ -481,8 +489,7 @@ public ChainedCorsPolicyBuilder andForOriginRegex(String regex) {
* @return {@link ChainedCorsPolicyBuilder} to support method chaining.
*/
public ChainedCorsPolicyBuilder andForOriginRegex(Pattern regex) {
// TODO
return new ChainedCorsPolicyBuilder(this);
return new ChainedCorsPolicyBuilder(this, regex);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.util.Arrays;
import java.util.function.Function;
import java.util.regex.Pattern;

import org.junit.ClassRule;
import org.junit.Test;
Expand All @@ -43,6 +44,7 @@
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.annotation.AdditionalHeader;
import com.linecorp.armeria.server.annotation.ConsumesJson;
import com.linecorp.armeria.server.annotation.Delete;
import com.linecorp.armeria.server.annotation.Get;
import com.linecorp.armeria.server.annotation.Options;
import com.linecorp.armeria.server.annotation.Param;
Expand Down Expand Up @@ -98,7 +100,7 @@ public void index() {}
}

@CorsDecorator(
originRegex = "http://example.*",
originRegex = "http:\\/\\/example.*",
allowedRequestMethods = HttpMethod.GET
)
private static class MyAnnotatedService4 {
Expand Down Expand Up @@ -296,7 +298,7 @@ public void actors(@Param String title) {}
sb.annotatedService("/cors14", new MyAnnotatedService3());

sb.service("/cors15", myService.decorate(
CorsService.builderForOriginRegex("http://example.*")
CorsService.builderForOriginRegex("^http:\\/\\/.*example.com$")
.shortCircuit()
.allowRequestMethods(HttpMethod.GET)
.newDecorator()));
Expand All @@ -307,6 +309,30 @@ public void actors(@Param String title) {}
.shortCircuit()
.allowRequestMethods(HttpMethod.GET)
.newDecorator()));

sb.annotatedService("/cors18", new Object() {
@Get("/index1")
public void index1() {}

@Post("/index2")
public void index2() {}

@Delete("/index3")
public void index3() {}
}, CorsService.builder()
.andForOriginRegex("^http:\\/\\/example.*")
.route("/cors18/index1")
.allowRequestMethods(HttpMethod.GET)
.and()
.andForOriginRegex(Pattern.compile(".*line.*"))
.route("/cors18/index2")
.allowRequestMethods(HttpMethod.POST)
.and()
.andForOrigin((origin) -> origin.contains("armeria"))
.route("/cors18/index3")
.allowRequestMethods(HttpMethod.DELETE)
.and()
.newDecorator());
}
};

Expand Down Expand Up @@ -712,9 +738,9 @@ public void testBuilderForOriginRegex() {

res = request(client, HttpMethod.GET, "/cors15", "http://example.com", "GET");
assertThat(res.status()).isEqualTo(HttpStatus.OK);
res = request(client, HttpMethod.GET, "/cors15", "http://example1.com", "GET");
res = request(client, HttpMethod.GET, "/cors15", "http://1.example.com", "GET");
assertThat(res.status()).isEqualTo(HttpStatus.OK);
res = request(client, HttpMethod.GET, "/cors15", "http://example.org", "GET");
res = request(client, HttpMethod.GET, "/cors15", "http://2.example.com", "GET");
assertThat(res.status()).isEqualTo(HttpStatus.OK);

res = request(client, HttpMethod.GET, "/cors15", "http://invalid.com", "GET");
Expand Down Expand Up @@ -755,4 +781,25 @@ public void testOriginPredicate() {
res = request(client, HttpMethod.GET, "/cors17", "http://invalid.com", "GET");
assertThat(res.status()).isEqualTo(HttpStatus.FORBIDDEN);
}

@Test
public void testOriginRegexAndPredicatePerRoute() {
final WebClient client = client();
AggregatedHttpResponse res;

res = preflightRequest(client, "/cors18/index1", "http://example.com", "GET");
assertThat(res.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS)).isEqualTo("GET");
res = preflightRequest(client, "/cors18/index1", "http://invalid.com", "GET");
assertThat(res.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS)).isNull();

res = preflightRequest(client, "/cors18/index2", "http://line.com", "POST");
assertThat(res.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS)).isEqualTo("POST");
res = preflightRequest(client, "/cors18/index2", "http://invalid.com", "GET");
assertThat(res.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS)).isNull();

res = preflightRequest(client, "/cors18/index3", "http://armeria.com", "DELETE");
assertThat(res.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS)).isEqualTo("DELETE");
res = preflightRequest(client, "/cors18/index3", "http://invalid.com", "DELETE");
assertThat(res.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS)).isNull();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import com.linecorp.armeria.common.websocket.WebSocketWriter;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.logging.LoggingService;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import io.netty.handler.codec.http.HttpHeaderValues;
Expand All @@ -50,12 +49,11 @@ class WebSocketServiceCorsTest {
static final ServerExtension server1 = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
sb.tlsSelfSigned()
.http(0)
.service("/chat", WebSocketService.builder(new CustomWebSocketServiceHandler())
.allowedOrigins("*")
.build())
.decorator(LoggingService.newDecorator());
sb.route()
.path("/chat")
.build(WebSocketService.builder(new CustomWebSocketServiceHandler())
.allowedOrigins("*")
.build());
}
};

Expand Down

0 comments on commit 815cc51

Please sign in to comment.