## Setup

In [1]:
from pprint import pprint

import ml_metadata as mlmd
from ml_metadata.metadata_store import metadata_store
from ml_metadata.proto import metadata_store_pb2

In [2]:
# Create the Store using local SQLite database
connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = 'metadata.sqlite.db'
connection_config.sqlite.connection_mode = 3 # READWRITE_OPENCREATE
store = metadata_store.MetadataStore(connection_config)

## Create Custom Types

In [3]:
# Registered Model
reg_model_type = metadata_store_pb2.ContextType()
reg_model_type.name = "odh.RegisteredModel"
reg_model_type_id = store.put_context_type(reg_model_type)
reg_model_type_id

10

In [4]:
# ModelVersion
model_version_type = metadata_store_pb2.ContextType()
model_version_type.name = "odh.ModelVersion"
model_version_type.properties["model_name"] = metadata_store_pb2.STRING
model_version_type.properties["version"] = metadata_store_pb2.STRING
model_version_type_id = store.put_context_type(model_version_type)
model_version_type_id

11

In [5]:
# ModelArtifact
model_artifact_type = metadata_store_pb2.ArtifactType()
model_artifact_type.name = "odh.ModelArtifact"
model_artifact_type_id = store.put_artifact_type(model_artifact_type)
model_artifact_type_id

12

In [6]:
# Ensure types have been created
assert store.get_context_type(type_name="odh.RegisteredModel").id == reg_model_type_id
assert store.get_context_type(type_name="odh.ModelVersion").id == model_version_type_id
assert store.get_artifact_type(type_name="odh.ModelArtifact").id == model_artifact_type_id

## Limitations

In [7]:
# Empty contexts set
assert len(store.get_contexts()) == 0

In [8]:
# Context are uniquely identified by <type, name> therefore we cannot have multiple ModelVersion with same name
# version_1 = metadata_store_pb2.Context()
# version_1.type_id = model_version_type_id
# version_1.name = "MyModel"
# version_1.properties["version"].string_value = "v1"
# store.put_contexts([version_1])

# version_2 = metadata_store_pb2.Context()
# version_2.type_id = model_version_type_id
# version_2.name = "MyModel"
# version_2.properties["version"].string_value = "v2"
# store.put_contexts([version_2])

## Model Registry Mapping

### Service Layer

In [9]:
# Create Utilities


def create_registered_model(name):
    """Create a new registered model"""

    reg_model = metadata_store_pb2.Context()
    reg_model.type_id = reg_model_type_id
    reg_model.name = name

    store.put_contexts([reg_model])


def create_model_version(name, version, uri):
    """Create a new model version
    A registered model must exist."""

    assert store.get_context_by_type_and_name(type_name="odh.RegisteredModel", context_name=name) != None

    full_name = f"{name}:{version}"
    
    # create the context
    model_v = metadata_store_pb2.Context()
    model_v.type_id = model_version_type_id
    model_v.name = full_name
    model_v.properties["model_name"].string_value = name
    model_v.properties["version"].string_value = version
    
    [model_v_id] = store.put_contexts([model_v])

    # create the artifact
    model_artifact = metadata_store_pb2.Artifact()
    model_artifact.type_id = model_artifact_type_id
    model_artifact.uri = uri
    model_artifact.name = f"{full_name}:model"

    [art_id] = store.put_artifacts([model_artifact])

    # create the attribution (context-artifact association)

    attribution = metadata_store_pb2.Attribution()
    attribution.artifact_id = art_id
    attribution.context_id = model_v_id

    store.put_attributions_and_associations([attribution], [])
    


In [10]:
# Get Utilities


def map_artifact(artifact):
    return {
        "name": artifact.name,
        "uri": artifact.uri
    }


def map_model_version(model_version):
    artifacts = store.get_artifacts_by_context(context_id=model_version.id)
    return {
        "model_name": model_version.properties["model_name"].string_value,
        "version": model_version.properties["version"].string_value,
        "create_time_since_epoch": model_version.create_time_since_epoch,
        "last_update_time_since_epoch": model_version.last_update_time_since_epoch,
        "artifacts": list(map(map_artifact, artifacts))
    }
    

def get_registered_model(name):
    """Retrieve a single registered model, together with all its versions"""

    reg_model = store.get_context_by_type_and_name(type_name="odh.RegisteredModel", context_name=name)
    query = 'type = "odh.ModelVersion" and properties.model_name.string_value = "{}"'.format(name)
    versions = store.get_contexts(list_options=mlmd.ListOptions(filter_query=query))

    return {
        "name": reg_model.name,
        "versions": list(map(map_model_version, versions))
    }


def get_model_version(name, version):
    """Retrieve a single registered model version"""

    reg_model = store.get_context_by_type_and_name(type_name="odh.RegisteredModel", context_name=name)
    query = 'type = "odh.ModelVersion" and properties.model_name.string_value = "{}" and properties.version.string_value = "{}"'.format(name, version)
    versions = store.get_contexts(list_options=mlmd.ListOptions(filter_query=query))
    assert len(versions) == 1
    
    return {
        "name": reg_model.name,
        "version": map_model_version(versions[0])
    }

### Example

In [11]:
create_registered_model(name="PricingModel")

In [12]:
create_model_version(name="PricingModel", version="v1", uri="/path/to/model/v1")
create_model_version(name="PricingModel", version="v2", uri="/path/to/model/v2")
create_model_version(name="PricingModel", version="v3", uri="/path/to/model/v3")

In [13]:
model = get_registered_model(name="PricingModel")
pprint(model)

{'name': 'PricingModel',
 'versions': [{'artifacts': [{'name': 'PricingModel:v1:model',
                              'uri': '/path/to/model/v1'}],
               'create_time_since_epoch': 1696845946045,
               'last_update_time_since_epoch': 1696845946045,
               'model_name': 'PricingModel',
               'version': 'v1'},
              {'artifacts': [{'name': 'PricingModel:v2:model',
                              'uri': '/path/to/model/v2'}],
               'create_time_since_epoch': 1696845946072,
               'last_update_time_since_epoch': 1696845946072,
               'model_name': 'PricingModel',
               'version': 'v2'},
              {'artifacts': [{'name': 'PricingModel:v3:model',
                              'uri': '/path/to/model/v3'}],
               'create_time_since_epoch': 1696845946106,
               'last_update_time_since_epoch': 1696845946106,
               'model_name': 'PricingModel',
               'version': 'v3'}]}


In [14]:
model_v1 = get_model_version(name="PricingModel", version="v1")
pprint(model_v1)

{'name': 'PricingModel',
 'version': {'artifacts': [{'name': 'PricingModel:v1:model',
                            'uri': '/path/to/model/v1'}],
             'create_time_since_epoch': 1696845946045,
             'last_update_time_since_epoch': 1696845946045,
             'model_name': 'PricingModel',
             'version': 'v1'}}


In [15]:
model_v2 = get_model_version(name="PricingModel", version="v2")
pprint(model_v2)

{'name': 'PricingModel',
 'version': {'artifacts': [{'name': 'PricingModel:v2:model',
                            'uri': '/path/to/model/v2'}],
             'create_time_since_epoch': 1696845946072,
             'last_update_time_since_epoch': 1696845946072,
             'model_name': 'PricingModel',
             'version': 'v2'}}


In [16]:
create_registered_model(name="Forecasting")

In [17]:
create_model_version(name="Forecasting", version="1.0", uri="/path/to/forecasting/v1.0")
create_model_version(name="Forecasting", version="1.1", uri="/path/to/forecasting/v1.1")

In [18]:
forecasting_v1 = get_model_version(name="Forecasting", version="1.0")
pprint(forecasting_v1)

{'name': 'Forecasting',
 'version': {'artifacts': [{'name': 'Forecasting:1.0:model',
                            'uri': '/path/to/forecasting/v1.0'}],
             'create_time_since_epoch': 1696845946199,
             'last_update_time_since_epoch': 1696845946199,
             'model_name': 'Forecasting',
             'version': '1.0'}}
