Skip to content

Commit

Permalink
Improve core layer testing (kubeflow#85)
Browse files Browse the repository at this point in the history
* Improve core layer testing

* Treat ids as string on service layer

* Moved testutils inside internal package

* Adapt test to name prefix implementation
  • Loading branch information
lampajr committed Oct 30, 2023
1 parent 270521d commit a309537
Show file tree
Hide file tree
Showing 7 changed files with 1,384 additions and 196 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ test: gen
test-nocache: gen
go test ./internal/... -count=1

.PHONY: test-cover
test-cover: gen
go test ./internal/... -cover -count=1

.PHONY: run/migrate
run/migrate: gen
go run main.go migrate --logtostderr=true -m config/metadata-library
Expand Down
18 changes: 9 additions & 9 deletions internal/core/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@ type ModelRegistryApi interface {
// approach used by MLMD gRPC api. If Id is provided update the entity otherwise create a new one.
UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error)

GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error)
GetRegisteredModelById(id string) (*openapi.RegisteredModel, error)
GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error)
GetRegisteredModels(listOptions ListOptions) (*openapi.RegisteredModelList, error)

// MODEL VERSION

// Create a new Model Version
// or update a Model Version associated to a specific RegisteredModel identified by parentResourceId parameter
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error)
UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error)

GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error)
GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error)
GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error)
GetModelVersionById(id string) (*openapi.ModelVersion, error)
GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error)
GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error)

// MODEL ARTIFACT

// Create a new Artifact
// or update an Artifact associated to a specific ModelVersion identified by parentResourceId parameter
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error)
UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error)

GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error)
GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error)
GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error)
GetModelArtifactById(id string) (*openapi.ModelArtifact, error)
GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error)
GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error)
}
144 changes: 101 additions & 43 deletions internal/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log"
"strconv"

"github.com/opendatahub-io/model-registry/internal/core/mapper"
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
Expand Down Expand Up @@ -80,7 +79,11 @@ func NewModelRegistryService(cc grpc.ClientConnInterface) (ModelRegistryApi, err
// REGISTERED MODELS

func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) {
log.Printf("Creating or updating registered model for %s", *registeredModel.Name)
if registeredModel.Id == nil {
log.Printf("Creating registered model for %s", *registeredModel.Name)
} else {
log.Printf("Updating registered model %s for %s", *registeredModel.Id, *registeredModel.Name)
}

modelCtx, err := serv.mapper.MapFromRegisteredModel(registeredModel)
if err != nil {
Expand All @@ -96,27 +99,32 @@ func (serv *modelRegistryService) UpsertRegisteredModel(registeredModel *openapi
return nil, err
}

modelId := &modelCtxResp.ContextIds[0]
model, err := serv.GetRegisteredModelById((*BaseResourceId)(modelId))
idAsString := mapper.IdToString(modelCtxResp.ContextIds[0])
model, err := serv.GetRegisteredModelById(*idAsString)
if err != nil {
return nil, err
}

return model, nil
}

func (serv *modelRegistryService) GetRegisteredModelById(id *BaseResourceId) (*openapi.RegisteredModel, error) {
log.Printf("Getting registered model %d", *id)
func (serv *modelRegistryService) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) {
log.Printf("Getting registered model %s", id)

idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}

getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
ContextIds: []int64{int64(*id)},
ContextIds: []int64{int64(*idAsInt)},
})
if err != nil {
return nil, err
}

if len(getByIdResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple registered models found for id %d", *id)
return nil, fmt.Errorf("multiple registered models found for id %s", id)
}

regModel, err := serv.mapper.MapToRegisteredModel(getByIdResp.Contexts[0])
Expand Down Expand Up @@ -191,10 +199,20 @@ func (serv *modelRegistryService) GetRegisteredModels(listOptions ListOptions) (

// MODEL VERSIONS

func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *BaseResourceId) (*openapi.ModelVersion, error) {
registeredModel, err := serv.GetRegisteredModelById(parentResourceId)
func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, parentResourceId *string) (*openapi.ModelVersion, error) {
if modelVersion.Id == nil {
log.Printf("Creating model version")
} else {
log.Printf("Updating model version %s", *modelVersion.Id)
}

if parentResourceId == nil {
return nil, fmt.Errorf("missing registered model id, cannot create model version without registered model")
}

registeredModel, err := serv.GetRegisteredModelById(*parentResourceId)
if err != nil {
return nil, fmt.Errorf("not a valid registered model id: %d", *parentResourceId)
return nil, fmt.Errorf("not a valid registered model id: %s", *parentResourceId)
}
registeredModelIdCtxID, err := mapper.IdToInt64(*registeredModel.Id)
if err != nil {
Expand All @@ -216,34 +234,42 @@ func (serv *modelRegistryService) UpsertModelVersion(modelVersion *openapi.Model
}

modelId := &modelCtxResp.ContextIds[0]
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
ParentContexts: []*proto.ParentContext{{
ChildId: modelId,
ParentId: registeredModelIdCtxID}},
TransactionOptions: &proto.TransactionOptions{},
})
if err != nil {
return nil, err
if modelVersion.Id == nil {
_, err = serv.mlmdClient.PutParentContexts(context.Background(), &proto.PutParentContextsRequest{
ParentContexts: []*proto.ParentContext{{
ChildId: modelId,
ParentId: registeredModelIdCtxID}},
TransactionOptions: &proto.TransactionOptions{},
})
if err != nil {
return nil, err
}
}

model, err := serv.GetModelVersionById((*BaseResourceId)(modelId))
idAsString := mapper.IdToString(*modelId)
model, err := serv.GetModelVersionById(*idAsString)
if err != nil {
return nil, err
}

return model, nil
}

func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*openapi.ModelVersion, error) {
func (serv *modelRegistryService) GetModelVersionById(id string) (*openapi.ModelVersion, error) {
idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}

getByIdResp, err := serv.mlmdClient.GetContextsByID(context.Background(), &proto.GetContextsByIDRequest{
ContextIds: []int64{int64(*id)},
ContextIds: []int64{int64(*idAsInt)},
})
if err != nil {
return nil, err
}

if len(getByIdResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple model versions found for id %d", *id)
return nil, fmt.Errorf("multiple model versions found for id %s", id)
}

modelVer, err := serv.mapper.MapToModelVersion(getByIdResp.Contexts[0])
Expand All @@ -254,10 +280,14 @@ func (serv *modelRegistryService) GetModelVersionById(id *BaseResourceId) (*open
return modelVer, nil
}

func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelVersion, error) {
func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, parentResourceId *string, externalId *string) (*openapi.ModelVersion, error) {
filterQuery := ""
if versionName != nil && parentResourceId != nil {
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *versionName))
idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *versionName))
} else if externalId != nil {
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
}
Expand All @@ -273,7 +303,7 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
}

if len(getByParamsResp.Contexts) != 1 {
return nil, fmt.Errorf("multiple registered models found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
return nil, fmt.Errorf("multiple model versions found for versionName=%v, parentResourceId=%v, externalId=%v", zeroIfNil(versionName), zeroIfNil(parentResourceId), zeroIfNil(externalId))
}

modelVer, err := serv.mapper.MapToModelVersion(getByParamsResp.Contexts[0])
Expand All @@ -283,14 +313,14 @@ func (serv *modelRegistryService) GetModelVersionByParams(versionName *string, p
return modelVer, nil
}

func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelVersionList, error) {
func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, parentResourceId *string) (*openapi.ModelVersionList, error) {
listOperationOptions, err := BuildListOperationOptions(listOptions)
if err != nil {
return nil, err
}

if parentResourceId != nil {
queryParentCtxId := fmt.Sprintf("parent_contexts_a.type = %d", *parentResourceId)
queryParentCtxId := fmt.Sprintf("parent_contexts_a.id = %s", *parentResourceId)
listOperationOptions.FilterQuery = &queryParentCtxId
}

Expand Down Expand Up @@ -322,25 +352,36 @@ func (serv *modelRegistryService) GetModelVersions(listOptions ListOptions, pare

// MODEL ARTIFACTS

func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *BaseResourceId) (*openapi.ModelArtifact, error) {
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, (*int64)(parentResourceId))
func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, parentResourceId *string) (*openapi.ModelArtifact, error) {
if modelArtifact.Id == nil {
log.Printf("Creating model artifact")
} else {
log.Printf("Updating model artifact %s", *modelArtifact.Id)
}

idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
artifact := serv.mapper.MapFromModelArtifact(*modelArtifact, idAsInt)

artifactsResp, err := serv.mlmdClient.PutArtifacts(context.Background(), &proto.PutArtifactsRequest{
Artifacts: []*proto.Artifact{artifact},
})
if err != nil {
return nil, err
}
idString := strconv.FormatInt(artifactsResp.ArtifactIds[0], 10)
modelArtifact.Id = &idString

// add explicit association between artifacts and model version
if parentResourceId != nil {
modelVersionIdCtx := int64(*parentResourceId)
if parentResourceId != nil && modelArtifact.Id == nil {
modelVersionIdCtx, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
attributions := []*proto.Attribution{}
for _, a := range artifactsResp.ArtifactIds {
attributions = append(attributions, &proto.Attribution{
ContextId: &modelVersionIdCtx,
ContextId: modelVersionIdCtx,
ArtifactId: &a,
})
}
Expand All @@ -353,12 +394,22 @@ func (serv *modelRegistryService) UpsertModelArtifact(modelArtifact *openapi.Mod
}
}

return modelArtifact, nil
idAsString := mapper.IdToString(artifactsResp.ArtifactIds[0])
mapped, err := serv.GetModelArtifactById(*idAsString)
if err != nil {
return nil, err
}
return mapped, nil
}

func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*openapi.ModelArtifact, error) {
func (serv *modelRegistryService) GetModelArtifactById(id string) (*openapi.ModelArtifact, error) {
idAsInt, err := mapper.IdToInt64(id)
if err != nil {
return nil, err
}

artifactsResp, err := serv.mlmdClient.GetArtifactsByID(context.Background(), &proto.GetArtifactsByIDRequest{
ArtifactIds: []int64{int64(*id)},
ArtifactIds: []int64{int64(*idAsInt)},
})
if err != nil {
return nil, err
Expand All @@ -372,14 +423,18 @@ func (serv *modelRegistryService) GetModelArtifactById(id *BaseResourceId) (*ope
return result, nil
}

func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *BaseResourceId, externalId *string) (*openapi.ModelArtifact, error) {
func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string, parentResourceId *string, externalId *string) (*openapi.ModelArtifact, error) {
var artifact0 *proto.Artifact

filterQuery := ""
if externalId != nil {
filterQuery = fmt.Sprintf("external_id = \"%s\"", *externalId)
} else if artifactName != nil && parentResourceId != nil {
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned((*int64)(parentResourceId), *artifactName))
idAsInt, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
filterQuery = fmt.Sprintf("name = \"%s\"", mapper.PrefixWhenOwned(idAsInt, *artifactName))
} else {
return nil, fmt.Errorf("invalid parameters call, supply either (artifactName and parentResourceId), or externalId")
}
Expand All @@ -406,7 +461,7 @@ func (serv *modelRegistryService) GetModelArtifactByParams(artifactName *string,
return result, nil
}

func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *BaseResourceId) (*openapi.ModelArtifactList, error) {
func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, parentResourceId *string) (*openapi.ModelArtifactList, error) {
listOperationOptions, err := BuildListOperationOptions(listOptions)
if err != nil {
return nil, err
Expand All @@ -415,9 +470,12 @@ func (serv *modelRegistryService) GetModelArtifacts(listOptions ListOptions, par
var artifacts []*proto.Artifact
var nextPageToken *string
if parentResourceId != nil {
ctxId := int64(*parentResourceId)
ctxId, err := mapper.IdToInt64(*parentResourceId)
if err != nil {
return nil, err
}
artifactsResp, err := serv.mlmdClient.GetArtifactsByContext(context.Background(), &proto.GetArtifactsByContextRequest{
ContextId: &ctxId,
ContextId: ctxId,
Options: listOperationOptions,
})
if err != nil {
Expand Down
Loading

0 comments on commit a309537

Please sign in to comment.