Skip to content

Commit

Permalink
Add configurations to set CORS allowed origins using regex and predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
seonWKim committed Mar 30, 2024
1 parent db3973d commit 871e82f
Show file tree
Hide file tree
Showing 13 changed files with 630 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.newCloseWebSocketFrame;

import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;

import org.slf4j.Logger;
Expand Down Expand Up @@ -93,19 +94,22 @@ public final class DefaultWebSocketService implements WebSocketService, WebSocke
private final Set<String> subprotocols;
private final Set<String> allowedOrigins;
private final boolean allowAnyOrigin;
private final Predicate<String> originMatchingPredicate;
private final boolean aggregateContinuation;

public DefaultWebSocketService(WebSocketServiceHandler handler, @Nullable HttpService fallbackService,
int maxFramePayloadLength, boolean allowMaskMismatch,
Set<String> subprotocols, Set<String> allowedOrigins,
boolean allowAnyOrigin, boolean aggregateContinuation) {
boolean allowAnyOrigin, Predicate<String> originMatchingPredicate,
boolean aggregateContinuation) {
this.handler = handler;
this.fallbackService = fallbackService;
this.maxFramePayloadLength = maxFramePayloadLength;
this.allowMaskMismatch = allowMaskMismatch;
this.subprotocols = subprotocols;
this.allowedOrigins = allowedOrigins;
this.allowAnyOrigin = allowAnyOrigin;
this.originMatchingPredicate = originMatchingPredicate;
this.aggregateContinuation = aggregateContinuation;
}

Expand Down Expand Up @@ -286,17 +290,17 @@ private HttpResponse checkOrigin(ServiceRequestContext ctx, RequestHeaders heade
"missing the origin header");
}

if (allowedOrigins.isEmpty()) {
if (allowedOrigins.isEmpty() && !originMatchingPredicate.test(origin)) {
// Only the same-origin is allowed.
if (!isSameOrigin(ctx, headers, origin)) {
return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8,
"not allowed origin: " + origin);
}
return null;
}
if (!allowedOrigins.contains(origin)) {
if (!allowedOrigins.contains(origin) && !originMatchingPredicate.test(origin)) {
return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8,
"not allowed origin: " + origin + ", allowed: " + allowedOrigins);
"not allowed origin or pattern: " + origin + ", allowed: " + allowedOrigins);
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
* Allowed origins.
* Sets this property to be {@code "*"} to allow any origin.
*/
String[] origins();
String[] origins() default {};

/**
* Specify allowed origins by regular expression.
*/
String originRegex() default "";

/**
* The path patterns that this policy is supposed to be applied to. If unspecified, all paths would be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.regex.Pattern;

import com.google.common.base.Ascii;
import com.google.common.collect.ImmutableList;
Expand All @@ -38,6 +40,7 @@

import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.server.Route;
import com.linecorp.armeria.server.annotation.AdditionalHeader;
import com.linecorp.armeria.server.annotation.decorator.CorsDecorator;
Expand All @@ -54,6 +57,8 @@
abstract class AbstractCorsPolicyBuilder {

private final Set<String> origins;
@Nullable
private Predicate<String> originPredicate;
private final List<Route> routes = new ArrayList<>();
private boolean credentialsAllowed;
private boolean nullOriginAllowed;
Expand All @@ -66,19 +71,27 @@ abstract class AbstractCorsPolicyBuilder {
private final Map<AsciiString, Supplier<?>> preflightResponseHeaders = new HashMap<>();
private boolean preflightResponseHeadersDisabled;

AbstractCorsPolicyBuilder() {
origins = Collections.emptySet();
}

AbstractCorsPolicyBuilder(List<String> origins) {
requireNonNull(origins, "origins");
checkArgument(!origins.isEmpty(), "origins is empty.");
for (int i = 0; i < origins.size(); i++) {
if (origins.get(i) == null) {
throw new NullPointerException("origins[" + i + ']');
}
}
this.origins = origins.stream().map(Ascii::toLowerCase).collect(toImmutableSet());
originPredicate = this.origins::contains;
}

AbstractCorsPolicyBuilder(Predicate<String> originPredicate) {
requireNonNull(originPredicate, "originPredicate");
origins = Collections.emptySet();
this.originPredicate = originPredicate;
}

AbstractCorsPolicyBuilder(Pattern originRegex) {
requireNonNull(originRegex, "originRegex");
origins = Collections.emptySet();
originPredicate = origin -> originRegex.matcher(origin).matches();
}

final void setConfig(CorsDecorator corsDecorator) {
Expand Down Expand Up @@ -428,8 +441,8 @@ public AbstractCorsPolicyBuilder disablePreflightResponseHeaders() {
* Returns a newly-created {@link CorsPolicy} based on the properties of this builder.
*/
CorsPolicy build() {
return new CorsPolicy(origins, routes, credentialsAllowed, maxAge, nullOriginAllowed,
exposedHeaders, allowAllRequestHeaders, allowedRequestHeaders,
return new CorsPolicy(origins, originPredicate, routes, credentialsAllowed, maxAge,
nullOriginAllowed, exposedHeaders, allowAllRequestHeaders, allowedRequestHeaders,
allowedRequestMethods, preflightResponseHeadersDisabled,
preflightResponseHeaders);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@

import static java.util.Objects.requireNonNull;

import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.regex.Pattern;

import com.linecorp.armeria.common.HttpMethod;

Expand All @@ -36,6 +39,7 @@ public final class ChainedCorsPolicyBuilder extends AbstractCorsPolicyBuilder {
private final CorsServiceBuilder serviceBuilder;

ChainedCorsPolicyBuilder(CorsServiceBuilder builder) {
super(Collections.singletonList("*"));
requireNonNull(builder, "builder");
serviceBuilder = builder;
}
Expand All @@ -46,6 +50,18 @@ public final class ChainedCorsPolicyBuilder extends AbstractCorsPolicyBuilder {
serviceBuilder = builder;
}

ChainedCorsPolicyBuilder(CorsServiceBuilder builder, Predicate<String> originPredicate) {
super(originPredicate);
requireNonNull(builder, "builder");
serviceBuilder = builder;
}

ChainedCorsPolicyBuilder(CorsServiceBuilder builder, Pattern originRegex) {
super(originRegex);
requireNonNull(builder, "builder");
serviceBuilder = builder;
}

/**
* Returns the parent {@link CorsServiceBuilder}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ public CorsPolicy getPolicy(@Nullable String origin, RoutingContext routingConte
if (isNullOrigin && policy.isNullOriginAllowed() &&
isPathMatched(policy, routingContext)) {
return policy;
} else if (!isNullOrigin && policy.origins().contains(lowerCaseOrigin) &&
isPathMatched(policy, routingContext)) {
return policy;
} else if (!isNullOrigin && isPathMatched(policy, routingContext)) {
if (policy.origins().contains(lowerCaseOrigin) ||
(policy.originPredicate() != null && policy.originPredicate().test(origin))) {
return policy;
}
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

import static java.util.Objects.requireNonNull;

import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.regex.Pattern;

import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.annotation.DecoratorFactoryFunction;
Expand All @@ -31,7 +35,28 @@ public final class CorsDecoratorFactoryFunction implements DecoratorFactoryFunct
@Override
public Function<? super HttpService, ? extends HttpService> newDecorator(CorsDecorator parameter) {
requireNonNull(parameter, "parameter");
final CorsServiceBuilder cb = CorsService.builder(parameter.origins());
if (parameter.origins().length == 0 && parameter.originRegex().isEmpty()) {
throw new IllegalArgumentException("Either origins or originRegex must be configured");
}

final CorsServiceBuilder cb;
final List<String> origins = Arrays.asList(parameter.origins());
if (!origins.isEmpty() && origins.contains("*")) {
cb = CorsService.builderForAnyOrigin();
} else {
Predicate<String> originPredicate = (unused) -> false;
for (String origin: origins) {
originPredicate = originPredicate.or(Predicate.isEqual(origin));
}

if (!parameter.originRegex().isEmpty()) {
final Pattern pattern = Pattern.compile(parameter.originRegex());
originPredicate = originPredicate.or(pattern.asPredicate());
}

cb = CorsService.builder(originPredicate);
}

cb.firstPolicyBuilder.setConfig(parameter);

final Function<? super HttpService, CorsService> decorator = cb.newDecorator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import com.google.common.base.Joiner;
Expand Down Expand Up @@ -76,7 +78,30 @@ public static CorsPolicyBuilder builder(Iterable<String> origins) {
return new CorsPolicyBuilder(origins);
}

/**
* Returns a new {@link CorsPolicyBuilder} with origins matching the {@code predicate}.
*/
public static CorsPolicyBuilder builder(Predicate<String> predicate) {
return new CorsPolicyBuilder(predicate);
}

/**
* Returns a new {@link CorsPolicyBuilder} with origins matching the {@code regex}.
*/
public static CorsPolicyBuilder builderForOriginRegex(String regex) {
return builderForOriginRegex(Pattern.compile(regex));
}

/**
* Returns a new {@link CorsPolicyBuilder} with origins matching the {@code regex}.
*/
public static CorsPolicyBuilder builderForOriginRegex(Pattern regex) {
return new CorsPolicyBuilder(regex);
}

private final Set<String> origins;
@Nullable
private final Predicate<String> originPredicate;
private final List<Route> routes;
private final boolean credentialsAllowed;
private final boolean nullOriginAllowed;
Expand All @@ -90,12 +115,14 @@ public static CorsPolicyBuilder builder(Iterable<String> origins) {
private final String joinedAllowedRequestMethods;
private final Map<AsciiString, Supplier<?>> preflightResponseHeaders;

CorsPolicy(Set<String> origins, List<Route> routes, boolean credentialsAllowed, long maxAge,
CorsPolicy(Set<String> origins, @Nullable Predicate<String> originPredicate,
List<Route> routes, boolean credentialsAllowed, long maxAge,
boolean nullOriginAllowed, Set<AsciiString> exposedHeaders,
boolean allowAllRequestHeaders, Set<AsciiString> allowedRequestHeaders,
EnumSet<HttpMethod> allowedRequestMethods, boolean preflightResponseHeadersDisabled,
Map<AsciiString, Supplier<?>> preflightResponseHeaders) {
this.origins = ImmutableSet.copyOf(origins);
this.originPredicate = originPredicate;
this.routes = ImmutableList.copyOf(routes);
this.credentialsAllowed = credentialsAllowed;
this.maxAge = maxAge;
Expand Down Expand Up @@ -136,6 +163,14 @@ public Set<String> origins() {
return origins;
}

/**
* Returns predicate to match origins.
*/
@Nullable
public Predicate<String> originPredicate() {
return originPredicate;
}

/**
* Returns the list of {@link Route}s that this policy is supposed to be applied to.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
*/
package com.linecorp.armeria.server.cors;

import java.util.Collections;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.regex.Pattern;

import com.google.common.collect.ImmutableList;

Expand All @@ -36,7 +39,9 @@
*/
public final class CorsPolicyBuilder extends AbstractCorsPolicyBuilder {

CorsPolicyBuilder() {}
CorsPolicyBuilder() {
super(Collections.singletonList("*"));
}

CorsPolicyBuilder(String... origins) {
super(ImmutableList.copyOf(origins));
Expand All @@ -46,6 +51,14 @@ public final class CorsPolicyBuilder extends AbstractCorsPolicyBuilder {
super(ImmutableList.copyOf(origins));
}

CorsPolicyBuilder(Predicate<String> predicate) {
super(predicate);
}

CorsPolicyBuilder(Pattern regex) {
super(regex);
}

/**
* Returns a newly-created {@link CorsPolicy} based on the properties of this builder.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import static java.util.Objects.requireNonNull;

import java.util.List;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import org.slf4j.Logger;
Expand Down Expand Up @@ -84,6 +86,30 @@ public static CorsServiceBuilder builder(Iterable<String> origins) {
return new CorsServiceBuilder(copied);
}

/**
* Returns a new {@link CorsServiceBuilder} with origins matching the {@code originPredicate}.
*/
public static CorsServiceBuilder builder(Predicate<String> originPredicate) {
requireNonNull(originPredicate, "originPredicate");
return new CorsServiceBuilder(originPredicate);
}

/**
* Returns a new {@link CorsServiceBuilder} with origins matching the {@code originRegex}.
*/
public static CorsServiceBuilder builderForOriginRegex(String originRegex) {
requireNonNull(originRegex, "originRegex");
return builderForOriginRegex(Pattern.compile(originRegex));
}

/**
* Returns a new {@link CorsServiceBuilder} with origins matching the {@code originRegex}.
*/
public static CorsServiceBuilder builderForOriginRegex(Pattern originRegex) {
requireNonNull(originRegex, "originRegex");
return new CorsServiceBuilder(originRegex);
}

private final CorsConfig config;

CorsService(HttpService delegate, CorsConfig config) {
Expand Down

0 comments on commit 871e82f

Please sign in to comment.