diff --git a/go.mod b/go.mod
index 1be2396f6e..d5917cb33e 100644
--- a/go.mod
+++ b/go.mod
@@ -25,7 +25,8 @@ require (
golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect
golang.org/x/net v0.0.0-20200513185701-a91f0712d120
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9 // indirect
- golang.org/x/tools v0.0.0-20200528185414-6be401e3f76e // indirect
+ golang.org/x/tools v0.0.0-20200521211927-2b542361a4fc // indirect
+ google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587 // indirect
google.golang.org/grpc v1.29.1
google.golang.org/protobuf v1.24.0 // indirect
gopkg.in/russross/blackfriday.v2 v2.0.0 // indirect
diff --git a/go.sum b/go.sum
index 1b2aa8bf63..90a517da7e 100644
--- a/go.sum
+++ b/go.sum
@@ -462,8 +462,10 @@ golang.org/x/tools v0.0.0-20200504022951-6b6965ac5dd1 h1:C8rdnd6KieI73Z2Av0sS0t4
golang.org/x/tools v0.0.0-20200504022951-6b6965ac5dd1/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200515220128-d3bf790afa53 h1:vmsb6v0zUdmUlXfwKaYrHPPRCV0lHq/IwNIf0ASGjyQ=
golang.org/x/tools v0.0.0-20200515220128-d3bf790afa53/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
-golang.org/x/tools v0.0.0-20200528185414-6be401e3f76e h1:jTL1CJ2kmavapMVdBKy6oVrhBHByRCMfykS45+lEFQk=
-golang.org/x/tools v0.0.0-20200528185414-6be401e3f76e/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+golang.org/x/tools v0.0.0-20200519205726-57a9e4404bf7 h1:nm4zDh9WvH4jiuUpMY5RUsvOwrtTVVAsUaCdLW71hfY=
+golang.org/x/tools v0.0.0-20200519205726-57a9e4404bf7/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+golang.org/x/tools v0.0.0-20200521211927-2b542361a4fc h1:6m2YO+AmBApbUOmhsghW+IfRyZOY4My4UYvQQrEpHfY=
+golang.org/x/tools v0.0.0-20200521211927-2b542361a4fc/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
diff --git a/sdk/go/client.go b/sdk/go/client.go
index ef8ec8b8cf..1f0cbaa3b9 100644
--- a/sdk/go/client.go
+++ b/sdk/go/client.go
@@ -61,7 +61,7 @@ func (fc *GrpcClient) GetOnlineFeatures(ctx context.Context, req *OnlineFeatures
// strip projects from to projects
for _, fieldValue := range resp.GetFieldValues() {
- stripFields := make(map[string]*types.Value)
+ stripFields := make(map[string]*types.Value, len(fieldValue.Fields))
for refStr, value := range fieldValue.Fields {
_, isEntity := entityRefs[refStr]
if !isEntity { // is feature ref
diff --git a/sdk/go/client_test.go b/sdk/go/client_test.go
new file mode 100644
index 0000000000..1ad6629e9a
--- /dev/null
+++ b/sdk/go/client_test.go
@@ -0,0 +1,95 @@
+package feast
+
+import (
+ "context"
+ "testing"
+
+ "github.com/feast-dev/feast/sdk/go/mocks"
+ "github.com/feast-dev/feast/sdk/go/protos/feast/serving"
+ "github.com/feast-dev/feast/sdk/go/protos/feast/types"
+ "github.com/golang/mock/gomock"
+ "github.com/google/go-cmp/cmp"
+ "github.com/opentracing/opentracing-go"
+)
+
+func TestGetOnlineFeatures(t *testing.T) {
+ tt := []struct {
+ name string
+ req OnlineFeaturesRequest
+ recieve OnlineFeaturesResponse
+ want OnlineFeaturesResponse
+ wantErr bool
+ err error
+ }{
+ {
+ name: "Valid client Get Online Features call",
+ req: OnlineFeaturesRequest{
+ Features: []string{
+ "driver:rating",
+ "rating",
+ },
+ Entities: []Row{
+ {"driver_id": Int64Val(1)},
+ },
+ Project: "driver_project",
+ },
+ // check GetOnlineFeatures() should strip projects returned from serving
+ recieve: OnlineFeaturesResponse{
+ RawResponse: &serving.GetOnlineFeaturesResponse{
+ FieldValues: []*serving.GetOnlineFeaturesResponse_FieldValues{
+ {
+ Fields: map[string]*types.Value{
+ "driver_project/driver:rating": Int64Val(1),
+ "driver_project/rating": Int64Val(1),
+ },
+ },
+ },
+ },
+ },
+ want: OnlineFeaturesResponse{
+ RawResponse: &serving.GetOnlineFeaturesResponse{
+ FieldValues: []*serving.GetOnlineFeaturesResponse_FieldValues{
+ {
+ Fields: map[string]*types.Value{
+ "driver:rating": Int64Val(1),
+ "rating": Int64Val(1),
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+ // mock feast grpc client get online feature requestss
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ cli := mock_serving.NewMockServingServiceClient(ctrl)
+ ctx := context.Background()
+ _, traceCtx := opentracing.StartSpanFromContext(ctx, "get_online_features")
+ rawRequest, _ := tc.req.buildRequest()
+ resp := tc.recieve.RawResponse
+ cli.EXPECT().GetOnlineFeatures(traceCtx, rawRequest).Return(resp, nil).Times(1)
+
+ client := &GrpcClient{
+ cli: cli,
+ }
+ got, err := client.GetOnlineFeatures(ctx, &tc.req)
+
+ if err != nil && !tc.wantErr {
+ t.Errorf("error = %v, wantErr %v", err, tc.wantErr)
+ return
+ }
+ if tc.wantErr && err.Error() != tc.err.Error() {
+ t.Errorf("error = %v, expected err = %v", err, tc.err)
+ return
+ }
+ // TODO: compare directly once OnlineFeaturesResponse no longer embeds a rawResponse.
+ if !cmp.Equal(got.RawResponse.String(), tc.want.RawResponse.String()) {
+ t.Errorf("got: \n%v\nwant:\n%v", got.RawResponse.String(), tc.want.RawResponse.String())
+ }
+ })
+ }
+}
diff --git a/sdk/go/go.mod b/sdk/go/go.mod
index 06bd693616..656810354b 100644
--- a/sdk/go/go.mod
+++ b/sdk/go/go.mod
@@ -3,6 +3,7 @@ module github.com/feast-dev/feast/sdk/go
go 1.13
require (
+ github.com/golang/mock v1.4.3
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0
github.com/google/go-cmp v0.4.0
github.com/opentracing/opentracing-go v1.1.0
diff --git a/sdk/go/go.sum b/sdk/go/go.sum
index f08f302617..39e6ef923b 100644
--- a/sdk/go/go.sum
+++ b/sdk/go/go.sum
@@ -12,7 +12,10 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekf
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
+github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
+github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw=
+github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
@@ -63,6 +66,7 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd h1:r7DufRZuZbWB7j439YfAzP8RPDa9unLkpwQKUYbIMPI=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
@@ -71,6 +75,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135 h1:5Beo0mZN8dRzgrMMkDp0jc8YXQKx9DiJ2k1dkvGsn5A=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -100,3 +106,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
+rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
diff --git a/sdk/go/mocks/serving_mock.go b/sdk/go/mocks/serving_mock.go
new file mode 100644
index 0000000000..bea754c1bf
--- /dev/null
+++ b/sdk/go/mocks/serving_mock.go
@@ -0,0 +1,116 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/feast-dev/feast/sdk/go/protos/feast/serving (interfaces: ServingServiceClient)
+
+// Package mock_serving is a generated GoMock package.
+package mock_serving
+
+import (
+ context "context"
+ serving "github.com/feast-dev/feast/sdk/go/protos/feast/serving"
+ gomock "github.com/golang/mock/gomock"
+ grpc "google.golang.org/grpc"
+ reflect "reflect"
+)
+
+// MockServingServiceClient is a mock of ServingServiceClient interface
+type MockServingServiceClient struct {
+ ctrl *gomock.Controller
+ recorder *MockServingServiceClientMockRecorder
+}
+
+// MockServingServiceClientMockRecorder is the mock recorder for MockServingServiceClient
+type MockServingServiceClientMockRecorder struct {
+ mock *MockServingServiceClient
+}
+
+// NewMockServingServiceClient creates a new mock instance
+func NewMockServingServiceClient(ctrl *gomock.Controller) *MockServingServiceClient {
+ mock := &MockServingServiceClient{ctrl: ctrl}
+ mock.recorder = &MockServingServiceClientMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockServingServiceClient) EXPECT() *MockServingServiceClientMockRecorder {
+ return m.recorder
+}
+
+// GetBatchFeatures mocks base method
+func (m *MockServingServiceClient) GetBatchFeatures(arg0 context.Context, arg1 *serving.GetBatchFeaturesRequest, arg2 ...grpc.CallOption) (*serving.GetBatchFeaturesResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetBatchFeatures", varargs...)
+ ret0, _ := ret[0].(*serving.GetBatchFeaturesResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetBatchFeatures indicates an expected call of GetBatchFeatures
+func (mr *MockServingServiceClientMockRecorder) GetBatchFeatures(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBatchFeatures", reflect.TypeOf((*MockServingServiceClient)(nil).GetBatchFeatures), varargs...)
+}
+
+// GetFeastServingInfo mocks base method
+func (m *MockServingServiceClient) GetFeastServingInfo(arg0 context.Context, arg1 *serving.GetFeastServingInfoRequest, arg2 ...grpc.CallOption) (*serving.GetFeastServingInfoResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetFeastServingInfo", varargs...)
+ ret0, _ := ret[0].(*serving.GetFeastServingInfoResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetFeastServingInfo indicates an expected call of GetFeastServingInfo
+func (mr *MockServingServiceClientMockRecorder) GetFeastServingInfo(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFeastServingInfo", reflect.TypeOf((*MockServingServiceClient)(nil).GetFeastServingInfo), varargs...)
+}
+
+// GetJob mocks base method
+func (m *MockServingServiceClient) GetJob(arg0 context.Context, arg1 *serving.GetJobRequest, arg2 ...grpc.CallOption) (*serving.GetJobResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetJob", varargs...)
+ ret0, _ := ret[0].(*serving.GetJobResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetJob indicates an expected call of GetJob
+func (mr *MockServingServiceClientMockRecorder) GetJob(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJob", reflect.TypeOf((*MockServingServiceClient)(nil).GetJob), varargs...)
+}
+
+// GetOnlineFeatures mocks base method
+func (m *MockServingServiceClient) GetOnlineFeatures(arg0 context.Context, arg1 *serving.GetOnlineFeaturesRequest, arg2 ...grpc.CallOption) (*serving.GetOnlineFeaturesResponse, error) {
+ m.ctrl.T.Helper()
+ varargs := []interface{}{arg0, arg1}
+ for _, a := range arg2 {
+ varargs = append(varargs, a)
+ }
+ ret := m.ctrl.Call(m, "GetOnlineFeatures", varargs...)
+ ret0, _ := ret[0].(*serving.GetOnlineFeaturesResponse)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetOnlineFeatures indicates an expected call of GetOnlineFeatures
+func (mr *MockServingServiceClientMockRecorder) GetOnlineFeatures(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ varargs := append([]interface{}{arg0, arg1}, arg2...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOnlineFeatures", reflect.TypeOf((*MockServingServiceClient)(nil).GetOnlineFeatures), varargs...)
+}
diff --git a/sdk/java/pom.xml b/sdk/java/pom.xml
index 75d82edcc0..7856afd47e 100644
--- a/sdk/java/pom.xml
+++ b/sdk/java/pom.xml
@@ -18,6 +18,7 @@
5.5.2
+ 2.28.2
@@ -40,6 +41,10 @@
io.grpc
grpc-stub
+
+ io.grpc
+ grpc-testing
+
com.google.protobuf
protobuf-java-util
@@ -55,7 +60,7 @@
slf4j-api
-
+
org.junit.jupiter
junit-jupiter-engine
@@ -80,7 +85,24 @@
3.6
compile
-
+
+ org.mockito
+ mockito-core
+ ${mockito.version}
+ test
+
+
+ org.mockito
+ mockito-inline
+ ${mockito.version}
+ test
+
+
+ org.junit.vintage
+ junit-vintage-engine
+ ${junit.version}
+ test
+
diff --git a/sdk/java/src/test/java/com/gojek/feast/FeastClientTest.java b/sdk/java/src/test/java/com/gojek/feast/FeastClientTest.java
new file mode 100644
index 0000000000..717792244e
--- /dev/null
+++ b/sdk/java/src/test/java/com/gojek/feast/FeastClientTest.java
@@ -0,0 +1,146 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2019 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.gojek.feast;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.AdditionalAnswers.delegatesTo;
+import static org.mockito.Mockito.mock;
+
+import com.google.protobuf.Timestamp;
+import feast.proto.serving.ServingAPIProto.FeatureReference;
+import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest;
+import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow;
+import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesResponse;
+import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesResponse.FieldValues;
+import feast.proto.serving.ServingServiceGrpc.ServingServiceImplBase;
+import feast.proto.types.ValueProto.Value;
+import io.grpc.ManagedChannel;
+import io.grpc.Status;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+
+public class FeastClientTest {
+ @Rule public GrpcCleanupRule grpcRule;
+ private ServingServiceImplBase servingMock =
+ mock(
+ ServingServiceImplBase.class,
+ delegatesTo(
+ new ServingServiceImplBase() {
+ @Override
+ public void getOnlineFeatures(
+ GetOnlineFeaturesRequest request,
+ StreamObserver responseObserver) {
+ if (!request.equals(FeastClientTest.getFakeRequest())) {
+ responseObserver.onError(Status.UNKNOWN.asRuntimeException());
+ }
+
+ responseObserver.onNext(FeastClientTest.getFakeResponse());
+ responseObserver.onCompleted();
+ }
+ }));
+ private FeastClient client;
+
+ @Before
+ public void setup() throws Exception {
+ this.grpcRule = new GrpcCleanupRule();
+ // setup fake serving service
+ String serverName = InProcessServerBuilder.generateName();
+ this.grpcRule.register(
+ InProcessServerBuilder.forName(serverName)
+ .directExecutor()
+ .addService(this.servingMock)
+ .build()
+ .start());
+
+ // setup test feast client target
+ ManagedChannel channel =
+ this.grpcRule.register(
+ InProcessChannelBuilder.forName(serverName).directExecutor().build());
+ this.client = new FeastClient(channel);
+ }
+
+ @Test
+ public void shouldGetOnlineFeatures() {
+ List rows =
+ this.client.getOnlineFeatures(
+ Arrays.asList("driver:name", "rating"),
+ Arrays.asList(
+ Row.create().set("driver_id", 1).setEntityTimestamp(Instant.ofEpochSecond(100))),
+ "driver_project");
+
+ assertEquals(
+ rows.get(0).getFields(),
+ new HashMap() {
+ {
+ put("driver_id", intValue(1));
+ put("driver:name", strValue("david"));
+ put("rating", intValue(3));
+ }
+ });
+ }
+
+ private static GetOnlineFeaturesRequest getFakeRequest() {
+ // setup mock serving service stub
+ return GetOnlineFeaturesRequest.newBuilder()
+ .addFeatures(
+ FeatureReference.newBuilder()
+ .setProject("driver_project")
+ .setFeatureSet("driver")
+ .setName("name")
+ .build())
+ .addFeatures(
+ FeatureReference.newBuilder().setProject("driver_project").setName("rating").build())
+ .addEntityRows(
+ EntityRow.newBuilder()
+ .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100))
+ .putFields("driver_id", intValue(1)))
+ .build();
+ }
+
+ private static GetOnlineFeaturesResponse getFakeResponse() {
+ return GetOnlineFeaturesResponse.newBuilder()
+ .addFieldValues(
+ FieldValues.newBuilder()
+ .putAllFields(
+ new HashMap() {
+ {
+ put("driver_id", intValue(1));
+ put("driver:name", strValue("david"));
+ put("rating", intValue(3));
+ }
+ })
+ .build())
+ .build();
+ }
+
+ private static Value strValue(String val) {
+ return Value.newBuilder().setStringVal(val).build();
+ }
+
+ private static Value intValue(int val) {
+ return Value.newBuilder().setInt32Val(val).build();
+ }
+}
diff --git a/sdk/python/feast/client.py b/sdk/python/feast/client.py
index 1c34213ad8..f10204c59d 100644
--- a/sdk/python/feast/client.py
+++ b/sdk/python/feast/client.py
@@ -658,14 +658,11 @@ def get_online_features(
# strip the project part the string feature references returned from serving
strip_fields = {}
for ref_str, value in field_value.fields.items():
- # find and ignore entities
- if ref_str in entity_refs:
- strip_fields[ref_str] = value
- else:
- strip_ref_str = repr(
+ if ref_str not in entity_refs:
+ ref_str = repr(
FeatureRef.from_str(ref_str, ignore_project=True)
)
- strip_fields[strip_ref_str] = value
+ strip_fields[ref_str] = value
strip_field_values.append(
GetOnlineFeaturesResponse.FieldValues(fields=strip_fields)
)
diff --git a/sdk/python/tests/test_client.py b/sdk/python/tests/test_client.py
index 380557ce92..75d45a583f 100644
--- a/sdk/python/tests/test_client.py
+++ b/sdk/python/tests/test_client.py
@@ -47,11 +47,12 @@
from feast.core.Source_pb2 import KafkaSourceConfig, Source, SourceType
from feast.core.Store_pb2 import Store
from feast.entity import Entity
-from feast.feature_set import Feature, FeatureSet, FeatureSetRef
+from feast.feature import Feature
+from feast.feature_set import FeatureSet, FeatureSetRef
from feast.job import IngestJob
+from feast.serving.ServingService_pb2 import DataFormat, FeastServingType
+from feast.serving.ServingService_pb2 import FeatureReference as FeatureRefProto
from feast.serving.ServingService_pb2 import (
- DataFormat,
- FeastServingType,
GetBatchFeaturesResponse,
GetFeastServingInfoResponse,
GetJobResponse,
@@ -210,42 +211,50 @@ def test_get_online_features(self, mocked_client, mocker):
def int_val(x):
return ValueProto.Value(int64_val=x)
- # serving can return feature references with projects,
- # get_online_features() should strip the project part.
- field_values = GetOnlineFeaturesResponse.FieldValues(
- fields={
- "driver_project/driver:driver_id": int_val(1),
- "driver_project/driver_id": int_val(9),
- }
+ request = GetOnlineFeaturesRequest()
+ request.features.extend(
+ [
+ FeatureRefProto(
+ project="driver_project", feature_set="driver", name="age"
+ ),
+ FeatureRefProto(project="driver_project", name="rating"),
+ ]
)
-
- response = GetOnlineFeaturesResponse()
- entity_rows = []
+ recieve_response = GetOnlineFeaturesResponse()
for row_number in range(1, ROW_COUNT + 1):
- response.field_values.append(field_values)
- entity_rows.append(
+ request.entity_rows.append(
GetOnlineFeaturesRequest.EntityRow(
- fields={"customer_id": int_val(row_number)}
+ fields={"driver_id": int_val(row_number)}
)
+ ),
+ field_values = GetOnlineFeaturesResponse.FieldValues(
+ fields={
+ "driver_id": int_val(row_number),
+ "driver_project/driver:age": int_val(1),
+ "driver_project/rating": int_val(9),
+ }
)
+ recieve_response.field_values.append(field_values)
mocker.patch.object(
mocked_client._serving_service_stub,
"GetOnlineFeatures",
- return_value=response,
+ return_value=recieve_response,
)
-
- # NOTE: Feast Serving does not allow for feature references
- # that specify the same feature in the same request
- response = mocked_client.get_online_features(
- entity_rows=entity_rows,
- feature_refs=["driver:driver_id", "driver_id"],
+ got_response = mocked_client.get_online_features(
+ entity_rows=request.entity_rows,
+ feature_refs=["driver:age", "rating"],
project="driver_project",
) # type: GetOnlineFeaturesResponse
+ mocked_client._serving_service_stub.GetOnlineFeatures.assert_called_with(
+ request
+ )
+ got_fields = got_response.field_values[0].fields
assert (
- response.field_values[0].fields["driver:driver_id"].int64_val == 1
- and response.field_values[0].fields["driver_id"].int64_val == 9
+ got_fields["driver_id"] == int_val(1)
+ and got_fields["driver:age"] == int_val(1)
+ and got_fields["rating"] == int_val(9)
)
@pytest.mark.parametrize(
@@ -635,7 +644,7 @@ def test_feature_set_ingest_throws_exception_if_kafka_down(
)
with pytest.raises(exception):
- test_client.ingest("driver-feature-set", dataframe)
+ test_client.ingest("driver-feature-set", dataframe, timeout=1)
@pytest.mark.parametrize(
"dataframe,exception,test_client",