diff --git a/go.mod b/go.mod index 1be2396f6e..d5917cb33e 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,8 @@ require ( golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect golang.org/x/net v0.0.0-20200513185701-a91f0712d120 golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9 // indirect - golang.org/x/tools v0.0.0-20200528185414-6be401e3f76e // indirect + golang.org/x/tools v0.0.0-20200521211927-2b542361a4fc // indirect + google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587 // indirect google.golang.org/grpc v1.29.1 google.golang.org/protobuf v1.24.0 // indirect gopkg.in/russross/blackfriday.v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index 1b2aa8bf63..90a517da7e 100644 --- a/go.sum +++ b/go.sum @@ -462,8 +462,10 @@ golang.org/x/tools v0.0.0-20200504022951-6b6965ac5dd1 h1:C8rdnd6KieI73Z2Av0sS0t4 golang.org/x/tools v0.0.0-20200504022951-6b6965ac5dd1/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200515220128-d3bf790afa53 h1:vmsb6v0zUdmUlXfwKaYrHPPRCV0lHq/IwNIf0ASGjyQ= golang.org/x/tools v0.0.0-20200515220128-d3bf790afa53/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20200528185414-6be401e3f76e h1:jTL1CJ2kmavapMVdBKy6oVrhBHByRCMfykS45+lEFQk= -golang.org/x/tools v0.0.0-20200528185414-6be401e3f76e/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200519205726-57a9e4404bf7 h1:nm4zDh9WvH4jiuUpMY5RUsvOwrtTVVAsUaCdLW71hfY= +golang.org/x/tools v0.0.0-20200519205726-57a9e4404bf7/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200521211927-2b542361a4fc h1:6m2YO+AmBApbUOmhsghW+IfRyZOY4My4UYvQQrEpHfY= +golang.org/x/tools v0.0.0-20200521211927-2b542361a4fc/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/sdk/go/client.go b/sdk/go/client.go index ef8ec8b8cf..1f0cbaa3b9 100644 --- a/sdk/go/client.go +++ b/sdk/go/client.go @@ -61,7 +61,7 @@ func (fc *GrpcClient) GetOnlineFeatures(ctx context.Context, req *OnlineFeatures // strip projects from to projects for _, fieldValue := range resp.GetFieldValues() { - stripFields := make(map[string]*types.Value) + stripFields := make(map[string]*types.Value, len(fieldValue.Fields)) for refStr, value := range fieldValue.Fields { _, isEntity := entityRefs[refStr] if !isEntity { // is feature ref diff --git a/sdk/go/client_test.go b/sdk/go/client_test.go new file mode 100644 index 0000000000..1ad6629e9a --- /dev/null +++ b/sdk/go/client_test.go @@ -0,0 +1,95 @@ +package feast + +import ( + "context" + "testing" + + "github.com/feast-dev/feast/sdk/go/mocks" + "github.com/feast-dev/feast/sdk/go/protos/feast/serving" + "github.com/feast-dev/feast/sdk/go/protos/feast/types" + "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" + "github.com/opentracing/opentracing-go" +) + +func TestGetOnlineFeatures(t *testing.T) { + tt := []struct { + name string + req OnlineFeaturesRequest + recieve OnlineFeaturesResponse + want OnlineFeaturesResponse + wantErr bool + err error + }{ + { + name: "Valid client Get Online Features call", + req: OnlineFeaturesRequest{ + Features: []string{ + "driver:rating", + "rating", + }, + Entities: []Row{ + {"driver_id": Int64Val(1)}, + }, + Project: "driver_project", + }, + // check GetOnlineFeatures() should strip projects returned from serving + recieve: OnlineFeaturesResponse{ + RawResponse: &serving.GetOnlineFeaturesResponse{ + FieldValues: []*serving.GetOnlineFeaturesResponse_FieldValues{ + { + Fields: map[string]*types.Value{ + "driver_project/driver:rating": Int64Val(1), + "driver_project/rating": Int64Val(1), + }, + }, + }, + }, + }, + want: OnlineFeaturesResponse{ + RawResponse: &serving.GetOnlineFeaturesResponse{ + FieldValues: []*serving.GetOnlineFeaturesResponse_FieldValues{ + { + Fields: map[string]*types.Value{ + "driver:rating": Int64Val(1), + "rating": Int64Val(1), + }, + }, + }, + }, + }, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + // mock feast grpc client get online feature requestss + ctrl := gomock.NewController(t) + defer ctrl.Finish() + cli := mock_serving.NewMockServingServiceClient(ctrl) + ctx := context.Background() + _, traceCtx := opentracing.StartSpanFromContext(ctx, "get_online_features") + rawRequest, _ := tc.req.buildRequest() + resp := tc.recieve.RawResponse + cli.EXPECT().GetOnlineFeatures(traceCtx, rawRequest).Return(resp, nil).Times(1) + + client := &GrpcClient{ + cli: cli, + } + got, err := client.GetOnlineFeatures(ctx, &tc.req) + + if err != nil && !tc.wantErr { + t.Errorf("error = %v, wantErr %v", err, tc.wantErr) + return + } + if tc.wantErr && err.Error() != tc.err.Error() { + t.Errorf("error = %v, expected err = %v", err, tc.err) + return + } + // TODO: compare directly once OnlineFeaturesResponse no longer embeds a rawResponse. + if !cmp.Equal(got.RawResponse.String(), tc.want.RawResponse.String()) { + t.Errorf("got: \n%v\nwant:\n%v", got.RawResponse.String(), tc.want.RawResponse.String()) + } + }) + } +} diff --git a/sdk/go/go.mod b/sdk/go/go.mod index 06bd693616..656810354b 100644 --- a/sdk/go/go.mod +++ b/sdk/go/go.mod @@ -3,6 +3,7 @@ module github.com/feast-dev/feast/sdk/go go 1.13 require ( + github.com/golang/mock v1.4.3 github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0 github.com/google/go-cmp v0.4.0 github.com/opentracing/opentracing-go v1.1.0 diff --git a/sdk/go/go.sum b/sdk/go/go.sum index f08f302617..39e6ef923b 100644 --- a/sdk/go/go.sum +++ b/sdk/go/go.sum @@ -12,7 +12,10 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekf github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= @@ -63,6 +66,7 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd h1:r7DufRZuZbWB7j439YfAzP8RPDa9unLkpwQKUYbIMPI= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= @@ -71,6 +75,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135 h1:5Beo0mZN8dRzgrMMkDp0jc8YXQKx9DiJ2k1dkvGsn5A= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -100,3 +106,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/sdk/go/mocks/serving_mock.go b/sdk/go/mocks/serving_mock.go new file mode 100644 index 0000000000..bea754c1bf --- /dev/null +++ b/sdk/go/mocks/serving_mock.go @@ -0,0 +1,116 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/feast-dev/feast/sdk/go/protos/feast/serving (interfaces: ServingServiceClient) + +// Package mock_serving is a generated GoMock package. +package mock_serving + +import ( + context "context" + serving "github.com/feast-dev/feast/sdk/go/protos/feast/serving" + gomock "github.com/golang/mock/gomock" + grpc "google.golang.org/grpc" + reflect "reflect" +) + +// MockServingServiceClient is a mock of ServingServiceClient interface +type MockServingServiceClient struct { + ctrl *gomock.Controller + recorder *MockServingServiceClientMockRecorder +} + +// MockServingServiceClientMockRecorder is the mock recorder for MockServingServiceClient +type MockServingServiceClientMockRecorder struct { + mock *MockServingServiceClient +} + +// NewMockServingServiceClient creates a new mock instance +func NewMockServingServiceClient(ctrl *gomock.Controller) *MockServingServiceClient { + mock := &MockServingServiceClient{ctrl: ctrl} + mock.recorder = &MockServingServiceClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockServingServiceClient) EXPECT() *MockServingServiceClientMockRecorder { + return m.recorder +} + +// GetBatchFeatures mocks base method +func (m *MockServingServiceClient) GetBatchFeatures(arg0 context.Context, arg1 *serving.GetBatchFeaturesRequest, arg2 ...grpc.CallOption) (*serving.GetBatchFeaturesResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetBatchFeatures", varargs...) + ret0, _ := ret[0].(*serving.GetBatchFeaturesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBatchFeatures indicates an expected call of GetBatchFeatures +func (mr *MockServingServiceClientMockRecorder) GetBatchFeatures(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBatchFeatures", reflect.TypeOf((*MockServingServiceClient)(nil).GetBatchFeatures), varargs...) +} + +// GetFeastServingInfo mocks base method +func (m *MockServingServiceClient) GetFeastServingInfo(arg0 context.Context, arg1 *serving.GetFeastServingInfoRequest, arg2 ...grpc.CallOption) (*serving.GetFeastServingInfoResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetFeastServingInfo", varargs...) + ret0, _ := ret[0].(*serving.GetFeastServingInfoResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetFeastServingInfo indicates an expected call of GetFeastServingInfo +func (mr *MockServingServiceClientMockRecorder) GetFeastServingInfo(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFeastServingInfo", reflect.TypeOf((*MockServingServiceClient)(nil).GetFeastServingInfo), varargs...) +} + +// GetJob mocks base method +func (m *MockServingServiceClient) GetJob(arg0 context.Context, arg1 *serving.GetJobRequest, arg2 ...grpc.CallOption) (*serving.GetJobResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetJob", varargs...) + ret0, _ := ret[0].(*serving.GetJobResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetJob indicates an expected call of GetJob +func (mr *MockServingServiceClientMockRecorder) GetJob(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJob", reflect.TypeOf((*MockServingServiceClient)(nil).GetJob), varargs...) +} + +// GetOnlineFeatures mocks base method +func (m *MockServingServiceClient) GetOnlineFeatures(arg0 context.Context, arg1 *serving.GetOnlineFeaturesRequest, arg2 ...grpc.CallOption) (*serving.GetOnlineFeaturesResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetOnlineFeatures", varargs...) + ret0, _ := ret[0].(*serving.GetOnlineFeaturesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOnlineFeatures indicates an expected call of GetOnlineFeatures +func (mr *MockServingServiceClientMockRecorder) GetOnlineFeatures(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOnlineFeatures", reflect.TypeOf((*MockServingServiceClient)(nil).GetOnlineFeatures), varargs...) +} diff --git a/sdk/java/pom.xml b/sdk/java/pom.xml index 75d82edcc0..7856afd47e 100644 --- a/sdk/java/pom.xml +++ b/sdk/java/pom.xml @@ -18,6 +18,7 @@ 5.5.2 + 2.28.2 @@ -40,6 +41,10 @@ io.grpc grpc-stub + + io.grpc + grpc-testing + com.google.protobuf protobuf-java-util @@ -55,7 +60,7 @@ slf4j-api - + org.junit.jupiter junit-jupiter-engine @@ -80,7 +85,24 @@ 3.6 compile - + + org.mockito + mockito-core + ${mockito.version} + test + + + org.mockito + mockito-inline + ${mockito.version} + test + + + org.junit.vintage + junit-vintage-engine + ${junit.version} + test + diff --git a/sdk/java/src/test/java/com/gojek/feast/FeastClientTest.java b/sdk/java/src/test/java/com/gojek/feast/FeastClientTest.java new file mode 100644 index 0000000000..717792244e --- /dev/null +++ b/sdk/java/src/test/java/com/gojek/feast/FeastClientTest.java @@ -0,0 +1,146 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.gojek.feast; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.Mockito.mock; + +import com.google.protobuf.Timestamp; +import feast.proto.serving.ServingAPIProto.FeatureReference; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesResponse; +import feast.proto.serving.ServingAPIProto.GetOnlineFeaturesResponse.FieldValues; +import feast.proto.serving.ServingServiceGrpc.ServingServiceImplBase; +import feast.proto.types.ValueProto.Value; +import io.grpc.ManagedChannel; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +public class FeastClientTest { + @Rule public GrpcCleanupRule grpcRule; + private ServingServiceImplBase servingMock = + mock( + ServingServiceImplBase.class, + delegatesTo( + new ServingServiceImplBase() { + @Override + public void getOnlineFeatures( + GetOnlineFeaturesRequest request, + StreamObserver responseObserver) { + if (!request.equals(FeastClientTest.getFakeRequest())) { + responseObserver.onError(Status.UNKNOWN.asRuntimeException()); + } + + responseObserver.onNext(FeastClientTest.getFakeResponse()); + responseObserver.onCompleted(); + } + })); + private FeastClient client; + + @Before + public void setup() throws Exception { + this.grpcRule = new GrpcCleanupRule(); + // setup fake serving service + String serverName = InProcessServerBuilder.generateName(); + this.grpcRule.register( + InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(this.servingMock) + .build() + .start()); + + // setup test feast client target + ManagedChannel channel = + this.grpcRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + this.client = new FeastClient(channel); + } + + @Test + public void shouldGetOnlineFeatures() { + List rows = + this.client.getOnlineFeatures( + Arrays.asList("driver:name", "rating"), + Arrays.asList( + Row.create().set("driver_id", 1).setEntityTimestamp(Instant.ofEpochSecond(100))), + "driver_project"); + + assertEquals( + rows.get(0).getFields(), + new HashMap() { + { + put("driver_id", intValue(1)); + put("driver:name", strValue("david")); + put("rating", intValue(3)); + } + }); + } + + private static GetOnlineFeaturesRequest getFakeRequest() { + // setup mock serving service stub + return GetOnlineFeaturesRequest.newBuilder() + .addFeatures( + FeatureReference.newBuilder() + .setProject("driver_project") + .setFeatureSet("driver") + .setName("name") + .build()) + .addFeatures( + FeatureReference.newBuilder().setProject("driver_project").setName("rating").build()) + .addEntityRows( + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("driver_id", intValue(1))) + .build(); + } + + private static GetOnlineFeaturesResponse getFakeResponse() { + return GetOnlineFeaturesResponse.newBuilder() + .addFieldValues( + FieldValues.newBuilder() + .putAllFields( + new HashMap() { + { + put("driver_id", intValue(1)); + put("driver:name", strValue("david")); + put("rating", intValue(3)); + } + }) + .build()) + .build(); + } + + private static Value strValue(String val) { + return Value.newBuilder().setStringVal(val).build(); + } + + private static Value intValue(int val) { + return Value.newBuilder().setInt32Val(val).build(); + } +} diff --git a/sdk/python/feast/client.py b/sdk/python/feast/client.py index 1c34213ad8..f10204c59d 100644 --- a/sdk/python/feast/client.py +++ b/sdk/python/feast/client.py @@ -658,14 +658,11 @@ def get_online_features( # strip the project part the string feature references returned from serving strip_fields = {} for ref_str, value in field_value.fields.items(): - # find and ignore entities - if ref_str in entity_refs: - strip_fields[ref_str] = value - else: - strip_ref_str = repr( + if ref_str not in entity_refs: + ref_str = repr( FeatureRef.from_str(ref_str, ignore_project=True) ) - strip_fields[strip_ref_str] = value + strip_fields[ref_str] = value strip_field_values.append( GetOnlineFeaturesResponse.FieldValues(fields=strip_fields) ) diff --git a/sdk/python/tests/test_client.py b/sdk/python/tests/test_client.py index 380557ce92..75d45a583f 100644 --- a/sdk/python/tests/test_client.py +++ b/sdk/python/tests/test_client.py @@ -47,11 +47,12 @@ from feast.core.Source_pb2 import KafkaSourceConfig, Source, SourceType from feast.core.Store_pb2 import Store from feast.entity import Entity -from feast.feature_set import Feature, FeatureSet, FeatureSetRef +from feast.feature import Feature +from feast.feature_set import FeatureSet, FeatureSetRef from feast.job import IngestJob +from feast.serving.ServingService_pb2 import DataFormat, FeastServingType +from feast.serving.ServingService_pb2 import FeatureReference as FeatureRefProto from feast.serving.ServingService_pb2 import ( - DataFormat, - FeastServingType, GetBatchFeaturesResponse, GetFeastServingInfoResponse, GetJobResponse, @@ -210,42 +211,50 @@ def test_get_online_features(self, mocked_client, mocker): def int_val(x): return ValueProto.Value(int64_val=x) - # serving can return feature references with projects, - # get_online_features() should strip the project part. - field_values = GetOnlineFeaturesResponse.FieldValues( - fields={ - "driver_project/driver:driver_id": int_val(1), - "driver_project/driver_id": int_val(9), - } + request = GetOnlineFeaturesRequest() + request.features.extend( + [ + FeatureRefProto( + project="driver_project", feature_set="driver", name="age" + ), + FeatureRefProto(project="driver_project", name="rating"), + ] ) - - response = GetOnlineFeaturesResponse() - entity_rows = [] + recieve_response = GetOnlineFeaturesResponse() for row_number in range(1, ROW_COUNT + 1): - response.field_values.append(field_values) - entity_rows.append( + request.entity_rows.append( GetOnlineFeaturesRequest.EntityRow( - fields={"customer_id": int_val(row_number)} + fields={"driver_id": int_val(row_number)} ) + ), + field_values = GetOnlineFeaturesResponse.FieldValues( + fields={ + "driver_id": int_val(row_number), + "driver_project/driver:age": int_val(1), + "driver_project/rating": int_val(9), + } ) + recieve_response.field_values.append(field_values) mocker.patch.object( mocked_client._serving_service_stub, "GetOnlineFeatures", - return_value=response, + return_value=recieve_response, ) - - # NOTE: Feast Serving does not allow for feature references - # that specify the same feature in the same request - response = mocked_client.get_online_features( - entity_rows=entity_rows, - feature_refs=["driver:driver_id", "driver_id"], + got_response = mocked_client.get_online_features( + entity_rows=request.entity_rows, + feature_refs=["driver:age", "rating"], project="driver_project", ) # type: GetOnlineFeaturesResponse + mocked_client._serving_service_stub.GetOnlineFeatures.assert_called_with( + request + ) + got_fields = got_response.field_values[0].fields assert ( - response.field_values[0].fields["driver:driver_id"].int64_val == 1 - and response.field_values[0].fields["driver_id"].int64_val == 9 + got_fields["driver_id"] == int_val(1) + and got_fields["driver:age"] == int_val(1) + and got_fields["rating"] == int_val(9) ) @pytest.mark.parametrize( @@ -635,7 +644,7 @@ def test_feature_set_ingest_throws_exception_if_kafka_down( ) with pytest.raises(exception): - test_client.ingest("driver-feature-set", dataframe) + test_client.ingest("driver-feature-set", dataframe, timeout=1) @pytest.mark.parametrize( "dataframe,exception,test_client",