diff --git a/auth/pom.xml b/auth/pom.xml index 42c6b25c3de..c15f57f5d14 100644 --- a/auth/pom.xml +++ b/auth/pom.xml @@ -21,6 +21,8 @@ 3.10 1.3.2 4.13 + 2.8.0 + 0.20.0 @@ -28,20 +30,24 @@ feast-common ${project.version} + + org.springframework + spring-context-support + net.devh grpc-server-spring-boot-starter - 2.4.0.RELEASE + ${grpc.spring.boot.starter.version} org.springframework.security spring-security-oauth2-resource-server - 5.3.0.RELEASE + ${spring.security.version} org.springframework.security spring-security-oauth2-jose - 5.3.0.RELEASE + ${spring.security.version} org.projectlombok @@ -56,10 +62,6 @@ com.fasterxml.jackson.core jackson-databind - - junit - junit - io.swagger swagger-annotations @@ -91,6 +93,71 @@ jsr305 3.0.2 + + org.springframework + spring-test + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + org.springframework.boot + spring-boot-starter-web + + + io.springfox + springfox-swagger2 + ${springfox-version} + + + io.springfox + springfox-swagger-ui + ${springfox-version} + + + javax.xml.bind + jaxb-api + 2.2.11 + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + org.openapitools + jackson-databind-nullable + 0.1.0 + + + + javax.validation + validation-api + + + org.springframework.boot + spring-boot-starter-test + test + + + org.junit.vintage + junit-vintage-engine + + + + + junit + junit + 4.12 + + + com.google.auth + google-auth-library-oauth2-http + ${google-auth-library-oauth2-http-version} + @@ -131,6 +198,10 @@ feast.auth.generated.client.api + + org.jacoco + jacoco-maven-plugin + diff --git a/auth/src/main/java/feast/auth/authorization/HttpAuthorizationProvider.java b/auth/src/main/java/feast/auth/authorization/HttpAuthorizationProvider.java index 6abe76f3b20..44c0a49634d 100644 --- a/auth/src/main/java/feast/auth/authorization/HttpAuthorizationProvider.java +++ b/auth/src/main/java/feast/auth/authorization/HttpAuthorizationProvider.java @@ -16,16 +16,17 @@ */ package feast.auth.authorization; +import feast.auth.config.CacheConfiguration; import feast.auth.generated.client.api.DefaultApi; import feast.auth.generated.client.invoker.ApiClient; import feast.auth.generated.client.invoker.ApiException; import feast.auth.generated.client.model.CheckAccessRequest; +import feast.auth.utils.AuthUtils; import java.util.Map; -import org.hibernate.validator.internal.constraintvalidators.bv.EmailValidator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.cache.annotation.Cacheable; import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.jwt.Jwt; /** * HTTPAuthorizationProvider uses an external HTTP service for authorizing requests. Please see @@ -41,7 +42,7 @@ public class HttpAuthorizationProvider implements AuthorizationProvider { * The default subject claim is the key within the Authentication object where the user's identity * can be found */ - private final String DEFAULT_SUBJECT_CLAIM = "email"; + private final String subjectClaim; /** * Initializes the HTTPAuthorizationProvider @@ -58,26 +59,29 @@ public HttpAuthorizationProvider(Map options) { ApiClient apiClient = new ApiClient(); apiClient.setBasePath(options.get("authorizationUrl")); this.defaultApiClient = new DefaultApi(apiClient); + subjectClaim = options.get("subjectClaim"); } /** - * Validates whether a user has access to a project + * Validates whether a user has access to a project. @Cacheable is using {@link + * CacheConfiguration} settings to cache output of the method {@link AuthorizationResult} for a + * specified duration set in cache settings. * * @param projectId Name of the Feast project * @param authentication Spring Security Authentication object * @return AuthorizationResult result of authorization query */ + @Cacheable(value = CacheConfiguration.AUTHORIZATION_CACHE, keyGenerator = "authKeyGenerator") public AuthorizationResult checkAccessToProject(String projectId, Authentication authentication) { CheckAccessRequest checkAccessRequest = new CheckAccessRequest(); Object context = getContext(authentication); - String subject = getSubjectFromAuth(authentication, DEFAULT_SUBJECT_CLAIM); + String subject = AuthUtils.getSubjectFromAuth(authentication, subjectClaim); String resource = "projects:" + projectId; checkAccessRequest.setAction("ALL"); checkAccessRequest.setContext(context); checkAccessRequest.setResource(resource); checkAccessRequest.setSubject(subject); - try { // Make authorization request to external service feast.auth.generated.client.model.AuthorizationResult authResult = @@ -112,31 +116,4 @@ private Object getContext(Authentication authentication) { // Not implemented yet, left empty return new Object(); } - - /** - * Get user email from their authentication object. - * - * @param authentication Spring Security Authentication object, used to extract user details - * @param subjectClaim Indicates the claim where the subject can be found - * @return String user email - */ - private String getSubjectFromAuth(Authentication authentication, String subjectClaim) { - Jwt principle = ((Jwt) authentication.getPrincipal()); - Map claims = principle.getClaims(); - String subjectValue = (String) claims.get(subjectClaim); - - if (subjectValue.isEmpty()) { - throw new IllegalStateException( - String.format("JWT does not have a valid claim %s.", subjectClaim)); - } - - if (subjectClaim.equals("email")) { - boolean validEmail = (new EmailValidator()).isValid(subjectValue, null); - if (!validEmail) { - throw new IllegalStateException("JWT contains an invalid email address"); - } - } - - return subjectValue; - } } diff --git a/auth/src/main/java/feast/auth/config/CacheConfiguration.java b/auth/src/main/java/feast/auth/config/CacheConfiguration.java new file mode 100644 index 00000000000..e8c46b3613c --- /dev/null +++ b/auth/src/main/java/feast/auth/config/CacheConfiguration.java @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.auth.config; + +import com.google.common.cache.CacheBuilder; +import feast.auth.utils.AuthUtils; +import java.lang.reflect.Method; +import java.util.concurrent.TimeUnit; +import lombok.Getter; +import lombok.Setter; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.cache.Cache; +import org.springframework.cache.CacheManager; +import org.springframework.cache.annotation.CachingConfigurer; +import org.springframework.cache.annotation.EnableCaching; +import org.springframework.cache.concurrent.ConcurrentMapCache; +import org.springframework.cache.concurrent.ConcurrentMapCacheManager; +import org.springframework.cache.interceptor.CacheErrorHandler; +import org.springframework.cache.interceptor.CacheResolver; +import org.springframework.cache.interceptor.KeyGenerator; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.core.Authentication; + +/** CacheConfiguration class defines Cache settings for HttpAuthorizationProvider class. */ +@Configuration +@EnableCaching +@Setter +@Getter +public class CacheConfiguration implements CachingConfigurer { + + private static final int CACHE_SIZE = 10000; + + public static int TTL = 60; + + public static final String AUTHORIZATION_CACHE = "authorization"; + + @Autowired SecurityProperties secutiryProps; + + @Bean + public CacheManager cacheManager() { + ConcurrentMapCacheManager cacheManager = + new ConcurrentMapCacheManager(AUTHORIZATION_CACHE) { + + @Override + protected Cache createConcurrentMapCache(final String name) { + return new ConcurrentMapCache( + name, + CacheBuilder.newBuilder() + .expireAfterWrite(TTL, TimeUnit.SECONDS) + .maximumSize(CACHE_SIZE) + .build() + .asMap(), + false); + } + }; + + return cacheManager; + } + + /* + * KeyGenerator used by {@link Cacheable} for caching authorization requests. + * Key format : checkAccessToProject-- + */ + @Bean + public KeyGenerator authKeyGenerator() { + return (Object target, Method method, Object... params) -> { + String projectId = (String) params[0]; + Authentication authentication = (Authentication) params[1]; + String subject = + AuthUtils.getSubjectFromAuth( + authentication, secutiryProps.getAuthorization().getOptions().get("subjectClaim")); + return String.format("%s-%s-%s", method.getName(), projectId, subject); + }; + } + + @Override + public CacheResolver cacheResolver() { + // TODO Auto-generated method stub + return null; + } + + @Override + public KeyGenerator keyGenerator() { + return null; + } + + @Override + public CacheErrorHandler errorHandler() { + // TODO Auto-generated method stub + return null; + } +} diff --git a/auth/src/main/java/feast/auth/config/SecurityConfig.java b/auth/src/main/java/feast/auth/config/SecurityConfig.java index f377c76a874..8229702b3ed 100644 --- a/auth/src/main/java/feast/auth/config/SecurityConfig.java +++ b/auth/src/main/java/feast/auth/config/SecurityConfig.java @@ -83,13 +83,13 @@ GrpcAuthenticationReader authenticationReader() { } /** - * Creates an AccessDecisionManager if authorization is enabled. This object determines the policy - * used to make authorization decisions. + * Creates an AccessDecisionManager if authentication is enabled. This object determines the + * policy used to make authentication decisions. * * @return AccessDecisionManager */ @Bean - @ConditionalOnProperty(prefix = "feast.security.authorization", name = "enabled") + @ConditionalOnProperty(prefix = "feast.security.authentication", name = "enabled") AccessDecisionManager accessDecisionManager() { final List> voters = new ArrayList<>(); voters.add(new AccessPredicateVoter()); diff --git a/auth/src/main/java/feast/auth/credentials/CoreAuthenticationProperties.java b/auth/src/main/java/feast/auth/credentials/CoreAuthenticationProperties.java new file mode 100644 index 00000000000..e307dfb1c83 --- /dev/null +++ b/auth/src/main/java/feast/auth/credentials/CoreAuthenticationProperties.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.auth.credentials; + +import feast.common.validators.OneOfStrings; +import java.util.Map; + +public class CoreAuthenticationProperties { + // needs to be set to true if authentication is enabled on core + private boolean enabled; + + // authentication provider to use + @OneOfStrings({"google", "oauth"}) + private String provider; + + // K/V options to initialize the provider. + Map options; + + public boolean isEnabled() { + return enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getProvider() { + return provider; + } + + public void setProvider(String provider) { + this.provider = provider; + } + + public Map getOptions() { + return options; + } + + public void setOptions(Map options) { + this.options = options; + } +} diff --git a/auth/src/main/java/feast/auth/credentials/GoogleAuthCredentials.java b/auth/src/main/java/feast/auth/credentials/GoogleAuthCredentials.java new file mode 100644 index 00000000000..709b803ce08 --- /dev/null +++ b/auth/src/main/java/feast/auth/credentials/GoogleAuthCredentials.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.auth.credentials; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +import com.google.auth.oauth2.IdTokenCredentials; +import com.google.auth.oauth2.ServiceAccountCredentials; +import io.grpc.CallCredentials; +import io.grpc.Metadata; +import io.grpc.Status; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.Executor; + +/* + * Google auth provider's callCredentials Implementation for serving. + * Used by CoreSpecService to connect to core. + */ +public class GoogleAuthCredentials extends CallCredentials { + private final IdTokenCredentials credentials; + private static final String BEARER_TYPE = "Bearer"; + private static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); + + public GoogleAuthCredentials(Map options) throws IOException { + + String targetAudience = options.getOrDefault("audience", "https://localhost"); + ServiceAccountCredentials serviceCreds = + (ServiceAccountCredentials) + ServiceAccountCredentials.getApplicationDefault() + .createScoped(Arrays.asList("openid", "email")); + + credentials = + IdTokenCredentials.newBuilder() + .setIdTokenProvider(serviceCreds) + .setTargetAudience(targetAudience) + .build(); + } + + @Override + public void applyRequestMetadata( + RequestInfo requestInfo, Executor appExecutor, MetadataApplier applier) { + appExecutor.execute( + () -> { + try { + credentials.refreshIfExpired(); + Metadata headers = new Metadata(); + headers.put( + AUTHORIZATION_METADATA_KEY, + String.format("%s %s", BEARER_TYPE, credentials.getIdToken().getTokenValue())); + applier.apply(headers); + } catch (Throwable e) { + applier.fail(Status.UNAUTHENTICATED.withCause(e)); + } + }); + } + + @Override + public void thisUsesUnstableApi() { + // TODO Auto-generated method stub + + } +} diff --git a/auth/src/main/java/feast/auth/credentials/OAuthCredentials.java b/auth/src/main/java/feast/auth/credentials/OAuthCredentials.java new file mode 100644 index 00000000000..e7ad47f3778 --- /dev/null +++ b/auth/src/main/java/feast/auth/credentials/OAuthCredentials.java @@ -0,0 +1,120 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.auth.credentials; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +import com.nimbusds.jose.util.JSONObjectUtils; +import io.grpc.CallCredentials; +import io.grpc.Metadata; +import io.grpc.Status; +import java.time.Instant; +import java.util.Map; +import java.util.concurrent.Executor; +import javax.security.sasl.AuthenticationException; +import net.minidev.json.JSONObject; +import okhttp3.FormBody; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; + +/* + * Oauth Credentials Implementation for serving. + * Used by CoreSpecService to connect to core. + */ +public class OAuthCredentials extends CallCredentials { + + private static final String JWK_ENDPOINT_URI = "jwkEndpointURI"; + static final String APPLICATION_JSON = "application/json"; + static final String CONTENT_TYPE = "content-type"; + static final String BEARER_TYPE = "Bearer"; + static final String GRANT_TYPE = "grant_type"; + static final String CLIENT_ID = "client_id"; + static final String CLIENT_SECRET = "client_secret"; + static final String AUDIENCE = "audience"; + static final String OAUTH_URL = "oauth_url"; + static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); + + private OkHttpClient httpClient; + private Request request; + private String accessToken; + private Instant tokenExpiryTime; + private NimbusJwtDecoder jwtDecoder; + + public OAuthCredentials(Map options) { + this.httpClient = new OkHttpClient(); + if (!(options.containsKey(GRANT_TYPE) + && options.containsKey(CLIENT_ID) + && options.containsKey(AUDIENCE) + && options.containsKey(CLIENT_SECRET) + && options.containsKey(OAUTH_URL) + && options.containsKey(JWK_ENDPOINT_URI))) { + throw new AssertionError( + "please configure the properties:" + + " grant_type, client_id, client_secret, audience, oauth_url, jwkEndpointURI"); + } + RequestBody requestBody = + new FormBody.Builder() + .add(GRANT_TYPE, options.get(GRANT_TYPE)) + .add(CLIENT_ID, options.get(CLIENT_ID)) + .add(CLIENT_SECRET, options.get(CLIENT_SECRET)) + .add(AUDIENCE, options.get(AUDIENCE)) + .build(); + this.request = + new Request.Builder() + .url(options.get(OAUTH_URL)) + .addHeader(CONTENT_TYPE, APPLICATION_JSON) + .post(requestBody) + .build(); + this.jwtDecoder = NimbusJwtDecoder.withJwkSetUri(options.get(JWK_ENDPOINT_URI)).build(); + } + + @Override + public void thisUsesUnstableApi() { + // TODO Auto-generated method stub + + } + + @Override + public void applyRequestMetadata( + RequestInfo requestInfo, Executor appExecutor, MetadataApplier applier) { + appExecutor.execute( + () -> { + try { + // Fetches new token if it is not available or if token has expired. + if (this.accessToken == null || Instant.now().isAfter(this.tokenExpiryTime)) { + Response response = httpClient.newCall(request).execute(); + if (!response.isSuccessful()) { + throw new AuthenticationException(response.message()); + } + JSONObject json = JSONObjectUtils.parse(response.body().string()); + this.accessToken = json.getAsString("access_token"); + this.tokenExpiryTime = jwtDecoder.decode(this.accessToken).getExpiresAt(); + } + Metadata headers = new Metadata(); + headers.put( + AUTHORIZATION_METADATA_KEY, String.format("%s %s", BEARER_TYPE, this.accessToken)); + applier.apply(headers); + } catch (Throwable e) { + applier.fail(Status.UNAUTHENTICATED.withCause(e)); + } + }); + } +} diff --git a/auth/src/main/java/feast/auth/service/AuthorizationService.java b/auth/src/main/java/feast/auth/service/AuthorizationService.java new file mode 100644 index 00000000000..24942611857 --- /dev/null +++ b/auth/src/main/java/feast/auth/service/AuthorizationService.java @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.auth.service; + +import feast.auth.authorization.AuthorizationProvider; +import feast.auth.authorization.AuthorizationResult; +import feast.auth.config.SecurityProperties; +import lombok.AllArgsConstructor; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.stereotype.Service; + +@AllArgsConstructor +@Service +public class AuthorizationService { + + private final SecurityProperties securityProperties; + private final AuthorizationProvider authorizationProvider; + + @Autowired + public AuthorizationService( + SecurityProperties securityProperties, + ObjectProvider authorizationProvider) { + this.securityProperties = securityProperties; + this.authorizationProvider = authorizationProvider.getIfAvailable(); + } + + /** + * Determine whether a user has access to a project. + * + * @param securityContext Spring Security Context used to identify a user or service. + * @param project Name of the project for which membership should be tested. + */ + public void authorizeRequest(SecurityContext securityContext, String project) { + Authentication authentication = securityContext.getAuthentication(); + if (!this.securityProperties.getAuthorization().isEnabled()) { + return; + } + + AuthorizationResult result = + this.authorizationProvider.checkAccessToProject(project, authentication); + if (!result.isAllowed()) { + throw new AccessDeniedException(result.getFailureReason().orElse("Access Denied")); + } + } +} diff --git a/auth/src/main/java/feast/auth/utils/AuthUtils.java b/auth/src/main/java/feast/auth/utils/AuthUtils.java new file mode 100644 index 00000000000..d211165c86e --- /dev/null +++ b/auth/src/main/java/feast/auth/utils/AuthUtils.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.auth.utils; + +import java.util.Map; +import org.hibernate.validator.internal.constraintvalidators.bv.EmailValidator; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.jwt.Jwt; + +public class AuthUtils { + + // Suppresses default constructor, ensuring non-instantiability. + private AuthUtils() {} + + /** + * Get user email from their authentication object. + * + * @param authentication Spring Security Authentication object, used to extract user details + * @param subjectClaim Indicates the claim where the subject can be found + * @return String user email + */ + public static String getSubjectFromAuth(Authentication authentication, String subjectClaim) { + Jwt principle = ((Jwt) authentication.getPrincipal()); + Map claims = principle.getClaims(); + String subjectValue = (String) claims.getOrDefault(subjectClaim, ""); + + if (subjectValue.isEmpty()) { + throw new IllegalStateException( + String.format("JWT does not have a valid claim %s.", subjectClaim)); + } + + if (subjectClaim.equals("email")) { + boolean validEmail = (new EmailValidator()).isValid(subjectValue, null); + if (!validEmail) { + throw new IllegalStateException("JWT contains an invalid email address"); + } + } + return subjectValue; + } +} diff --git a/auth/src/test/java/feast/auth/authorization/HttpAuthorizationProviderCachingTest.java b/auth/src/test/java/feast/auth/authorization/HttpAuthorizationProviderCachingTest.java new file mode 100644 index 00000000000..7d683470264 --- /dev/null +++ b/auth/src/test/java/feast/auth/authorization/HttpAuthorizationProviderCachingTest.java @@ -0,0 +1,117 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.auth.authorization; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import feast.auth.config.CacheConfiguration; +import feast.auth.config.SecurityProperties; +import feast.auth.config.SecurityProperties.AuthenticationProperties; +import feast.auth.config.SecurityProperties.AuthorizationProperties; +import feast.auth.generated.client.api.DefaultApi; +import feast.auth.generated.client.model.AuthorizationResult; +import feast.auth.generated.client.model.CheckAccessRequest; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; +import org.mockito.internal.util.reflection.FieldSetter; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringRunner; + +@RunWith(SpringRunner.class) +@ContextConfiguration( + classes = {CacheConfiguration.class, HttpAuthorizationProviderCachingTest.Config.class}) +public class HttpAuthorizationProviderCachingTest { + + // static since field needs to updated in provider() bean + private static DefaultApi api = Mockito.mock(DefaultApi.class); + + @Autowired AuthorizationProvider provider; + + @Configuration + static class Config { + @Bean + SecurityProperties securityProps() { + // setting TTL static variable in SecurityProperties bean, since CacheConfiguration bean is + // dependent on SecurityProperties. + CacheConfiguration.TTL = 1; + AuthenticationProperties authentication = Mockito.mock(AuthenticationProperties.class); + AuthorizationProperties authorization = new AuthorizationProperties(); + authorization.setEnabled(true); + authorization.setProvider("http"); + Map options = new HashMap<>(); + options.put("authorizationUrl", "localhost"); + options.put("subjectClaim", "email"); + authorization.setOptions(options); + SecurityProperties sp = new SecurityProperties(); + sp.setAuthentication(authentication); + sp.setAuthorization(authorization); + return sp; + } + + @Bean + AuthorizationProvider provider() throws NoSuchFieldException, SecurityException { + Map options = new HashMap<>(); + options.put("authorizationUrl", "localhost"); + options.put("subjectClaim", "email"); + HttpAuthorizationProvider provider = new HttpAuthorizationProvider(options); + FieldSetter.setField(provider, provider.getClass().getDeclaredField("defaultApiClient"), api); + return provider; + } + } + + @Test + public void testCheckAccessToProjectShouldReadFromCacheWhenAvailable() throws Exception { + Authentication auth = Mockito.mock(Authentication.class); + Jwt jwt = Mockito.mock(Jwt.class); + Map claims = new HashMap<>(); + claims.put("email", "test@test.com"); + doReturn(jwt).when(auth).getCredentials(); + doReturn(jwt).when(auth).getPrincipal(); + doReturn(claims).when(jwt).getClaims(); + doReturn("test_token").when(jwt).getTokenValue(); + AuthorizationResult authResult = new AuthorizationResult(); + authResult.setAllowed(true); + doReturn(authResult).when(api).checkAccessPost(any(CheckAccessRequest.class)); + + // Should save the result in cache + provider.checkAccessToProject("test", auth); + // Should read from cache + provider.checkAccessToProject("test", auth); + verify(api, times(1)).checkAccessPost(any(CheckAccessRequest.class)); + + // cache ttl is set to 1 second for testing. + Thread.sleep(1100); + + // Should make an invocation to external service + provider.checkAccessToProject("test", auth); + verify(api, times(2)).checkAccessPost(any(CheckAccessRequest.class)); + // Should read from cache + provider.checkAccessToProject("test", auth); + verify(api, times(2)).checkAccessPost(any(CheckAccessRequest.class)); + } +} diff --git a/common/pom.xml b/common/pom.xml index db681090fe5..a8d652c0c6c 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -44,11 +44,15 @@ - + dev.feast datatypes-java ${project.version} compile + + + com.google.protobuf + protobuf-java-util @@ -58,13 +62,54 @@ javax.validation validation-api - 2.0.0.Final + + com.google.auto.value + auto-value-annotations + + + com.google.auto.value + auto-value + + + com.google.code.gson + gson + + + net.devh + grpc-server-spring-boot-starter + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.security + spring-security-core + + + org.springframework.boot + spring-boot-starter-data-jpa + + + + + org.slf4j + slf4j-api + + + junit junit - 4.12 + test + + + org.hamcrest + hamcrest-library test - \ No newline at end of file + diff --git a/common/src/main/java/feast/common/interceptors/GrpcMessageInterceptor.java b/common/src/main/java/feast/common/interceptors/GrpcMessageInterceptor.java new file mode 100644 index 00000000000..53dec6a0294 --- /dev/null +++ b/common/src/main/java/feast/common/interceptors/GrpcMessageInterceptor.java @@ -0,0 +1,92 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.common.interceptors; + +import com.google.protobuf.Empty; +import com.google.protobuf.Message; +import feast.common.logging.AuditLogger; +import feast.common.logging.entry.MessageAuditLogEntry; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import org.slf4j.event.Level; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; + +/** + * GrpcMessageInterceptor intercepts a GRPC calls to log handling of GRPC messages to the Audit Log. + * Intercepts the incoming and outgoing messages logs them to the audit log, together with method + * name and assumed authenticated identity (if authentication is enabled). NOTE: + * GrpcMessageInterceptor assumes that all service calls are unary (ie single request/response). + */ +public class GrpcMessageInterceptor implements ServerInterceptor { + @Override + public Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + MessageAuditLogEntry.Builder entryBuilder = MessageAuditLogEntry.newBuilder(); + // default response message to empty proto in log entry. + entryBuilder.setResponse(Empty.newBuilder().build()); + + // Unpack service & method name from call + // full method name is in format ./ + String fullMethodName = call.getMethodDescriptor().getFullMethodName(); + entryBuilder.setService( + fullMethodName.substring(fullMethodName.lastIndexOf(".") + 1, fullMethodName.indexOf("/"))); + entryBuilder.setMethod(fullMethodName.substring(fullMethodName.indexOf("/") + 1)); + + // Attempt Extract current authenticated identity. + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + String identity = (authentication == null) ? "" : authentication.getName(); + entryBuilder.setIdentity(identity); + + // Register forwarding call to intercept outgoing response and log to audit log + call = + new SimpleForwardingServerCall(call) { + @Override + public void sendMessage(RespT message) { + // 2. Track the response & Log entry to audit logger + super.sendMessage(message); + entryBuilder.setResponse((Message) message); + } + + @Override + public void close(Status status, Metadata trailers) { + super.close(status, trailers); + // 3. Log the message log entry to the audit log + Level logLevel = (status.isOk()) ? Level.INFO : Level.ERROR; + entryBuilder.setStatusCode(status.getCode()); + AuditLogger.logMessage(logLevel, entryBuilder); + } + }; + + ServerCall.Listener listener = next.startCall(call, headers); + return new SimpleForwardingServerCallListener(listener) { + @Override + // Register listener to intercept incoming request messages and log to audit log + public void onMessage(ReqT message) { + super.onMessage(message); + // 1. Track the request. + entryBuilder.setRequest((Message) message); + } + }; + } +} diff --git a/common/src/main/java/feast/common/logging/AuditLogger.java b/common/src/main/java/feast/common/logging/AuditLogger.java new file mode 100644 index 00000000000..4e779a75aaf --- /dev/null +++ b/common/src/main/java/feast/common/logging/AuditLogger.java @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.common.logging; + +import feast.common.logging.config.LoggingProperties; +import feast.common.logging.config.LoggingProperties.AuditLogProperties; +import feast.common.logging.entry.ActionAuditLogEntry; +import feast.common.logging.entry.AuditLogEntry; +import feast.common.logging.entry.AuditLogEntryKind; +import feast.common.logging.entry.LogResource; +import feast.common.logging.entry.LogResource.ResourceType; +import feast.common.logging.entry.MessageAuditLogEntry; +import feast.common.logging.entry.TransitionAuditLogEntry; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.Marker; +import org.slf4j.MarkerFactory; +import org.slf4j.event.Level; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.info.BuildProperties; +import org.springframework.stereotype.Component; + +@Slf4j +@Component +public class AuditLogger { + private static final Marker AUDIT_MARKER = MarkerFactory.getMarker("AUDIT_MARK"); + private static AuditLogProperties properties; + private static BuildProperties buildProperties; + + @Autowired + public AuditLogger(LoggingProperties loggingProperties, BuildProperties buildProperties) { + // Spring runs this constructor when creating the AuditLogger bean, + // which allows us to populate the AuditLogger class with dependencies. + // This allows us to use the dependencies in the AuditLogger's static methods + AuditLogger.properties = loggingProperties.getAudit(); + AuditLogger.buildProperties = buildProperties; + } + + /** + * Log the handling of a Protobuf message by a service call. + * + * @param entryBuilder with all fields set except instance. + */ + public static void logMessage(Level level, MessageAuditLogEntry.Builder entryBuilder) { + log( + level, + entryBuilder + .setComponent(buildProperties.getArtifact()) + .setVersion(buildProperties.getVersion()) + .build()); + } + + /** + * Log an action being taken on a specific resource + * + * @param level describing the severity of the log. + * @param action name of the action being taken on specific resource. + * @param resourceType the type of resource being logged. + * @param resourceId resource specific identifier identifing the instance of the resource. + */ + public static void logAction( + Level level, String action, ResourceType resourceType, String resourceId) { + log( + level, + ActionAuditLogEntry.of( + buildProperties.getArtifact(), + buildProperties.getArtifact(), + LogResource.of(resourceType, resourceId), + action)); + } + + /** + * Log a transition in state/status in a specific resource. + * + * @param level describing the severity of the log. + * @param status name of end status which the resource transition to. + * @param resourceType the type of resource being logged. + * @param resourceId resource specific identifier identifing the instance of the resource. + */ + public static void logTransition( + Level level, String status, ResourceType resourceType, String resourceId) { + log( + level, + TransitionAuditLogEntry.of( + buildProperties.getArtifact(), + buildProperties.getArtifact(), + LogResource.of(resourceType, resourceId), + status)); + } + + /** + * Log given {@link AuditLogEntry} at the given logging {@link Level} to the Audit log. + * + * @param level describing the severity of the log. + * @param entry the {@link AuditLogEntry} to push to the audit log. + */ + private static void log(Level level, AuditLogEntry entry) { + // Check if audit logging is of this specific log entry enabled. + if (!properties.isEnabled()) { + return; + } + if (entry.getKind().equals(AuditLogEntryKind.MESSAGE) + && !properties.isMessageLoggingEnabled()) { + return; + } + + // Log event to audit log through enabled formats + String entryJSON = entry.toJSON(); + switch (level) { + case TRACE: + log.trace(AUDIT_MARKER, entryJSON); + break; + case DEBUG: + log.debug(AUDIT_MARKER, entryJSON); + break; + case INFO: + log.info(AUDIT_MARKER, entryJSON); + break; + case WARN: + log.warn(AUDIT_MARKER, entryJSON); + break; + case ERROR: + log.error(AUDIT_MARKER, entryJSON); + break; + } + } +} diff --git a/core/src/main/java/feast/core/job/TerminateJobTask.java b/common/src/main/java/feast/common/logging/config/LoggingProperties.java similarity index 53% rename from core/src/main/java/feast/core/job/TerminateJobTask.java rename to common/src/main/java/feast/common/logging/config/LoggingProperties.java index c408578a3bd..54932c9ca8f 100644 --- a/core/src/main/java/feast/core/job/TerminateJobTask.java +++ b/common/src/main/java/feast/common/logging/config/LoggingProperties.java @@ -1,6 +1,6 @@ /* * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2020 The Feast Authors + * Copyright 2018-2019 The Feast Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,30 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.core.job; +package feast.common.logging.config; -import feast.core.log.Action; -import feast.core.model.Job; -import lombok.Builder; +import javax.validation.constraints.NotNull; import lombok.Getter; import lombok.Setter; -/** Task to terminate given {@link Job} by using {@link JobManager} */ @Getter @Setter -@Builder(setterPrefix = "set") -public class TerminateJobTask implements JobTask { - private Job job; - private JobManager jobManager; +public class LoggingProperties { + @NotNull private AuditLogProperties audit; - @Override - public Job call() { - JobTask.logAudit( - Action.ABORT, - job, - "Aborting job %s for runner %s", - job.getId(), - jobManager.getRunnerType().toString()); - return jobManager.abortJob(job); + @Getter + @Setter + public static class AuditLogProperties { + // Whether to enable/disable audit logging entirely. + private boolean enabled; + + // Whether to enable/disable message level (ie request/response) audit logging. + private boolean messageLoggingEnabled; } } diff --git a/common/src/main/java/feast/common/logging/entry/ActionAuditLogEntry.java b/common/src/main/java/feast/common/logging/entry/ActionAuditLogEntry.java new file mode 100644 index 00000000000..cec85b736a6 --- /dev/null +++ b/common/src/main/java/feast/common/logging/entry/ActionAuditLogEntry.java @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.common.logging.entry; + +import com.google.auto.value.AutoValue; + +/** ActionAuditLogEntry records an action being taken on a specific resource */ +@AutoValue +public abstract class ActionAuditLogEntry extends AuditLogEntry { + /** The name of the action taken on the resource. */ + public abstract String getAction(); + + /** The target resource of which the action was taken on. */ + public abstract LogResource getResource(); + + /** + * Create an {@link AuditLogEntry} that records an action being taken on a specific resource. + * + * @param component The name of th Feast component producing this {@link AuditLogEntry}. + * @param version The version of Feast producing this {@link AuditLogEntry}. + * @param resource The target resource of which the action was taken on. + * @param action The name of the action being taken on the given resource. + */ + public static ActionAuditLogEntry of( + String component, String version, LogResource resource, String action) { + return new AutoValue_ActionAuditLogEntry( + component, version, AuditLogEntryKind.ACTION, action, resource); + } +} diff --git a/common/src/main/java/feast/common/logging/entry/AuditLogEntry.java b/common/src/main/java/feast/common/logging/entry/AuditLogEntry.java new file mode 100644 index 00000000000..9aa8fcb8c5c --- /dev/null +++ b/common/src/main/java/feast/common/logging/entry/AuditLogEntry.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.common.logging.entry; + +import com.google.gson.Gson; + +/** + * AuditLogEntry represents a single audit Log Entry. Audit log entry can converted into string with + * {{@link #toString()} for human readable representation. Or structured JSON with {{@link + * #toJSON()} for a machine parsable representation. + */ +public abstract class AuditLogEntry { + /** Declare Log Type to allow external Logging systems to filter out {@link AuditLogEntry} */ + public final String logType = "FeastAuditLogEntry"; + + public final String application = "Feast"; + + /** The name of the Feast component producing this {@link AuditLogEntry} */ + public abstract String getComponent(); + + /** The version of Feast producing this {@link AuditLogEntry} */ + public abstract String getVersion(); + + public abstract AuditLogEntryKind getKind(); + + /** Return a structured JSON representation of this {@link AuditLogEntry} */ + public String toJSON() { + Gson gson = new Gson(); + return gson.toJson(this); + } +} diff --git a/core/src/main/java/feast/core/log/Resource.java b/common/src/main/java/feast/common/logging/entry/AuditLogEntryKind.java similarity index 78% rename from core/src/main/java/feast/core/log/Resource.java rename to common/src/main/java/feast/common/logging/entry/AuditLogEntryKind.java index d8e484b3885..d673f6bdb30 100644 --- a/core/src/main/java/feast/core/log/Resource.java +++ b/common/src/main/java/feast/common/logging/entry/AuditLogEntryKind.java @@ -14,13 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.core.log; +package feast.common.logging.entry; -/** Resources interacted with, for audit logging purposes */ -public enum Resource { - FEATURE, - FEATURE_GROUP, - ENTITY, - STORAGE, - JOB +/** AuditLogEntryKind lists the various kinds of {@link AuditLogEntry} */ +public enum AuditLogEntryKind { + MESSAGE, + ACTION, + TRANSITION, } diff --git a/core/src/main/java/feast/core/job/JobTask.java b/common/src/main/java/feast/common/logging/entry/LogResource.java similarity index 53% rename from core/src/main/java/feast/core/job/JobTask.java rename to common/src/main/java/feast/common/logging/entry/LogResource.java index c7809c4ab82..02e7589f976 100644 --- a/core/src/main/java/feast/core/job/JobTask.java +++ b/common/src/main/java/feast/common/logging/entry/LogResource.java @@ -1,6 +1,6 @@ /* * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2020 The Feast Authors + * Copyright 2018-2019 The Feast Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,19 +14,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.core.job; +package feast.common.logging.entry; -import feast.core.log.Action; -import feast.core.log.AuditLogger; -import feast.core.log.Resource; -import feast.core.model.Job; -import java.util.concurrent.Callable; +import com.google.auto.value.AutoValue; -public interface JobTask extends Callable { - static void logAudit(Action action, Job job, String detail, Object... args) { - AuditLogger.log(Resource.JOB, job.getId(), action, detail, args); +@AutoValue +/** + * LogResource is used in {@link AuditLogEntry} to reference a specific resource as the subject of + * the log + */ +public abstract class LogResource { + public enum ResourceType { + JOB, + FEATURE_SET, } - @Override - Job call() throws RuntimeException; + public abstract ResourceType getType(); + + public abstract String getId(); + + public static LogResource of(ResourceType type, String id) { + return new AutoValue_LogResource(type, id); + } } diff --git a/common/src/main/java/feast/common/logging/entry/MessageAuditLogEntry.java b/common/src/main/java/feast/common/logging/entry/MessageAuditLogEntry.java new file mode 100644 index 00000000000..745cc1283ae --- /dev/null +++ b/common/src/main/java/feast/common/logging/entry/MessageAuditLogEntry.java @@ -0,0 +1,120 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.common.logging.entry; + +import com.google.auto.value.AutoValue; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; +import com.google.protobuf.Empty; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.util.JsonFormat; +import io.grpc.Status.Code; +import java.lang.reflect.Type; +import java.util.UUID; + +/** MessageAuditLogEntry records the handling of a Protobuf message by a service call. */ +@AutoValue +public abstract class MessageAuditLogEntry extends AuditLogEntry { + /** Id used to identify the service call that the log entry is recording */ + public abstract UUID getId(); + + /** The name of the service that was used to handle the service call. */ + public abstract String getService(); + + /** The name of the method that was used to handle the service call. */ + public abstract String getMethod(); + + /** The request Protobuf {@link Message} that was passed to the Service in the service call. */ + public abstract Message getRequest(); + + /** + * The response Protobuf {@link Message} that was passed to the Service in the service call. May + * be an {@link Empty} protobuf no request could be collected due to an error. + */ + public abstract Message getResponse(); + + /** + * The authenticated identity that was assumed during the handling of the service call. For + * example, the user id or email that identifies the user making the call. Empty if the service + * call is not authenticated. + */ + public abstract String getIdentity(); + + /** The result status code of the service call. */ + public abstract Code getStatusCode(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setId(UUID id); + + public abstract Builder setComponent(String component); + + public abstract Builder setVersion(String component); + + public abstract Builder setKind(AuditLogEntryKind kind); + + public abstract Builder setService(String name); + + public abstract Builder setMethod(String name); + + public abstract Builder setRequest(Message request); + + public abstract Builder setResponse(Message response); + + public abstract Builder setIdentity(String identity); + + public abstract Builder setStatusCode(Code statusCode); + + public abstract MessageAuditLogEntry build(); + } + + public static MessageAuditLogEntry.Builder newBuilder() { + return new AutoValue_MessageAuditLogEntry.Builder() + .setKind(AuditLogEntryKind.MESSAGE) + .setId(UUID.randomUUID()); + } + + @Override + public String toJSON() { + // GSON requires custom typeadapter (serializer) to convert Protobuf messages to JSON properly + Gson gson = + new GsonBuilder() + .registerTypeAdapter( + Message.class, + new JsonSerializer() { + @Override + public JsonElement serialize( + Message message, Type type, JsonSerializationContext context) { + try { + String messageJSON = JsonFormat.printer().print(message); + return new JsonParser().parse(messageJSON); + } catch (InvalidProtocolBufferException e) { + + throw new RuntimeException( + "Unexpected exception converting Protobuf to JSON", e); + } + } + }) + .create(); + return gson.toJson(this); + } +} diff --git a/common/src/main/java/feast/common/logging/entry/TransitionAuditLogEntry.java b/common/src/main/java/feast/common/logging/entry/TransitionAuditLogEntry.java new file mode 100644 index 00000000000..0f139b7bdbd --- /dev/null +++ b/common/src/main/java/feast/common/logging/entry/TransitionAuditLogEntry.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.common.logging.entry; + +import com.google.auto.value.AutoValue; + +/** TransitionAuditLogEntry records a transition in state/status in a specific resource. */ +@AutoValue +public abstract class TransitionAuditLogEntry extends AuditLogEntry { + /** The resource which the state/status transition occured. */ + public abstract LogResource getResource(); + + /** The end status with the resource transition to. */ + public abstract String getStatus(); + + /** + * Construct a new {@link AuditLogEntry} to record a transition in state/status in a specific + * resource. + * + * @param component The name of th Feast component producing this {@link AuditLogEntry}. + * @param version The version of Feast producing this {@link AuditLogEntry}. + * @param resource the resource which the transtion occured + * @param status the end status which the resource transitioned to. + */ + public static TransitionAuditLogEntry of( + String component, String version, LogResource resource, String status) { + return new AutoValue_TransitionAuditLogEntry( + component, version, AuditLogEntryKind.TRANSITION, resource, status); + } +} diff --git a/common/src/main/resources/log4j2.xml b/common/src/main/resources/log4j2.xml new file mode 100644 index 00000000000..c75c2db13cc --- /dev/null +++ b/common/src/main/resources/log4j2.xml @@ -0,0 +1,48 @@ + + + + + + + %d{yyyy-MM-dd HH:mm:ss.SSS} %5p ${hostName} --- [%15.15t] %-40.40c{1.} : %m%n%ex + + + {"time":"%d{yyyy-MM-dd'T'HH:mm:ssXXX}","hostname":"${hostName}","severity":"%p","message":%m}%n%ex + + + + + + + + + + + + + + + + + + + + + + + diff --git a/common/src/test/java/feast/common/logging/entry/AuditLogEntryTest.java b/common/src/test/java/feast/common/logging/entry/AuditLogEntryTest.java new file mode 100644 index 00000000000..a332e0be799 --- /dev/null +++ b/common/src/test/java/feast/common/logging/entry/AuditLogEntryTest.java @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.common.logging.entry; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import feast.common.logging.entry.LogResource.ResourceType; +import feast.proto.serving.ServingAPIProto.FeatureReference; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesResponse; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesResponse.FieldValues; +import feast.proto.types.ValueProto.Value; +import io.grpc.Status; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; + +public class AuditLogEntryTest { + public List getTestAuditLogs() { + GetOnlineFeaturesRequest requestSpec = + GetOnlineFeaturesRequest.newBuilder() + .setOmitEntitiesInResponse(false) + .addAllFeatures( + Arrays.asList( + FeatureReference.newBuilder().setName("feature1").build(), + FeatureReference.newBuilder().setName("feature2").build())) + .build(); + + GetOnlineFeaturesResponse responseSpec = + GetOnlineFeaturesResponse.newBuilder() + .addAllFieldValues( + Arrays.asList( + FieldValues.newBuilder() + .putFields("feature", Value.newBuilder().setInt32Val(32).build()) + .build(), + FieldValues.newBuilder() + .putFields("feature2", Value.newBuilder().setInt32Val(64).build()) + .build())) + .build(); + + return Arrays.asList( + MessageAuditLogEntry.newBuilder() + .setComponent("feast-serving") + .setVersion("0.6") + .setService("ServingService") + .setMethod("getOnlineFeatures") + .setRequest(requestSpec) + .setResponse(responseSpec) + .setStatusCode(Status.OK.getCode()) + .setIdentity("adam@no.such.email") + .build(), + ActionAuditLogEntry.of( + "core", "0.6", LogResource.of(ResourceType.JOB, "kafka-to-redis"), "CREATE"), + TransitionAuditLogEntry.of( + "core", + "0.6", + LogResource.of(ResourceType.FEATURE_SET, "project/feature_set"), + "READY")); + } + + @Test + public void shouldReturnJSONRepresentationOfAuditLog() { + for (AuditLogEntry auditLog : getTestAuditLogs()) { + // Check that auditLog's toJSON() returns valid JSON + String logJSON = auditLog.toJSON(); + System.out.println(logJSON); + JsonParser parser = new JsonParser(); + + // check basic fields are present in JSON representation. + JsonObject logObject = parser.parse(logJSON).getAsJsonObject(); + assertThat(logObject.getAsJsonPrimitive("logType").getAsString(), equalTo(auditLog.logType)); + assertThat( + logObject.getAsJsonPrimitive("kind").getAsString(), equalTo(auditLog.getKind().name())); + } + } +} diff --git a/core/pom.xml b/core/pom.xml index 4a0d3791c92..a7ec3374737 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -75,6 +75,13 @@ dev.feast feast-ingestion ${project.version} + + + + org.slf4j + slf4j-simple + + dev.feast @@ -100,7 +107,7 @@ javax.inject 1 - + org.springframework.boot spring-boot-starter-web @@ -117,37 +124,42 @@ org.springframework.security spring-security-core - 5.3.0.RELEASE + ${spring.security.version} org.springframework.security spring-security-config - 5.3.0.RELEASE + ${spring.security.version} org.springframework.security.oauth spring-security-oauth2 - 2.4.0.RELEASE + ${spring.security.oauth2.version} org.springframework.security spring-security-oauth2-client - 5.3.0.RELEASE + ${spring.security.version} org.springframework.security spring-security-web - 5.3.0.RELEASE + ${spring.security.version} + + + org.springframework.security + spring-security-oauth2-resource-server + ${spring.security.version} org.springframework.security spring-security-oauth2-jose - 5.3.0.RELEASE - + ${spring.security.version} + net.devh grpc-server-spring-boot-starter - 2.4.0.RELEASE + ${grpc.spring.boot.starter.version} com.nimbusds @@ -157,14 +169,14 @@ org.springframework.security spring-security-oauth2-core - 5.3.0.RELEASE + ${spring.security.version} - + org.springframework.boot spring-boot-starter-data-jpa - + org.springframework.boot spring-boot-starter-actuator @@ -175,17 +187,17 @@ org.springframework.boot spring-boot-configuration-processor - + io.grpc grpc-services - + io.grpc grpc-stub - + com.google.protobuf protobuf-java-util @@ -288,8 +300,6 @@ javax.xml.bind jaxb-api - - org.flywaydb flyway-core @@ -300,11 +310,10 @@ hibernate-validator-annotation-processor 6.1.2.Final - org.mockito mockito-core - 2.23.0 + ${mockito.version} test diff --git a/core/src/main/java/feast/core/config/CoreSecurityConfig.java b/core/src/main/java/feast/core/config/CoreSecurityConfig.java index 6689db60c1d..3e4c2baa9eb 100644 --- a/core/src/main/java/feast/core/config/CoreSecurityConfig.java +++ b/core/src/main/java/feast/core/config/CoreSecurityConfig.java @@ -28,7 +28,7 @@ @Configuration @Slf4j -@ComponentScan("feast.auth") +@ComponentScan(basePackages = {"feast.auth.config", "feast.auth.service"}) public class CoreSecurityConfig { /** diff --git a/core/src/main/java/feast/core/config/FeastProperties.java b/core/src/main/java/feast/core/config/FeastProperties.java index 799000631d5..5beb18d7377 100644 --- a/core/src/main/java/feast/core/config/FeastProperties.java +++ b/core/src/main/java/feast/core/config/FeastProperties.java @@ -17,6 +17,9 @@ package feast.core.config; import feast.auth.config.SecurityProperties; +import feast.auth.config.SecurityProperties.AuthenticationProperties; +import feast.auth.config.SecurityProperties.AuthorizationProperties; +import feast.common.logging.config.LoggingProperties; import feast.common.validators.OneOfStrings; import feast.core.config.FeastProperties.StreamProperties.FeatureStreamOptions; import java.net.InetAddress; @@ -41,11 +44,13 @@ import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.info.BuildProperties; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; @Getter @Setter @Configuration +@ComponentScan("feast.common.logging") @ConfigurationProperties(prefix = "feast", ignoreInvalidFields = true) public class FeastProperties { @@ -72,11 +77,19 @@ public FeastProperties() {} /* Feast Kafka stream properties */ private StreamProperties stream; - private SecurityProperties security; + @NotNull private SecurityProperties security; @Bean SecurityProperties securityProperties() { - return this.getSecurity(); + return getSecurity(); + } + + /* Feast Audit Logging properties */ + @NotNull private LoggingProperties logging; + + @Bean + LoggingProperties loggingProperties() { + return getLogging(); } /** Feast job properties. These properties are used for ingestion jobs. */ @@ -278,5 +291,19 @@ public void validate() { + e.getMessage()); } } + + // Validate AuthenticationProperties + Set> authenticationPropsViolations = + validator.validate(getSecurity().getAuthentication()); + if (!authenticationPropsViolations.isEmpty()) { + throw new ConstraintViolationException(authenticationPropsViolations); + } + + // Validate AuthorizationProperties + Set> authorizationPropsViolations = + validator.validate(getSecurity().getAuthorization()); + if (!authorizationPropsViolations.isEmpty()) { + throw new ConstraintViolationException(authorizationPropsViolations); + } } } diff --git a/core/src/main/java/feast/core/grpc/CoreServiceImpl.java b/core/src/main/java/feast/core/grpc/CoreServiceImpl.java index d014f582a3f..19d73f8f8f8 100644 --- a/core/src/main/java/feast/core/grpc/CoreServiceImpl.java +++ b/core/src/main/java/feast/core/grpc/CoreServiceImpl.java @@ -18,12 +18,14 @@ import com.google.api.gax.rpc.InvalidArgumentException; import com.google.protobuf.InvalidProtocolBufferException; +import feast.auth.service.AuthorizationService; +import feast.common.interceptors.GrpcMessageInterceptor; import feast.core.config.FeastProperties; import feast.core.exception.RetrievalException; import feast.core.grpc.interceptors.MonitoringInterceptor; import feast.core.model.Project; -import feast.core.service.AccessManagementService; import feast.core.service.JobService; +import feast.core.service.ProjectService; import feast.core.service.SpecService; import feast.core.service.StatsService; import feast.proto.core.CoreServiceGrpc.CoreServiceImplBase; @@ -37,31 +39,35 @@ import lombok.extern.slf4j.Slf4j; import net.devh.boot.grpc.server.service.GrpcService; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.access.AccessDeniedException; import org.springframework.security.core.context.SecurityContextHolder; /** Implementation of the feast core GRPC service. */ @Slf4j -@GrpcService(interceptors = {MonitoringInterceptor.class}) +@GrpcService(interceptors = {GrpcMessageInterceptor.class, MonitoringInterceptor.class}) public class CoreServiceImpl extends CoreServiceImplBase { private final FeastProperties feastProperties; private SpecService specService; private JobService jobService; private StatsService statsService; - private AccessManagementService accessManagementService; + private ProjectService projectService; + private final AuthorizationService authorizationService; @Autowired public CoreServiceImpl( SpecService specService, - AccessManagementService accessManagementService, + ProjectService projectService, StatsService statsService, JobService jobService, - FeastProperties feastProperties) { + FeastProperties feastProperties, + AuthorizationService authorizationService) { this.specService = specService; - this.accessManagementService = accessManagementService; + this.projectService = projectService; this.jobService = jobService; this.feastProperties = feastProperties; this.statsService = statsService; + this.authorizationService = authorizationService; } @Override @@ -178,10 +184,11 @@ public void listStores( public void applyFeatureSet( ApplyFeatureSetRequest request, StreamObserver responseObserver) { - accessManagementService.checkIfProjectMember( - SecurityContextHolder.getContext(), request.getFeatureSet().getSpec().getProject()); + String projectId = null; try { + projectId = request.getFeatureSet().getSpec().getProject(); + authorizationService.authorizeRequest(SecurityContextHolder.getContext(), projectId); ApplyFeatureSetResponse response = specService.applyFeatureSet(request.getFeatureSet()); responseObserver.onNext(response); responseObserver.onCompleted(); @@ -192,6 +199,13 @@ public void applyFeatureSet( e); responseObserver.onError( Status.ALREADY_EXISTS.withDescription(e.getMessage()).withCause(e).asRuntimeException()); + } catch (AccessDeniedException e) { + log.info(String.format("User prevented from accessing project: %s", projectId)); + responseObserver.onError( + Status.PERMISSION_DENIED + .withDescription(e.getMessage()) + .withCause(e) + .asRuntimeException()); } catch (Exception e) { log.error("Exception has occurred in ApplyFeatureSet method: ", e); responseObserver.onError( @@ -217,7 +231,7 @@ public void updateStore( public void createProject( CreateProjectRequest request, StreamObserver responseObserver) { try { - accessManagementService.createProject(request.getName()); + projectService.createProject(request.getName()); responseObserver.onNext(CreateProjectResponse.getDefaultInstance()); responseObserver.onCompleted(); } catch (Exception e) { @@ -230,12 +244,11 @@ public void createProject( @Override public void archiveProject( ArchiveProjectRequest request, StreamObserver responseObserver) { - - accessManagementService.checkIfProjectMember( - SecurityContextHolder.getContext(), request.getName()); - + String projectId = null; try { - accessManagementService.archiveProject(request.getName()); + projectId = request.getName(); + authorizationService.authorizeRequest(SecurityContextHolder.getContext(), projectId); + projectService.archiveProject(projectId); responseObserver.onNext(ArchiveProjectResponse.getDefaultInstance()); responseObserver.onCompleted(); } catch (IllegalArgumentException e) { @@ -249,6 +262,13 @@ public void archiveProject( log.error("Attempted to archive an unsupported project:", e); responseObserver.onError( Status.UNIMPLEMENTED.withDescription(e.getMessage()).withCause(e).asRuntimeException()); + } catch (AccessDeniedException e) { + log.info(String.format("User prevented from accessing project: %s", projectId)); + responseObserver.onError( + Status.PERMISSION_DENIED + .withDescription(e.getMessage()) + .withCause(e) + .asRuntimeException()); } catch (Exception e) { log.error("Exception has occurred in the createProject method: ", e); responseObserver.onError( @@ -260,7 +280,7 @@ public void archiveProject( public void listProjects( ListProjectsRequest request, StreamObserver responseObserver) { try { - List projects = accessManagementService.listProjects(); + List projects = projectService.listProjects(); responseObserver.onNext( ListProjectsResponse.newBuilder() .addAllProjects(projects.stream().map(Project::getName).collect(Collectors.toList())) diff --git a/core/src/main/java/feast/core/grpc/HealthServiceImpl.java b/core/src/main/java/feast/core/grpc/HealthServiceImpl.java index b83a05b7f03..0a1f10109ca 100644 --- a/core/src/main/java/feast/core/grpc/HealthServiceImpl.java +++ b/core/src/main/java/feast/core/grpc/HealthServiceImpl.java @@ -16,7 +16,7 @@ */ package feast.core.grpc; -import feast.core.service.AccessManagementService; +import feast.core.service.ProjectService; import io.grpc.Status; import io.grpc.health.v1.HealthGrpc.HealthImplBase; import io.grpc.health.v1.HealthProto.HealthCheckRequest; @@ -30,18 +30,18 @@ @Slf4j @GrpcService public class HealthServiceImpl extends HealthImplBase { - private final AccessManagementService accessManagementService; + private final ProjectService projectService; @Autowired - public HealthServiceImpl(AccessManagementService accessManagementService) { - this.accessManagementService = accessManagementService; + public HealthServiceImpl(ProjectService projectService) { + this.projectService = projectService; } @Override public void check( HealthCheckRequest request, StreamObserver responseObserver) { try { - accessManagementService.listProjects(); + projectService.listProjects(); responseObserver.onNext( HealthCheckResponse.newBuilder().setStatus(ServingStatus.SERVING).build()); responseObserver.onCompleted(); diff --git a/core/src/main/java/feast/core/job/CreateJobTask.java b/core/src/main/java/feast/core/job/CreateJobTask.java deleted file mode 100644 index a0d8d3a1d52..00000000000 --- a/core/src/main/java/feast/core/job/CreateJobTask.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2020 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.core.job; - -import feast.core.log.Action; -import feast.core.model.Job; -import feast.core.model.JobStatus; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** Task that starts recently created {@link Job} by using {@link JobManager}. */ -@Getter -@Setter -@Builder(setterPrefix = "set") -public class CreateJobTask implements JobTask { - final Logger log = LoggerFactory.getLogger(CreateJobTask.class); - - private Job job; - private JobManager jobManager; - - @Override - public Job call() { - String runnerName = jobManager.getRunnerType().toString(); - - job.setRunner(jobManager.getRunnerType()); - job.setStatus(JobStatus.PENDING); - - try { - JobTask.logAudit(Action.SUBMIT, job, "Building graph and submitting to %s", runnerName); - - job = jobManager.startJob(job); - var extId = job.getExtId(); - if (extId.isEmpty()) { - throw new RuntimeException( - String.format("Could not submit job: \n%s", "unable to retrieve job external id")); - } - - var auditMessage = "Job submitted to runner %s with ext id %s."; - JobTask.logAudit(Action.STATUS_CHANGE, job, auditMessage, runnerName, extId); - - return job; - } catch (Exception e) { - log.error(e.getMessage()); - var auditMessage = "Job failed to be submitted to runner %s. Job status changed to ERROR."; - JobTask.logAudit(Action.STATUS_CHANGE, job, auditMessage, runnerName); - - job.setStatus(JobStatus.ERROR); - return job; - } - } -} diff --git a/core/src/main/java/feast/core/job/JobManager.java b/core/src/main/java/feast/core/job/JobManager.java index 20b5a861084..90f5f873033 100644 --- a/core/src/main/java/feast/core/job/JobManager.java +++ b/core/src/main/java/feast/core/job/JobManager.java @@ -29,10 +29,11 @@ public interface JobManager { Runner getRunnerType(); /** - * Start an import job. Start should change the status of the Job from PENDING to RUNNING. + * Start an import job. The JobManager should also attach external id that is specific to + * JobManager implementation * * @param job job to start - * @return Job + * @return Running Job with extId set. */ Job startJob(Job job); @@ -45,11 +46,10 @@ public interface JobManager { Job updateJob(Job job); /** - * Abort a job given runner-specific job ID. Abort should change the status of the Job from - * RUNNING to ABORTING. + * Abort a job given runner-specific job ID. * * @param job to abort. - * @return The aborted Job + * @return The Aborting Job */ Job abortJob(Job job); diff --git a/core/src/main/java/feast/core/job/UpgradeJobTask.java b/core/src/main/java/feast/core/job/UpgradeJobTask.java deleted file mode 100644 index e7de8f5e275..00000000000 --- a/core/src/main/java/feast/core/job/UpgradeJobTask.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2020 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.core.job; - -import feast.core.log.Action; -import feast.core.model.Job; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; - -/** Task that upgrade given {@link Job} by restarting it in {@link JobManager} */ -@Getter -@Setter -@Builder(setterPrefix = "set") -public class UpgradeJobTask implements JobTask { - private JobManager jobManager; - private Job job; - - @Override - public Job call() { - JobTask.logAudit( - Action.UPDATE, - job, - "Updating job %s for runner %s", - job.getId(), - jobManager.getRunnerType().toString()); - - return jobManager.updateJob(job); - } -} diff --git a/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java b/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java index a2937ee621e..5a1cdadadc5 100644 --- a/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java +++ b/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java @@ -123,7 +123,6 @@ public Job startJob(Job job) { .collect(Collectors.toSet()), false); job.setExtId(extId); - job.setStatus(JobStatus.RUNNING); return job; } catch (RuntimeException e) { @@ -183,7 +182,6 @@ public Job abortJob(Job job) { Strings.lenientFormat("Unable to drain job with id: %s", dataflowJobId), e); } - job.setStatus(JobStatus.ABORTING); return job; } diff --git a/core/src/main/java/feast/core/job/dataflow/DataflowRunnerConfig.java b/core/src/main/java/feast/core/job/dataflow/DataflowRunnerConfig.java index 804d258f46d..a87afa1bcf6 100644 --- a/core/src/main/java/feast/core/job/dataflow/DataflowRunnerConfig.java +++ b/core/src/main/java/feast/core/job/dataflow/DataflowRunnerConfig.java @@ -32,7 +32,7 @@ public class DataflowRunnerConfig extends RunnerConfig { public DataflowRunnerConfig(DataflowRunnerConfigOptions runnerConfigOptions) { this.project = runnerConfigOptions.getProject(); this.region = runnerConfigOptions.getRegion(); - this.zone = runnerConfigOptions.getZone(); + this.workerZone = runnerConfigOptions.getWorkerZone(); this.serviceAccount = runnerConfigOptions.getServiceAccount(); this.network = runnerConfigOptions.getNetwork(); this.subnetwork = runnerConfigOptions.getSubnetwork(); @@ -44,6 +44,8 @@ public DataflowRunnerConfig(DataflowRunnerConfigOptions runnerConfigOptions) { this.deadLetterTableSpec = runnerConfigOptions.getDeadLetterTableSpec(); this.diskSizeGb = runnerConfigOptions.getDiskSizeGb(); this.labels = runnerConfigOptions.getLabelsMap(); + this.enableStreamingEngine = runnerConfigOptions.getEnableStreamingEngine(); + this.workerDiskType = runnerConfigOptions.getWorkerDiskType(); validate(); } @@ -54,7 +56,7 @@ public DataflowRunnerConfig(DataflowRunnerConfigOptions runnerConfigOptions) { @NotBlank public String region; /* GCP availability zone for operations. */ - @NotBlank public String zone; + @NotBlank public String workerZone; /* Run the job as a specific service account, instead of the default GCE robot. */ public String serviceAccount; @@ -91,6 +93,12 @@ public DataflowRunnerConfig(DataflowRunnerConfigOptions runnerConfigOptions) { public Map labels; + /* If true job will be run on StreamingEngine instead of VMs */ + public Boolean enableStreamingEngine; + + /* Type of persistent disk to be used by workers */ + public String workerDiskType; + /** Validates Dataflow runner configuration options */ public void validate() { ValidatorFactory factory = Validation.buildDefaultValidatorFactory(); diff --git a/core/src/main/java/feast/core/job/direct/DirectRunnerJobManager.java b/core/src/main/java/feast/core/job/direct/DirectRunnerJobManager.java index 8ee326affa1..4d952a1b10e 100644 --- a/core/src/main/java/feast/core/job/direct/DirectRunnerJobManager.java +++ b/core/src/main/java/feast/core/job/direct/DirectRunnerJobManager.java @@ -84,7 +84,6 @@ public Job startJob(Job job) { DirectJob directJob = new DirectJob(job.getId(), pipelineResult); jobs.add(directJob); job.setExtId(job.getId()); - job.setStatus(JobStatus.RUNNING); return job; } catch (Exception e) { log.error("Error submitting job", e); @@ -156,7 +155,6 @@ public Job abortJob(Job job) { jobs.remove(job.getExtId()); } - job.setStatus(JobStatus.ABORTING); return job; } diff --git a/core/src/main/java/feast/core/job/task/CreateJobTask.java b/core/src/main/java/feast/core/job/task/CreateJobTask.java new file mode 100644 index 00000000000..fa63cc67aa9 --- /dev/null +++ b/core/src/main/java/feast/core/job/task/CreateJobTask.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.core.job.task; + +import feast.common.logging.AuditLogger; +import feast.common.logging.entry.LogResource.ResourceType; +import feast.core.job.JobManager; +import feast.core.model.Job; +import feast.core.model.JobStatus; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.event.Level; + +/** Task that starts recently created {@link Job} by using {@link feast.core.job.JobManager}. */ +@Slf4j +public class CreateJobTask extends JobTask { + + public CreateJobTask(Job job, JobManager jobManager) { + super(job, jobManager); + } + + @Override + public Job call() { + try { + String runnerName = jobManager.getRunnerType().toString(); + changeJobStatus(JobStatus.PENDING); + + // Start job with jobManager. + job.setRunner(jobManager.getRunnerType()); + job = jobManager.startJob(job); + + log.info(String.format("Build graph and submitting to %s", runnerName)); + AuditLogger.logAction(Level.INFO, JobTasks.CREATE.name(), ResourceType.JOB, job.getId()); + + // Check for expected external job id + if (job.getExtId().isEmpty()) { + throw new RuntimeException( + String.format( + "Could not submit job %s: unable to retrieve job external id", job.getId())); + } + + log.info( + String.format("Job submitted to runner %s with ext id %s.", runnerName, job.getExtId())); + changeJobStatus(JobStatus.RUNNING); + return job; + } catch (Exception e) { + handleException(e); + return job; + } + } +} diff --git a/core/src/main/java/feast/core/job/task/JobTask.java b/core/src/main/java/feast/core/job/task/JobTask.java new file mode 100644 index 00000000000..60a51809d67 --- /dev/null +++ b/core/src/main/java/feast/core/job/task/JobTask.java @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.core.job.task; + +import feast.common.logging.AuditLogger; +import feast.common.logging.entry.LogResource.ResourceType; +import feast.core.job.JobManager; +import feast.core.model.Job; +import feast.core.model.JobStatus; +import java.util.concurrent.Callable; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.event.Level; + +@Getter +@Setter +@Slf4j +public abstract class JobTask implements Callable { + protected Job job; + protected JobManager jobManager; + + public JobTask(Job job, JobManager jobManager) { + this.job = job; + this.jobManager = jobManager; + } + + @Override + public abstract Job call() throws RuntimeException; + + /** + * Change Job Status to the given status and logs changes in Job Status to audit and normal log. + */ + protected void changeJobStatus(JobStatus newStatus) { + JobStatus currentStatus = job.getStatus(); + if (currentStatus != newStatus) { + job.setStatus(newStatus); + log.info( + String.format("Job status updated: changed from %s to %s", currentStatus, newStatus)); + + AuditLogger.logTransition(Level.INFO, newStatus.name(), ResourceType.JOB, job.getId()); + log.info("test"); + } + } + + /** + * Handle Exception when executing JobTask by transition Job to ERROR status and logging exception + */ + protected void handleException(Exception e) { + log.error("Unexpected exception performing JobTask: %s", e.getMessage()); + e.printStackTrace(); + changeJobStatus(JobStatus.ERROR); + } +} diff --git a/core/src/main/java/feast/core/log/Action.java b/core/src/main/java/feast/core/job/task/JobTasks.java similarity index 69% rename from core/src/main/java/feast/core/log/Action.java rename to core/src/main/java/feast/core/job/task/JobTasks.java index 3eb24c080a4..85b52bdfdeb 100644 --- a/core/src/main/java/feast/core/log/Action.java +++ b/core/src/main/java/feast/core/job/task/JobTasks.java @@ -1,6 +1,6 @@ /* * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors + * Copyright 2018-2020 The Feast Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,20 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.core.log; +package feast.core.job.task; -/** Actions taken for audit logging purposes */ -public enum Action { - // Job-related actions - SUBMIT, - STATUS_CHANGE, +/** Enum listing of the available Job Tasks to perform on Jobs */ +public enum JobTasks { + CREATE, + UPDATE_STATUS, + RESTART, ABORT, - - // Spec-related - UPDATE, - REGISTER, - - // Storage-related - ADD, - SCHEMA_UPDATE, } diff --git a/core/src/main/java/feast/core/job/task/RestartJobTask.java b/core/src/main/java/feast/core/job/task/RestartJobTask.java new file mode 100644 index 00000000000..990e9191ddf --- /dev/null +++ b/core/src/main/java/feast/core/job/task/RestartJobTask.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.core.job.task; + +import feast.common.logging.AuditLogger; +import feast.common.logging.entry.LogResource.ResourceType; +import feast.core.job.JobManager; +import feast.core.model.Job; +import feast.core.model.JobStatus; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.event.Level; + +/** Task that restarts given {@link Job} by restarting it in {@link JobManager} */ +@Slf4j +public class RestartJobTask extends JobTask { + public RestartJobTask(Job job, JobManager jobManager) { + super(job, jobManager); + } + + @Override + public Job call() { + try { + job = jobManager.restartJob(job); + log.info("Restart job %s for runner %s", job.getId(), job.getRunner().toString()); + AuditLogger.logAction(Level.INFO, JobTasks.RESTART.name(), ResourceType.JOB, job.getId()); + + changeJobStatus(JobStatus.RUNNING); + return job; + } catch (Exception e) { + handleException(e); + return job; + } + } +} diff --git a/core/src/main/java/feast/core/job/task/TerminateJobTask.java b/core/src/main/java/feast/core/job/task/TerminateJobTask.java new file mode 100644 index 00000000000..5099d2d2049 --- /dev/null +++ b/core/src/main/java/feast/core/job/task/TerminateJobTask.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.core.job.task; + +import feast.common.logging.AuditLogger; +import feast.common.logging.entry.LogResource.ResourceType; +import feast.core.job.JobManager; +import feast.core.model.Job; +import feast.core.model.JobStatus; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.event.Level; + +/** Task to terminate given {@link Job} by using {@link JobManager} */ +@Slf4j +public class TerminateJobTask extends JobTask { + public TerminateJobTask(Job job, JobManager jobManager) { + super(job, jobManager); + } + + @Override + public Job call() { + try { + job = jobManager.abortJob(job); + log.info( + String.format( + "Aborted job %s for runner %s", job.getId(), jobManager.getRunnerType().toString())); + AuditLogger.logAction(Level.INFO, JobTasks.ABORT.name(), ResourceType.JOB, job.getId()); + + changeJobStatus(JobStatus.ABORTING); + return job; + } catch (Exception e) { + handleException(e); + return job; + } + } +} diff --git a/core/src/main/java/feast/core/job/UpdateJobStatusTask.java b/core/src/main/java/feast/core/job/task/UpdateJobStatusTask.java similarity index 58% rename from core/src/main/java/feast/core/job/UpdateJobStatusTask.java rename to core/src/main/java/feast/core/job/task/UpdateJobStatusTask.java index 9ee4d2f1eec..c793ab6a815 100644 --- a/core/src/main/java/feast/core/job/UpdateJobStatusTask.java +++ b/core/src/main/java/feast/core/job/task/UpdateJobStatusTask.java @@ -14,37 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.core.job; +package feast.core.job.task; -import feast.core.log.Action; +import feast.core.job.JobManager; import feast.core.model.Job; import feast.core.model.JobStatus; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; /** * Task that retrieves status from {@link JobManager} on given {@link Job} and update the job * accordingly in-place */ -@Getter -@Setter -@Builder(setterPrefix = "set") -public class UpdateJobStatusTask implements JobTask { - private Job job; - private JobManager jobManager; +public class UpdateJobStatusTask extends JobTask { + public UpdateJobStatusTask(Job job, JobManager jobManager) { + super(job, jobManager); + } @Override public Job call() { - JobStatus currentStatus = job.getStatus(); - JobStatus newStatus = jobManager.getJobStatus(job); + try { + JobStatus newStatus = jobManager.getJobStatus(job); + changeJobStatus(newStatus); - if (newStatus != currentStatus) { - var auditMessage = "Job status updated: changed from %s to %s"; - JobTask.logAudit(Action.STATUS_CHANGE, job, auditMessage, currentStatus, newStatus); + return job; + } catch (Exception e) { + handleException(e); + return job; } - - job.setStatus(newStatus); - return job; } } diff --git a/core/src/main/java/feast/core/log/AuditLogger.java b/core/src/main/java/feast/core/log/AuditLogger.java deleted file mode 100644 index 5349b5548b0..00000000000 --- a/core/src/main/java/feast/core/log/AuditLogger.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.core.log; - -import com.google.common.base.Strings; -import java.util.Date; -import java.util.Map; -import java.util.TreeMap; -import lombok.extern.log4j.Log4j2; -import org.apache.logging.log4j.Level; -import org.apache.logging.log4j.message.ObjectMessage; - -@Log4j2 -public class AuditLogger { - private static final Level AUDIT_LEVEL = Level.getLevel("AUDIT"); - - /** - * Log to stdout a json formatted audit log. - * - * @param resource type of resource - * @param id id of resource, if any - * @param action action taken - * @param detail additional detail. Supports string formatting. - * @param args arguments to the detail string - */ - public static void log( - Resource resource, String id, Action action, String detail, Object... args) { - Map map = new TreeMap<>(); - map.put("timestamp", new Date().toString()); - map.put("resource", resource.toString()); - map.put("id", id); - map.put("action", action.toString()); - map.put("detail", Strings.lenientFormat(detail, args)); - ObjectMessage msg = new ObjectMessage(map); - - log.log(AUDIT_LEVEL, msg); - } -} diff --git a/core/src/main/java/feast/core/service/JobCoordinatorService.java b/core/src/main/java/feast/core/service/JobCoordinatorService.java index 59cc619fa65..7cfdd342105 100644 --- a/core/src/main/java/feast/core/service/JobCoordinatorService.java +++ b/core/src/main/java/feast/core/service/JobCoordinatorService.java @@ -26,6 +26,7 @@ import feast.core.dao.FeatureSetRepository; import feast.core.dao.JobRepository; import feast.core.job.*; +import feast.core.job.task.*; import feast.core.model.*; import feast.core.model.FeatureSet; import feast.core.model.Job; @@ -129,6 +130,7 @@ void startOrUpdateJobs(List tasks) { } } catch (ExecutionException | InterruptedException | TimeoutException e) { log.warn("Unable to start or update job: {}", e.getMessage()); + e.printStackTrace(); } completedTasks++; } @@ -162,7 +164,7 @@ List makeJobUpdateTasks(Iterable>> sourceToStor if (job.isDeployed()) { if (!job.isRunning()) { - jobTasks.add(UpdateJobStatusTask.builder().setJob(job).setJobManager(jobManager).build()); + jobTasks.add(new UpdateJobStatusTask(job, jobManager)); // Mark that it is not safe to stop jobs without disrupting ingestion isSafeToStopJobs = false; @@ -180,9 +182,9 @@ List makeJobUpdateTasks(Iterable>> sourceToStor isSafeToStopJobs = false; - jobTasks.add(CreateJobTask.builder().setJob(job).setJobManager(jobManager).build()); + jobTasks.add(new CreateJobTask(job, jobManager)); } else { - jobTasks.add(UpdateJobStatusTask.builder().setJob(job).setJobManager(jobManager).build()); + jobTasks.add(new UpdateJobStatusTask(job, jobManager)); } } else { job.setId(groupingStrategy.createJobId(job)); @@ -192,7 +194,7 @@ List makeJobUpdateTasks(Iterable>> sourceToStor .filter(fs -> fs.getSource().equals(source)) .collect(Collectors.toSet())); - jobTasks.add(CreateJobTask.builder().setJob(job).setJobManager(jobManager).build()); + jobTasks.add(new CreateJobTask(job, jobManager)); } // Record the job as required to safeguard it from getting stopped @@ -203,8 +205,7 @@ List makeJobUpdateTasks(Iterable>> sourceToStor getExtraJobs(activeJobs) .forEach( extraJob -> { - jobTasks.add( - TerminateJobTask.builder().setJob(extraJob).setJobManager(jobManager).build()); + jobTasks.add(new TerminateJobTask(extraJob, jobManager)); }); } diff --git a/core/src/main/java/feast/core/service/JobService.java b/core/src/main/java/feast/core/service/JobService.java index e4c2ea255ab..812f5c17200 100644 --- a/core/src/main/java/feast/core/service/JobService.java +++ b/core/src/main/java/feast/core/service/JobService.java @@ -20,12 +20,10 @@ import feast.core.dao.JobRepository; import feast.core.job.JobManager; import feast.core.job.Runner; -import feast.core.log.Action; -import feast.core.log.AuditLogger; -import feast.core.log.Resource; +import feast.core.job.task.RestartJobTask; +import feast.core.job.task.TerminateJobTask; import feast.core.model.Job; import feast.core.model.JobStatus; -import feast.proto.core.CoreServiceProto.ListFeatureSetsRequest; import feast.proto.core.CoreServiceProto.ListIngestionJobsRequest; import feast.proto.core.CoreServiceProto.ListIngestionJobsResponse; import feast.proto.core.CoreServiceProto.RestartIngestionJobRequest; @@ -54,14 +52,11 @@ @Service public class JobService { private final JobRepository jobRepository; - private final SpecService specService; private final Map jobManagers; @Autowired - public JobService( - JobRepository jobRepository, SpecService specService, List jobManagerList) { + public JobService(JobRepository jobRepository, List jobManagerList) { this.jobRepository = jobRepository; - this.specService = specService; this.jobManagers = new HashMap<>(); for (JobManager manager : jobManagerList) { @@ -77,12 +72,10 @@ public JobService( * * @param request list ingestion jobs request specifying which jobs to include * @throws IllegalArgumentException when given filter in a unsupported configuration - * @throws InvalidProtocolBufferException on error when constructing response protobuf * @return list ingestion jobs response */ @Transactional(readOnly = true) - public ListIngestionJobsResponse listJobs(ListIngestionJobsRequest request) - throws InvalidProtocolBufferException { + public ListIngestionJobsResponse listJobs(ListIngestionJobsRequest request) { Set matchingJobIds = new HashSet<>(); // check that filter specified and not empty @@ -142,7 +135,11 @@ public ListIngestionJobsResponse listJobs(ListIngestionJobsRequest request) if (job.getStatus() == JobStatus.ERROR) { continue; } - ingestJobs.add(job.toProto()); + try { + ingestJobs.add(job.toProto()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Unexpected failure to construct Protobuf", e); + } } // pack jobs into response @@ -176,14 +173,9 @@ public RestartIngestionJobResponse restartJob(RestartIngestionJobRequest request "Restarting a job with a transitional, terminal or unknown status is unsupported"); } - // restart job with job manager - JobManager jobManager = this.jobManagers.get(job.getRunner()); - job = jobManager.restartJob(job); - log.info( - String.format( - "Restarted job (id: %s, extId: %s runner: %s)", - job.getId(), job.getExtId(), job.getRunner())); - this.logStatusChange(job, status, job.getStatus()); + // restart job by running job task + new RestartJobTask(job, jobManagers.get(job.getRunner())).call(); + // update job model in job repository this.jobRepository.saveAndFlush(job); @@ -219,16 +211,10 @@ public StopIngestionJobResponse stopJob(StopIngestionJobRequest request) { throw new UnsupportedOperationException( "Stopping a job with a transitional or unknown status is unsupported"); } - this.logStatusChange(job, status, job.getStatus()); - // stop job with job manager - JobManager jobManager = this.jobManagers.get(job.getRunner()); - job = jobManager.abortJob(job); - log.info( - String.format( - "Aborted job (id: %s, extId: %s runner: %s)", - job.getId(), job.getExtId(), job.getRunner())); - this.logStatusChange(job, status, job.getStatus()); + // stop job with job task + new TerminateJobTask(job, jobManagers.get(job.getRunner())).call(); + // update job model in job repository this.jobRepository.saveAndFlush(job); @@ -248,31 +234,4 @@ private Set mergeResults(Set results, Collection newResults) { } return results; } - - /** converts feature set reference to a list feature set filter */ - private ListFeatureSetsRequest.Filter toListFeatureSetFilter(FeatureSetReference fsReference) { - // match featuresets using contents of featureset reference - String fsName = fsReference.getName(); - String fsProject = fsReference.getProject(); - - // construct list featureset request filter using feature set reference - // for proto3, default value for missing values: - // - numeric values (ie int) is zero - // - strings is empty string - return ListFeatureSetsRequest.Filter.newBuilder() - .setFeatureSetName(fsName.isEmpty() ? "*" : fsName) - .setProject(fsProject.isEmpty() ? "*" : fsProject) - .build(); - } - - /** log job status using job manager */ - private void logStatusChange(Job job, JobStatus oldStatus, JobStatus newStatus) { - AuditLogger.log( - Resource.JOB, - job.getId(), - Action.STATUS_CHANGE, - "Job status transition: changed from %s to %s", - oldStatus, - newStatus); - } } diff --git a/core/src/main/java/feast/core/service/AccessManagementService.java b/core/src/main/java/feast/core/service/ProjectService.java similarity index 52% rename from core/src/main/java/feast/core/service/AccessManagementService.java rename to core/src/main/java/feast/core/service/ProjectService.java index bd5eeed906b..308c79bccf6 100644 --- a/core/src/main/java/feast/core/service/AccessManagementService.java +++ b/core/src/main/java/feast/core/service/ProjectService.java @@ -16,53 +16,24 @@ */ package feast.core.service; -import feast.auth.authorization.AuthorizationProvider; -import feast.auth.authorization.AuthorizationResult; -import feast.auth.config.SecurityProperties; -import feast.core.config.FeastProperties; import feast.core.dao.ProjectRepository; import feast.core.model.Project; import java.util.List; import java.util.Optional; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.ObjectProvider; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.security.access.AccessDeniedException; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContext; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @Slf4j @Service -public class AccessManagementService { +public class ProjectService { - private SecurityProperties securityProperties; - - private AuthorizationProvider authorizationProvider; private ProjectRepository projectRepository; - public AccessManagementService( - FeastProperties feastProperties, - ProjectRepository projectRepository, - AuthorizationProvider authorizationProvider) { - this.projectRepository = projectRepository; - this.authorizationProvider = authorizationProvider; - this.securityProperties = feastProperties.getSecurity(); - } - @Autowired - public AccessManagementService( - FeastProperties feastProperties, - ProjectRepository projectRepository, - ObjectProvider authorizationProvider) { + public ProjectService(ProjectRepository projectRepository) { this.projectRepository = projectRepository; - // create default project if it does not yet exist. - if (!projectRepository.existsById(Project.DEFAULT_NAME)) { - this.createProject(Project.DEFAULT_NAME); - } - this.authorizationProvider = authorizationProvider.getIfUnique(); - this.securityProperties = feastProperties.getSecurity(); } /** @@ -107,22 +78,4 @@ public void archiveProject(String name) { public List listProjects() { return projectRepository.findAllByArchivedIsFalse(); } - - /** - * Determine whether a user belongs to a Project - * - * @param securityContext User's Spring Security Context. Used to identify user. - * @param projectId Id (name) of the project for which membership should be tested. - */ - public void checkIfProjectMember(SecurityContext securityContext, String projectId) { - Authentication authentication = securityContext.getAuthentication(); - if (!this.securityProperties.getAuthorization().isEnabled()) { - return; - } - AuthorizationResult result = - this.authorizationProvider.checkAccessToProject(projectId, authentication); - if (!result.isAllowed()) { - throw new AccessDeniedException(result.getFailureReason().orElse("AccessDenied")); - } - } } diff --git a/core/src/main/resources/application.yml b/core/src/main/resources/application.yml index fbdf6036328..69ed090a0f7 100644 --- a/core/src/main/resources/application.yml +++ b/core/src/main/resources/application.yml @@ -42,11 +42,13 @@ feast: options: project: my_gcp_project region: asia-east1 - zone: asia-east1-a + workerZone: asia-east1-a tempLocation: gs://bucket/tempLocation network: default subnetwork: regions/asia-east1/subnetworks/mysubnetwork maxNumWorkers: 1 + enableStreamingEngine: false + workerDiskType: compute.googleapis.com/projects/asia-east1-a/diskTypes/pd-ssd autoscalingAlgorithm: THROUGHPUT_BASED usePublicIps: false workerMachineType: n1-standard-1 @@ -94,6 +96,16 @@ feast: provider: http options: authorizationUrl: http://localhost:8082 + subjectClaim: email + + logging: + # Audit logging provides a machine readable structured JSON log that can give better + # insight into what is happening in Feast. + audit: + # Whether audit logging is enabled. + enabled: true + # Whether to enable message level (ie request/response) audit logging + messageLoggingEnabled: false grpc: server: diff --git a/core/src/main/resources/log4j2.xml b/core/src/main/resources/log4j2.xml index efbf7d1f624..8781d668a84 100644 --- a/core/src/main/resources/log4j2.xml +++ b/core/src/main/resources/log4j2.xml @@ -21,26 +21,28 @@ %d{yyyy-MM-dd HH:mm:ss.SSS} %5p ${hostName} --- [%15.15t] %-40.40c{1.} : %m%n%ex - ${env:LOG_TYPE:-Console} - ${env:LOG_LEVEL:-info} + + {"time":"%d{yyyy-MM-dd'T'HH:mm:ssXXX}","hostname":"${hostName}","severity":"%p","message":%m}%n%ex + - - - + - + + - - + + + - - + + + diff --git a/core/src/test/java/feast/core/grpc/CoreServiceAuthTest.java b/core/src/test/java/feast/core/grpc/CoreServiceAuthTest.java index bd59510c673..39f80429fc7 100644 --- a/core/src/test/java/feast/core/grpc/CoreServiceAuthTest.java +++ b/core/src/test/java/feast/core/grpc/CoreServiceAuthTest.java @@ -27,14 +27,15 @@ import feast.auth.authorization.AuthorizationProvider; import feast.auth.authorization.AuthorizationResult; import feast.auth.config.SecurityProperties; +import feast.auth.service.AuthorizationService; import feast.core.config.FeastProperties; import feast.core.dao.ProjectRepository; import feast.core.model.Entity; import feast.core.model.Feature; import feast.core.model.FeatureSet; import feast.core.model.Source; -import feast.core.service.AccessManagementService; import feast.core.service.JobService; +import feast.core.service.ProjectService; import feast.core.service.SpecService; import feast.core.service.StatsService; import feast.proto.core.CoreServiceProto.ApplyFeatureSetRequest; @@ -45,6 +46,7 @@ import feast.proto.core.SourceProto.KafkaSourceConfig; import feast.proto.core.SourceProto.SourceType; import feast.proto.types.ValueProto.ValueType.Enum; +import io.grpc.StatusRuntimeException; import io.grpc.internal.testing.StreamRecorder; import java.sql.Date; import java.time.Instant; @@ -53,7 +55,6 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.springframework.security.access.AccessDeniedException; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; @@ -61,7 +62,7 @@ class CoreServiceAuthTest { private CoreServiceImpl coreService; - private AccessManagementService accessManagementService; + private ProjectService projectService; @Mock private SpecService specService; @Mock private ProjectRepository projectRepository; @@ -78,11 +79,12 @@ class CoreServiceAuthTest { sp.setAuthorization(authProp); FeastProperties feastProperties = new FeastProperties(); feastProperties.setSecurity(sp); - accessManagementService = - new AccessManagementService(feastProperties, projectRepository, authProvider); + projectService = new ProjectService(projectRepository); + AuthorizationService authService = + new AuthorizationService(feastProperties.getSecurity(), authProvider); coreService = new CoreServiceImpl( - specService, accessManagementService, statsService, jobService, feastProperties); + specService, projectService, statsService, jobService, feastProperties, authService); } @Test @@ -108,7 +110,11 @@ void cantApplyFeatureSetIfNotProjectMember() throws InvalidProtocolBufferExcepti ApplyFeatureSetRequest.newBuilder().setFeatureSet(spec).build(); assertThrows( - AccessDeniedException.class, () -> coreService.applyFeatureSet(request, responseObserver)); + StatusRuntimeException.class, + () -> { + coreService.applyFeatureSet(request, responseObserver); + throw responseObserver.getError(); + }); } @Test diff --git a/core/src/test/java/feast/core/job/dataflow/DataflowJobManagerTest.java b/core/src/test/java/feast/core/job/dataflow/DataflowJobManagerTest.java index 5b5c6a6340e..3250c1d42be 100644 --- a/core/src/test/java/feast/core/job/dataflow/DataflowJobManagerTest.java +++ b/core/src/test/java/feast/core/job/dataflow/DataflowJobManagerTest.java @@ -81,7 +81,7 @@ public void setUp() { Builder optionsBuilder = DataflowRunnerConfigOptions.newBuilder(); optionsBuilder.setProject("project"); optionsBuilder.setRegion("region"); - optionsBuilder.setZone("zone"); + optionsBuilder.setWorkerZone("zone"); optionsBuilder.setTempLocation("tempLocation"); optionsBuilder.setNetwork("network"); optionsBuilder.setSubnetwork("subnetwork"); @@ -213,7 +213,6 @@ public void shouldStartJobWithCorrectPipelineOptions() throws IOException { actualPipelineOptions.getSpecsStreamingUpdateConfigJson(), equalTo(printer.print(specsStreamingUpdateConfig))); assertThat(actual.getExtId(), equalTo(expectedExtJobId)); - assertThat(actual.getStatus(), equalTo(JobStatus.RUNNING)); } @Test diff --git a/core/src/test/java/feast/core/job/dataflow/DataflowRunnerConfigTest.java b/core/src/test/java/feast/core/job/dataflow/DataflowRunnerConfigTest.java index 925e48aec11..9c6b5a085c8 100644 --- a/core/src/test/java/feast/core/job/dataflow/DataflowRunnerConfigTest.java +++ b/core/src/test/java/feast/core/job/dataflow/DataflowRunnerConfigTest.java @@ -33,7 +33,9 @@ public void shouldConvertToPipelineArgs() throws IllegalAccessException { DataflowRunnerConfigOptions.newBuilder() .setProject("my-project") .setRegion("asia-east1") - .setZone("asia-east1-a") + .setWorkerZone("asia-east1-a") + .setEnableStreamingEngine(true) + .setWorkerDiskType("pd-ssd") .setTempLocation("gs://bucket/tempLocation") .setNetwork("default") .setSubnetwork("regions/asia-east1/subnetworks/mysubnetwork") @@ -52,7 +54,7 @@ public void shouldConvertToPipelineArgs() throws IllegalAccessException { Arrays.asList( "--project=my-project", "--region=asia-east1", - "--zone=asia-east1-a", + "--workerZone=asia-east1-a", "--tempLocation=gs://bucket/tempLocation", "--network=default", "--subnetwork=regions/asia-east1/subnetworks/mysubnetwork", @@ -62,7 +64,9 @@ public void shouldConvertToPipelineArgs() throws IllegalAccessException { "--workerMachineType=n1-standard-1", "--deadLetterTableSpec=project_id:dataset_id.table_id", "--diskSizeGb=100", - "--labels={\"key\":\"value\"}") + "--labels={\"key\":\"value\"}", + "--enableStreamingEngine=true", + "--workerDiskType=pd-ssd") .toArray(String[]::new); assertThat(args.size(), equalTo(expectedArgs.length)); assertThat(args, containsInAnyOrder(expectedArgs)); @@ -74,7 +78,7 @@ public void shouldIgnoreOptionalArguments() throws IllegalAccessException { DataflowRunnerConfigOptions.newBuilder() .setProject("my-project") .setRegion("asia-east1") - .setZone("asia-east1-a") + .setWorkerZone("asia-east1-a") .setTempLocation("gs://bucket/tempLocation") .setNetwork("default") .setSubnetwork("regions/asia-east1/subnetworks/mysubnetwork") @@ -90,7 +94,7 @@ public void shouldIgnoreOptionalArguments() throws IllegalAccessException { Arrays.asList( "--project=my-project", "--region=asia-east1", - "--zone=asia-east1-a", + "--workerZone=asia-east1-a", "--tempLocation=gs://bucket/tempLocation", "--network=default", "--subnetwork=regions/asia-east1/subnetworks/mysubnetwork", @@ -98,7 +102,8 @@ public void shouldIgnoreOptionalArguments() throws IllegalAccessException { "--autoscalingAlgorithm=THROUGHPUT_BASED", "--usePublicIps=false", "--workerMachineType=n1-standard-1", - "--labels={}") + "--labels={}", + "--enableStreamingEngine=false") .toArray(String[]::new); assertThat(args.size(), equalTo(expectedArgs.length)); assertThat(args, containsInAnyOrder(expectedArgs)); diff --git a/core/src/test/java/feast/core/job/direct/DirectRunnerJobManagerTest.java b/core/src/test/java/feast/core/job/direct/DirectRunnerJobManagerTest.java index 5846cf54906..c2f0f2a4e1f 100644 --- a/core/src/test/java/feast/core/job/direct/DirectRunnerJobManagerTest.java +++ b/core/src/test/java/feast/core/job/direct/DirectRunnerJobManagerTest.java @@ -158,7 +158,6 @@ public void shouldStartDirectJobAndRegisterPipelineResult() throws IOException { verify(drJobManager, times(1)).runPipeline(pipelineOptionsCaptor.capture()); verify(directJobRegistry, times(1)).add(directJobCaptor.capture()); - assertThat(actual.getStatus(), equalTo(JobStatus.RUNNING)); ImportOptions actualPipelineOptions = pipelineOptionsCaptor.getValue(); DirectJob jobStarted = directJobCaptor.getValue(); @@ -201,6 +200,5 @@ public void shouldAbortJobThenRemoveFromRegistry() throws IOException { job = drJobManager.abortJob(job); verify(directJob, times(1)).abort(); verify(directJobRegistry, times(1)).remove("ext1"); - assertThat(job.getStatus(), equalTo(JobStatus.ABORTING)); } } diff --git a/core/src/test/java/feast/core/job/JobTasksTest.java b/core/src/test/java/feast/core/job/task/JobTasksTest.java similarity index 91% rename from core/src/test/java/feast/core/job/JobTasksTest.java rename to core/src/test/java/feast/core/job/task/JobTasksTest.java index d1e1b651c19..d463def0669 100644 --- a/core/src/test/java/feast/core/job/JobTasksTest.java +++ b/core/src/test/java/feast/core/job/task/JobTasksTest.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.core.job; +package feast.core.job.task; import static org.hamcrest.core.IsEqual.equalTo; import static org.junit.Assert.assertThat; @@ -24,6 +24,7 @@ import static org.mockito.MockitoAnnotations.initMocks; import com.google.common.collect.ImmutableSet; +import feast.core.job.*; import feast.core.model.*; import feast.core.util.TestUtil; import feast.proto.core.SourceProto; @@ -73,6 +74,8 @@ public void setUp() { .setBootstrapServers("servers:9092") .build()) .build()); + + TestUtil.setupAuditLogger(); } Job makeJob(String extId, List featureSets, JobStatus status) { @@ -90,19 +93,15 @@ Job makeJob(String extId, List featureSets, JobStatus status) { } CreateJobTask makeCreateTask(Job currentJob) { - return CreateJobTask.builder().setJob(currentJob).setJobManager(jobManager).build(); - } - - UpgradeJobTask makeUpgradeTask(Job currentJob) { - return UpgradeJobTask.builder().setJob(currentJob).setJobManager(jobManager).build(); + return new CreateJobTask(currentJob, jobManager); } UpdateJobStatusTask makeCheckStatusTask(Job currentJob) { - return UpdateJobStatusTask.builder().setJob(currentJob).setJobManager(jobManager).build(); + return new UpdateJobStatusTask(currentJob, jobManager); } TerminateJobTask makeTerminateTask(Job currentJob) { - return TerminateJobTask.builder().setJob(currentJob).setJobManager(jobManager).build(); + return new TerminateJobTask(currentJob, jobManager); } @Test diff --git a/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java b/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java index 621590bd064..54a8482daa0 100644 --- a/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java +++ b/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java @@ -41,6 +41,7 @@ import feast.core.dao.JobRepository; import feast.core.dao.SourceRepository; import feast.core.job.*; +import feast.core.job.task.*; import feast.core.model.*; import feast.core.util.TestUtil; import feast.proto.core.CoreServiceProto.ListFeatureSetsRequest.Filter; @@ -91,6 +92,7 @@ public void setUp() { JobProperties jobProperties = new JobProperties(); jobProperties.setJobUpdateTimeoutSeconds(5); feastProperties.setJobs(jobProperties); + TestUtil.setupAuditLogger(); jcsWithConsolidation = new JobCoordinatorService( @@ -214,7 +216,7 @@ private Job newJob(String id, Store store, Source source, FeatureSet... featureS Job job = Job.builder() .setId(id) - .setExtId("") + .setExtId("extId") .setRunner(Runner.DATAFLOW) .setSource(source) .setFeatureSetJobStatuses(TestUtil.makeFeatureSetJobStatus(featureSets)) @@ -752,7 +754,8 @@ public void shouldCreateJobPerStore() throws InvalidProtocolBufferException { Job expected1 = newJob("", store1, source); Job expected2 = newJob("", store2, source); - when(jobManager.startJob(any())).thenReturn(new Job()); + when(jobManager.startJob(expected1)).thenReturn(expected1); + when(jobManager.startJob(expected2)).thenReturn(expected2); when(jobManager.getRunnerType()).thenReturn(Runner.DATAFLOW); jcsWithJobPerStore.Poll(); @@ -805,6 +808,11 @@ public void shouldCloneRunningJobOnUpgrade() throws InvalidProtocolBufferExcepti existingJob.setFeatureSetJobStatuses(new HashSet<>()); existingJob.setStatus(JobStatus.RUNNING); + Job spawnJob = newJob("some-other-id", store1, source); + existingJob.setExtId("extId2"); + existingJob.setFeatureSetJobStatuses(new HashSet<>()); + existingJob.setStatus(JobStatus.RUNNING); + when(jobRepository .findFirstBySourceTypeAndSourceConfigAndStoreNameAndStatusNotInOrderByLastUpdatedDesc( eq(source.getType()), @@ -814,6 +822,7 @@ public void shouldCloneRunningJobOnUpgrade() throws InvalidProtocolBufferExcepti .thenReturn(Optional.of(existingJob)); when(jobManager.getRunnerType()).thenReturn(Runner.DATAFLOW); + when(jobManager.startJob(any())).thenReturn(spawnJob); jcsWithConsolidation.Poll(); @@ -821,7 +830,6 @@ public void shouldCloneRunningJobOnUpgrade() throws InvalidProtocolBufferExcepti // not stopped yet verify(jobManager, never()).abortJob(any()); - verify(jobManager, times(1)).startJob(jobCaptor.capture()); Job actual = jobCaptor.getValue(); diff --git a/core/src/test/java/feast/core/service/JobServiceTest.java b/core/src/test/java/feast/core/service/JobServiceTest.java index 1e0bec76da3..ec09820b363 100644 --- a/core/src/test/java/feast/core/service/JobServiceTest.java +++ b/core/src/test/java/feast/core/service/JobServiceTest.java @@ -35,7 +35,6 @@ import feast.core.model.*; import feast.core.util.TestUtil; import feast.proto.core.CoreServiceProto.ListFeatureSetsRequest; -import feast.proto.core.CoreServiceProto.ListFeatureSetsResponse; import feast.proto.core.CoreServiceProto.ListIngestionJobsRequest; import feast.proto.core.CoreServiceProto.ListIngestionJobsResponse; import feast.proto.core.CoreServiceProto.RestartIngestionJobRequest; @@ -57,7 +56,6 @@ public class JobServiceTest { // mocks @Mock private JobRepository jobRepository; @Mock private JobManager jobManager; - @Mock private SpecService specService; // fake models private Source dataSource; private Store dataStore; @@ -97,28 +95,12 @@ public void setup() { this.listFilters = this.newDummyListRequestFilters(); // setup mock objects - this.setupSpecService(); this.setupJobRepository(); this.setupJobManager(); + TestUtil.setupAuditLogger(); // create test target - this.jobService = - new JobService(this.jobRepository, this.specService, Arrays.asList(this.jobManager)); - } - - public void setupSpecService() { - try { - ListFeatureSetsResponse response = - ListFeatureSetsResponse.newBuilder().addFeatureSets(this.featureSet.toProto()).build(); - - when(this.specService.listFeatureSets(this.listFilters.get(0))).thenReturn(response); - - when(this.specService.listFeatureSets(this.listFilters.get(1))).thenReturn(response); - - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - fail("Unexpected exception"); - } + this.jobService = new JobService(this.jobRepository, Arrays.asList(this.jobManager)); } public void setupJobRepository() { @@ -190,13 +172,7 @@ private List newDummyListRequestFilters() { private ListIngestionJobsResponse tryListJobs(ListIngestionJobsRequest request) { ListIngestionJobsResponse response = null; - try { - response = this.jobService.listJobs(request); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - fail("Caught Unexpected exception"); - } - + response = this.jobService.listJobs(request); return response; } diff --git a/core/src/test/java/feast/core/service/AccessManagementServiceTest.java b/core/src/test/java/feast/core/service/ProjectServiceTest.java similarity index 63% rename from core/src/test/java/feast/core/service/AccessManagementServiceTest.java rename to core/src/test/java/feast/core/service/ProjectServiceTest.java index fa69a7e7a83..a32a85c991a 100644 --- a/core/src/test/java/feast/core/service/AccessManagementServiceTest.java +++ b/core/src/test/java/feast/core/service/ProjectServiceTest.java @@ -16,16 +16,12 @@ */ package feast.core.service; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.initMocks; -import feast.auth.authorization.AuthorizationProvider; -import feast.auth.config.SecurityProperties; -import feast.core.config.FeastProperties; import feast.core.dao.ProjectRepository; import feast.core.model.Project; import java.util.Arrays; @@ -38,70 +34,57 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; -public class AccessManagementServiceTest { +public class ProjectServiceTest { @Mock private ProjectRepository projectRepository; @Rule public final ExpectedException expectedException = ExpectedException.none(); - private AccessManagementService accessManagementService; + private ProjectService projectService; @Before public void setUp() { initMocks(this); projectRepository = mock(ProjectRepository.class); - SecurityProperties.AuthorizationProperties authProp = - new SecurityProperties.AuthorizationProperties(); - authProp.setEnabled(false); - SecurityProperties sp = new SecurityProperties(); - sp.setAuthorization(authProp); - FeastProperties feastProperties = new FeastProperties(); - feastProperties.setSecurity(sp); - accessManagementService = - new AccessManagementService( - feastProperties, projectRepository, mock(AuthorizationProvider.class)); - } - - @Test - public void testDefaultProjectCreateInConstructor() { - verify(this.projectRepository).saveAndFlush(new Project(Project.DEFAULT_NAME)); + projectService = new ProjectService(projectRepository); } @Test public void shouldCreateProjectIfItDoesntExist() { String projectName = "project1"; Project project = new Project(projectName); - when(projectRepository.saveAndFlush(any(Project.class))).thenReturn(project); - accessManagementService.createProject(projectName); - verify(projectRepository, times(1)).saveAndFlush(any()); + when(projectRepository.saveAndFlush(project)).thenReturn(project); + projectService.createProject(projectName); + verify(projectRepository, times(1)).saveAndFlush(project); } @Test(expected = IllegalArgumentException.class) public void shouldNotCreateProjectIfItExist() { String projectName = "project1"; when(projectRepository.existsById(projectName)).thenReturn(true); - accessManagementService.createProject(projectName); + projectService.createProject(projectName); } @Test public void shouldArchiveProjectIfItExists() { String projectName = "project1"; - when(projectRepository.findById(projectName)).thenReturn(Optional.of(new Project(projectName))); - accessManagementService.archiveProject(projectName); - verify(projectRepository, times(1)).saveAndFlush(any(Project.class)); + Project project = new Project(projectName); + when(projectRepository.findById(projectName)).thenReturn(Optional.of(project)); + projectService.archiveProject(projectName); + verify(projectRepository, times(1)).saveAndFlush(project); } @Test public void shouldNotArchiveDefaultProject() { expectedException.expect(IllegalArgumentException.class); - this.accessManagementService.archiveProject(Project.DEFAULT_NAME); + this.projectService.archiveProject(Project.DEFAULT_NAME); } @Test(expected = IllegalArgumentException.class) public void shouldNotArchiveProjectIfItIsAlreadyArchived() { String projectName = "project1"; when(projectRepository.findById(projectName)).thenReturn(Optional.empty()); - accessManagementService.archiveProject(projectName); + projectService.archiveProject(projectName); } @Test @@ -110,7 +93,7 @@ public void shouldListProjects() { Project project = new Project(projectName); List expected = Arrays.asList(project); when(projectRepository.findAllByArchivedIsFalse()).thenReturn(expected); - List actual = accessManagementService.listProjects(); + List actual = projectService.listProjects(); Assert.assertEquals(expected, actual); } } diff --git a/core/src/test/java/feast/core/util/TestUtil.java b/core/src/test/java/feast/core/util/TestUtil.java index b93aaa8cd42..0199d21d708 100644 --- a/core/src/test/java/feast/core/util/TestUtil.java +++ b/core/src/test/java/feast/core/util/TestUtil.java @@ -16,6 +16,12 @@ */ package feast.core.util; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import feast.common.logging.AuditLogger; +import feast.common.logging.config.LoggingProperties; +import feast.common.logging.config.LoggingProperties.AuditLogProperties; import feast.core.model.Entity; import feast.core.model.Feature; import feast.core.model.FeatureSet; @@ -41,6 +47,7 @@ import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.springframework.boot.info.BuildProperties; public class TestUtil { public static Set makeFeatureSetJobStatus(FeatureSet... featureSets) { @@ -145,4 +152,18 @@ public static FeatureSetJobStatus CreateFeatureSetJobStatusWithJob( return featureSetJobStatus; } + + /** Setup the audit logger. This call is required to use the audit logger when testing. */ + public static void setupAuditLogger() { + AuditLogProperties properties = new AuditLogProperties(); + properties.setEnabled(true); + LoggingProperties loggingProperties = new LoggingProperties(); + loggingProperties.setAudit(properties); + + BuildProperties buildProperties = mock(BuildProperties.class); + when(buildProperties.getArtifact()).thenReturn("feast-core"); + when(buildProperties.getVersion()).thenReturn("0.6"); + + new AuditLogger(loggingProperties, buildProperties); + } } diff --git a/datatypes/java/pom.xml b/datatypes/java/pom.xml index 5810a6db96a..dd2a162c01c 100644 --- a/datatypes/java/pom.xml +++ b/datatypes/java/pom.xml @@ -54,11 +54,11 @@ true - com.google.protobuf:protoc:${protocVersion}:exe:${os.detected.classifier} + com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:${grpcVersion}:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier} diff --git a/infra/charts/feast/README.md b/infra/charts/feast/README.md index 6b0325884e7..1a11ce1e294 100644 --- a/infra/charts/feast/README.md +++ b/infra/charts/feast/README.md @@ -164,6 +164,7 @@ feast-batch-serving: staging_location: gs:///feast-staging-location initial_retry_delay_seconds: 3 total_timeout_seconds: 21600 + write_triggering_frequency_seconds: 600 subscriptions: - name: "*" project: "*" @@ -240,7 +241,7 @@ feast-core: options: project: region: - zone: + workerZone: tempLocation: network: subnetwork: @@ -280,6 +281,7 @@ feast-batch-serving: staging_location: gs:///feast-staging-location initial_retry_delay_seconds: 3 total_timeout_seconds: 21600 + write_triggering_frequency_seconds: 600 subscriptions: - name: "*" project: "*" diff --git a/infra/charts/feast/README.md.gotmpl b/infra/charts/feast/README.md.gotmpl index 75fc6661c80..56023730fde 100644 --- a/infra/charts/feast/README.md.gotmpl +++ b/infra/charts/feast/README.md.gotmpl @@ -137,6 +137,7 @@ feast-batch-serving: staging_location: gs:///feast-staging-location initial_retry_delay_seconds: 3 total_timeout_seconds: 21600 + write_triggering_frequency_seconds: 600 subscriptions: - name: "*" project: "*" @@ -213,7 +214,7 @@ feast-core: options: project: region: - zone: + workerZone: tempLocation: network: subnetwork: diff --git a/infra/charts/feast/values-batch-serving.yaml b/infra/charts/feast/values-batch-serving.yaml index 3ee35be1061..afa99f6d69b 100644 --- a/infra/charts/feast/values-batch-serving.yaml +++ b/infra/charts/feast/values-batch-serving.yaml @@ -20,6 +20,7 @@ feast-batch-serving: staging_location: gs:///feast-staging-location initial_retry_delay_seconds: 3 total_timeout_seconds: 21600 + write_triggering_frequency_seconds: 600 subscriptions: - name: "*" project: "*" diff --git a/infra/charts/feast/values-dataflow-runner.yaml b/infra/charts/feast/values-dataflow-runner.yaml index 0469a6349e2..56e51551970 100644 --- a/infra/charts/feast/values-dataflow-runner.yaml +++ b/infra/charts/feast/values-dataflow-runner.yaml @@ -19,10 +19,12 @@ feast-core: options: project: region: - zone: + workerZone: tempLocation: network: subnetwork: + enableStreamingEngine: false + workerDiskType: maxNumWorkers: 1 autoscalingAlgorithm: THROUGHPUT_BASED usePublicIps: false diff --git a/infra/docker-compose/docker-compose.online.yml b/infra/docker-compose/docker-compose.online.yml index b01d0882fb4..0e5a3cfaec6 100644 --- a/infra/docker-compose/docker-compose.online.yml +++ b/infra/docker-compose/docker-compose.online.yml @@ -5,11 +5,16 @@ services: image: ${FEAST_SERVING_IMAGE}:${FEAST_VERSION} volumes: - ./serving/${FEAST_ONLINE_SERVING_CONFIG}:/etc/feast/application.yml + # Required if authentication is enabled on core and + # provider is 'google'. GOOGLE_APPLICATION_CREDENTIALS is used for connecting to core. + - ./gcp-service-accounts/${FEAST_BATCH_SERVING_GCP_SERVICE_ACCOUNT_KEY}:/etc/gcloud/service-accounts/key.json depends_on: - redis ports: - 6566:6566 restart: on-failure + environment: + GOOGLE_APPLICATION_CREDENTIALS: /etc/gcloud/service-accounts/key.json command: - java - -jar diff --git a/infra/docker-compose/serving/batch-serving.yml b/infra/docker-compose/serving/batch-serving.yml index c34aba277c7..3feb81c84e1 100644 --- a/infra/docker-compose/serving/batch-serving.yml +++ b/infra/docker-compose/serving/batch-serving.yml @@ -12,6 +12,7 @@ feast: staging_location: gs://gcs_bucket/prefix initial_retry_delay_seconds: 1 total_timeout_seconds: 21600 + write_triggering_frequency_seconds: 600 subscriptions: - name: "*" project: "*" diff --git a/infra/docker/ci/Dockerfile b/infra/docker/ci/Dockerfile index 08da02ae202..4e7c383524e 100644 --- a/infra/docker/ci/Dockerfile +++ b/infra/docker/ci/Dockerfile @@ -30,7 +30,7 @@ ENV PATH $GOPATH/bin:/usr/local/go/bin:$PATH ENV PATH="$HOME/bin:${PATH}" # Install Protoc and Plugins -ENV PROTOC_VERSION 3.10.0 +ENV PROTOC_VERSION 3.12.2 RUN PROTOC_ZIP=protoc-${PROTOC_VERSION}-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/$PROTOC_ZIP && \ diff --git a/infra/scripts/test-end-to-end-batch.sh b/infra/scripts/test-end-to-end-batch.sh index 60f09cb4163..cf48a7e14a6 100755 --- a/infra/scripts/test-end-to-end-batch.sh +++ b/infra/scripts/test-end-to-end-batch.sh @@ -113,10 +113,6 @@ feast: tracing: enabled: false -grpc: - port: 6566 - enable-reflection: true - server: port: 8081 diff --git a/infra/scripts/test-end-to-end-redis-cluster.sh b/infra/scripts/test-end-to-end-redis-cluster.sh index 9094fc3a2e0..be38fe765bc 100755 --- a/infra/scripts/test-end-to-end-redis-cluster.sh +++ b/infra/scripts/test-end-to-end-redis-cluster.sh @@ -67,10 +67,6 @@ feast: tracing: enabled: false -grpc: - port: 6566 - enable-reflection: true - spring: main: web-environment: false diff --git a/infra/scripts/test-end-to-end.sh b/infra/scripts/test-end-to-end.sh index 75bacd3560a..84d65aebe25 100755 --- a/infra/scripts/test-end-to-end.sh +++ b/infra/scripts/test-end-to-end.sh @@ -2,10 +2,7 @@ set -e set -o pipefail -ENABLE_AUTH="False" -if [[ -n $1 ]]; then - ENABLE_AUTH=$1 -fi +[[ $1 == "True" ]] && ENABLE_AUTH="true" || ENABLE_AUTH="false" echo "Authenication enabled : ${ENABLE_AUTH}" test -z ${GOOGLE_APPLICATION_CREDENTIALS} && GOOGLE_APPLICATION_CREDENTIALS="/etc/gcloud/service-account.json" @@ -60,20 +57,13 @@ feast: authentication: enabled: true provider: jwt + options: + jwkEndpointURI: "https://www.googleapis.com/oauth2/v3/certs" authorization: enabled: false provider: none EOF -if [[ ${ENABLE_AUTH} = "True" ]]; - then - print_banner "Starting 'Feast core with auth'." - start_feast_core /tmp/core.warehouse.application.yml - else - print_banner "Starting 'Feast core without auth'." - start_feast_core -fi - cat < /tmp/serving.warehouse.application.yml feast: stores: @@ -86,8 +76,30 @@ feast: subscriptions: - name: "*" project: "*" + core-authentication: + enabled: $ENABLE_AUTH + provider: google + security: + authentication: + enabled: $ENABLE_AUTH + provider: jwt + authorization: + enabled: false + provider: none EOF +if [[ ${ENABLE_AUTH} = "true" ]]; + then + print_banner "Starting Feast core with auth" + start_feast_core /tmp/core.warehouse.application.yml + print_banner "Starting Feast Serving with auth" + else + print_banner "Starting Feast core without auth" + start_feast_core + print_banner "Starting Feast Serving without auth" +fi + + start_feast_serving /tmp/serving.warehouse.application.yml install_python_with_miniconda_and_feast_sdk diff --git a/infra/scripts/test-templates/values-end-to-end-batch-dataflow.yaml b/infra/scripts/test-templates/values-end-to-end-batch-dataflow.yaml index 48231face69..377fa7a0aee 100644 --- a/infra/scripts/test-templates/values-end-to-end-batch-dataflow.yaml +++ b/infra/scripts/test-templates/values-end-to-end-batch-dataflow.yaml @@ -27,7 +27,7 @@ feast-core: options: project: $GCLOUD_PROJECT region: $GCLOUD_REGION - zone: $GCLOUD_REGION-a + workerZone: $GCLOUD_REGION-a tempLocation: gs://$TEMP_BUCKET/tempLocation network: $GCLOUD_NETWORK subnetwork: regions/$GCLOUD_REGION/subnetworks/$GCLOUD_SUBNET diff --git a/ingestion/pom.xml b/ingestion/pom.xml index de50789a67b..d1fbf1d45b5 100644 --- a/ingestion/pom.xml +++ b/ingestion/pom.xml @@ -57,10 +57,6 @@ - - org.springframework - org.springframework.vendor - io.opencensus io.opencensus.vendor diff --git a/ingestion/src/test/java/feast/ingestion/ImportJobTest.java b/ingestion/src/test/java/feast/ingestion/ImportJobTest.java index 1cfd29b5415..f775bc31bbb 100644 --- a/ingestion/src/test/java/feast/ingestion/ImportJobTest.java +++ b/ingestion/src/test/java/feast/ingestion/ImportJobTest.java @@ -217,7 +217,9 @@ public void runPipeline_ShouldWriteToRedisCorrectlyGivenValidSpecAndFeatureRow() .map(FeatureSpec::getName) .collect(Collectors.toList()) .contains(field.getName())) - .map(field -> field.toBuilder().clearName().build()) + .map( + field -> + field.toBuilder().setName(TestUtil.hash(field.getName())).build()) .collect(Collectors.toList()); randomRow = randomRow diff --git a/ingestion/src/test/java/feast/ingestion/transform/specs/FeatureSetSpecReadAndWriteTest.java b/ingestion/src/test/java/feast/ingestion/transform/specs/FeatureSetSpecReadAndWriteTest.java index e73123a810e..340af58e8c5 100644 --- a/ingestion/src/test/java/feast/ingestion/transform/specs/FeatureSetSpecReadAndWriteTest.java +++ b/ingestion/src/test/java/feast/ingestion/transform/specs/FeatureSetSpecReadAndWriteTest.java @@ -34,7 +34,6 @@ import java.util.stream.Collectors; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.direct.DirectOptions; -import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.TestPipeline; @@ -44,7 +43,6 @@ import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.common.serialization.ByteArraySerializer; import org.apache.kafka.common.serialization.Deserializer; -import org.joda.time.Duration; import org.junit.*; public class FeatureSetSpecReadAndWriteTest { @@ -101,7 +99,7 @@ public static PipelineOptions makePipelineOptions() { } @Test - public void pipelineShouldReadSpecsAndAcknowledge() { + public void pipelineShouldReadSpecsAndAcknowledge() throws InterruptedException { SourceProto.Source source = SourceProto.Source.newBuilder() .setKafkaSourceConfig( @@ -153,8 +151,8 @@ public void pipelineShouldReadSpecsAndAcknowledge() { publishSpecToKafka("project", "fs", 3, source); publishSpecToKafka("project", "fs_2", 2, source); - PipelineResult run = p.run(); - run.waitUntilFinish(Duration.standardSeconds(10)); + p.run(); + Thread.sleep(10000); List acks = getFeatureSetSpecAcks(); @@ -178,7 +176,7 @@ public void pipelineShouldReadSpecsAndAcknowledge() { // in-flight update 1 publishSpecToKafka("project", "fs", 4, source); - run.waitUntilFinish(Duration.standardSeconds(5)); + Thread.sleep(5000); assertThat( getFeatureSetSpecAcks(), @@ -192,7 +190,7 @@ public void pipelineShouldReadSpecsAndAcknowledge() { // in-flight update 2 publishSpecToKafka("project", "fs_2", 3, source); - run.waitUntilFinish(Duration.standardSeconds(5)); + Thread.sleep(5000); assertThat( getFeatureSetSpecAcks(), diff --git a/ingestion/src/test/java/feast/test/TestUtil.java b/ingestion/src/test/java/feast/test/TestUtil.java index f3ae9f6a988..1fb8ea89ea3 100644 --- a/ingestion/src/test/java/feast/test/TestUtil.java +++ b/ingestion/src/test/java/feast/test/TestUtil.java @@ -19,10 +19,11 @@ import static feast.common.models.FeatureSet.getFeatureSetStringRef; import com.google.common.collect.ImmutableList; +import com.google.common.hash.Hashing; import com.google.common.io.Files; import com.google.protobuf.ByteString; import com.google.protobuf.Message; -import com.google.protobuf.util.Timestamps; +import com.google.protobuf.Timestamp; import feast.ingestion.transform.metrics.WriteSuccessMetricsTransform; import feast.proto.core.FeatureSetProto.FeatureSet; import feast.proto.core.FeatureSetProto.FeatureSetSpec; @@ -36,6 +37,7 @@ import java.net.DatagramSocket; import java.net.SocketException; import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -218,10 +220,15 @@ public static FeatureRow createRandomFeatureRow(FeatureSetSpec featureSetSpec) { */ public static FeatureRow createRandomFeatureRow( FeatureSetSpec featureSetSpec, int randomStringSize) { + + Instant time = Instant.now(); + Timestamp timestamp = + Timestamp.newBuilder().setSeconds(time.getEpochSecond()).setNanos(time.getNano()).build(); + Builder builder = FeatureRow.newBuilder() .setFeatureSet(getFeatureSetStringRef(featureSetSpec)) - .setEventTimestamp(Timestamps.fromMillis(System.currentTimeMillis())); + .setEventTimestamp(timestamp); featureSetSpec .getEntitiesList() @@ -511,4 +518,8 @@ public static void waitUntilAllElementsAreWrittenToStore( } } } + + public static String hash(String input) { + return Hashing.murmur3_32().hashString(input, StandardCharsets.UTF_8).toString(); + } } diff --git a/pom.xml b/pom.xml index 5de3f1116a3..826bfa02ebe 100644 --- a/pom.xml +++ b/pom.xml @@ -47,23 +47,31 @@ UTF-8 UTF-8 - 1.17.1 - 3.10.0 - 3.10.0 - 2.0.9.RELEASE - 2.18.0 - 1.91.0 + 1.30.2 + 3.12.2 + 3.12.2 + 2.3.1.RELEASE + 5.3.0.RELEASE + 2.9.0.RELEASE + 2.22.0 + 1.111.1 0.8.0 1.9.10 1.3 - 5.3.6.Final - 2.3.0 + 5.4.18.Final + 2.5.0 2.28.2 - 0.21.0 + 0.26.0 2.12.1 - 5.2.4 + 6.0.8 + 2.9.9 + 2.0.2 + 2.5.0.RELEASE + + false + 1.6.6 @@ -153,45 +161,80 @@ io.grpc grpc-core - ${grpcVersion} + ${grpc.version} + + + io.grpc + grpc-api + ${grpc.version} + + + io.grpc + grpc-context + ${grpc.version} + + + io.grpc + grpc-all + ${grpc.version} + + + io.grpc + grpc-okhttp + ${grpc.version} + + + io.grpc + grpc-auth + ${grpc.version} + + + io.grpc + grpc-grpclb + ${grpc.version} + + + io.grpc + grpc-alts + ${grpc.version} io.grpc grpc-netty - ${grpcVersion} + ${grpc.version} io.grpc grpc-netty-shaded - ${grpcVersion} + ${grpc.version} io.grpc grpc-protobuf - ${grpcVersion} + ${grpc.version} io.grpc grpc-services - ${grpcVersion} + ${grpc.version} io.grpc grpc-stub - ${grpcVersion} + ${grpc.version} io.grpc grpc-testing - ${grpcVersion} + ${grpc.version} test - io.github.lognet - grpc-spring-boot-starter - 3.0.2 + net.devh + grpc-server-spring-boot-starter + ${grpc.spring.boot.starter.version} @@ -207,6 +250,11 @@ + + joda-time + joda-time + ${joda.time.version} + com.datadoghq java-dogstatsd-client @@ -220,12 +268,12 @@ com.google.protobuf protobuf-java - ${protobufVersion} + ${protobuf.version} com.google.protobuf protobuf-java-util - ${protobufVersion} + ${protobuf.version} org.projectlombok @@ -233,6 +281,21 @@ 1.18.12 provided + + com.google.auto.value + auto-value-annotations + ${auto.value.version} + + + com.google.auto.value + auto-value + ${auto.value.version} + + + com.google.code.gson + gson + 2.8.5 + @@ -272,7 +335,7 @@ org.springframework.boot spring-boot-starter-web - ${springBootVersion} + ${spring.boot.version} org.springframework.boot @@ -316,7 +379,7 @@ org.springframework.boot spring-boot-dependencies - ${springBootVersion} + ${spring.boot.version} pom import @@ -643,7 +706,7 @@ org.springframework.boot spring-boot-maven-plugin - ${springBootVersion} + ${spring.boot.version} diff --git a/protos/feast/core/Runner.proto b/protos/feast/core/Runner.proto index 0684356f8d2..9bb4457d4f7 100644 --- a/protos/feast/core/Runner.proto +++ b/protos/feast/core/Runner.proto @@ -45,7 +45,7 @@ message DataflowRunnerConfigOptions { string region = 2; /* GCP availability zone for operations. */ - string zone = 3; + string workerZone = 3; /* Run the job as a specific service account, instead of the default GCE robot. */ string serviceAccount = 4; @@ -81,4 +81,9 @@ message DataflowRunnerConfigOptions { /* Disk size to use on each remote Compute Engine worker instance */ int32 diskSizeGb = 14; + /* Run job on Dataflow Streaming Engine instead of creating worker VMs */ + bool enableStreamingEngine = 15; + + /* Type of persistent disk to be used by workers */ + string workerDiskType = 16; } \ No newline at end of file diff --git a/protos/feast/core/Store.proto b/protos/feast/core/Store.proto index 3b4394150db..780d7a7db8b 100644 --- a/protos/feast/core/Store.proto +++ b/protos/feast/core/Store.proto @@ -108,7 +108,7 @@ message Store { int32 initial_backoff_ms = 3; // Optional. Maximum total number of retries for connecting to Redis. Default to zero retries. int32 max_retries = 4; - // Optional. how often flush data to redis + // Optional. How often flush data to redis int32 flush_frequency_seconds = 5; } @@ -118,6 +118,7 @@ message Store { string staging_location = 3; int32 initial_retry_delay_seconds = 4; int32 total_timeout_seconds = 5; + // Required. Frequency of running BQ load job and flushing all collected rows to BQ table int32 write_triggering_frequency_seconds = 6; } @@ -131,7 +132,7 @@ message Store { string connection_string = 1; int32 initial_backoff_ms = 2; int32 max_retries = 3; - // Optional. how often flush data to redis + // Optional. How often flush data to redis int32 flush_frequency_seconds = 4; } diff --git a/sdk/python/feast/client.py b/sdk/python/feast/client.py index 01c9ed83bea..86b9a2c57f6 100644 --- a/sdk/python/feast/client.py +++ b/sdk/python/feast/client.py @@ -32,10 +32,10 @@ from feast.config import Config from feast.constants import ( - CONFIG_CORE_ENABLE_AUTH_KEY, CONFIG_CORE_ENABLE_SSL_KEY, CONFIG_CORE_SERVER_SSL_CERT_KEY, CONFIG_CORE_URL_KEY, + CONFIG_ENABLE_AUTH_KEY, CONFIG_GRPC_CONNECTION_TIMEOUT_DEFAULT_KEY, CONFIG_PROJECT_KEY, CONFIG_SERVING_ENABLE_SSL_KEY, @@ -112,9 +112,9 @@ def __init__(self, options: Optional[Dict[str, str]] = None, **kwargs): project: Sets the active project. This field is optional. core_secure: Use client-side SSL/TLS for Core gRPC API serving_secure: Use client-side SSL/TLS for Serving gRPC API - core_enable_auth: Enable authentication and authorization - core_auth_provider: Authentication provider – "google" or "oauth" - if core_auth_provider is "oauth", the following fields are mandatory – + enable_auth: Enable authentication and authorization + auth_provider: Authentication provider – "google" or "oauth" + if auth_provider is "oauth", the following fields are mandatory – oauth_grant_type, oauth_client_id, oauth_client_secret, oauth_audience, oauth_token_request_url Args: @@ -132,7 +132,7 @@ def __init__(self, options: Optional[Dict[str, str]] = None, **kwargs): self._auth_metadata: Optional[grpc.AuthMetadataPlugin] = None # Configure Auth Metadata Plugin if auth is enabled - if self._config.getboolean(CONFIG_CORE_ENABLE_AUTH_KEY): + if self._config.getboolean(CONFIG_ENABLE_AUTH_KEY): self._auth_metadata = feast_auth.get_auth_metadata_plugin(self._config) @property @@ -146,7 +146,7 @@ def _core_service(self): channel = create_grpc_channel( url=self._config.get(CONFIG_CORE_URL_KEY), enable_ssl=self._config.getboolean(CONFIG_CORE_ENABLE_SSL_KEY), - enable_auth=self._config.getboolean(CONFIG_CORE_ENABLE_AUTH_KEY), + enable_auth=self._config.getboolean(CONFIG_ENABLE_AUTH_KEY), ssl_server_cert_path=self._config.get(CONFIG_CORE_SERVER_SSL_CERT_KEY), auth_metadata_plugin=self._auth_metadata, timeout=self._config.getint(CONFIG_GRPC_CONNECTION_TIMEOUT_DEFAULT_KEY), @@ -165,11 +165,11 @@ def _serving_service(self): channel = create_grpc_channel( url=self._config.get(CONFIG_SERVING_URL_KEY), enable_ssl=self._config.getboolean(CONFIG_SERVING_ENABLE_SSL_KEY), - enable_auth=False, + enable_auth=self._config.getboolean(CONFIG_ENABLE_AUTH_KEY), ssl_server_cert_path=self._config.get( CONFIG_SERVING_SERVER_SSL_CERT_KEY ), - auth_metadata_plugin=None, + auth_metadata_plugin=self._auth_metadata, timeout=self._config.getint(CONFIG_GRPC_CONNECTION_TIMEOUT_DEFAULT_KEY), ) self._serving_service_stub = ServingServiceStub(channel) @@ -271,6 +271,7 @@ def version(self): serving_version = self._serving_service.GetFeastServingInfo( GetFeastServingInfoRequest(), timeout=self._config.getint(CONFIG_GRPC_CONNECTION_TIMEOUT_DEFAULT_KEY), + metadata=self._get_grpc_metadata(), ).version result["serving"] = {"url": self.serving_url, "version": serving_version} @@ -522,7 +523,7 @@ def list_features_by_ref( ) feature_protos = self._core_service.ListFeatures( - ListFeaturesRequest(filter=filter) + ListFeaturesRequest(filter=filter), metadata=self._get_grpc_metadata(), ) # type: ListFeaturesResponse features_dict = {} @@ -619,6 +620,7 @@ def get_historical_features( serving_info = self._serving_service.GetFeastServingInfo( GetFeastServingInfoRequest(), timeout=self._config.getint(CONFIG_GRPC_CONNECTION_TIMEOUT_DEFAULT_KEY), + metadata=self._get_grpc_metadata(), ) # type: GetFeastServingInfoResponse if serving_info.type != FeastServingType.FEAST_SERVING_TYPE_BATCH: @@ -669,11 +671,17 @@ def get_historical_features( # Retrieve Feast Job object to manage life cycle of retrieval try: - response = self._serving_service.GetBatchFeatures(request) + response = self._serving_service.GetBatchFeatures( + request, metadata=self._get_grpc_metadata() + ) except grpc.RpcError as e: raise grpc.RpcError(e.details()) - return RetrievalJob(response.job, self._serving_service) + return RetrievalJob( + response.job, + self._serving_service, + auth_metadata_plugin=self._auth_metadata, + ) def get_online_features( self, @@ -722,7 +730,8 @@ def get_online_features( features=_build_feature_references(feature_ref_strs=feature_refs), entity_rows=_infer_online_entity_rows(entity_rows), project=project if project is not None else self.project, - ) + ), + metadata=self._get_grpc_metadata(), ) except grpc.RpcError as e: raise grpc.RpcError(e.details()) @@ -759,9 +768,9 @@ def list_ingest_jobs( ) request = ListIngestionJobsRequest(filter=list_filter) # make list request & unpack response - response = self._core_service.ListIngestionJobs(request) # type: ignore + response = self._core_service.ListIngestionJobs(request, metadata=self._get_grpc_metadata(),) # type: ignore ingest_jobs = [ - IngestJob(proto, self._core_service) for proto in response.jobs # type: ignore + IngestJob(proto, self._core_service, auth_metadata_plugin=self._auth_metadata) for proto in response.jobs # type: ignore ] return ingest_jobs @@ -778,7 +787,9 @@ def restart_ingest_job(self, job: IngestJob): """ request = RestartIngestionJobRequest(id=job.id) try: - self._core_service.RestartIngestionJob(request) # type: ignore + self._core_service.RestartIngestionJob( + request, metadata=self._get_grpc_metadata(), + ) # type: ignore except grpc.RpcError as e: raise grpc.RpcError(e.details()) @@ -794,7 +805,9 @@ def stop_ingest_job(self, job: IngestJob): """ request = StopIngestionJobRequest(id=job.id) try: - self._core_service.StopIngestionJob(request) # type: ignore + self._core_service.StopIngestionJob( + request, metadata=self._get_grpc_metadata(), + ) # type: ignore except grpc.RpcError as e: raise grpc.RpcError(e.details()) @@ -836,11 +849,33 @@ def ingest( Returns: str: ingestion id for this dataset + + Examples: + >>> from feast import Client + >>> + >>> client = Client(core_url="localhost:6565") + >>> fs_df = pd.DataFrame( + >>> { + >>> "datetime": [pd.datetime.now()], + >>> "driver": [1001], + >>> "rating": [4.3], + >>> } + >>> ) + >>> client.set_project("project1") + >>> client.ingest("driver", fs_df) + >>> + >>> driver_fs = client.get_feature_set(name="driver", project="project1") + >>> client.ingest(driver_fs, fs_df) """ if isinstance(feature_set, FeatureSet): name = feature_set.name + project = feature_set.project elif isinstance(feature_set, str): + if self.project is not None: + project = self.project + else: + project = "default" name = feature_set else: raise Exception("Feature set name must be provided") @@ -858,7 +893,9 @@ def ingest( while True: if timeout is not None and time.time() - current_time >= timeout: raise TimeoutError("Timed out waiting for feature set to be ready") - fetched_feature_set: Optional[FeatureSet] = self.get_feature_set(name) + fetched_feature_set: Optional[FeatureSet] = self.get_feature_set( + name, project + ) if ( fetched_feature_set is not None and fetched_feature_set.status == FeatureSetStatus.STATUS_READY @@ -996,7 +1033,7 @@ def _get_grpc_metadata(self): Returns: Tuple of metadata to attach to each gRPC call """ - if self._config.getboolean(CONFIG_CORE_ENABLE_AUTH_KEY) and self._auth_metadata: + if self._config.getboolean(CONFIG_ENABLE_AUTH_KEY) and self._auth_metadata: return self._auth_metadata.get_signed_meta() return () diff --git a/sdk/python/feast/constants.py b/sdk/python/feast/constants.py index 911432326a9..67f4808010e 100644 --- a/sdk/python/feast/constants.py +++ b/sdk/python/feast/constants.py @@ -42,8 +42,8 @@ class AuthProvider(Enum): CONFIG_PROJECT_KEY = "project" CONFIG_CORE_URL_KEY = "core_url" CONFIG_CORE_ENABLE_SSL_KEY = "core_enable_ssl" -CONFIG_CORE_ENABLE_AUTH_KEY = "core_enable_auth" -CONFIG_CORE_ENABLE_AUTH_TOKEN_KEY = "core_auth_token" +CONFIG_ENABLE_AUTH_KEY = "enable_auth" +CONFIG_ENABLE_AUTH_TOKEN_KEY = "auth_token" CONFIG_CORE_SERVER_SSL_CERT_KEY = "core_server_ssl_cert" CONFIG_SERVING_URL_KEY = "serving_url" CONFIG_SERVING_ENABLE_SSL_KEY = "serving_enable_ssl" @@ -58,7 +58,7 @@ class AuthProvider(Enum): CONFIG_OAUTH_CLIENT_SECRET_KEY = "oauth_client_secret" CONFIG_OAUTH_AUDIENCE_KEY = "oauth_audience" CONFIG_OAUTH_TOKEN_REQUEST_URL_KEY = "oauth_token_request_url" -CONFIG_CORE_AUTH_PROVIDER = "core_auth_provider" +CONFIG_AUTH_PROVIDER = "auth_provider" CONFIG_TIMEOUT_KEY = "timeout" CONFIG_MAX_WAIT_INTERVAL_KEY = "max_wait_interval" @@ -72,7 +72,7 @@ class AuthProvider(Enum): # Enable or disable TLS/SSL to Feast Core CONFIG_CORE_ENABLE_SSL_KEY: "False", # Enable user authentication to Feast Core - CONFIG_CORE_ENABLE_AUTH_KEY: "False", + CONFIG_ENABLE_AUTH_KEY: "False", # Path to certificate(s) to secure connection to Feast Core CONFIG_CORE_SERVER_SSL_CERT_KEY: "", # Default Feast Serving URL @@ -91,5 +91,5 @@ class AuthProvider(Enum): CONFIG_TIMEOUT_KEY: "21600", CONFIG_MAX_WAIT_INTERVAL_KEY: "60", # Authentication Provider - Google OpenID/OAuth - CONFIG_CORE_AUTH_PROVIDER: "google", + CONFIG_AUTH_PROVIDER: "google", } diff --git a/sdk/python/feast/grpc/auth.py b/sdk/python/feast/grpc/auth.py index ab1de836311..9680607b8e3 100644 --- a/sdk/python/feast/grpc/auth.py +++ b/sdk/python/feast/grpc/auth.py @@ -19,8 +19,8 @@ from feast.config import Config from feast.constants import ( - CONFIG_CORE_AUTH_PROVIDER, - CONFIG_CORE_ENABLE_AUTH_TOKEN_KEY, + CONFIG_AUTH_PROVIDER, + CONFIG_ENABLE_AUTH_TOKEN_KEY, CONFIG_OAUTH_AUDIENCE_KEY, CONFIG_OAUTH_CLIENT_ID_KEY, CONFIG_OAUTH_CLIENT_SECRET_KEY, @@ -44,9 +44,9 @@ def get_auth_metadata_plugin(config: Config) -> grpc.AuthMetadataPlugin: Args: config: Feast Configuration object """ - if AuthProvider(config.get(CONFIG_CORE_AUTH_PROVIDER)) == AuthProvider.GOOGLE: + if AuthProvider(config.get(CONFIG_AUTH_PROVIDER)) == AuthProvider.GOOGLE: return GoogleOpenIDAuthMetadataPlugin(config) - elif AuthProvider(config.get(CONFIG_CORE_AUTH_PROVIDER)) == AuthProvider.OAUTH: + elif AuthProvider(config.get(CONFIG_AUTH_PROVIDER)) == AuthProvider.OAUTH: return OAuthMetadataPlugin(config) else: raise RuntimeError( @@ -75,8 +75,8 @@ def __init__(self, config: Config): self._token = None # If provided, set a static token - if config.exists(CONFIG_CORE_ENABLE_AUTH_TOKEN_KEY): - self._static_token = config.get(CONFIG_CORE_ENABLE_AUTH_TOKEN_KEY) + if config.exists(CONFIG_ENABLE_AUTH_TOKEN_KEY): + self._static_token = config.get(CONFIG_ENABLE_AUTH_TOKEN_KEY) self._refresh_token(config) elif ( config.exists(CONFIG_OAUTH_GRANT_TYPE_KEY) @@ -171,8 +171,8 @@ def __init__(self, config: Config): self._token = None # If provided, set a static token - if config.exists(CONFIG_CORE_ENABLE_AUTH_TOKEN_KEY): - self._static_token = config.get(CONFIG_CORE_ENABLE_AUTH_TOKEN_KEY) + if config.exists(CONFIG_ENABLE_AUTH_TOKEN_KEY): + self._static_token = config.get(CONFIG_ENABLE_AUTH_TOKEN_KEY) self._request = requests.Request() self._refresh_token() diff --git a/sdk/python/feast/job.py b/sdk/python/feast/job.py index 25396213e47..cda4b26d300 100644 --- a/sdk/python/feast/job.py +++ b/sdk/python/feast/job.py @@ -2,6 +2,7 @@ from urllib.parse import urlparse import fastavro +import grpc import pandas as pd from google.protobuf.json_format import MessageToJson @@ -39,15 +40,20 @@ class RetrievalJob: """ def __init__( - self, job_proto: JobProto, serving_stub: ServingServiceStub, + self, + job_proto: JobProto, + serving_stub: ServingServiceStub, + auth_metadata_plugin: grpc.AuthMetadataPlugin = None, ): """ Args: job_proto: Job proto object (wrapped by this job object) serving_stub: Stub for Feast serving service + auth_metadata_plugin: plugin to fetch auth metadata """ self.job_proto = job_proto self.serving_stub = serving_stub + self.auth_metadata = auth_metadata_plugin @property def id(self): @@ -68,7 +74,10 @@ def reload(self): Reload the latest job status Returns: None """ - self.job_proto = self.serving_stub.GetJob(GetJobRequest(job=self.job_proto)).job + self.job_proto = self.serving_stub.GetJob( + GetJobRequest(job=self.job_proto), + metadata=self.auth_metadata.get_signed_meta() if self.auth_metadata else (), + ).job def get_avro_files(self, timeout_sec: int = int(defaults[CONFIG_TIMEOUT_KEY])): """ @@ -218,16 +227,23 @@ class IngestJob: Defines a job for feature ingestion in feast. """ - def __init__(self, job_proto: IngestJobProto, core_stub: CoreServiceStub): + def __init__( + self, + job_proto: IngestJobProto, + core_stub: CoreServiceStub, + auth_metadata_plugin: grpc.AuthMetadataPlugin = None, + ): """ Construct a native ingest job from its protobuf version. Args: job_proto: Job proto object to construct from. core_stub: stub for Feast CoreService + auth_metadata_plugin: plugin to fetch auth metadata """ self.proto = job_proto self.core_svc = core_stub + self.auth_metadata = auth_metadata_plugin def reload(self): """ @@ -235,7 +251,10 @@ def reload(self): """ # pull latest proto from feast core response = self.core_svc.ListIngestionJobs( - ListIngestionJobsRequest(filter=ListIngestionJobsRequest.Filter(id=self.id)) + ListIngestionJobsRequest( + filter=ListIngestionJobsRequest.Filter(id=self.id) + ), + metadata=self.auth_metadata.get_signed_meta() if self.auth_metadata else (), ) self.proto = response.jobs[0] diff --git a/sdk/python/tests/grpc/test_auth.py b/sdk/python/tests/grpc/test_auth.py index 90896ee925f..7f023aabcfd 100644 --- a/sdk/python/tests/grpc/test_auth.py +++ b/sdk/python/tests/grpc/test_auth.py @@ -76,8 +76,8 @@ def refresh(self, request): def config_oauth(): config_dict = { "core_url": "localhost:50051", - "core_enable_auth": True, - "core_auth_provider": "oauth", + "enable_auth": True, + "auth_provider": "oauth", "oauth_grant_type": "client_credentials", "oauth_client_id": "fakeID", "oauth_client_secret": "fakeSecret", @@ -91,13 +91,8 @@ def config_oauth(): def config_google(): config_dict = { "core_url": "localhost:50051", - "core_enable_auth": True, - "core_auth_provider": "google", - "oauth_grant_type": "client_credentials", - "oauth_client_id": "fakeID", - "oauth_client_secret": "fakeSecret", - "oauth_audience": AUDIENCE, - "oauth_token_request_url": AUTH_URL, + "enable_auth": True, + "auth_provider": "google", } return Config(config_dict) @@ -106,8 +101,8 @@ def config_google(): def config_with_missing_variable(): config_dict = { "core_url": "localhost:50051", - "core_enable_auth": True, - "core_auth_provider": "oauth", + "enable_auth": True, + "auth_provider": "oauth", "oauth_grant_type": "client_credentials", "oauth_client_id": "fakeID", "oauth_client_secret": "fakeSecret", diff --git a/sdk/python/tests/test_client.py b/sdk/python/tests/test_client.py index 8712847b4bb..416d4b2dde8 100644 --- a/sdk/python/tests/test_client.py +++ b/sdk/python/tests/test_client.py @@ -24,6 +24,7 @@ import pytest from google.protobuf.duration_pb2 import Duration from mock import MagicMock, patch +from pytest_lazyfixture import lazy_fixture from pytz import timezone from feast.client import Client @@ -81,6 +82,7 @@ "TY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDI" "yfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" ) +AUTH_METADATA = (("authorization", f"Bearer {_FAKE_JWT_TOKEN}"),) class TestClient: @@ -103,6 +105,32 @@ def mock_client(self): client._serving_url = SERVING_URL return client + @pytest.fixture + def mock_client_with_auth(self): + client = Client( + core_url=CORE_URL, + serving_url=SERVING_URL, + enable_auth=True, + auth_token=_FAKE_JWT_TOKEN, + ) + client._core_url = CORE_URL + client._serving_url = SERVING_URL + return client + + @pytest.fixture + def secure_mock_client_with_auth(self): + client = Client( + core_url=CORE_URL, + serving_url=SERVING_URL, + core_enable_ssl=True, + serving_enable_ssl=True, + enable_auth=True, + auth_token=_FAKE_JWT_TOKEN, + ) + client._core_url = CORE_URL + client._serving_url = SERVING_URL + return client + @pytest.fixture def server_credentials(self): private_key = pkgutil.get_data(__name__, _PRIVATE_KEY_RESOURCE_PATH) @@ -216,8 +244,8 @@ def secure_core_client_with_auth(self, secure_core_server_with_auth): yield Client( core_url="localhost:50055", core_enable_ssl=True, - core_enable_auth=True, - core_auth_token=_FAKE_JWT_TOKEN, + enable_auth=True, + auth_token=_FAKE_JWT_TOKEN, ) @pytest.fixture @@ -226,7 +254,7 @@ def client(self, core_server, serving_server): @pytest.mark.parametrize( "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + [lazy_fixture("mock_client"), lazy_fixture("secure_mock_client")], ) def test_version(self, mocked_client, mocker): mocked_client._core_service_stub = Core.CoreServiceStub( @@ -257,10 +285,21 @@ def test_version(self, mocked_client, mocker): ) @pytest.mark.parametrize( - "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + "mocked_client,auth_metadata", + [ + (lazy_fixture("mock_client"), ()), + (lazy_fixture("mock_client_with_auth"), (AUTH_METADATA)), + (lazy_fixture("secure_mock_client"), ()), + (lazy_fixture("secure_mock_client_with_auth"), (AUTH_METADATA)), + ], + ids=[ + "mock_client_without_auth", + "mock_client_with_auth", + "secure_mock_client_without_auth", + "secure_mock_client_with_auth", + ], ) - def test_get_online_features(self, mocked_client, mocker): + def test_get_online_features(self, mocked_client, auth_metadata, mocker): ROW_COUNT = 300 mocked_client._serving_service_stub = Serving.ServingServiceStub( @@ -312,7 +351,7 @@ def int_val(x): project="driver_project", ) # type: GetOnlineFeaturesResponse mocked_client._serving_service_stub.GetOnlineFeatures.assert_called_with( - request + request, metadata=auth_metadata ) got_fields = got_response.field_values[0].fields @@ -333,7 +372,7 @@ def int_val(x): @pytest.mark.parametrize( "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + [lazy_fixture("mock_client"), lazy_fixture("secure_mock_client")], ) def test_get_feature_set(self, mocked_client, mocker): mocked_client._core_service_stub = Core.CoreServiceStub( @@ -397,7 +436,7 @@ def test_get_feature_set(self, mocked_client, mocker): @pytest.mark.parametrize( "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + [lazy_fixture("mock_client"), lazy_fixture("secure_mock_client")], ) def test_list_feature_sets(self, mocked_client, mocker): mocker.patch.object( @@ -458,7 +497,7 @@ def test_list_feature_sets(self, mocked_client, mocker): @pytest.mark.parametrize( "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + [lazy_fixture("mock_client"), lazy_fixture("secure_mock_client")], ) def test_list_features(self, mocked_client, mocker): mocker.patch.object( @@ -504,7 +543,7 @@ def test_list_features(self, mocked_client, mocker): @pytest.mark.parametrize( "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + [lazy_fixture("mock_client"), lazy_fixture("secure_mock_client")], ) def test_list_ingest_jobs(self, mocked_client, mocker): mocker.patch.object( @@ -560,7 +599,7 @@ def test_list_ingest_jobs(self, mocked_client, mocker): @pytest.mark.parametrize( "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + [lazy_fixture("mock_client"), lazy_fixture("secure_mock_client")], ) def test_restart_ingest_job(self, mocked_client, mocker): mocker.patch.object( @@ -583,7 +622,7 @@ def test_restart_ingest_job(self, mocked_client, mocker): @pytest.mark.parametrize( "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + [lazy_fixture("mock_client"), lazy_fixture("secure_mock_client")], ) def test_stop_ingest_job(self, mocked_client, mocker): mocker.patch.object( @@ -606,7 +645,12 @@ def test_stop_ingest_job(self, mocked_client, mocker): @pytest.mark.parametrize( "mocked_client", - [pytest.lazy_fixture("mock_client"), pytest.lazy_fixture("secure_mock_client")], + [ + lazy_fixture("mock_client"), + lazy_fixture("mock_client_with_auth"), + lazy_fixture("secure_mock_client"), + lazy_fixture("secure_mock_client_with_auth"), + ], ) def test_get_historical_features(self, mocked_client, mocker): @@ -725,8 +769,7 @@ def test_get_historical_features(self, mocked_client, mocker): assert actual_dataframe[["driver_id"]].equals(expected_dataframe[["driver_id"]]) @pytest.mark.parametrize( - "test_client", - [pytest.lazy_fixture("client"), pytest.lazy_fixture("secure_client")], + "test_client", [lazy_fixture("client"), lazy_fixture("secure_client")], ) def test_apply_feature_set_success(self, test_client): @@ -770,8 +813,8 @@ def test_apply_feature_set_success(self, test_client): @pytest.mark.parametrize( "dataframe,test_client", [ - (dataframes.GOOD, pytest.lazy_fixture("client")), - (dataframes.GOOD, pytest.lazy_fixture("secure_client")), + (dataframes.GOOD, lazy_fixture("client")), + (dataframes.GOOD, lazy_fixture("secure_client")), ], ) def test_feature_set_ingest_success(self, dataframe, test_client, mocker): @@ -802,7 +845,7 @@ def test_feature_set_ingest_success(self, dataframe, test_client, mocker): @pytest.mark.parametrize( "dataframe,test_client,exception", - [(dataframes.GOOD, pytest.lazy_fixture("client"), Exception)], + [(dataframes.GOOD, lazy_fixture("client"), Exception)], ) def test_feature_set_ingest_throws_exception_if_kafka_down( self, dataframe, test_client, exception, mocker @@ -835,8 +878,8 @@ def test_feature_set_ingest_throws_exception_if_kafka_down( @pytest.mark.parametrize( "dataframe,exception,test_client", [ - (dataframes.GOOD, TimeoutError, pytest.lazy_fixture("client")), - (dataframes.GOOD, TimeoutError, pytest.lazy_fixture("secure_client")), + (dataframes.GOOD, TimeoutError, lazy_fixture("client")), + (dataframes.GOOD, TimeoutError, lazy_fixture("secure_client")), ], ) def test_feature_set_ingest_fail_if_pending( @@ -872,26 +915,22 @@ def test_feature_set_ingest_fail_if_pending( @pytest.mark.parametrize( "dataframe,exception,test_client", [ - (dataframes.BAD_NO_DATETIME, Exception, pytest.lazy_fixture("client")), + (dataframes.BAD_NO_DATETIME, Exception, lazy_fixture("client")), ( dataframes.BAD_INCORRECT_DATETIME_TYPE, Exception, - pytest.lazy_fixture("client"), - ), - (dataframes.BAD_NO_ENTITY, Exception, pytest.lazy_fixture("client")), - (dataframes.NO_FEATURES, Exception, pytest.lazy_fixture("client")), - ( - dataframes.BAD_NO_DATETIME, - Exception, - pytest.lazy_fixture("secure_client"), + lazy_fixture("client"), ), + (dataframes.BAD_NO_ENTITY, Exception, lazy_fixture("client")), + (dataframes.NO_FEATURES, Exception, lazy_fixture("client")), + (dataframes.BAD_NO_DATETIME, Exception, lazy_fixture("secure_client"),), ( dataframes.BAD_INCORRECT_DATETIME_TYPE, Exception, - pytest.lazy_fixture("secure_client"), + lazy_fixture("secure_client"), ), - (dataframes.BAD_NO_ENTITY, Exception, pytest.lazy_fixture("secure_client")), - (dataframes.NO_FEATURES, Exception, pytest.lazy_fixture("secure_client")), + (dataframes.BAD_NO_ENTITY, Exception, lazy_fixture("secure_client")), + (dataframes.NO_FEATURES, Exception, lazy_fixture("secure_client")), ], ) def test_feature_set_ingest_failure(self, test_client, dataframe, exception): @@ -911,8 +950,8 @@ def test_feature_set_ingest_failure(self, test_client, dataframe, exception): @pytest.mark.parametrize( "dataframe,test_client", [ - (dataframes.ALL_TYPES, pytest.lazy_fixture("client")), - (dataframes.ALL_TYPES, pytest.lazy_fixture("secure_client")), + (dataframes.ALL_TYPES, lazy_fixture("client")), + (dataframes.ALL_TYPES, lazy_fixture("secure_client")), ], ) def test_feature_set_types_success(self, test_client, dataframe, mocker): @@ -1007,9 +1046,7 @@ def test_auth_success_with_insecure_channel_on_core_url( self, insecure_core_server_with_auth ): client = Client( - core_url="localhost:50056", - core_enable_auth=True, - core_auth_token=_FAKE_JWT_TOKEN, + core_url="localhost:50056", enable_auth=True, auth_token=_FAKE_JWT_TOKEN, ) client.list_feature_sets() diff --git a/serving/pom.xml b/serving/pom.xml index 66a840e0097..35b919968c5 100644 --- a/serving/pom.xml +++ b/serving/pom.xml @@ -121,6 +121,12 @@ feast-common ${project.version} + + + dev.feast + feast-auth + ${project.version} + @@ -134,7 +140,7 @@ true - + org.springframework.boot spring-boot-starter-web @@ -155,29 +161,23 @@ true - - - io.github.lognet - grpc-spring-boot-starter - - org.springframework.boot spring-boot-starter-actuator - + io.grpc grpc-services - + io.grpc grpc-stub - + com.google.protobuf protobuf-java-util @@ -249,7 +249,7 @@ - + io.grpc grpc-testing @@ -278,6 +278,80 @@ embedded-redis test + + jakarta.validation + jakarta.validation-api + ${jakarta.validation.api.version} + + + org.springframework.security + spring-security-core + ${spring.security.version} + + + org.springframework.security + spring-security-config + ${spring.security.version} + + + org.springframework.security.oauth + spring-security-oauth2 + ${spring.security.oauth2.version} + + + org.springframework.security + spring-security-oauth2-client + ${spring.security.version} + + + org.springframework.security + spring-security-web + ${spring.security.version} + + + org.springframework.security + spring-security-oauth2-jose + ${spring.security.version} + + + net.devh + grpc-server-spring-boot-starter + ${grpc.spring.boot.starter.version} + + + com.nimbusds + nimbus-jose-jwt + 8.2.1 + + + org.springframework.security + spring-security-oauth2-core + ${spring.security.version} + + + org.testcontainers + testcontainers + 1.14.3 + test + + + org.testcontainers + junit-jupiter + 1.14.3 + test + + + org.awaitility + awaitility + 3.0.0 + test + + + sh.ory.keto + keto-client + 0.4.4-alpha.1 + test + @@ -310,7 +384,6 @@ true - diff --git a/serving/src/main/java/feast/serving/config/FeastProperties.java b/serving/src/main/java/feast/serving/config/FeastProperties.java index f905f5f5c02..6a1d1a55171 100644 --- a/serving/src/main/java/feast/serving/config/FeastProperties.java +++ b/serving/src/main/java/feast/serving/config/FeastProperties.java @@ -25,17 +25,32 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; +import feast.auth.config.SecurityProperties; +import feast.auth.config.SecurityProperties.AuthenticationProperties; +import feast.auth.config.SecurityProperties.AuthorizationProperties; +import feast.auth.credentials.CoreAuthenticationProperties; +import feast.common.logging.config.LoggingProperties; import feast.proto.core.StoreProto; import java.util.*; import java.util.stream.Collectors; +import javax.annotation.PostConstruct; +import javax.validation.ConstraintViolation; +import javax.validation.ConstraintViolationException; +import javax.validation.Validation; +import javax.validation.Validator; +import javax.validation.ValidatorFactory; import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; import javax.validation.constraints.Positive; import org.apache.logging.log4j.core.config.plugins.validation.constraints.ValidHost; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.info.BuildProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; /** Feast Serving properties. */ +@ComponentScan("feast.common.logging") @ConfigurationProperties(prefix = "feast", ignoreInvalidFields = true) public class FeastProperties { @@ -61,6 +76,41 @@ public FeastProperties() {} /* Feast Core port to connect to. */ @Positive private int coreGrpcPort; + private CoreAuthenticationProperties coreAuthentication; + + public CoreAuthenticationProperties getCoreAuthentication() { + return coreAuthentication; + } + + public void setCoreAuthentication(CoreAuthenticationProperties coreAuthentication) { + this.coreAuthentication = coreAuthentication; + } + + private SecurityProperties security; + + @Bean + SecurityProperties securityProperties() { + return this.getSecurity(); + } + + /** + * Getter for SecurityProperties + * + * @return Returns the {@link SecurityProperties} object. + */ + public SecurityProperties getSecurity() { + return security; + } + + /** + * Setter for SecurityProperties + * + * @param security :input {@link SecurityProperties} object + */ + public void setSecurity(SecurityProperties security) { + this.security = security; + } + /** * Finds and returns the active store * @@ -99,6 +149,14 @@ public void setActiveStore(String activeStore) { /* Metric tracing properties. */ private TracingProperties tracing; + /* Feast Audit Logging properties */ + @NotNull private LoggingProperties logging; + + @Bean + LoggingProperties loggingProperties() { + return getLogging(); + } + /** * Gets Serving store configuration as a list of {@link Store}. * @@ -427,6 +485,20 @@ public void setTracing(TracingProperties tracing) { this.tracing = tracing; } + /** + * Gets logging properties + * + * @return logging properties + */ + public LoggingProperties getLogging() { + return logging; + } + + /** Sets logging properties @@param logging the logging properties */ + public void setLogging(LoggingProperties logging) { + this.logging = logging; + } + /** The type Job store properties. */ public static class JobStoreProperties { @@ -539,4 +611,41 @@ public void setServiceName(String serviceName) { this.serviceName = serviceName; } } + + /** + * Validates all FeastProperties. This method runs after properties have been initialized and + * individually and conditionally validates each class. + */ + @PostConstruct + public void validate() { + ValidatorFactory factory = Validation.buildDefaultValidatorFactory(); + Validator validator = factory.getValidator(); + + // Validate root fields in FeastProperties + Set> violations = validator.validate(this); + if (!violations.isEmpty()) { + throw new ConstraintViolationException(violations); + } + + // Validate CoreAuthenticationProperties + Set> coreAuthenticationPropsViolations = + validator.validate(getCoreAuthentication()); + if (!coreAuthenticationPropsViolations.isEmpty()) { + throw new ConstraintViolationException(coreAuthenticationPropsViolations); + } + + // Validate AuthenticationProperties + Set> authenticationPropsViolations = + validator.validate(getSecurity().getAuthentication()); + if (!authenticationPropsViolations.isEmpty()) { + throw new ConstraintViolationException(authenticationPropsViolations); + } + + // Validate AuthorizationProperties + Set> authorizationPropsViolations = + validator.validate(getSecurity().getAuthorization()); + if (!authorizationPropsViolations.isEmpty()) { + throw new ConstraintViolationException(authorizationPropsViolations); + } + } } diff --git a/serving/src/main/java/feast/serving/config/ServingSecurityConfig.java b/serving/src/main/java/feast/serving/config/ServingSecurityConfig.java new file mode 100644 index 00000000000..2d0a46763a7 --- /dev/null +++ b/serving/src/main/java/feast/serving/config/ServingSecurityConfig.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.config; + +import feast.auth.credentials.GoogleAuthCredentials; +import feast.auth.credentials.OAuthCredentials; +import io.grpc.CallCredentials; +import java.io.IOException; +import net.devh.boot.grpc.server.security.check.AccessPredicate; +import net.devh.boot.grpc.server.security.check.GrpcSecurityMetadataSource; +import net.devh.boot.grpc.server.security.check.ManualGrpcSecurityMetadataSource; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Configuration; + +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@Configuration +@ComponentScan(basePackages = {"feast.auth.config", "feast.auth.service"}) +public class ServingSecurityConfig { + + private final FeastProperties feastProperties; + + public ServingSecurityConfig(FeastProperties feastProperties) { + this.feastProperties = feastProperties; + } + + /** + * Creates a SecurityMetadataSource when authentication is enabled. This allows for the + * configuration of endpoint level security rules. + * + * @return GrpcSecurityMetadataSource + */ + @Bean + @ConditionalOnProperty(prefix = "feast.security.authentication", name = "enabled") + GrpcSecurityMetadataSource grpcSecurityMetadataSource() { + final ManualGrpcSecurityMetadataSource source = new ManualGrpcSecurityMetadataSource(); + + // Authentication is enabled for all gRPC endpoints + source.setDefault(AccessPredicate.authenticated()); + return source; + } + + /** + * Creates a CallCredentials when authentication is enabled on core. This allows serving to + * connect to core with CallCredentials + * + * @return CallCredentials + */ + @Bean + @ConditionalOnProperty(prefix = "feast.core-authentication", name = "enabled") + CallCredentials CoreGrpcAuthenticationCredentials() throws IOException { + switch (feastProperties.getCoreAuthentication().getProvider()) { + case "google": + return new GoogleAuthCredentials(feastProperties.getCoreAuthentication().getOptions()); + case "oauth": + return new OAuthCredentials(feastProperties.getCoreAuthentication().getOptions()); + default: + throw new IllegalArgumentException( + "Please configure an Core Authentication Provider " + + "if you have enabled Authentication on core. " + + "Currently `google` and `oauth` are supported"); + } + } +} diff --git a/serving/src/main/java/feast/serving/config/SpecServiceConfig.java b/serving/src/main/java/feast/serving/config/SpecServiceConfig.java index 0a62557077f..75b77a29a03 100644 --- a/serving/src/main/java/feast/serving/config/SpecServiceConfig.java +++ b/serving/src/main/java/feast/serving/config/SpecServiceConfig.java @@ -21,10 +21,12 @@ import feast.proto.core.StoreProto; import feast.serving.specs.CachedSpecService; import feast.serving.specs.CoreSpecService; +import io.grpc.CallCredentials; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.slf4j.Logger; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -58,9 +60,11 @@ public ScheduledExecutorService cachedSpecServiceScheduledExecutorService( } @Bean - public CachedSpecService specService(FeastProperties feastProperties) + public CachedSpecService specService( + FeastProperties feastProperties, ObjectProvider callCredentials) throws InvalidProtocolBufferException, JsonProcessingException { - CoreSpecService coreService = new CoreSpecService(feastCoreHost, feastCorePort); + CoreSpecService coreService = + new CoreSpecService(feastCoreHost, feastCorePort, callCredentials); StoreProto.Store storeProto = feastProperties.getActiveStore().toProto(); CachedSpecService cachedSpecStorage = new CachedSpecService(coreService, storeProto); try { diff --git a/serving/src/main/java/feast/serving/controller/HealthServiceController.java b/serving/src/main/java/feast/serving/controller/HealthServiceController.java index 0810429183e..5225a7ea2ed 100644 --- a/serving/src/main/java/feast/serving/controller/HealthServiceController.java +++ b/serving/src/main/java/feast/serving/controller/HealthServiceController.java @@ -26,12 +26,12 @@ import io.grpc.health.v1.HealthProto.HealthCheckResponse; import io.grpc.health.v1.HealthProto.HealthCheckResponse.ServingStatus; import io.grpc.stub.StreamObserver; -import org.lognet.springboot.grpc.GRpcService; +import net.devh.boot.grpc.server.service.GrpcService; import org.springframework.beans.factory.annotation.Autowired; // Reference: https://github.com/grpc/grpc/blob/master/doc/health-checking.md -@GRpcService(interceptors = {GrpcMonitoringInterceptor.class}) +@GrpcService(interceptors = {GrpcMonitoringInterceptor.class}) public class HealthServiceController extends HealthImplBase { private CachedSpecService specService; private ServingService servingService; diff --git a/serving/src/main/java/feast/serving/controller/ServingServiceGRpcController.java b/serving/src/main/java/feast/serving/controller/ServingServiceGRpcController.java index 3fae6ae65a7..e888f523164 100644 --- a/serving/src/main/java/feast/serving/controller/ServingServiceGRpcController.java +++ b/serving/src/main/java/feast/serving/controller/ServingServiceGRpcController.java @@ -16,6 +16,9 @@ */ package feast.serving.controller; +import feast.auth.service.AuthorizationService; +import feast.common.interceptors.GrpcMessageInterceptor; +import feast.proto.serving.ServingAPIProto.FeatureReference; import feast.proto.serving.ServingAPIProto.GetBatchFeaturesRequest; import feast.proto.serving.ServingAPIProto.GetBatchFeaturesResponse; import feast.proto.serving.ServingAPIProto.GetFeastServingInfoRequest; @@ -35,11 +38,16 @@ import io.opentracing.Scope; import io.opentracing.Span; import io.opentracing.Tracer; -import org.lognet.springboot.grpc.GRpcService; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import net.devh.boot.grpc.server.service.GrpcService; import org.slf4j.Logger; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.core.context.SecurityContextHolder; -@GRpcService(interceptors = {GrpcMonitoringInterceptor.class}) +@GrpcService(interceptors = {GrpcMessageInterceptor.class, GrpcMonitoringInterceptor.class}) public class ServingServiceGRpcController extends ServingServiceImplBase { private static final Logger log = @@ -47,10 +55,15 @@ public class ServingServiceGRpcController extends ServingServiceImplBase { private final ServingService servingService; private final String version; private final Tracer tracer; + private final AuthorizationService authorizationService; @Autowired public ServingServiceGRpcController( - ServingService servingService, FeastProperties feastProperties, Tracer tracer) { + AuthorizationService authorizationService, + ServingService servingService, + FeastProperties feastProperties, + Tracer tracer) { + this.authorizationService = authorizationService; this.servingService = servingService; this.version = feastProperties.getVersion(); this.tracer = tracer; @@ -72,6 +85,16 @@ public void getOnlineFeatures( StreamObserver responseObserver) { Span span = tracer.buildSpan("getOnlineFeatures").start(); try (Scope scope = tracer.scopeManager().activate(span, false)) { + // authorize for the project in request object. + if (request.getProject() != null && !request.getProject().isEmpty()) { + // project set at root level overrides the project set at feature set level + this.authorizationService.authorizeRequest( + SecurityContextHolder.getContext(), request.getProject()); + } else { + // authorize for projects set in feature list, backward compatibility for + // <=v0.5.X + this.checkProjectAccess(request.getFeaturesList()); + } RequestHelper.validateOnlineRequest(request); GetOnlineFeaturesResponse onlineFeatures = servingService.getOnlineFeatures(request); responseObserver.onNext(onlineFeatures); @@ -80,6 +103,13 @@ public void getOnlineFeatures( log.error("Failed to retrieve specs in SpecService", e); responseObserver.onError( Status.NOT_FOUND.withDescription(e.getMessage()).withCause(e).asException()); + } catch (AccessDeniedException e) { + log.info(String.format("User prevented from accessing one of the projects in request")); + responseObserver.onError( + Status.PERMISSION_DENIED + .withDescription(e.getMessage()) + .withCause(e) + .asRuntimeException()); } catch (Exception e) { log.warn("Failed to get Online Features", e); responseObserver.onError(e); @@ -92,6 +122,7 @@ public void getBatchFeatures( GetBatchFeaturesRequest request, StreamObserver responseObserver) { try { RequestHelper.validateBatchRequest(request); + this.checkProjectAccess(request.getFeaturesList()); GetBatchFeaturesResponse batchFeatures = servingService.getBatchFeatures(request); responseObserver.onNext(batchFeatures); responseObserver.onCompleted(); @@ -99,6 +130,13 @@ public void getBatchFeatures( log.error("Failed to retrieve specs in SpecService", e); responseObserver.onError( Status.NOT_FOUND.withDescription(e.getMessage()).withCause(e).asException()); + } catch (AccessDeniedException e) { + log.info(String.format("User prevented from accessing one of the projects in request")); + responseObserver.onError( + Status.PERMISSION_DENIED + .withDescription(e.getMessage()) + .withCause(e) + .asRuntimeException()); } catch (Exception e) { log.warn("Failed to get Batch Features", e); responseObserver.onError(e); @@ -116,4 +154,19 @@ public void getJob(GetJobRequest request, StreamObserver respons responseObserver.onError(e); } } + + private void checkProjectAccess(List featureList) { + Set projectList = + featureList.stream().map(FeatureReference::getProject).collect(Collectors.toSet()); + if (projectList.isEmpty()) { + authorizationService.authorizeRequest(SecurityContextHolder.getContext(), "default"); + } else { + projectList.stream() + .forEach( + project -> { + this.authorizationService.authorizeRequest( + SecurityContextHolder.getContext(), project); + }); + } + } } diff --git a/serving/src/main/java/feast/serving/service/OnlineServingService.java b/serving/src/main/java/feast/serving/service/OnlineServingService.java index a357904e32f..a7d9d284aa2 100644 --- a/serving/src/main/java/feast/serving/service/OnlineServingService.java +++ b/serving/src/main/java/feast/serving/service/OnlineServingService.java @@ -17,6 +17,7 @@ package feast.serving.service; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; import com.google.protobuf.Duration; import feast.common.models.Feature; import feast.common.models.FeatureSet; @@ -36,7 +37,6 @@ import io.opentracing.Tracer; import java.util.*; import java.util.stream.Collectors; -import org.apache.beam.vendor.grpc.v1p21p0.com.google.common.collect.Streams; import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; diff --git a/serving/src/main/java/feast/serving/specs/CoreSpecService.java b/serving/src/main/java/feast/serving/specs/CoreSpecService.java index e2feaebccb2..8dcfd0695eb 100644 --- a/serving/src/main/java/feast/serving/specs/CoreSpecService.java +++ b/serving/src/main/java/feast/serving/specs/CoreSpecService.java @@ -24,9 +24,11 @@ import feast.proto.core.CoreServiceProto.UpdateStoreRequest; import feast.proto.core.CoreServiceProto.UpdateStoreResponse; import feast.proto.core.StoreProto.Store; +import io.grpc.CallCredentials; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import org.slf4j.Logger; +import org.springframework.beans.factory.ObjectProvider; /** Client for interfacing with specs in Feast Core. */ public class CoreSpecService { @@ -34,10 +36,16 @@ public class CoreSpecService { private static final Logger log = org.slf4j.LoggerFactory.getLogger(CoreSpecService.class); private final CoreServiceGrpc.CoreServiceBlockingStub blockingStub; - public CoreSpecService(String feastCoreHost, int feastCorePort) { + public CoreSpecService( + String feastCoreHost, int feastCorePort, ObjectProvider callCredentials) { ManagedChannel channel = ManagedChannelBuilder.forAddress(feastCoreHost, feastCorePort).usePlaintext().build(); - blockingStub = CoreServiceGrpc.newBlockingStub(channel); + CallCredentials creds = callCredentials.getIfAvailable(); + if (creds != null) { + blockingStub = CoreServiceGrpc.newBlockingStub(channel).withCallCredentials(creds); + } else { + blockingStub = CoreServiceGrpc.newBlockingStub(channel); + } } public GetFeatureSetResponse getFeatureSet(GetFeatureSetRequest getFeatureSetRequest) { diff --git a/serving/src/main/resources/application.yml b/serving/src/main/resources/application.yml index 2399d132ef9..7e14cb58a89 100644 --- a/serving/src/main/resources/application.yml +++ b/serving/src/main/resources/application.yml @@ -3,10 +3,36 @@ feast: # Feast Serving requires connection to Feast Core to retrieve and reload Feast metadata (e.g. FeatureSpecs, Store information) core-host: ${FEAST_CORE_HOST:localhost} core-grpc-port: ${FEAST_CORE_GRPC_PORT:6565} + + core-authentication: + enabled: false # should be set to true if authentication is enabled on core. + provider: google # can be set to `oauth` or `google` + # if google, GOOGLE_APPLICATION_CREDENTIALS environment variable should be set. + options: + #if provider is oauth following properties need to be set, else serving boot up will fail. + oauth_url: https://localhost/oauth/token #oauth token request url + grant_type: client_credentials #oauth grant type + client_id: #oauth client id which will be used for jwt token token request + client_secret: #oauth client secret which will be used for jwt token token request + audience: https://localhost #token audience. + jwkEndpointURI: #jwk enpoint uri, used for caching token till expiry. + # Indicates the active store. Only a single store in the last can be active at one time. In the future this key # will be deprecated in order to allow multiple stores to be served from a single serving instance active_store: online + + security: + authentication: + enabled: false + provider: jwt + options: + jwkEndpointURI: "https://www.googleapis.com/oauth2/v3/certs" + authorization: + enabled: false + provider: http + options: + basePath: http://localhost:3000 # List of store configurations stores: @@ -71,14 +97,25 @@ feast: # Redis port to connect to redis_port: 6379 + logging: + # Audit logging provides a machine readable structured JSON log that can give better + # insight into what is happening in Feast. + audit: + # Whether audit logging is enabled. + enabled: true + # Whether to enable message level (ie request/response) audit logging + messageLoggingEnabled: false + grpc: - # The port number Feast Serving GRPC service should listen on - # It is set default to 6566 so it does not conflict with the GRPC server on Feast Core - # which defaults to port 6565 - port: ${GRPC_PORT:6566} - # This allows client to discover GRPC endpoints easily - # https://github.com/grpc/grpc-java/blob/master/documentation/server-reflection-tutorial.md - enable-reflection: ${GRPC_ENABLE_REFLECTION:true} + server: + # The port number Feast Serving GRPC service should listen on + # It is set default to 6566 so it does not conflict with the GRPC server on Feast Core + # which defaults to port 6565 + port: ${GRPC_PORT:6566} + security: + enabled: false + certificateChainPath: server.crt + privateKeyPath: server.key server: # The port number on which the Tomcat webserver that serves REST API endpoints should listen diff --git a/serving/src/main/resources/log4j2.xml b/serving/src/main/resources/log4j2.xml index 661c8e5061c..c75c2db13cc 100644 --- a/serving/src/main/resources/log4j2.xml +++ b/serving/src/main/resources/log4j2.xml @@ -16,28 +16,33 @@ ~ --> - + %d{yyyy-MM-dd HH:mm:ss.SSS} %5p ${hostName} --- [%15.15t] %-40.40c{1.} : %m%n%ex - ${env:LOG_TYPE:-Console} - ${env:LOG_LEVEL:-info} + + {"time":"%d{yyyy-MM-dd'T'HH:mm:ssXXX}","hostname":"${hostName}","severity":"%p","message":%m}%n%ex + + - + + - - - - - - + + + + + + + + - \ No newline at end of file + diff --git a/serving/src/test/java/feast/serving/controller/ServingServiceGRpcControllerTest.java b/serving/src/test/java/feast/serving/controller/ServingServiceGRpcControllerTest.java index 5c8308daea5..3577f098c1e 100644 --- a/serving/src/test/java/feast/serving/controller/ServingServiceGRpcControllerTest.java +++ b/serving/src/test/java/feast/serving/controller/ServingServiceGRpcControllerTest.java @@ -16,9 +16,20 @@ */ package feast.serving.controller; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.initMocks; import com.google.protobuf.Timestamp; +import feast.auth.authorization.AuthorizationProvider; +import feast.auth.authorization.AuthorizationResult; +import feast.auth.config.SecurityProperties; +import feast.auth.config.SecurityProperties.AuthenticationProperties; +import feast.auth.config.SecurityProperties.AuthorizationProperties; +import feast.auth.service.AuthorizationService; import feast.proto.serving.ServingAPIProto.FeatureReference; import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest; import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; @@ -34,6 +45,9 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.Mockito; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; public class ServingServiceGRpcControllerTest { @@ -45,6 +59,10 @@ public class ServingServiceGRpcControllerTest { private ServingServiceGRpcController service; + @Mock private Authentication authentication; + + @Mock private AuthorizationProvider authProvider; + @Before public void setUp() { initMocks(this); @@ -59,23 +77,65 @@ public void setUp() { .putFields("entity1", Value.newBuilder().setInt64Val(1).build()) .putFields("entity2", Value.newBuilder().setInt64Val(1).build())) .build(); + } + private ServingServiceGRpcController getServingServiceGRpcController(boolean enableAuth) { Tracer tracer = Configuration.fromEnv("dummy").getTracer(); FeastProperties feastProperties = new FeastProperties(); - service = new ServingServiceGRpcController(mockServingService, feastProperties, tracer); + + AuthorizationProperties authorizationProps = new AuthorizationProperties(); + authorizationProps.setEnabled(enableAuth); + AuthenticationProperties authenticationProps = new AuthenticationProperties(); + authenticationProps.setEnabled(enableAuth); + SecurityProperties securityProperties = new SecurityProperties(); + securityProperties.setAuthentication(authenticationProps); + securityProperties.setAuthorization(authorizationProps); + feastProperties.setSecurity(securityProperties); + AuthorizationService authorizationservice = + new AuthorizationService(feastProperties.getSecurity(), authProvider); + return new ServingServiceGRpcController( + authorizationservice, mockServingService, feastProperties, tracer); } @Test public void shouldPassValidRequestAsIs() { + service = getServingServiceGRpcController(false); service.getOnlineFeatures(validRequest, mockStreamObserver); Mockito.verify(mockServingService).getOnlineFeatures(validRequest); } @Test public void shouldCallOnErrorIfEntityDatasetIsNotSet() { + service = getServingServiceGRpcController(false); GetOnlineFeaturesRequest missingEntityName = GetOnlineFeaturesRequest.newBuilder(validRequest).clearEntityRows().build(); service.getOnlineFeatures(missingEntityName, mockStreamObserver); Mockito.verify(mockStreamObserver).onError(Mockito.any(StatusRuntimeException.class)); } + + @Test + public void shouldPassValidRequestAsIsIfRequestIsAuthorized() { + service = getServingServiceGRpcController(true); + SecurityContext context = mock(SecurityContext.class); + SecurityContextHolder.setContext(context); + when(context.getAuthentication()).thenReturn(authentication); + doReturn(AuthorizationResult.success()) + .when(authProvider) + .checkAccessToProject(anyString(), any(Authentication.class)); + service.getOnlineFeatures(validRequest, mockStreamObserver); + Mockito.verify(mockServingService).getOnlineFeatures(validRequest); + } + + @Test + public void shouldThrowErrorOnValidRequestIfRequestIsUnauthorized() { + service = getServingServiceGRpcController(true); + SecurityContext context = mock(SecurityContext.class); + SecurityContextHolder.setContext(context); + when(context.getAuthentication()).thenReturn(authentication); + doReturn(AuthorizationResult.failed(null)) + .when(authProvider) + .checkAccessToProject(anyString(), any(Authentication.class)); + service.getOnlineFeatures(validRequest, mockStreamObserver); + Mockito.verify(mockStreamObserver).onError(Mockito.any(StatusRuntimeException.class)); + } } diff --git a/serving/src/test/java/feast/serving/it/AuthTestUtils.java b/serving/src/test/java/feast/serving/it/AuthTestUtils.java new file mode 100644 index 00000000000..5ec7298e988 --- /dev/null +++ b/serving/src/test/java/feast/serving/it/AuthTestUtils.java @@ -0,0 +1,283 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.it; + +import static org.awaitility.Awaitility.waitAtMost; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.beans.HasPropertyWithValue.hasProperty; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.protobuf.Timestamp; +import feast.auth.credentials.OAuthCredentials; +import feast.proto.core.CoreServiceGrpc; +import feast.proto.core.FeatureSetProto; +import feast.proto.core.FeatureSetProto.FeatureSetStatus; +import feast.proto.core.SourceProto; +import feast.proto.serving.ServingAPIProto.FeatureReference; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; +import feast.proto.serving.ServingServiceGrpc; +import feast.proto.types.ValueProto; +import feast.proto.types.ValueProto.Value; +import io.grpc.CallCredentials; +import io.grpc.Channel; +import io.grpc.ManagedChannelBuilder; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.runners.model.InitializationError; +import sh.ory.keto.ApiClient; +import sh.ory.keto.ApiException; +import sh.ory.keto.Configuration; +import sh.ory.keto.api.EnginesApi; +import sh.ory.keto.model.OryAccessControlPolicy; +import sh.ory.keto.model.OryAccessControlPolicyRole; + +public class AuthTestUtils { + + private static final String DEFAULT_FLAVOR = "glob"; + + static SourceProto.Source defaultSource = + createSource("kafka:9092,localhost:9094", "feast-features"); + + public static SourceProto.Source getDefaultSource() { + return defaultSource; + } + + public static SourceProto.Source createSource(String server, String topic) { + return SourceProto.Source.newBuilder() + .setType(SourceProto.SourceType.KAFKA) + .setKafkaSourceConfig( + SourceProto.KafkaSourceConfig.newBuilder() + .setBootstrapServers(server) + .setTopic(topic) + .build()) + .build(); + } + + public static FeatureSetProto.FeatureSet createFeatureSet( + SourceProto.Source source, + String projectName, + String name, + List> entities, + List> features) { + return FeatureSetProto.FeatureSet.newBuilder() + .setSpec( + FeatureSetProto.FeatureSetSpec.newBuilder() + .setSource(source) + .setName(name) + .setProject(projectName) + .addAllEntities( + entities.stream() + .map( + pair -> + FeatureSetProto.EntitySpec.newBuilder() + .setName(pair.getLeft()) + .setValueType(pair.getRight()) + .build()) + .collect(Collectors.toList())) + .addAllFeatures( + features.stream() + .map( + pair -> + FeatureSetProto.FeatureSpec.newBuilder() + .setName(pair.getLeft()) + .setValueType(pair.getRight()) + .build()) + .collect(Collectors.toList())) + .build()) + .build(); + } + + public static GetOnlineFeaturesRequest createOnlineFeatureRequest( + String projectName, String featureName, String entityId, int entityValue) { + return GetOnlineFeaturesRequest.newBuilder() + .setProject(projectName) + .addFeatures(FeatureReference.newBuilder().setName(featureName).build()) + .addEntityRows( + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields(entityId, Value.newBuilder().setInt64Val(entityValue).build())) + .build(); + } + + public static void applyFeatureSet( + CoreSimpleAPIClient secureApiClient, + String projectName, + String entityId, + String featureName) { + List> entities = new ArrayList<>(); + entities.add(Pair.of(entityId, ValueProto.ValueType.Enum.INT64)); + List> features = new ArrayList<>(); + features.add(Pair.of(featureName, ValueProto.ValueType.Enum.INT64)); + String featureSetName = "test_1"; + FeatureSetProto.FeatureSet expectedFeatureSet = + AuthTestUtils.createFeatureSet( + AuthTestUtils.getDefaultSource(), projectName, featureSetName, entities, features); + secureApiClient.simpleApplyFeatureSet(expectedFeatureSet); + waitAtMost(2, TimeUnit.MINUTES) + .until( + () -> { + return secureApiClient.simpleGetFeatureSet(projectName, featureSetName).getMeta(); + }, + hasProperty("status", equalTo(FeatureSetStatus.STATUS_READY))); + FeatureSetProto.FeatureSet actualFeatureSet = + secureApiClient.simpleGetFeatureSet(projectName, featureSetName); + assertEquals( + expectedFeatureSet.getSpec().getProject(), actualFeatureSet.getSpec().getProject()); + assertEquals(expectedFeatureSet.getSpec().getName(), actualFeatureSet.getSpec().getName()); + assertEquals(expectedFeatureSet.getSpec().getSource(), actualFeatureSet.getSpec().getSource()); + assertEquals(FeatureSetStatus.STATUS_READY, actualFeatureSet.getMeta().getStatus()); + } + + public static CoreSimpleAPIClient getSecureApiClientForCore( + int feastCorePort, Map options) { + CallCredentials callCredentials = null; + callCredentials = new OAuthCredentials(options); + Channel secureChannel = + ManagedChannelBuilder.forAddress("localhost", feastCorePort).usePlaintext().build(); + + CoreServiceGrpc.CoreServiceBlockingStub secureCoreService = + CoreServiceGrpc.newBlockingStub(secureChannel).withCallCredentials(callCredentials); + + return new CoreSimpleAPIClient(secureCoreService); + } + + public static ServingServiceGrpc.ServingServiceBlockingStub getServingServiceStub( + boolean isSecure, int feastServingPort, Map options) { + Channel secureChannel = + ManagedChannelBuilder.forAddress("localhost", feastServingPort).usePlaintext().build(); + + if (isSecure) { + CallCredentials callCredentials = null; + callCredentials = new OAuthCredentials(options); + return ServingServiceGrpc.newBlockingStub(secureChannel).withCallCredentials(callCredentials); + } else { + return ServingServiceGrpc.newBlockingStub(secureChannel); + } + } + + public static void seedHydra( + String hydraExternalUrl, + String clientId, + String clientSecrret, + String audience, + String grantType) + throws IOException, InitializationError { + + OkHttpClient httpClient = new OkHttpClient(); + String createClientEndpoint = String.format("%s/%s", hydraExternalUrl, "clients"); + JsonObject jsonObject = new JsonObject(); + JsonArray audienceArrray = new JsonArray(); + audienceArrray.add(audience); + JsonArray grantTypes = new JsonArray(); + grantTypes.add(grantType); + jsonObject.addProperty("client_id", clientId); + jsonObject.addProperty("client_secret", clientSecrret); + jsonObject.addProperty("token_endpoint_auth_method", "client_secret_post"); + jsonObject.add("audience", audienceArrray); + jsonObject.add("grant_types", grantTypes); + MediaType JSON = MediaType.parse("application/json; charset=utf-8"); + + RequestBody requestBody = RequestBody.create(JSON, jsonObject.toString()); + Request request = + new Request.Builder() + .url(createClientEndpoint) + .addHeader("Content-Type", "application/json") + .post(requestBody) + .build(); + Response response = httpClient.newCall(request).execute(); + if (!response.isSuccessful()) { + throw new InitializationError(response.message()); + } + } + + public static void seedKeto(String url, String project, String subjectInProject, String admin) + throws ApiException { + ApiClient ketoClient = Configuration.getDefaultApiClient(); + ketoClient.setBasePath(url); + EnginesApi enginesApi = new EnginesApi(ketoClient); + + // Add policies + OryAccessControlPolicy adminPolicy = getAdminPolicy(); + enginesApi.upsertOryAccessControlPolicy(DEFAULT_FLAVOR, adminPolicy); + + OryAccessControlPolicy projectPolicy = getMyProjectMemberPolicy(project); + enginesApi.upsertOryAccessControlPolicy(DEFAULT_FLAVOR, projectPolicy); + + // Add policy roles + OryAccessControlPolicyRole adminPolicyRole = getAdminPolicyRole(admin); + enginesApi.upsertOryAccessControlPolicyRole(DEFAULT_FLAVOR, adminPolicyRole); + + OryAccessControlPolicyRole myProjectMemberPolicyRole = + getMyProjectMemberPolicyRole(project, subjectInProject); + enginesApi.upsertOryAccessControlPolicyRole(DEFAULT_FLAVOR, myProjectMemberPolicyRole); + } + + private static OryAccessControlPolicyRole getMyProjectMemberPolicyRole( + String project, String subjectInProject) { + OryAccessControlPolicyRole role = new OryAccessControlPolicyRole(); + role.setId(String.format("roles:%s-project-members", project)); + role.setMembers(Collections.singletonList("users:" + subjectInProject)); + return role; + } + + private static OryAccessControlPolicyRole getAdminPolicyRole(String subjectIsAdmin) { + OryAccessControlPolicyRole role = new OryAccessControlPolicyRole(); + role.setId("roles:admin"); + role.setMembers(Collections.singletonList("users:" + subjectIsAdmin)); + return role; + } + + private static OryAccessControlPolicy getAdminPolicy() { + OryAccessControlPolicy policy = new OryAccessControlPolicy(); + policy.setId("policies:admin"); + policy.subjects(Collections.singletonList("roles:admin")); + policy.resources(Collections.singletonList("resources:**")); + policy.actions(Collections.singletonList("actions:**")); + policy.effect("allow"); + policy.conditions(null); + return policy; + } + + private static OryAccessControlPolicy getMyProjectMemberPolicy(String project) { + OryAccessControlPolicy policy = new OryAccessControlPolicy(); + policy.setId(String.format("policies:%s-project-members-policy", project)); + policy.subjects(Collections.singletonList(String.format("roles:%s-project-members", project))); + policy.resources( + Arrays.asList( + String.format("resources:projects:%s", project), + String.format("resources:projects:%s:**", project))); + policy.actions(Collections.singletonList("actions:**")); + policy.effect("allow"); + policy.conditions(null); + return policy; + } +} diff --git a/serving/src/test/java/feast/serving/it/BaseAuthIT.java b/serving/src/test/java/feast/serving/it/BaseAuthIT.java new file mode 100644 index 00000000000..bdbe432ed8e --- /dev/null +++ b/serving/src/test/java/feast/serving/it/BaseAuthIT.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.it; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.DynamicPropertyRegistry; +import org.springframework.test.context.DynamicPropertySource; + +@ActiveProfiles("it") +@SpringBootTest +public class BaseAuthIT { + + static final String FEATURE_NAME = "feature_1"; + static final String ENTITY_ID = "entity_id"; + static final String PROJECT_NAME = "project_1"; + static final int CORE_START_MAX_WAIT_TIME_IN_MINUTES = 3; + static final String CLIENT_ID = "client_id"; + static final String CLIENT_SECRET = "client_secret"; + static final String TOKEN_URL = "http://localhost:4444/oauth2/token"; + static final String JWK_URI = "http://localhost:4444/.well-known/jwks.json"; + + static final String GRANT_TYPE = "client_credentials"; + + static final String AUDIENCE = "https://localhost"; + + static final String CORE = "core_1"; + + static final String HYDRA = "hydra_1"; + static final int HYDRA_PORT = 4445; + + static CoreSimpleAPIClient insecureApiClient; + + static final int REDIS_PORT = 6379; + + static final int FEAST_CORE_PORT = 6565; + + @DynamicPropertySource + static void properties(DynamicPropertyRegistry registry) { + registry.add("feast.stores[0].name", () -> "online"); + registry.add("feast.stores[0].type", () -> "REDIS"); + // Redis needs to accessible by both core and serving, hence using host address + registry.add( + "feast.stores[0].config.host", + () -> { + try { + return InetAddress.getLocalHost().getHostAddress(); + } catch (UnknownHostException e) { + e.printStackTrace(); + return ""; + } + }); + registry.add("feast.stores[0].config.port", () -> REDIS_PORT); + registry.add("feast.stores[0].subscriptions[0].name", () -> "*"); + registry.add("feast.stores[0].subscriptions[0].project", () -> "*"); + + registry.add("feast.core-authentication.options.oauth_url", () -> TOKEN_URL); + registry.add("feast.core-authentication.options.grant_type", () -> GRANT_TYPE); + registry.add("feast.core-authentication.options.client_id", () -> CLIENT_ID); + registry.add("feast.core-authentication.options.client_secret", () -> CLIENT_SECRET); + registry.add("feast.core-authentication.options.audience", () -> AUDIENCE); + registry.add("feast.core-authentication.options.jwkEndpointURI", () -> JWK_URI); + registry.add("feast.security.authentication.options.jwkEndpointURI", () -> JWK_URI); + } +} diff --git a/serving/src/test/java/feast/serving/it/CoreSimpleAPIClient.java b/serving/src/test/java/feast/serving/it/CoreSimpleAPIClient.java new file mode 100644 index 00000000000..7d9313150d7 --- /dev/null +++ b/serving/src/test/java/feast/serving/it/CoreSimpleAPIClient.java @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.it; + +import feast.proto.core.CoreServiceGrpc; +import feast.proto.core.CoreServiceProto; +import feast.proto.core.FeatureSetProto; + +public class CoreSimpleAPIClient { + private CoreServiceGrpc.CoreServiceBlockingStub stub; + + public CoreSimpleAPIClient(CoreServiceGrpc.CoreServiceBlockingStub stub) { + this.stub = stub; + } + + public void simpleApplyFeatureSet(FeatureSetProto.FeatureSet featureSet) { + stub.applyFeatureSet( + CoreServiceProto.ApplyFeatureSetRequest.newBuilder().setFeatureSet(featureSet).build()); + } + + public FeatureSetProto.FeatureSet simpleGetFeatureSet(String projectName, String name) { + return stub.getFeatureSet( + CoreServiceProto.GetFeatureSetRequest.newBuilder() + .setName(name) + .setProject(projectName) + .build()) + .getFeatureSet(); + } +} diff --git a/serving/src/test/java/feast/serving/it/ServingServiceOauthAuthenticationIT.java b/serving/src/test/java/feast/serving/it/ServingServiceOauthAuthenticationIT.java new file mode 100644 index 00000000000..edd16c24a87 --- /dev/null +++ b/serving/src/test/java/feast/serving/it/ServingServiceOauthAuthenticationIT.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.it; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.testcontainers.containers.wait.strategy.Wait.forHttp; + +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesResponse; +import feast.proto.serving.ServingServiceGrpc.ServingServiceBlockingStub; +import feast.proto.types.ValueProto.Value; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import org.junit.ClassRule; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.runners.model.InitializationError; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.ActiveProfiles; +import org.testcontainers.containers.DockerComposeContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +@ActiveProfiles("it") +@SpringBootTest( + properties = { + "feast.core-authentication.enabled=true", + "feast.core-authentication.provider=oauth", + "feast.security.authentication.enabled=true", + "feast.security.authorization.enabled=false" + }) +@Testcontainers +public class ServingServiceOauthAuthenticationIT extends BaseAuthIT { + + static final Map options = new HashMap<>(); + + static final int FEAST_SERVING_PORT = 6566; + + @ClassRule @Container + public static DockerComposeContainer environment = + new DockerComposeContainer( + new File("src/test/resources/docker-compose/docker-compose-it-hydra.yml"), + new File("src/test/resources/docker-compose/docker-compose-it-core.yml")) + .withExposedService(HYDRA, HYDRA_PORT, forHttp("/health/alive").forStatusCode(200)) + .withExposedService( + CORE, + 6565, + Wait.forLogMessage(".*gRPC Server started.*\\n", 1) + .withStartupTimeout(Duration.ofMinutes(CORE_START_MAX_WAIT_TIME_IN_MINUTES))); + + @BeforeAll + static void globalSetup() throws IOException, InitializationError, InterruptedException { + String hydraExternalHost = environment.getServiceHost(HYDRA, HYDRA_PORT); + Integer hydraExternalPort = environment.getServicePort(HYDRA, HYDRA_PORT); + String hydraExternalUrl = String.format("http://%s:%s", hydraExternalHost, hydraExternalPort); + AuthTestUtils.seedHydra(hydraExternalUrl, CLIENT_ID, CLIENT_SECRET, AUDIENCE, GRANT_TYPE); + + // set up options for call credentials + options.put("oauth_url", TOKEN_URL); + options.put(CLIENT_ID, CLIENT_ID); + options.put(CLIENT_SECRET, CLIENT_SECRET); + options.put("jwkEndpointURI", JWK_URI); + options.put("audience", AUDIENCE); + options.put("grant_type", GRANT_TYPE); + } + + @Test + public void shouldNotAllowUnauthenticatedGetOnlineFeatures() { + ServingServiceBlockingStub servingStub = + AuthTestUtils.getServingServiceStub(false, FEAST_SERVING_PORT, null); + GetOnlineFeaturesRequest onlineFeatureRequest = + AuthTestUtils.createOnlineFeatureRequest(PROJECT_NAME, FEATURE_NAME, ENTITY_ID, 1); + Exception exception = + assertThrows( + StatusRuntimeException.class, + () -> { + servingStub.getOnlineFeatures(onlineFeatureRequest); + }); + + String expectedMessage = "UNAUTHENTICATED: Authentication failed"; + String actualMessage = exception.getMessage(); + assertEquals(actualMessage, expectedMessage); + } + + @Test + void canGetOnlineFeaturesIfAuthenticated() { + // apply feature set + CoreSimpleAPIClient coreClient = + AuthTestUtils.getSecureApiClientForCore(FEAST_CORE_PORT, options); + AuthTestUtils.applyFeatureSet(coreClient, PROJECT_NAME, ENTITY_ID, FEATURE_NAME); + ServingServiceBlockingStub servingStub = + AuthTestUtils.getServingServiceStub(true, FEAST_SERVING_PORT, options); + GetOnlineFeaturesRequest onlineFeatureRequest = + AuthTestUtils.createOnlineFeatureRequest(PROJECT_NAME, FEATURE_NAME, ENTITY_ID, 1); + GetOnlineFeaturesResponse featureResponse = servingStub.getOnlineFeatures(onlineFeatureRequest); + assertEquals(1, featureResponse.getFieldValuesCount()); + Map fieldsMap = featureResponse.getFieldValues(0).getFieldsMap(); + assertTrue(fieldsMap.containsKey(ENTITY_ID)); + assertTrue(fieldsMap.containsKey(FEATURE_NAME)); + ((ManagedChannel) servingStub.getChannel()).shutdown(); + } +} diff --git a/serving/src/test/java/feast/serving/it/ServingServiceOauthAuthorizationIT.java b/serving/src/test/java/feast/serving/it/ServingServiceOauthAuthorizationIT.java new file mode 100644 index 00000000000..aaee2321a5f --- /dev/null +++ b/serving/src/test/java/feast/serving/it/ServingServiceOauthAuthorizationIT.java @@ -0,0 +1,212 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.it; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.testcontainers.containers.wait.strategy.Wait.forHttp; + +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesResponse; +import feast.proto.serving.ServingServiceGrpc.ServingServiceBlockingStub; +import feast.proto.types.ValueProto.Value; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import org.junit.ClassRule; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.runners.model.InitializationError; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.DynamicPropertyRegistry; +import org.springframework.test.context.DynamicPropertySource; +import org.testcontainers.containers.DockerComposeContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import sh.ory.keto.ApiException; + +@ActiveProfiles("it") +@SpringBootTest( + properties = { + "feast.core-authentication.enabled=true", + "feast.core-authentication.provider=oauth", + "feast.security.authentication.enabled=true", + "feast.security.authorization.enabled=true" + }) +@Testcontainers +public class ServingServiceOauthAuthorizationIT extends BaseAuthIT { + + static final Map adminCredentials = new HashMap<>(); + static final Map memberCredentials = new HashMap<>(); + static final String PROJECT_MEMBER_CLIENT_ID = "client_id_1"; + static final String NOT_PROJECT_MEMBER_CLIENT_ID = "client_id_2"; + private static int KETO_PORT = 4466; + private static int KETO_ADAPTOR_PORT = 8080; + static String subjectClaim = "sub"; + static CoreSimpleAPIClient coreClient; + static final int FEAST_SERVING_PORT = 6766; + + @ClassRule @Container + public static DockerComposeContainer environment = + new DockerComposeContainer( + new File("src/test/resources/docker-compose/docker-compose-it-hydra.yml"), + new File("src/test/resources/docker-compose/docker-compose-it-core.yml"), + new File("src/test/resources/docker-compose/docker-compose-it-keto.yml")) + .withExposedService(HYDRA, HYDRA_PORT, forHttp("/health/alive").forStatusCode(200)) + .withExposedService( + CORE, + 6565, + Wait.forLogMessage(".*gRPC Server started.*\\n", 1) + .withStartupTimeout(Duration.ofMinutes(CORE_START_MAX_WAIT_TIME_IN_MINUTES))) + .withExposedService("adaptor_1", KETO_ADAPTOR_PORT) + .withExposedService("keto_1", KETO_PORT, forHttp("/health/ready").forStatusCode(200));; + + @DynamicPropertySource + static void initialize(DynamicPropertyRegistry registry) { + + // Seed Keto with data + String ketoExternalHost = environment.getServiceHost("keto_1", KETO_PORT); + Integer ketoExternalPort = environment.getServicePort("keto_1", KETO_PORT); + String ketoExternalUrl = String.format("http://%s:%s", ketoExternalHost, ketoExternalPort); + try { + AuthTestUtils.seedKeto(ketoExternalUrl, PROJECT_NAME, PROJECT_MEMBER_CLIENT_ID, CLIENT_ID); + } catch (ApiException e) { + throw new RuntimeException(String.format("Could not seed Keto store %s", ketoExternalUrl)); + } + + // Get Keto Authorization Server (Adaptor) url + String ketoAdaptorHost = environment.getServiceHost("adaptor_1", KETO_ADAPTOR_PORT); + Integer ketoAdaptorPort = environment.getServicePort("adaptor_1", KETO_ADAPTOR_PORT); + String ketoAdaptorUrl = String.format("http://%s:%s", ketoAdaptorHost, ketoAdaptorPort); + + // Initialize dynamic properties + registry.add("feast.security.authorization.options.subjectClaim", () -> subjectClaim); + registry.add("feast.security.authentication.options.jwkEndpointURI", () -> JWK_URI); + registry.add("feast.security.authorization.options.authorizationUrl", () -> ketoAdaptorUrl); + registry.add("grpc.server.port", () -> FEAST_SERVING_PORT); + } + + @BeforeAll + static void globalSetup() throws IOException, InitializationError, InterruptedException { + String hydraExternalHost = environment.getServiceHost(HYDRA, HYDRA_PORT); + Integer hydraExternalPort = environment.getServicePort(HYDRA, HYDRA_PORT); + String hydraExternalUrl = String.format("http://%s:%s", hydraExternalHost, hydraExternalPort); + AuthTestUtils.seedHydra(hydraExternalUrl, CLIENT_ID, CLIENT_SECRET, AUDIENCE, GRANT_TYPE); + AuthTestUtils.seedHydra( + hydraExternalUrl, PROJECT_MEMBER_CLIENT_ID, CLIENT_SECRET, AUDIENCE, GRANT_TYPE); + AuthTestUtils.seedHydra( + hydraExternalUrl, NOT_PROJECT_MEMBER_CLIENT_ID, CLIENT_SECRET, AUDIENCE, GRANT_TYPE); + // set up options for call credentials + adminCredentials.put("oauth_url", TOKEN_URL); + adminCredentials.put(CLIENT_ID, CLIENT_ID); + adminCredentials.put(CLIENT_SECRET, CLIENT_SECRET); + adminCredentials.put("jwkEndpointURI", JWK_URI); + adminCredentials.put("audience", AUDIENCE); + adminCredentials.put("grant_type", GRANT_TYPE); + + coreClient = AuthTestUtils.getSecureApiClientForCore(FEAST_CORE_PORT, adminCredentials); + } + + @BeforeEach + public void setUp() { + // seed core + AuthTestUtils.applyFeatureSet(coreClient, PROJECT_NAME, ENTITY_ID, FEATURE_NAME); + } + + @Test + public void shouldNotAllowUnauthenticatedGetOnlineFeatures() { + ServingServiceBlockingStub servingStub = + AuthTestUtils.getServingServiceStub(false, FEAST_SERVING_PORT, null); + + GetOnlineFeaturesRequest onlineFeatureRequest = + AuthTestUtils.createOnlineFeatureRequest(PROJECT_NAME, FEATURE_NAME, ENTITY_ID, 1); + Exception exception = + assertThrows( + StatusRuntimeException.class, + () -> { + servingStub.getOnlineFeatures(onlineFeatureRequest); + }); + + String expectedMessage = "UNAUTHENTICATED: Authentication failed"; + String actualMessage = exception.getMessage(); + assertEquals(actualMessage, expectedMessage); + ((ManagedChannel) servingStub.getChannel()).shutdown(); + } + + @Test + void canGetOnlineFeaturesIfAdmin() { + // apply feature set + ServingServiceBlockingStub servingStub = + AuthTestUtils.getServingServiceStub(true, FEAST_SERVING_PORT, adminCredentials); + GetOnlineFeaturesRequest onlineFeatureRequest = + AuthTestUtils.createOnlineFeatureRequest(PROJECT_NAME, FEATURE_NAME, ENTITY_ID, 1); + GetOnlineFeaturesResponse featureResponse = servingStub.getOnlineFeatures(onlineFeatureRequest); + assertEquals(1, featureResponse.getFieldValuesCount()); + Map fieldsMap = featureResponse.getFieldValues(0).getFieldsMap(); + assertTrue(fieldsMap.containsKey(ENTITY_ID)); + assertTrue(fieldsMap.containsKey(FEATURE_NAME)); + ((ManagedChannel) servingStub.getChannel()).shutdown(); + } + + @Test + void canGetOnlineFeaturesIfProjectMember() { + Map memberCredsOptions = new HashMap<>(); + memberCredsOptions.putAll(adminCredentials); + memberCredsOptions.put(CLIENT_ID, PROJECT_MEMBER_CLIENT_ID); + ServingServiceBlockingStub servingStub = + AuthTestUtils.getServingServiceStub(true, FEAST_SERVING_PORT, memberCredsOptions); + GetOnlineFeaturesRequest onlineFeatureRequest = + AuthTestUtils.createOnlineFeatureRequest(PROJECT_NAME, FEATURE_NAME, ENTITY_ID, 1); + GetOnlineFeaturesResponse featureResponse = servingStub.getOnlineFeatures(onlineFeatureRequest); + assertEquals(1, featureResponse.getFieldValuesCount()); + Map fieldsMap = featureResponse.getFieldValues(0).getFieldsMap(); + assertTrue(fieldsMap.containsKey(ENTITY_ID)); + assertTrue(fieldsMap.containsKey(FEATURE_NAME)); + ((ManagedChannel) servingStub.getChannel()).shutdown(); + } + + @Test + void cantGetOnlineFeaturesIfNotProjectMember() { + Map notMemberCredsOptions = new HashMap<>(); + notMemberCredsOptions.putAll(adminCredentials); + notMemberCredsOptions.put(CLIENT_ID, NOT_PROJECT_MEMBER_CLIENT_ID); + ServingServiceBlockingStub servingStub = + AuthTestUtils.getServingServiceStub(true, FEAST_SERVING_PORT, notMemberCredsOptions); + GetOnlineFeaturesRequest onlineFeatureRequest = + AuthTestUtils.createOnlineFeatureRequest(PROJECT_NAME, FEATURE_NAME, ENTITY_ID, 1); + StatusRuntimeException exception = + assertThrows( + StatusRuntimeException.class, + () -> servingStub.getOnlineFeatures(onlineFeatureRequest)); + + String expectedMessage = + String.format( + "PERMISSION_DENIED: Access denied to project %s for subject %s", + PROJECT_NAME, NOT_PROJECT_MEMBER_CLIENT_ID); + String actualMessage = exception.getMessage(); + assertEquals(actualMessage, expectedMessage); + ((ManagedChannel) servingStub.getChannel()).shutdown(); + } +} diff --git a/serving/src/test/resources/application-it.properties b/serving/src/test/resources/application-it.properties new file mode 100644 index 00000000000..000e512a680 --- /dev/null +++ b/serving/src/test/resources/application-it.properties @@ -0,0 +1,18 @@ +# +# Copyright 2018 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +feast.core-authentication.enabled=false +feast.security.authentication.enabled=false +feast.security.authorization.enabled=false \ No newline at end of file diff --git a/serving/src/test/resources/docker-compose/core/application-it.yml b/serving/src/test/resources/docker-compose/core/application-it.yml new file mode 100644 index 00000000000..35f2ee54631 --- /dev/null +++ b/serving/src/test/resources/docker-compose/core/application-it.yml @@ -0,0 +1,21 @@ +feast: + jobs: + polling_interval_milliseconds: 30000 + job_update_timeout_seconds: 240 + active_runner: direct + runners: + - name: direct + type: DirectRunner + options: {} + stream: + type: kafka + options: + topic: feast-features + bootstrapServers: "kafka:9092,localhost:9094" + + security: + authentication: + enabled: true + provider: jwt + options: + jwkEndpointURI: http://hydra:4444/.well-known/jwks.json \ No newline at end of file diff --git a/serving/src/test/resources/docker-compose/docker-compose-it-core.yml b/serving/src/test/resources/docker-compose/docker-compose-it-core.yml new file mode 100644 index 00000000000..bb7cdce8abb --- /dev/null +++ b/serving/src/test/resources/docker-compose/docker-compose-it-core.yml @@ -0,0 +1,53 @@ +version: '3' + +services: + core: + image: gcr.io/kf-feast/feast-core:latest + volumes: + - ./core/application-it.yml:/etc/feast/application.yml + environment: + DB_HOST: db + restart: on-failure + depends_on: + - db + - kafka + ports: + - 6565:6565 + command: + - java + - -jar + - /opt/feast/feast-core.jar + - --spring.config.location=classpath:/application.yml,file:/etc/feast/application.yml + + kafka: + image: confluentinc/cp-kafka:5.2.1 + environment: + KAFKA_ZOOKEEPER_CONNECT: zookeeper:2181 + KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 + KAFKA_ADVERTISED_LISTENERS: INSIDE://kafka:9092,OUTSIDE://localhost:9094 + KAFKA_LISTENERS: INSIDE://:9092,OUTSIDE://:9094 + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: INSIDE:PLAINTEXT,OUTSIDE:PLAINTEXT + KAFKA_INTER_BROKER_LISTENER_NAME: INSIDE + ports: + - "9092:9092" + - "9094:9094" + + depends_on: + - zookeeper + + zookeeper: + image: confluentinc/cp-zookeeper:5.2.1 + environment: + ZOOKEEPER_CLIENT_PORT: 2181 + + db: + image: postgres:12-alpine + environment: + POSTGRES_PASSWORD: password + ports: + - "5432:5432" + + redis: + image: redis:5-alpine + ports: + - "6379:6379" \ No newline at end of file diff --git a/serving/src/test/resources/docker-compose/docker-compose-it-hydra.yml b/serving/src/test/resources/docker-compose/docker-compose-it-hydra.yml new file mode 100644 index 00000000000..1c20610cc73 --- /dev/null +++ b/serving/src/test/resources/docker-compose/docker-compose-it-hydra.yml @@ -0,0 +1,54 @@ +version: '3' + +services: + hydra-migrate: + image: oryd/hydra:v1.6.0 + environment: + - DSN=postgres://hydra:secret@postgresd:5432/hydra?sslmode=disable&max_conns=20&max_idle_conns=4 + command: + migrate sql -e --yes + restart: on-failure + + hydra: + depends_on: + - hydra-migrate + environment: + - DSN=postgres://hydra:secret@postgresd:5432/hydra?sslmode=disable&max_conns=20&max_idle_conns=4 + + postgresd: + image: postgres:9.6 + ports: + - "54320:5432" + environment: + - POSTGRES_USER=hydra + - POSTGRES_PASSWORD=secret + - POSTGRES_DB=hydra + + hydra: + image: oryd/hydra:v1.6.0 + ports: + - "4444:4444" # Public port + - "4445:4445" # Admin port + #- "5555:5555" # Port for hydra token user + command: + serve all --dangerous-force-http + environment: + - URLS_SELF_ISSUER=http://hydra:4444 + - URLS_CONSENT=http://hydra:3000/consent + - URLS_LOGIN=http://hydra:3000/login + - URLS_LOGOUT=http://hydra:3000/logout + - DSN=memory + - SECRETS_SYSTEM=youReallyNeedToChangeThis + - OIDC_SUBJECT_IDENTIFIERS_SUPPORTED_TYPES=public,pairwise + - OIDC_SUBJECT_IDENTIFIERS_PAIRWISE_SALT=youReallyNeedToChangeThis + - OAUTH2_ACCESS_TOKEN_STRATEGY=jwt + - OIDC_SUBJECT_IDENTIFIERS_SUPPORTED_TYPES=public + restart: unless-stopped + + consent: + environment: + - HYDRA_ADMIN_URL=http://hydra:4445 + image: oryd/hydra-login-consent-node:v1.5.2 + ports: + - "3000:3000" + restart: unless-stopped diff --git a/serving/src/test/resources/docker-compose/docker-compose-it-keto.yml b/serving/src/test/resources/docker-compose/docker-compose-it-keto.yml new file mode 100644 index 00000000000..8ebf7f225e0 --- /dev/null +++ b/serving/src/test/resources/docker-compose/docker-compose-it-keto.yml @@ -0,0 +1,44 @@ +version: '3' +services: + keto: + depends_on: + - ketodb + - migrations + image: oryd/keto:v0.4.3-alpha.2 + environment: + - DSN=postgres://keto:keto@ketodb:5432/keto?sslmode=disable + command: + - serve + ports: + - 4466 + + ketodb: + image: bitnami/postgresql:9.6 + environment: + - POSTGRESQL_USERNAME=keto + - POSTGRESQL_PASSWORD=keto + - POSTGRESQL_DATABASE=keto + ports: + - "54340:5432" + + migrations: + depends_on: + - ketodb + image: oryd/keto:v0.4.3-alpha.2 + environment: + - DSN=postgres://keto:keto@ketodb:5432/keto?sslmode=disable + command: + - migrate + - sql + - -e + + adaptor: + depends_on: + - keto + image: gcr.io/kf-feast/feast-keto-auth-server:latest + environment: + SERVER_PORT: 8080 + KETO_URL: http://keto:4466 + ports: + - 8080 + restart: on-failure \ No newline at end of file diff --git a/storage/api/src/main/java/feast/storage/common/testing/TestUtil.java b/storage/api/src/main/java/feast/storage/common/testing/TestUtil.java index 43a96e97efa..773abd57d61 100644 --- a/storage/api/src/main/java/feast/storage/common/testing/TestUtil.java +++ b/storage/api/src/main/java/feast/storage/common/testing/TestUtil.java @@ -16,14 +16,17 @@ */ package feast.storage.common.testing; +import com.google.common.hash.Hashing; import com.google.protobuf.ByteString; -import com.google.protobuf.util.Timestamps; +import com.google.protobuf.Timestamp; import feast.proto.core.FeatureSetProto.FeatureSet; import feast.proto.core.FeatureSetProto.FeatureSetSpec; import feast.proto.types.FeatureRowProto.FeatureRow; import feast.proto.types.FeatureRowProto.FeatureRow.Builder; import feast.proto.types.FieldProto.Field; import feast.proto.types.ValueProto.*; +import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.concurrent.ThreadLocalRandom; import org.apache.commons.lang3.RandomStringUtils; @@ -53,10 +56,15 @@ public static FeatureRow createRandomFeatureRow(FeatureSet featureSet) { * @return {@link FeatureRow} */ public static FeatureRow createRandomFeatureRow(FeatureSet featureSet, int randomStringSize) { + + Instant time = Instant.now(); + Timestamp timestamp = + Timestamp.newBuilder().setSeconds(time.getEpochSecond()).setNanos(time.getNano()).build(); + Builder builder = FeatureRow.newBuilder() .setFeatureSet(getFeatureSetReference(featureSet)) - .setEventTimestamp(Timestamps.fromMillis(System.currentTimeMillis())); + .setEventTimestamp(timestamp); featureSet .getSpec() @@ -185,4 +193,8 @@ public static Field field(String name, Object value, ValueType.Enum valueType) { throw new IllegalStateException("Unexpected valueType: " + value.getClass()); } } + + public static String hash(String input) { + return Hashing.murmur3_32().hashString(input, StandardCharsets.UTF_8).toString(); + } } diff --git a/storage/connectors/bigquery/pom.xml b/storage/connectors/bigquery/pom.xml index 32c6dda4810..1b97d57b2cb 100644 --- a/storage/connectors/bigquery/pom.xml +++ b/storage/connectors/bigquery/pom.xml @@ -96,5 +96,11 @@ hamcrest-library test + + org.mockito + mockito-core + ${mockito.version} + test + diff --git a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/compression/FeatureRowsBatch.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/compression/FeatureRowsBatch.java index 1befae221b7..f9f74c15bf3 100644 --- a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/compression/FeatureRowsBatch.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/compression/FeatureRowsBatch.java @@ -200,7 +200,7 @@ public FeatureRowsBatch withFeatureSetReference(String featureSetReference) { } public Row toRow() { - return Row.withSchema(schema).attachValues(values).build(); + return Row.withSchema(schema).attachValues(values); } public static FeatureRowsBatch fromRow(Row row) { diff --git a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryFeatureSink.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryFeatureSink.java index ed5a7b020ee..2d55b308dd0 100644 --- a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryFeatureSink.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryFeatureSink.java @@ -16,6 +16,8 @@ */ package feast.storage.connectors.bigquery.writer; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.value.AutoValue; import com.google.cloud.bigquery.*; @@ -62,6 +64,12 @@ public abstract class BigQueryFeatureSink implements FeatureSink { * @return {@link BigQueryFeatureSink.Builder} */ public static FeatureSink fromConfig(BigQueryConfig config) { + checkArgument( + config.getWriteTriggeringFrequencySeconds() > 0, + "Invalid configuration: " + + "write_triggering_frequency_seconds in BigQueryConfig must be positive integer. " + + "Please fix that in your serving configuration."); + return BigQueryFeatureSink.builder() .setDatasetId(config.getDatasetId()) .setProjectId(config.getProjectId()) diff --git a/storage/connectors/bigquery/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoadsWithResult.java b/storage/connectors/bigquery/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoadsWithResult.java index f6be75fe13e..1cfb4087460 100644 --- a/storage/connectors/bigquery/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoadsWithResult.java +++ b/storage/connectors/bigquery/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoadsWithResult.java @@ -1,7 +1,7 @@ package org.apache.beam.sdk.io.gcp.bigquery; +import static com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.resolveTempLocation; -import static org.apache.beam.vendor.grpc.v1p21p0.com.google.common.base.Preconditions.checkArgument; import com.google.api.services.bigquery.model.TableRow; import com.google.auto.value.AutoValue; @@ -306,6 +306,8 @@ PCollection> writeSinglePartitionWithResult( getIgnoreUnknownValues(), getKmsKey(), getRowWriterFactory().getSourceFormat(), + true, getSchemaUpdateOptions())); } + } diff --git a/storage/connectors/redis/pom.xml b/storage/connectors/redis/pom.xml index 3aa863f6811..ca6e8d42ad6 100644 --- a/storage/connectors/redis/pom.xml +++ b/storage/connectors/redis/pom.xml @@ -45,7 +45,7 @@ org.mockito mockito-core - 2.23.0 + ${mockito.version} test @@ -89,6 +89,17 @@ 4.12 test + + org.apache.beam + beam-sdks-java-extensions-protobuf + ${org.apache.beam.version} + test + + + org.slf4j + slf4j-simple + 1.7.30 + test + - diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/FeatureRowDecoder.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/FeatureRowDecoder.java index aad3147f710..d89e5373669 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/FeatureRowDecoder.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/FeatureRowDecoder.java @@ -16,12 +16,17 @@ */ package feast.storage.connectors.redis.retriever; +import com.google.common.hash.Hashing; import feast.proto.core.FeatureSetProto.FeatureSetSpec; import feast.proto.core.FeatureSetProto.FeatureSpec; import feast.proto.types.FeatureRowProto.FeatureRow; import feast.proto.types.FieldProto.Field; +import feast.proto.types.ValueProto.Value; +import feast.storage.connectors.redis.writer.RedisCustomIO; +import java.nio.charset.StandardCharsets; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -36,60 +41,102 @@ public FeatureRowDecoder(String featureSetRef, FeatureSetSpec spec) { } /** - * A feature row is considered encoded if the feature set and field names are not set. This method - * is required for backward compatibility purposes, to allow Feast serving to continue serving non - * encoded Feature Row ingested by an older version of Feast. + * Check if encoded feature row can be decoded by v1 Decoder. The v1 Decoder requires that the + * Feature Row to have both it's feature set reference and fields names are not set. The no. of + * fields in the feature row should also match up with the number of fields in the Feature Set + * spec. NOTE: This method is deprecated and will be removed in Feast v0.7. * * @param featureRow Feature row * @return boolean */ - public boolean isEncoded(FeatureRow featureRow) { + @Deprecated + private boolean isEncodedV1(FeatureRow featureRow) { return featureRow.getFeatureSet().isEmpty() - && featureRow.getFieldsList().stream().allMatch(field -> field.getName().isEmpty()); + && featureRow.getFieldsList().stream().allMatch(field -> field.getName().isEmpty()) + && featureRow.getFieldsList().size() == spec.getFeaturesList().size(); } /** - * Validates if an encoded feature row can be decoded without exception. + * Check if encoded feature row can be decoded by Decoder. The v2 Decoder requires that a Feature + * Row to have both it feature set reference and fields names are set. * * @param featureRow Feature row * @return boolean */ - public boolean isEncodingValid(FeatureRow featureRow) { - return featureRow.getFieldsList().size() == spec.getFeaturesList().size(); + private boolean isEncodedV2(FeatureRow featureRow) { + return !featureRow.getFieldsList().stream().anyMatch(field -> field.getName().isEmpty()); } /** - * Decoding feature row by repopulating the field names based on the corresponding feature set - * spec. + * Decode feature row encoded by {@link RedisCustomIO}. NOTE: The v1 Decoder will be removed in + * Feast 0.7 * + * @throws IllegalArgumentException if unable to the decode the given feature row * @param encodedFeatureRow Feature row * @return boolean */ public FeatureRow decode(FeatureRow encodedFeatureRow) { - final List fieldsWithoutName = encodedFeatureRow.getFieldsList(); + if (isEncodedV1(encodedFeatureRow)) { + // TODO: remove v1 feature row decoder in Feast 0.7 + // Decode Feature Rows using the v1 Decoder. + final List fieldsWithoutName = encodedFeatureRow.getFieldsList(); + List featureNames = + spec.getFeaturesList().stream() + .sorted(Comparator.comparing(FeatureSpec::getName)) + .map(FeatureSpec::getName) + .collect(Collectors.toList()); - List featureNames = - spec.getFeaturesList().stream() - .sorted(Comparator.comparing(FeatureSpec::getName)) - .map(FeatureSpec::getName) - .collect(Collectors.toList()); - List fields = - IntStream.range(0, featureNames.size()) - .mapToObj( - featureNameIndex -> { - String featureName = featureNames.get(featureNameIndex); - return fieldsWithoutName - .get(featureNameIndex) - .toBuilder() - .setName(featureName) - .build(); - }) - .collect(Collectors.toList()); - return encodedFeatureRow - .toBuilder() - .clearFields() - .setFeatureSet(featureSetRef) - .addAllFields(fields) - .build(); + List fields = + IntStream.range(0, featureNames.size()) + .mapToObj( + featureNameIndex -> { + String featureName = featureNames.get(featureNameIndex); + return fieldsWithoutName + .get(featureNameIndex) + .toBuilder() + .setName(featureName) + .build(); + }) + .collect(Collectors.toList()); + + return encodedFeatureRow + .toBuilder() + .clearFields() + .setFeatureSet(featureSetRef) + .addAllFields(fields) + .build(); + } + if (isEncodedV2(encodedFeatureRow)) { + // Decode Feature Rows using the v2 Decoder. + // v2 Decoder input Feature Rows should use a hashed name as the field name and + // should not have feature set reference set. + // Decoding reverts the field name to a unhashed string and set feature set reference. + Map nameHashValueMap = + encodedFeatureRow.getFieldsList().stream() + .collect(Collectors.toMap(field -> field.getName(), field -> field.getValue())); + + List featureNames = + spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList()); + + List fields = + featureNames.stream() + .map( + name -> { + String nameHash = + Hashing.murmur3_32().hashString(name, StandardCharsets.UTF_8).toString(); + Value value = + nameHashValueMap.getOrDefault(nameHash, Value.newBuilder().build()); + return Field.newBuilder().setName(name).setValue(value).build(); + }) + .collect(Collectors.toList()); + + return encodedFeatureRow + .toBuilder() + .clearFields() + .setFeatureSet(featureSetRef) + .addAllFields(fields) + .build(); + } + throw new IllegalArgumentException("Failed to decode FeatureRow row: Possible data corruption"); } } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClusterOnlineRetriever.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClusterOnlineRetriever.java index c006149cd51..2146ec2f87b 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClusterOnlineRetriever.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClusterOnlineRetriever.java @@ -158,17 +158,11 @@ private List> getFeaturesFromRedis( // decode feature rows from data bytes using decoder. FeatureRow featureRow = FeatureRow.parseFrom(featureRowBytes); - if (decoder.isEncoded(featureRow)) { - if (decoder.isEncodingValid(featureRow)) { - featureRow = decoder.decode(featureRow); - } else { - // decoding feature row failed: data corruption could have occurred - throw Status.DATA_LOSS - .withDescription( - "Failed to decode FeatureRow from bytes retrieved from redis" - + ": Possible data corruption") - .asRuntimeException(); - } + try { + featureRow = decoder.decode(featureRow); + } catch (IllegalArgumentException e) { + // decoding feature row failed: data corruption could have occurred + throw Status.DATA_LOSS.withCause(e).withDescription(e.getMessage()).asRuntimeException(); } featureRows.add(Optional.of(featureRow)); } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java index ef80b06799b..049175879de 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java @@ -151,15 +151,11 @@ private List> getFeaturesFromRedis( // decode feature rows from data bytes using decoder. FeatureRow featureRow = FeatureRow.parseFrom(featureRowBytes); - if (decoder.isEncoded(featureRow) && decoder.isEncodingValid(featureRow)) { + try { featureRow = decoder.decode(featureRow); - } else { + } catch (IllegalArgumentException e) { // decoding feature row failed: data corruption could have occurred - throw Status.DATA_LOSS - .withDescription( - "Failed to decode FeatureRow from bytes retrieved from redis" - + ": Possible data corruption") - .asRuntimeException(); + throw Status.DATA_LOSS.withCause(e).withDescription(e.getMessage()).asRuntimeException(); } featureRows.add(Optional.of(featureRow)); } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java new file mode 100644 index 00000000000..d6c83c3a540 --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.writer; + +import feast.storage.common.retry.Retriable; +import io.lettuce.core.RedisException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.function.Function; +import org.apache.beam.sdk.transforms.DoFn; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class for redis-related DoFns. Assumes that operations will be batched. Prepares redisClient + * on DoFn.Setup stage and close it on DoFn.Teardown stage. + * + * @param + * @param + */ +public class BatchDoFnWithRedis extends DoFn { + private static final Logger log = LoggerFactory.getLogger(BatchDoFnWithRedis.class); + + private final RedisIngestionClient redisIngestionClient; + + BatchDoFnWithRedis(RedisIngestionClient redisIngestionClient) { + this.redisIngestionClient = redisIngestionClient; + } + + @Setup + public void setup() { + this.redisIngestionClient.setup(); + } + + @StartBundle + public void startBundle() { + try { + redisIngestionClient.connect(); + } catch (RedisException e) { + log.error("Connection to redis cannot be established: %s", e); + } + } + + void executeBatch(Function>> executor) + throws Exception { + this.redisIngestionClient + .getBackOffExecutor() + .execute( + new Retriable() { + @Override + public void execute() throws ExecutionException, InterruptedException { + if (!redisIngestionClient.isConnected()) { + redisIngestionClient.connect(); + } + + Iterable> futures = executor.apply(redisIngestionClient); + redisIngestionClient.sync(futures); + } + + @Override + public Boolean isExceptionRetriable(Exception e) { + return e instanceof RedisException; + } + + @Override + public void cleanUpAfterFailure() {} + }); + } + + @Teardown + public void teardown() { + redisIngestionClient.shutdown(); + } +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java index 389db4be3ad..f36d70563e1 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java @@ -20,7 +20,6 @@ import feast.proto.core.StoreProto; import feast.storage.common.retry.BackOffExecutor; import io.lettuce.core.LettuceFutures; -import io.lettuce.core.RedisFuture; import io.lettuce.core.RedisURI; import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; @@ -28,6 +27,8 @@ import io.lettuce.core.codec.ByteArrayCodec; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.joda.time.Duration; @@ -39,7 +40,6 @@ public class RedisClusterIngestionClient implements RedisIngestionClient { private transient RedisClusterClient clusterClient; private StatefulRedisClusterConnection connection; private RedisAdvancedClusterAsyncCommands commands; - private List futures = Lists.newArrayList(); public RedisClusterIngestionClient(StoreProto.Store.RedisClusterConfig redisClusterConfig) { this.uriList = @@ -55,7 +55,6 @@ public RedisClusterIngestionClient(StoreProto.Store.RedisClusterConfig redisClus redisClusterConfig.getInitialBackoffMs() > 0 ? redisClusterConfig.getInitialBackoffMs() : 1; this.backOffExecutor = new BackOffExecutor(redisClusterConfig.getMaxRetries(), Duration.millis(backoffMs)); - this.clusterClient = RedisClusterClient.create(uriList); } @Override @@ -78,6 +77,10 @@ public void connect() { if (!isConnected()) { this.connection = clusterClient.connect(new ByteArrayCodec()); this.commands = connection.async(); + + // despite we're using async API client still flushes after each command by default + // which we don't want since we produce all commands in batches + this.commands.setAutoFlushCommands(false); } } @@ -87,46 +90,20 @@ public boolean isConnected() { } @Override - public void sync() { - try { - LettuceFutures.awaitAll(60, TimeUnit.SECONDS, futures.toArray(new RedisFuture[0])); - } finally { - futures.clear(); - } - } - - @Override - public void pexpire(byte[] key, Long expiryMillis) { - futures.add(commands.pexpire(key, expiryMillis)); - } - - @Override - public void append(byte[] key, byte[] value) { - futures.add(commands.append(key, value)); - } - - @Override - public void set(byte[] key, byte[] value) { - futures.add(commands.set(key, value)); - } + public void sync(Iterable> futures) { + this.connection.flushCommands(); - @Override - public void lpush(byte[] key, byte[] value) { - futures.add(commands.lpush(key, value)); - } - - @Override - public void rpush(byte[] key, byte[] value) { - futures.add(commands.rpush(key, value)); + LettuceFutures.awaitAll( + 60, TimeUnit.SECONDS, Lists.newArrayList(futures).toArray(new Future[0])); } @Override - public void sadd(byte[] key, byte[] value) { - futures.add(commands.sadd(key, value)); + public CompletableFuture set(byte[] key, byte[] value) { + return commands.set(key, value).toCompletableFuture(); } @Override - public void zadd(byte[] key, Long score, byte[] value) { - futures.add(commands.zadd(key, score, value)); + public CompletableFuture get(byte[] key) { + return commands.get(key).toCompletableFuture(); } } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java index dcd2e5bfda1..c42cff7bd0f 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java @@ -17,7 +17,9 @@ package feast.storage.connectors.redis.writer; import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; +import com.google.common.collect.Streams; +import com.google.common.hash.Hashing; +import com.google.protobuf.InvalidProtocolBufferException; import feast.proto.core.FeatureSetProto.EntitySpec; import feast.proto.core.FeatureSetProto.FeatureSetSpec; import feast.proto.core.FeatureSetProto.FeatureSpec; @@ -28,18 +30,20 @@ import feast.proto.types.ValueProto; import feast.storage.api.writer.FailedElement; import feast.storage.api.writer.WriteResult; -import feast.storage.common.retry.Retriable; -import io.lettuce.core.RedisException; +import feast.storage.connectors.redis.retriever.FeatureRowDecoder; +import java.nio.charset.StandardCharsets; +import java.util.*; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutionException; +import java.util.function.BinaryOperator; import java.util.stream.Collectors; import org.apache.beam.sdk.transforms.*; import org.apache.beam.sdk.transforms.windowing.*; import org.apache.beam.sdk.values.*; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.commons.lang3.tuple.ImmutablePair; +import org.joda.time.DateTime; import org.joda.time.Duration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -113,63 +117,23 @@ public void process(ProcessContext c) { redisWrite.get(failedInsertsTupleTag)); } - public static class WriteDoFn extends DoFn, FeatureRow> { - private PCollectionView>> featureSetSpecsView; - private RedisIngestionClient redisIngestionClient; + /** + * Writes batch of {@link FeatureRow} to Redis. Only latest values should be written. In order + * to guarantee that we first fetch all existing values (first batch operation), compare with + * current batch by eventTimestamp, and send to redis values (second batch operation) that were + * confirmed to be most recent. + */ + public static class WriteDoFn extends BatchDoFnWithRedis, FeatureRow> { + private final PCollectionView>> featureSetSpecsView; WriteDoFn( RedisIngestionClient redisIngestionClient, PCollectionView>> featureSetSpecsView) { - this.redisIngestionClient = redisIngestionClient; + super(redisIngestionClient); this.featureSetSpecsView = featureSetSpecsView; } - @Setup - public void setup() { - this.redisIngestionClient.setup(); - } - - @StartBundle - public void startBundle() { - try { - redisIngestionClient.connect(); - } catch (RedisException e) { - log.error("Connection to redis cannot be established ", e); - } - } - - private void executeBatch( - Iterable featureRows, Map featureSetSpecs) - throws Exception { - this.redisIngestionClient - .getBackOffExecutor() - .execute( - new Retriable() { - @Override - public void execute() throws ExecutionException, InterruptedException { - if (!redisIngestionClient.isConnected()) { - redisIngestionClient.connect(); - } - featureRows.forEach( - row -> { - redisIngestionClient.set( - getKey(row, featureSetSpecs.get(row.getFeatureSet())), - getValue(row, featureSetSpecs.get(row.getFeatureSet()))); - }); - redisIngestionClient.sync(); - } - - @Override - public Boolean isExceptionRetriable(Exception e) { - return e instanceof RedisException; - } - - @Override - public void cleanUpAfterFailure() {} - }); - } - private FailedElement toFailedElement( FeatureRow featureRow, Exception exception, String jobName) { return FailedElement.newBuilder() @@ -181,7 +145,7 @@ private FailedElement toFailedElement( .build(); } - private byte[] getKey(FeatureRow featureRow, FeatureSetSpec spec) { + private RedisKey getKey(FeatureRow featureRow, FeatureSetSpec spec) { List entityNames = spec.getEntitiesList().stream() .map(EntitySpec::getName) @@ -200,65 +164,148 @@ private byte[] getKey(FeatureRow featureRow, FeatureSetSpec spec) { for (String entityName : entityNames) { redisKeyBuilder.addEntities(entityFields.get(entityName)); } - return redisKeyBuilder.build().toByteArray(); + return redisKeyBuilder.build(); } - private byte[] getValue(FeatureRow featureRow, FeatureSetSpec spec) { + /** + * Encode the Feature Row as bytes to store in Redis in encoded Feature Row encoding. To + * reduce storage space consumption in redis, feature rows are "encoded" by hashing the fields + * names and not unsetting the feature set reference. {@link FeatureRowDecoder} is + * rensponsible for reversing this "encoding" step. + */ + private FeatureRow getValue(FeatureRow featureRow, FeatureSetSpec spec) { List featureNames = spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList()); - Map fieldValueOnlyMap = + + Map fieldValueOnlyMap = featureRow.getFieldsList().stream() .filter(field -> featureNames.contains(field.getName())) .distinct() .collect( Collectors.toMap( - Field::getName, - field -> Field.newBuilder().setValue(field.getValue()).build())); + Field::getName, field -> Field.newBuilder().setValue(field.getValue()))); List values = featureNames.stream() .sorted() .map( - featureName -> - fieldValueOnlyMap.getOrDefault( - featureName, - Field.newBuilder() - .setValue(ValueProto.Value.getDefaultInstance()) - .build())) + featureName -> { + Field.Builder field = + fieldValueOnlyMap.getOrDefault( + featureName, + Field.newBuilder().setValue(ValueProto.Value.getDefaultInstance())); + + // Encode the name of the as the hash of the field name. + // Use hash of name instead of the name of to reduce redis storage consumption + // per feature row stored. + String nameHash = + Hashing.murmur3_32() + .hashString(featureName, StandardCharsets.UTF_8) + .toString(); + field.setName(nameHash); + + return field.build(); + }) .collect(Collectors.toList()); return FeatureRow.newBuilder() .setEventTimestamp(featureRow.getEventTimestamp()) .addAllFields(values) - .build() - .toByteArray(); + .build(); } @ProcessElement public void processElement(ProcessContext context) { - List featureRows = Lists.newArrayList(context.element().iterator()); - + List filteredFeatureRows = Collections.synchronizedList(new ArrayList<>()); Map latestSpecs = - context.sideInput(featureSetSpecsView).entrySet().stream() - .map(e -> ImmutablePair.of(e.getKey(), Iterators.getLast(e.getValue().iterator()))) - .collect(Collectors.toMap(ImmutablePair::getLeft, ImmutablePair::getRight)); + getLatestSpecs(context.sideInput(featureSetSpecsView)); + + Map deduplicatedRows = + deduplicateRows(context.element(), latestSpecs); try { - executeBatch(featureRows, latestSpecs); - featureRows.forEach(row -> context.output(successfulInsertsTag, row)); + executeBatch( + (redisIngestionClient) -> + deduplicatedRows.entrySet().stream() + .map( + entry -> + redisIngestionClient + .get(entry.getKey().toByteArray()) + .thenAccept( + currentValue -> { + FeatureRow newRow = entry.getValue(); + if (rowShouldBeWritten(newRow, currentValue)) { + filteredFeatureRows.add(newRow); + } + })) + .collect(Collectors.toList())); + + executeBatch( + redisIngestionClient -> + filteredFeatureRows.stream() + .map( + row -> + redisIngestionClient.set( + getKey(row, latestSpecs.get(row.getFeatureSet())).toByteArray(), + getValue(row, latestSpecs.get(row.getFeatureSet())) + .toByteArray())) + .collect(Collectors.toList())); + + filteredFeatureRows.forEach(row -> context.output(successfulInsertsTag, row)); } catch (Exception e) { - featureRows.forEach( - failedMutation -> { - FailedElement failedElement = - toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName()); - context.output(failedInsertsTupleTag, failedElement); - }); + deduplicatedRows + .values() + .forEach( + failedMutation -> { + FailedElement failedElement = + toFailedElement( + failedMutation, e, context.getPipelineOptions().getJobName()); + context.output(failedInsertsTupleTag, failedElement); + }); } } - @Teardown - public void teardown() { - redisIngestionClient.shutdown(); + boolean rowShouldBeWritten(FeatureRow newRow, byte[] currentValue) { + if (currentValue == null) { + // nothing to compare with + return true; + } + FeatureRow currentRow; + try { + currentRow = FeatureRow.parseFrom(currentValue); + } catch (InvalidProtocolBufferException e) { + // definitely need to replace current value + return true; + } + + // check whether new row has later eventTimestamp + return new DateTime(currentRow.getEventTimestamp().getSeconds() * 1000L) + .isBefore(new DateTime(newRow.getEventTimestamp().getSeconds() * 1000L)); + } + + /** Deduplicate rows by key within batch. Keep only latest eventTimestamp */ + Map deduplicateRows( + Iterable rows, Map latestSpecs) { + Comparator byEventTimestamp = + Comparator.comparing(r -> r.getEventTimestamp().getSeconds()); + + FeatureRow identity = + FeatureRow.newBuilder() + .setEventTimestamp( + com.google.protobuf.Timestamp.newBuilder().setSeconds(-1).build()) + .build(); + + return Streams.stream(rows) + .collect( + Collectors.groupingBy( + row -> getKey(row, latestSpecs.get(row.getFeatureSet())), + Collectors.reducing(identity, BinaryOperator.maxBy(byEventTimestamp)))); + } + + Map getLatestSpecs(Map> specs) { + return specs.entrySet().stream() + .map(e -> ImmutablePair.of(e.getKey(), Iterators.getLast(e.getValue().iterator()))) + .collect(Collectors.toMap(ImmutablePair::getLeft, ImmutablePair::getRight)); } } } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java index 6616a79aaca..e9b1a5dc445 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java @@ -18,6 +18,8 @@ import feast.storage.common.retry.BackOffExecutor; import java.io.Serializable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; public interface RedisIngestionClient extends Serializable { @@ -31,19 +33,9 @@ public interface RedisIngestionClient extends Serializable { boolean isConnected(); - void sync(); + void sync(Iterable> futures); - void pexpire(byte[] key, Long expiryMillis); + CompletableFuture set(byte[] key, byte[] value); - void append(byte[] key, byte[] value); - - void set(byte[] key, byte[] value); - - void lpush(byte[] key, byte[] value); - - void rpush(byte[] key, byte[] value); - - void sadd(byte[] key, byte[] value); - - void zadd(byte[] key, Long score, byte[] value); + CompletableFuture get(byte[] key); } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java index 24591a1dc0f..f0a2054b9bd 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java @@ -21,12 +21,12 @@ import feast.storage.common.retry.BackOffExecutor; import io.lettuce.core.LettuceFutures; import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisFuture; import io.lettuce.core.RedisURI; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.codec.ByteArrayCodec; -import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import org.joda.time.Duration; @@ -38,7 +38,6 @@ public class RedisStandaloneIngestionClient implements RedisIngestionClient { private static final int DEFAULT_TIMEOUT = 2000; private StatefulRedisConnection connection; private RedisAsyncCommands commands; - private List futures = Lists.newArrayList(); public RedisStandaloneIngestionClient(StoreProto.Store.RedisConfig redisConfig) { this.host = redisConfig.getHost(); @@ -69,6 +68,9 @@ public void connect() { if (!isConnected()) { this.connection = this.redisclient.connect(new ByteArrayCodec()); this.commands = connection.async(); + + // enable pipelining of commands + this.commands.setAutoFlushCommands(false); } } @@ -78,48 +80,20 @@ public boolean isConnected() { } @Override - public void sync() { - // Wait for some time for futures to complete - // TODO: should this be configurable? - try { - LettuceFutures.awaitAll(60, TimeUnit.SECONDS, futures.toArray(new RedisFuture[0])); - } finally { - futures.clear(); - } - } - - @Override - public void pexpire(byte[] key, Long expiryMillis) { - commands.pexpire(key, expiryMillis); - } - - @Override - public void append(byte[] key, byte[] value) { - futures.add(commands.append(key, value)); - } - - @Override - public void set(byte[] key, byte[] value) { - futures.add(commands.set(key, value)); - } + public void sync(Iterable> futures) { + this.connection.flushCommands(); - @Override - public void lpush(byte[] key, byte[] value) { - futures.add(commands.lpush(key, value)); - } - - @Override - public void rpush(byte[] key, byte[] value) { - futures.add(commands.rpush(key, value)); + LettuceFutures.awaitAll( + 60, TimeUnit.SECONDS, Lists.newArrayList(futures).toArray(new Future[0])); } @Override - public void sadd(byte[] key, byte[] value) { - futures.add(commands.sadd(key, value)); + public CompletableFuture set(byte[] key, byte[] value) { + return commands.set(key, value).toCompletableFuture(); } @Override - public void zadd(byte[] key, Long score, byte[] value) { - futures.add(commands.zadd(key, score, value)); + public CompletableFuture get(byte[] key) { + return commands.get(key).toCompletableFuture(); } } diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/FeatureRowDecoderTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/FeatureRowDecoderTest.java index 63ad7aa26de..c843d311274 100644 --- a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/FeatureRowDecoderTest.java +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/FeatureRowDecoderTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.*; +import com.google.common.hash.Hashing; import com.google.protobuf.Timestamp; import feast.proto.core.FeatureSetProto; import feast.proto.core.FeatureSetProto.FeatureSetSpec; @@ -25,6 +26,7 @@ import feast.proto.types.FieldProto.Field; import feast.proto.types.ValueProto.Value; import feast.proto.types.ValueProto.ValueType; +import java.nio.charset.StandardCharsets; import java.util.Collections; import org.junit.Test; @@ -48,10 +50,29 @@ public class FeatureRowDecoderTest { .build(); @Test - public void featureRowWithFieldNamesIsNotConsideredAsEncoded() { - + public void shouldDecodeValidEncodedFeatureRowV2() { FeatureRowDecoder decoder = new FeatureRowDecoder("feature_set_ref", spec); - FeatureRowProto.FeatureRow nonEncodedFeatureRow = + + FeatureRowProto.FeatureRow encodedFeatureRow = + FeatureRowProto.FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setNanos(1000)) + .addFields( + Field.newBuilder() + .setName( + Hashing.murmur3_32() + .hashString("feature1", StandardCharsets.UTF_8) + .toString()) + .setValue(Value.newBuilder().setInt32Val(2))) + .addFields( + Field.newBuilder() + .setName( + Hashing.murmur3_32() + .hashString("feature2", StandardCharsets.UTF_8) + .toString()) + .setValue(Value.newBuilder().setFloatVal(1.0f))) + .build(); + + FeatureRowProto.FeatureRow expectedFeatureRow = FeatureRowProto.FeatureRow.newBuilder() .setFeatureSet("feature_set_ref") .setEventTimestamp(Timestamp.newBuilder().setNanos(1000)) @@ -62,26 +83,88 @@ public void featureRowWithFieldNamesIsNotConsideredAsEncoded() { .setName("feature2") .setValue(Value.newBuilder().setFloatVal(1.0f))) .build(); - assertFalse(decoder.isEncoded(nonEncodedFeatureRow)); + + assertEquals(expectedFeatureRow, decoder.decode(encodedFeatureRow)); } @Test - public void encodingIsInvalidIfNumberOfFeaturesInSpecDiffersFromFeatureRow() { - + public void shouldDecodeValidFeatureRowV2WithIncompleteFields() { FeatureRowDecoder decoder = new FeatureRowDecoder("feature_set_ref", spec); FeatureRowProto.FeatureRow encodedFeatureRow = FeatureRowProto.FeatureRow.newBuilder() .setEventTimestamp(Timestamp.newBuilder().setNanos(1000)) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt32Val(2))) + .addFields( + Field.newBuilder() + .setName( + Hashing.murmur3_32() + .hashString("feature1", StandardCharsets.UTF_8) + .toString()) + .setValue(Value.newBuilder().setInt32Val(2))) + .build(); + + // should decode missing fields as fields with unset value. + FeatureRowProto.FeatureRow expectedFeatureRow = + FeatureRowProto.FeatureRow.newBuilder() + .setFeatureSet("feature_set_ref") + .setEventTimestamp(Timestamp.newBuilder().setNanos(1000)) + .addFields( + Field.newBuilder().setName("feature1").setValue(Value.newBuilder().setInt32Val(2))) + .addFields(Field.newBuilder().setName("feature2").setValue(Value.newBuilder().build())) .build(); - assertFalse(decoder.isEncodingValid(encodedFeatureRow)); + assertEquals(expectedFeatureRow, decoder.decode(encodedFeatureRow)); } @Test - public void shouldDecodeValidEncodedFeatureRow() { + public void shouldDecodeValidFeatureRowV2AndIgnoreExtraFields() { + FeatureRowDecoder decoder = new FeatureRowDecoder("feature_set_ref", spec); + FeatureRowProto.FeatureRow encodedFeatureRow = + FeatureRowProto.FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setNanos(1000)) + .addFields( + Field.newBuilder() + .setName( + Hashing.murmur3_32() + .hashString("feature1", StandardCharsets.UTF_8) + .toString()) + .setValue(Value.newBuilder().setInt32Val(2))) + .addFields( + Field.newBuilder() + .setName( + Hashing.murmur3_32() + .hashString("feature2", StandardCharsets.UTF_8) + .toString()) + .setValue(Value.newBuilder().setFloatVal(1.0f))) + .addFields( + Field.newBuilder() + .setName( + Hashing.murmur3_32() + .hashString("feature3", StandardCharsets.UTF_8) + .toString()) + .setValue(Value.newBuilder().setStringVal("data"))) + .build(); + + // should decode missing fields as fields with unset value. + FeatureRowProto.FeatureRow expectedFeatureRow = + FeatureRowProto.FeatureRow.newBuilder() + .setFeatureSet("feature_set_ref") + .setEventTimestamp(Timestamp.newBuilder().setNanos(1000)) + .addFields( + Field.newBuilder().setName("feature1").setValue(Value.newBuilder().setInt32Val(2))) + .addFields( + Field.newBuilder() + .setName("feature2") + .setValue(Value.newBuilder().setFloatVal(1.0f))) + .build(); + + assertEquals(expectedFeatureRow, decoder.decode(encodedFeatureRow)); + } + + // TODO: remove this test in Feast 0.7 when support for Feature Row v1 encoding is removed + @Test + public void shouldDecodeValidEncodedFeatureRowV1() { FeatureRowDecoder decoder = new FeatureRowDecoder("feature_set_ref", spec); FeatureRowProto.FeatureRow encodedFeatureRow = @@ -103,8 +186,6 @@ public void shouldDecodeValidEncodedFeatureRow() { .setValue(Value.newBuilder().setFloatVal(1.0f))) .build(); - assertTrue(decoder.isEncoded(encodedFeatureRow)); - assertTrue(decoder.isEncodingValid(encodedFeatureRow)); assertEquals(expectedFeatureRow, decoder.decode(encodedFeatureRow)); } } diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java deleted file mode 100644 index 2adf0cec47f..00000000000 --- a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java +++ /dev/null @@ -1,507 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.storage.connectors.redis.writer; - -import static feast.storage.common.testing.TestUtil.field; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.protobuf.Timestamp; -import feast.common.models.FeatureSetReference; -import feast.proto.core.FeatureSetProto.EntitySpec; -import feast.proto.core.FeatureSetProto.FeatureSetSpec; -import feast.proto.core.FeatureSetProto.FeatureSpec; -import feast.proto.core.StoreProto.Store.RedisClusterConfig; -import feast.proto.storage.RedisProto.RedisKey; -import feast.proto.types.FeatureRowProto.FeatureRow; -import feast.proto.types.FieldProto.Field; -import feast.proto.types.ValueProto.Value; -import feast.proto.types.ValueProto.ValueType.Enum; -import io.lettuce.core.RedisURI; -import io.lettuce.core.cluster.RedisClusterClient; -import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; -import io.lettuce.core.cluster.api.sync.RedisClusterCommands; -import io.lettuce.core.codec.ByteArrayCodec; -import java.io.File; -import java.io.IOException; -import java.nio.file.Paths; -import java.util.*; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.ScheduledThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import net.ishiis.redis.unit.RedisCluster; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.transforms.windowing.*; -import org.apache.beam.sdk.values.PCollection; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; - -public class RedisClusterFeatureSinkTest { - @Rule public transient TestPipeline p = TestPipeline.create(); - - private static String REDIS_CLUSTER_HOST = "localhost"; - private static int REDIS_CLUSTER_PORT1 = 6380; - private static int REDIS_CLUSTER_PORT2 = 6381; - private static int REDIS_CLUSTER_PORT3 = 6382; - private static String CONNECTION_STRING = "localhost:6380,localhost:6381,localhost:6382"; - private RedisCluster redisCluster; - private RedisClusterClient redisClusterClient; - private RedisClusterCommands redisClusterCommands; - - private RedisFeatureSink redisClusterFeatureSink; - - @Before - public void setUp() throws IOException { - redisCluster = new RedisCluster(REDIS_CLUSTER_PORT1, REDIS_CLUSTER_PORT2, REDIS_CLUSTER_PORT3); - redisCluster.start(); - redisClusterClient = - RedisClusterClient.create( - Arrays.asList( - RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT1), - RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT2), - RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT3))); - StatefulRedisClusterConnection connection = - redisClusterClient.connect(new ByteArrayCodec()); - redisClusterCommands = connection.sync(); - redisClusterCommands.setTimeout(java.time.Duration.ofMillis(600000)); - - FeatureSetSpec spec1 = - FeatureSetSpec.newBuilder() - .setName("fs") - .setProject("myproject") - .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build()) - .addFeatures( - FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build()) - .build(); - - FeatureSetSpec spec2 = - FeatureSetSpec.newBuilder() - .setName("feature_set") - .setProject("myproject") - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_primary") - .setValueType(Enum.INT32) - .build()) - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_secondary") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build()) - .addFeatures( - FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build()) - .build(); - - Map specMap = - ImmutableMap.of( - FeatureSetReference.of("myproject", "fs", 1), spec1, - FeatureSetReference.of("myproject", "feature_set", 1), spec2); - RedisClusterConfig redisClusterConfig = - RedisClusterConfig.newBuilder() - .setConnectionString(CONNECTION_STRING) - .setInitialBackoffMs(2000) - .setMaxRetries(4) - .build(); - - redisClusterFeatureSink = - RedisFeatureSink.builder().setRedisClusterConfig(redisClusterConfig).build(); - redisClusterFeatureSink.prepareWrite(p.apply("Specs-1", Create.of(specMap))); - } - - static boolean deleteDirectory(File directoryToBeDeleted) { - File[] allContents = directoryToBeDeleted.listFiles(); - if (allContents != null) { - for (File file : allContents) { - deleteDirectory(file); - } - } - return directoryToBeDeleted.delete(); - } - - @After - public void teardown() { - redisCluster.stop(); - redisClusterClient.shutdown(); - deleteDirectory(new File(String.valueOf(Paths.get(System.getProperty("user.dir"), ".redis")))); - } - - @Test - public void shouldWriteToRedis() { - - HashMap kvs = new LinkedHashMap<>(); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) - .build()); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 2, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("two"))) - .build()); - - List featureRows = - ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build(), - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 2, Enum.INT64)) - .addFields(field("feature", "two", Enum.STRING)) - .build()); - - p.apply(Create.of(featureRows)).apply(redisClusterFeatureSink.writer()); - p.run(); - - kvs.forEach( - (key, value) -> { - byte[] actual = redisClusterCommands.get(key.toByteArray()); - assertThat(actual, equalTo(value.toByteArray())); - }); - } - - @Test(timeout = 15000) - public void shouldRetryFailConnection() throws InterruptedException { - HashMap kvs = new LinkedHashMap<>(); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) - .build()); - - List featureRows = - ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); - - PCollection failedElementCount = - p.apply(Create.of(featureRows)) - .apply(redisClusterFeatureSink.writer()) - .getFailedInserts() - .apply(Count.globally()); - - redisCluster.stop(); - final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1); - ScheduledFuture scheduledRedisRestart = - redisRestartExecutor.schedule( - () -> { - redisCluster.start(); - }, - 3, - TimeUnit.SECONDS); - - PAssert.that(failedElementCount).containsInAnyOrder(0L); - p.run(); - scheduledRedisRestart.cancel(true); - - kvs.forEach( - (key, value) -> { - byte[] actual = redisClusterCommands.get(key.toByteArray()); - assertThat(actual, equalTo(value.toByteArray())); - }); - } - - @Test - public void shouldProduceFailedElementIfRetryExceeded() { - RedisClusterConfig redisClusterConfig = - RedisClusterConfig.newBuilder() - .setConnectionString(CONNECTION_STRING) - .setInitialBackoffMs(2000) - .setMaxRetries(1) - .build(); - - FeatureSetSpec spec1 = - FeatureSetSpec.newBuilder() - .setName("fs") - .setProject("myproject") - .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build()) - .addFeatures( - FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build()) - .build(); - Map specMap = ImmutableMap.of("myproject/fs", spec1); - redisClusterFeatureSink = - RedisFeatureSink.builder() - .setRedisClusterConfig(redisClusterConfig) - .build() - .withSpecsView(p.apply("Specs-2", Create.of(specMap)).apply("View", View.asMultimap())); - - redisCluster.stop(); - - List featureRows = - ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); - - PCollection failedElementCount = - p.apply(Create.of(featureRows)) - .apply("modifiedSink", redisClusterFeatureSink.writer()) - .getFailedInserts() - .apply(Count.globally()); - - PAssert.that(failedElementCount).containsInAnyOrder(1L); - p.run(); - } - - @Test - public void shouldConvertRowWithDuplicateEntitiesToValidKey() { - - FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(2))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); - - FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) - .build(); - - p.apply(Create.of(offendingRow)).apply(redisClusterFeatureSink.writer()); - - p.run(); - - byte[] actual = redisClusterCommands.get(expectedKey.toByteArray()); - assertThat(actual, equalTo(expectedValue.toByteArray())); - } - - @Test - public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { - FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .build(); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); - - List expectedFields = - Arrays.asList( - Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")).build(), - Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)).build()); - FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addAllFields(expectedFields) - .build(); - - p.apply(Create.of(offendingRow)).apply(redisClusterFeatureSink.writer()); - - p.run(); - - byte[] actual = redisClusterCommands.get(expectedKey.toByteArray()); - assertThat(actual, equalTo(expectedValue.toByteArray())); - } - - @Test - public void shouldMergeDuplicateFeatureFields() { - FeatureRow featureRowWithDuplicatedFeatureFields = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); - - FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) - .build(); - - p.apply(Create.of(featureRowWithDuplicatedFeatureFields)) - .apply(redisClusterFeatureSink.writer()); - - p.run(); - - byte[] actual = redisClusterCommands.get(expectedKey.toByteArray()); - assertThat(actual, equalTo(expectedValue.toByteArray())); - } - - @Test - public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { - FeatureRow featureRowWithDuplicatedFeatureFields = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .build(); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); - - FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields(Field.newBuilder().setValue(Value.getDefaultInstance())) - .build(); - - p.apply(Create.of(featureRowWithDuplicatedFeatureFields)) - .apply(redisClusterFeatureSink.writer()); - - p.run(); - - byte[] actual = redisClusterCommands.get(expectedKey.toByteArray()); - assertThat(actual, equalTo(expectedValue.toByteArray())); - } -} diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java index 63ec136c5d0..12377fd1d1b 100644 --- a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java @@ -17,65 +17,115 @@ package feast.storage.connectors.redis.writer; import static feast.storage.common.testing.TestUtil.field; +import static feast.storage.common.testing.TestUtil.hash; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsCollectionWithSize.hasSize; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.protobuf.Message; import com.google.protobuf.Timestamp; import feast.common.models.FeatureSetReference; import feast.proto.core.FeatureSetProto.EntitySpec; import feast.proto.core.FeatureSetProto.FeatureSetSpec; import feast.proto.core.FeatureSetProto.FeatureSpec; import feast.proto.core.StoreProto; +import feast.proto.core.StoreProto.Store.RedisClusterConfig; import feast.proto.core.StoreProto.Store.RedisConfig; import feast.proto.storage.RedisProto.RedisKey; import feast.proto.types.FeatureRowProto.FeatureRow; import feast.proto.types.FieldProto.Field; import feast.proto.types.ValueProto.Value; import feast.proto.types.ValueProto.ValueType.Enum; +import io.lettuce.core.AbstractRedisClient; import io.lettuce.core.RedisClient; import io.lettuce.core.RedisURI; -import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.sync.RedisStringCommands; +import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.codec.ByteArrayCodec; -import java.io.IOException; import java.util.*; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import net.ishiis.redis.unit.Redis; +import net.ishiis.redis.unit.RedisCluster; +import net.ishiis.redis.unit.RedisServer; +import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import redis.embedded.Redis; -import redis.embedded.RedisServer; +import org.junit.*; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +@RunWith(Parameterized.class) public class RedisFeatureSinkTest { @Rule public transient TestPipeline p = TestPipeline.create(); private static String REDIS_HOST = "localhost"; - private static int REDIS_PORT = 51234; - private Redis redis; - private RedisClient redisClient; - private RedisStringCommands sync; + private static int REDIS_PORT = 51233; + private static Integer[] REDIS_CLUSTER_PORTS = {6380, 6381, 6382}; + private RedisStringCommands sync; private RedisFeatureSink redisFeatureSink; private Map specMap; - @Before - public void setUp() throws IOException { - redis = new RedisServer(REDIS_PORT); - redis.start(); - redisClient = + @Parameterized.Parameters + public static Iterable backends() { + Redis redis = new RedisServer(REDIS_PORT); + RedisClient client = RedisClient.create(new RedisURI(REDIS_HOST, REDIS_PORT, java.time.Duration.ofMillis(2000))); - StatefulRedisConnection connection = redisClient.connect(new ByteArrayCodec()); - sync = connection.sync(); + + Redis redisCluster = new RedisCluster(REDIS_CLUSTER_PORTS); + RedisClusterClient clientCluster = + RedisClusterClient.create( + Lists.newArrayList(REDIS_CLUSTER_PORTS).stream() + .map(port -> RedisURI.create(REDIS_HOST, port)) + .collect(Collectors.toList())); + + StoreProto.Store.RedisConfig redisConfig = + StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build(); + + StoreProto.Store.RedisClusterConfig redisClusterConfig = + StoreProto.Store.RedisClusterConfig.newBuilder() + .setConnectionString( + Lists.newArrayList(REDIS_CLUSTER_PORTS).stream() + .map(port -> String.format("%s:%d", REDIS_HOST, port)) + .collect(Collectors.joining(","))) + .setInitialBackoffMs(2000) + .setMaxRetries(4) + .build(); + + return Arrays.asList( + new Object[] {redis, client, redisConfig}, + new Object[] {redisCluster, clientCluster, redisClusterConfig}); + } + + @Parameterized.Parameter(0) + public Redis redisServer; + + @Parameterized.Parameter(1) + public AbstractRedisClient redisClient; + + @Parameterized.Parameter(2) + public Message redisConfig; + + @Before + public void setUp() { + redisServer.start(); + + if (redisClient instanceof RedisClient) { + sync = ((RedisClient) redisClient).connect(new ByteArrayCodec()).sync(); + } else { + sync = ((RedisClusterClient) redisClient).connect(new ByteArrayCodec()).sync(); + } FeatureSetSpec spec1 = FeatureSetSpec.newBuilder() @@ -110,17 +160,42 @@ public void setUp() throws IOException { ImmutableMap.of( FeatureSetReference.of("myproject", "fs", 1), spec1, FeatureSetReference.of("myproject", "feature_set", 1), spec2); - StoreProto.Store.RedisConfig redisConfig = - StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build(); - redisFeatureSink = RedisFeatureSink.builder().setRedisConfig(redisConfig).build(); + RedisFeatureSink.Builder builder = RedisFeatureSink.builder(); + if (redisConfig instanceof RedisConfig) { + builder = builder.setRedisConfig((RedisConfig) redisConfig); + } else { + builder = builder.setRedisClusterConfig((RedisClusterConfig) redisConfig); + } + redisFeatureSink = builder.build(); redisFeatureSink.prepareWrite(p.apply("Specs-1", Create.of(specMap))); } @After - public void teardown() { - redisClient.shutdown(); - redis.stop(); + public void tearDown() { + if (redisServer.isActive()) { + redisServer.stop(); + } + } + + private RedisKey createRedisKey(String featureSetRef, Field... fields) { + return RedisKey.newBuilder() + .setFeatureSet(featureSetRef) + .addAllEntities(Lists.newArrayList(fields)) + .build(); + } + + private FeatureRow createFeatureRow(String featureSetRef, Timestamp timestamp, Field... fields) { + FeatureRow.Builder builder = FeatureRow.newBuilder(); + if (featureSetRef != null) { + builder.setFeatureSet(featureSetRef); + } + + if (timestamp != null) { + builder.setEventTimestamp(timestamp); + } + + return builder.addAllFields(Lists.newArrayList(fields)).build(); } @Test @@ -128,36 +203,26 @@ public void shouldWriteToRedis() { HashMap kvs = new LinkedHashMap<>(); kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) - .build()); + createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)), + createFeatureRow( + null, Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING))); kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 2, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("two"))) - .build()); + createRedisKey("myproject/fs", field("entity", 2, Enum.INT64)), + createFeatureRow( + null, Timestamp.getDefaultInstance(), field(hash("feature"), "two", Enum.STRING))); List featureRows = ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build(), - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 2, Enum.INT64)) - .addFields(field("feature", "two", Enum.STRING)) - .build()); + createFeatureRow( + "myproject/fs", + null, + field("entity", 1, Enum.INT64), + field("feature", "one", Enum.STRING)), + createFeatureRow( + "myproject/fs", + null, + field("entity", 2, Enum.INT64), + field("feature", "two", Enum.STRING))); p.apply(Create.of(featureRows)).apply(redisFeatureSink.writer()); p.run(); @@ -169,7 +234,7 @@ public void shouldWriteToRedis() { }); } - @Test(timeout = 10000) + @Test(timeout = 30000) public void shouldRetryFailConnection() throws InterruptedException { RedisConfig redisConfig = RedisConfig.newBuilder() @@ -187,22 +252,17 @@ public void shouldRetryFailConnection() throws InterruptedException { HashMap kvs = new LinkedHashMap<>(); kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) - .build()); + createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)), + createFeatureRow( + "", Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING))); List featureRows = ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); + createFeatureRow( + "myproject/fs", + null, + field("entity", 1, Enum.INT64), + field("feature", "one", Enum.STRING))); PCollection failedElementCount = p.apply(Create.of(featureRows)) @@ -210,12 +270,12 @@ public void shouldRetryFailConnection() throws InterruptedException { .getFailedInserts() .apply(Count.globally()); - redis.stop(); + redisServer.stop(); final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1); ScheduledFuture scheduledRedisRestart = redisRestartExecutor.schedule( () -> { - redis.start(); + redisServer.start(); }, 3, TimeUnit.SECONDS); @@ -245,14 +305,9 @@ public void shouldProduceFailedElementIfRetryExceeded() { HashMap kvs = new LinkedHashMap<>(); kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) - .build()); + createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)), + createFeatureRow( + "", Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING))); List featureRows = ImmutableList.of( @@ -268,7 +323,7 @@ public void shouldProduceFailedElementIfRetryExceeded() { .getFailedInserts() .apply(Count.globally()); - redis.stop(); + redisServer.stop(); PAssert.that(failedElementCount).containsInAnyOrder(1L); p.run(); } @@ -277,50 +332,27 @@ public void shouldProduceFailedElementIfRetryExceeded() { public void shouldConvertRowWithDuplicateEntitiesToValidKey() { FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(2))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_primary", 2, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_1", "strValue1", Enum.STRING), + field("feature_2", 1001, Enum.INT64)); RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) - .build(); + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(10).build(), + field(hash("feature_1"), "strValue1", Enum.STRING), + field(hash("feature_2"), 1001, Enum.INT64)); p.apply(Create.of(offendingRow)).apply(redisFeatureSink.writer()); @@ -333,49 +365,26 @@ public void shouldConvertRowWithDuplicateEntitiesToValidKey() { @Test public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .build(); + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_secondary", "a", Enum.STRING), + field("entity_id_primary", 1, Enum.INT32), + field("feature_2", 1001, Enum.INT64), + field("feature_1", "strValue1", Enum.STRING)); RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); - List expectedFields = - Arrays.asList( - Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1")).build(), - Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001)).build()); FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addAllFields(expectedFields) - .build(); + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(10).build(), + field(hash("feature_1"), "strValue1", Enum.STRING), + field(hash("feature_2"), 1001, Enum.INT64)); p.apply(Create.of(offendingRow)).apply(redisFeatureSink.writer()); @@ -388,50 +397,27 @@ public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { @Test public void shouldMergeDuplicateFeatureFields() { FeatureRow featureRowWithDuplicatedFeatureFields = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 1001, Enum.INT64), + field("feature_1", "strValue1", Enum.STRING), + field("feature_1", "strValue1", Enum.STRING)); RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) - .build(); + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(10).build(), + field(hash("feature_1"), "strValue1", Enum.STRING), + field(hash("feature_2"), 1001, Enum.INT64)); p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.writer()); @@ -444,42 +430,28 @@ public void shouldMergeDuplicateFeatureFields() { @Test public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { FeatureRow featureRowWithDuplicatedFeatureFields = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .build(); + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_1", "strValue1", Enum.STRING)); RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields(Field.newBuilder().setValue(Value.getDefaultInstance())) - .build(); + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(10).build(), + field(hash("feature_1"), "strValue1", Enum.STRING), + Field.newBuilder() + .setName(hash("feature_2")) + .setValue(Value.getDefaultInstance()) + .build()); p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.writer()); @@ -488,4 +460,206 @@ public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { byte[] actual = sync.get(expectedKey.toByteArray()); assertThat(actual, equalTo(expectedValue.toByteArray())); } + + @Test + public void shouldDeduplicateRowsWithinBatch() { + TestStream featureRowTestStream = + TestStream.create(ProtoCoder.of(FeatureRow.class)) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 111, Enum.INT32))) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 222, Enum.INT32))) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.getDefaultInstance(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 333, Enum.INT32))) + .advanceWatermarkToInfinity(); + + p.apply(featureRowTestStream).apply(redisFeatureSink.writer()); + p.run(); + + RedisKey expectedKey = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); + + FeatureRow expectedValue = + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(20).build(), + Field.newBuilder() + .setName(hash("feature_1")) + .setValue(Value.getDefaultInstance()) + .build(), + field(hash("feature_2"), 111, Enum.INT32)); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); + } + + @Test + public void shouldWriteWithLatterTimestamp() { + TestStream featureRowTestStream = + TestStream.create(ProtoCoder.of(FeatureRow.class)) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 111, Enum.INT32))) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", 2, Enum.INT32), + field("entity_id_secondary", "b", Enum.STRING), + field("feature_2", 222, Enum.INT32))) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 3, Enum.INT32), + field("entity_id_secondary", "c", Enum.STRING), + field("feature_2", 333, Enum.INT32))) + .advanceWatermarkToInfinity(); + + RedisKey keyA = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); + + RedisKey keyB = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 2, Enum.INT32), + field("entity_id_secondary", "b", Enum.STRING)); + + RedisKey keyC = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 3, Enum.INT32), + field("entity_id_secondary", "c", Enum.STRING)); + + sync.set( + keyA.toByteArray(), + createFeatureRow("", Timestamp.newBuilder().setSeconds(30).build()).toByteArray()); + + sync.set( + keyB.toByteArray(), + createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray()); + + sync.set( + keyC.toByteArray(), + createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray()); + + p.apply(featureRowTestStream).apply(redisFeatureSink.writer()); + p.run(); + + assertThat( + sync.get(keyA.toByteArray()), + equalTo(createFeatureRow("", Timestamp.newBuilder().setSeconds(30).build()).toByteArray())); + + assertThat( + sync.get(keyB.toByteArray()), + equalTo( + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(20).build(), + Field.newBuilder() + .setName(hash("feature_1")) + .setValue(Value.getDefaultInstance()) + .build(), + field(hash("feature_2"), 222, Enum.INT32)) + .toByteArray())); + + assertThat( + sync.get(keyC.toByteArray()), + equalTo(createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray())); + } + + @Test + public void shouldOverwriteInvalidRows() { + TestStream featureRowTestStream = + TestStream.create(ProtoCoder.of(FeatureRow.class)) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_1", "text", Enum.STRING), + field("feature_2", 111, Enum.INT32))) + .advanceWatermarkToInfinity(); + + RedisKey expectedKey = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); + + sync.set(expectedKey.toByteArray(), "some-invalid-data".getBytes()); + + p.apply(featureRowTestStream).apply(redisFeatureSink.writer()); + p.run(); + + FeatureRow expectedValue = + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(20).build(), + field(hash("feature_1"), "text", Enum.STRING), + field(hash("feature_2"), 111, Enum.INT32)); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); + } + + @Test + public void loadTest() { + List rows = + IntStream.range(0, 10000) + .mapToObj( + i -> + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", i, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_1", "text", Enum.STRING), + field("feature_2", 111, Enum.INT32))) + .collect(Collectors.toList()); + + p.apply(Create.of(rows)).apply(redisFeatureSink.writer()); + p.run(); + + List outcome = + IntStream.range(0, 10000) + .mapToObj( + i -> + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", i, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)) + .toByteArray()) + .map(sync::get) + .collect(Collectors.toList()); + + assertThat(outcome, hasSize(10000)); + assertThat("All rows were stored", outcome.stream().allMatch(Objects::nonNull)); + } } diff --git a/tests/e2e/redis/basic-ingest-redis-serving.py b/tests/e2e/redis/basic-ingest-redis-serving.py index 5bb79bc8e8f..c1e25508d44 100644 --- a/tests/e2e/redis/basic-ingest-redis-serving.py +++ b/tests/e2e/redis/basic-ingest-redis-serving.py @@ -4,7 +4,7 @@ import tempfile import time import uuid -from datetime import datetime +from datetime import datetime, timedelta import grpc import numpy as np @@ -14,6 +14,8 @@ from google.protobuf.duration_pb2 import Duration from feast.client import Client +from feast.config import Config +from feast.constants import CONFIG_AUTH_PROVIDER from feast.core import CoreService_pb2 from feast.core.CoreService_pb2 import ApplyFeatureSetResponse, GetFeatureSetResponse from feast.core.CoreService_pb2_grpc import CoreServiceStub @@ -21,6 +23,7 @@ from feast.entity import Entity from feast.feature import Feature from feast.feature_set import FeatureSet, FeatureSetRef +from feast.grpc.auth import get_auth_metadata_plugin from feast.serving.ServingService_pb2 import ( GetOnlineFeaturesRequest, GetOnlineFeaturesResponse, @@ -34,6 +37,7 @@ FLOAT_TOLERANCE = 0.00001 PROJECT_NAME = "basic_" + uuid.uuid4().hex.upper()[0:6] DIR_PATH = os.path.dirname(os.path.realpath(__file__)) +AUTH_PROVIDER = "google" def basic_dataframe(entities, features, ingest_time, n_size, null_features=[]): @@ -103,7 +107,7 @@ def allow_dirty(pytestconfig): @pytest.fixture(scope="module") def enable_auth(pytestconfig): - return pytestconfig.getoption("enable_auth") + return True if pytestconfig.getoption("enable_auth").lower() == "true" else False @pytest.fixture(scope="module") @@ -114,8 +118,8 @@ def client(core_url, serving_url, allow_dirty, enable_auth): client = Client( core_url=core_url, serving_url=serving_url, - core_enable_auth=enable_auth, - core_auth_provider="google", + enable_auth=enable_auth, + auth_provider=AUTH_PROVIDER, ) client.create_project(PROJECT_NAME) @@ -162,6 +166,13 @@ def test_version_returns_results(client): assert not version_info["serving"] == "not configured" +def test_list_feature_sets_when_auth_enabled_should_raise(enable_auth): + if enable_auth: + client = Client(core_url=core_url, serving_url=serving_url, enable_auth=False) + with pytest.raises(ConnectionError): + client.list_feature_sets() + + @pytest.mark.timeout(45) @pytest.mark.run(order=10) def test_basic_register_feature_set_success(client): @@ -558,6 +569,202 @@ def try_get_features2(): ) +@pytest.mark.timeout(600) +@pytest.mark.run(order=16) +def test_basic_ingest_retrieval_fs(client): + # Set to another project to test ingestion based on current project context + client.set_project(PROJECT_NAME + "_NS1") + driver_fs = FeatureSet( + name="driver_fs", + features=[ + Feature(name="driver_fs_rating", dtype=ValueType.FLOAT), + Feature(name="driver_fs_cost", dtype=ValueType.FLOAT), + ], + entities=[Entity("driver_fs_id", ValueType.INT64)], + max_age=Duration(seconds=3600), + ) + client.apply(driver_fs) + + N_ROWS = 2 + time_offset = datetime.utcnow().replace(tzinfo=pytz.utc) + driver_df = pd.DataFrame( + { + "datetime": [time_offset] * N_ROWS, + "driver_fs_id": [i for i in range(N_ROWS)], + "driver_fs_rating": [float(i) for i in range(N_ROWS)], + "driver_fs_cost": [float(i) + 0.5 for i in range(N_ROWS)], + } + ) + client.ingest(driver_fs, driver_df, timeout=600) + time.sleep(15) + + online_request_entity = [{"driver_fs_id": 0}, {"driver_fs_id": 1}] + online_request_features = ["driver_fs_rating", "driver_fs_cost"] + + def try_get_features(): + response = client.get_online_features( + entity_rows=online_request_entity, feature_refs=online_request_features + ) + return response, True + + online_features_actual = wait_retry_backoff( + retry_fn=try_get_features, + timeout_secs=90, + timeout_msg="Timed out trying to get online feature values", + ) + + online_features_expected = { + "driver_fs_id": [0, 1], + "driver_fs_rating": [0.0, 1.0], + "driver_fs_cost": [0.5, 1.5], + } + + assert online_features_actual.to_dict() == online_features_expected + + +@pytest.mark.timeout(600) +@pytest.mark.run(order=17) +def test_basic_ingest_retrieval_str(client): + # Set to another project to test ingestion based on current project context + client.set_project(PROJECT_NAME + "_NS1") + customer_fs = FeatureSet( + name="cust_fs", + features=[ + Feature(name="cust_rating", dtype=ValueType.INT64), + Feature(name="cust_cost", dtype=ValueType.FLOAT), + ], + entities=[Entity("cust_id", ValueType.INT64)], + max_age=Duration(seconds=3600), + ) + client.apply(customer_fs) + + N_ROWS = 2 + time_offset = datetime.utcnow().replace(tzinfo=pytz.utc) + cust_df = pd.DataFrame( + { + "datetime": [time_offset] * N_ROWS, + "cust_id": [i for i in range(N_ROWS)], + "cust_rating": [i for i in range(N_ROWS)], + "cust_cost": [float(i) + 0.5 for i in range(N_ROWS)], + } + ) + client.ingest("cust_fs", cust_df, timeout=600) + time.sleep(15) + + online_request_entity = [{"cust_id": 0}, {"cust_id": 1}] + online_request_features = ["cust_rating", "cust_cost"] + + def try_get_features(): + response = client.get_online_features( + entity_rows=online_request_entity, feature_refs=online_request_features + ) + return response, True + + online_features_actual = wait_retry_backoff( + retry_fn=try_get_features, + timeout_secs=90, + timeout_msg="Timed out trying to get online feature values", + ) + + online_features_expected = { + "cust_id": [0, 1], + "cust_rating": [0, 1], + "cust_cost": [0.5, 1.5], + } + + assert online_features_actual.to_dict() == online_features_expected + + +@pytest.mark.timeout(600) +@pytest.mark.run(order=18) +def test_basic_retrieve_feature_row_missing_fields(client, cust_trans_df): + feature_refs = ["daily_transactions", "total_transactions", "null_values"] + + # apply cust_trans_fs and ingest dataframe + client.set_project(PROJECT_NAME + "_basic_retrieve_missing_fields") + old_cust_trans_fs = FeatureSet.from_yaml(f"{DIR_PATH}/basic/cust_trans_fs.yaml") + client.apply(old_cust_trans_fs) + client.ingest(old_cust_trans_fs, cust_trans_df) + + # update cust_trans_fs with one additional feature. + # feature rows ingested before the feature set update will be missing a field. + new_cust_trans_fs = client.get_feature_set(name="customer_transactions") + new_cust_trans_fs.add(Feature("n_trips", ValueType.INT64)) + client.apply(new_cust_trans_fs) + # sleep to ensure feature set update is propagated + time.sleep(15) + + # attempt to retrieve features from feature rows with missing fields + def try_get_features(): + response = client.get_online_features( + entity_rows=[ + {"customer_id": np.int64(cust_trans_df.iloc[0]["customer_id"])} + ], + feature_refs=feature_refs + ["n_trips"], + ) # type: GetOnlineFeaturesResponse + # check if the ingested fields can be correctly retrieved. + is_ok = all( + [ + check_online_response(ref, cust_trans_df, response) + for ref in feature_refs + ] + ) + # should return null_value status for missing field n_trips + is_missing_ok = ( + response.field_values[0].statuses["n_trips"] + == GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE + ) + return response, is_ok and is_missing_ok + + wait_retry_backoff( + retry_fn=try_get_features, + timeout_secs=90, + timeout_msg="Timed out trying to get online feature values", + ) + + +@pytest.mark.timeout(600) +@pytest.mark.run(order=19) +def test_basic_retrieve_feature_row_extra_fields(client, cust_trans_df): + feature_refs = ["daily_transactions", "total_transactions"] + # apply cust_trans_fs and ingest dataframe + client.set_project(PROJECT_NAME + "_basic_retrieve_missing_fields") + old_cust_trans_fs = FeatureSet.from_yaml(f"{DIR_PATH}/basic/cust_trans_fs.yaml") + client.apply(old_cust_trans_fs) + client.ingest(old_cust_trans_fs, cust_trans_df) + + # update cust_trans_fs with the null_values feature dropped. + # feature rows ingested before the feature set update will have an extra field. + new_cust_trans_fs = client.get_feature_set(name="customer_transactions") + new_cust_trans_fs.drop("null_values") + client.apply(new_cust_trans_fs) + # sleep to ensure feature set update is propagated + time.sleep(15) + + # attempt to retrieve features from feature rows with extra fields + def try_get_features(): + response = client.get_online_features( + entity_rows=[ + {"customer_id": np.int64(cust_trans_df.iloc[0]["customer_id"])} + ], + feature_refs=feature_refs, + ) # type: GetOnlineFeaturesResponse + # check if the non dropped fields can be correctly retrieved. + is_ok = all( + [ + check_online_response(ref, cust_trans_df, response) + for ref in feature_refs + ] + ) + return response, is_ok + + wait_retry_backoff( + retry_fn=try_get_features, + timeout_secs=90, + timeout_msg="Timed out trying to get online feature values", + ) + + @pytest.fixture(scope="module") def all_types_dataframe(): return pd.DataFrame( @@ -614,6 +821,8 @@ def all_types_dataframe(): @pytest.mark.timeout(45) @pytest.mark.run(order=20) def test_all_types_register_feature_set_success(client): + client.set_project(PROJECT_NAME) + all_types_fs_expected = FeatureSet( name="all_types", entities=[Entity(name="user_id", dtype=ValueType.INT64)], @@ -723,9 +932,11 @@ def try_get_features(): @pytest.mark.timeout(300) -@pytest.mark.run(order=29) +@pytest.mark.run(order=35) def test_all_types_ingest_jobs(client, all_types_dataframe): # list ingestion jobs given featureset + client.set_project(PROJECT_NAME) + all_types_fs = client.get_feature_set(name="all_types") ingest_jobs = client.list_ingest_jobs( feature_set_ref=FeatureSetRef.from_feature_set(all_types_fs) @@ -783,7 +994,7 @@ def large_volume_dataframe(): @pytest.mark.timeout(45) -@pytest.mark.run(order=30) +@pytest.mark.run(order=40) def test_large_volume_register_feature_set_success(client): cust_trans_fs_expected = FeatureSet.from_yaml( f"{DIR_PATH}/large_volume/cust_trans_large_fs.yaml" @@ -809,7 +1020,7 @@ def test_large_volume_register_feature_set_success(client): @pytest.mark.timeout(300) -@pytest.mark.run(order=31) +@pytest.mark.run(order=41) def test_large_volume_ingest_success(client, large_volume_dataframe): # Get large volume feature set cust_trans_fs = client.get_feature_set(name="customer_transactions_large") @@ -819,7 +1030,7 @@ def test_large_volume_ingest_success(client, large_volume_dataframe): @pytest.mark.timeout(90) -@pytest.mark.run(order=32) +@pytest.mark.run(order=42) def test_large_volume_retrieve_online_success(client, large_volume_dataframe): # Poll serving for feature values until the correct values are returned feature_refs = [ @@ -905,7 +1116,7 @@ def all_types_parquet_file(): @pytest.mark.timeout(300) -@pytest.mark.run(order=40) +@pytest.mark.run(order=50) def test_all_types_parquet_register_feature_set_success(client): # Load feature set from file all_types_parquet_expected = FeatureSet.from_yaml( @@ -933,7 +1144,7 @@ def test_all_types_parquet_register_feature_set_success(client): @pytest.mark.timeout(600) -@pytest.mark.run(order=41) +@pytest.mark.run(order=51) def test_all_types_infer_register_ingest_file_success(client, all_types_parquet_file): # Get feature set all_types_fs = client.get_feature_set(name="all_types_parquet") @@ -943,7 +1154,7 @@ def test_all_types_infer_register_ingest_file_success(client, all_types_parquet_ @pytest.mark.timeout(200) -@pytest.mark.run(order=50) +@pytest.mark.run(order=60) def test_list_entities_and_features(client): customer_entity = Entity("customer_id", ValueType.INT64) driver_entity = Entity("driver_id", ValueType.INT64) @@ -1018,7 +1229,7 @@ def test_list_entities_and_features(client): @pytest.mark.timeout(900) -@pytest.mark.run(order=60) +@pytest.mark.run(order=70) def test_sources_deduplicate_ingest_jobs(client): source = KafkaSource("localhost:9092", "feast-features") alt_source = KafkaSource("localhost:9092", "feast-data") @@ -1066,6 +1277,58 @@ def get_running_jobs(): time.sleep(1) +@pytest.mark.run(order=30) +def test_sink_writes_only_recent_rows(client): + client.set_project("default") + + feature_refs = ["driver:rating", "driver:cost"] + + later_df = basic_dataframe( + entities=["driver_id"], + features=["rating", "cost"], + ingest_time=datetime.utcnow(), + n_size=5, + ) + + earlier_df = basic_dataframe( + entities=["driver_id"], + features=["rating", "cost"], + ingest_time=datetime.utcnow() - timedelta(minutes=5), + n_size=5, + ) + + def try_get_features(): + response = client.get_online_features( + entity_rows=[ + GetOnlineFeaturesRequest.EntityRow( + fields={"driver_id": Value(int64_val=later_df.iloc[0]["driver_id"])} + ) + ], + feature_refs=feature_refs, + ) # type: GetOnlineFeaturesResponse + is_ok = all( + [check_online_response(ref, later_df, response) for ref in feature_refs] + ) + return response, is_ok + + # test compaction within batch + client.ingest("driver", pd.concat([earlier_df, later_df])) + wait_retry_backoff( + retry_fn=try_get_features, + timeout_secs=90, + timeout_msg="Timed out trying to get online feature values", + ) + + # test read before write + client.ingest("driver", earlier_df) + time.sleep(10) + wait_retry_backoff( + retry_fn=try_get_features, + timeout_secs=90, + timeout_msg="Timed out trying to get online feature values", + ) + + # TODO: rewrite these using python SDK once the labels are implemented there class TestsBasedOnGrpc: GRPC_CONNECTION_TIMEOUT = 3 @@ -1091,22 +1354,33 @@ def core_service_stub(self, core_url): core_service_stub = CoreServiceStub(core_channel) return core_service_stub - def apply_feature_set(self, core_service_stub, feature_set_proto): + @pytest.fixture(scope="module") + def auth_meta_data(self, enable_auth): + if not enable_auth: + return None + else: + metadata = {CONFIG_AUTH_PROVIDER: AUTH_PROVIDER} + metadata_plugin = get_auth_metadata_plugin(config=Config(metadata)) + return metadata_plugin.get_signed_meta() + + def apply_feature_set(self, core_service_stub, feature_set_proto, auth_meta_data): try: apply_fs_response = core_service_stub.ApplyFeatureSet( CoreService_pb2.ApplyFeatureSetRequest(feature_set=feature_set_proto), timeout=self.GRPC_CONNECTION_TIMEOUT, + metadata=auth_meta_data, ) # type: ApplyFeatureSetResponse except grpc.RpcError as e: raise grpc.RpcError(e.details()) return apply_fs_response.feature_set - def get_feature_set(self, core_service_stub, name, project): + def get_feature_set(self, core_service_stub, name, project, auth_meta_data): try: get_feature_set_response = core_service_stub.GetFeatureSet( CoreService_pb2.GetFeatureSetRequest( project=project, name=name.strip(), - ) + ), + metadata=auth_meta_data, ) # type: GetFeatureSetResponse except grpc.RpcError as e: raise grpc.RpcError(e.details()) @@ -1114,17 +1388,17 @@ def get_feature_set(self, core_service_stub, name, project): @pytest.mark.timeout(45) @pytest.mark.run(order=51) - def test_register_feature_set_with_labels(self, core_service_stub): + def test_register_feature_set_with_labels(self, core_service_stub, auth_meta_data): feature_set_name = "test_feature_set_labels" feature_set_proto = FeatureSet( name=feature_set_name, project=PROJECT_NAME, labels={self.LABEL_KEY: self.LABEL_VALUE}, ).to_proto() - self.apply_feature_set(core_service_stub, feature_set_proto) + self.apply_feature_set(core_service_stub, feature_set_proto, auth_meta_data) retrieved_feature_set = self.get_feature_set( - core_service_stub, feature_set_name, PROJECT_NAME + core_service_stub, feature_set_name, PROJECT_NAME, auth_meta_data ) assert self.LABEL_KEY in retrieved_feature_set.spec.labels @@ -1132,7 +1406,7 @@ def test_register_feature_set_with_labels(self, core_service_stub): @pytest.mark.timeout(45) @pytest.mark.run(order=52) - def test_register_feature_with_labels(self, core_service_stub): + def test_register_feature_with_labels(self, core_service_stub, auth_meta_data): feature_set_name = "test_feature_labels" feature_set_proto = FeatureSet( name=feature_set_name, @@ -1145,10 +1419,10 @@ def test_register_feature_with_labels(self, core_service_stub): ) ], ).to_proto() - self.apply_feature_set(core_service_stub, feature_set_proto) + self.apply_feature_set(core_service_stub, feature_set_proto, auth_meta_data) retrieved_feature_set = self.get_feature_set( - core_service_stub, feature_set_name, PROJECT_NAME + core_service_stub, feature_set_name, PROJECT_NAME, auth_meta_data ) retrieved_feature = retrieved_feature_set.spec.features[0]