## Generic Imports

In [16]:
import os
from pprint import pprint

## Setup ML Metadata

In [1]:
from ml_metadata.metadata_store import metadata_store
from ml_metadata.proto import metadata_store_pb2

In [5]:
connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = os.path.join(os.getcwd(), "mlmddb")
connection_config.sqlite.connection_mode = 3
store = metadata_store.MetadataStore(connection_config)

try: 
    artifact_type_id = store.get_artifact_type("dvc.Model").id
    print("artifact_type_id was found in DB")
except metadata_store.errors.NotFoundError:
    print("artifact_type_id was not found in DB, creating...")
    new_type = metadata_store_pb2.ArtifactType()
    new_type.name = "dvc.Model"
    artifact_type_id = store.put_artifact_type(new_type)

print(f"Using {artifact_type_id} for storing artifacts")

artifact_type_id was not found in DB, creating...
Using 11 for storing artifacts


## Load DVC Model

In [10]:
import dvc.api
import pickle

In [8]:
modelpkl = dvc.api.read(
    'model.pkl',
    repo='https://github.com/iterative/example-get-started',
    mode='rb',
    rev='text-classification@v1.2.0'
)

In [13]:
model = pickle.loads(modelpkl)
model

In [24]:
resource_url = dvc.api.get_url(
    path='model.pkl',
    repo='https://github.com/iterative/example-get-started',
    rev='text-classification@v1.2.0'
)
pprint(resource_url)

'https://remote.dvc.org/get-started/files/md5/66/4cf1568870c9984e0836092587aa87'


In [18]:
params = dvc.api.params_show(
    repo='https://github.com/iterative/example-get-started',
    rev='text-classification@v1.2.0'
)
pprint(params)

{'featurize': {'max_features': 200, 'ngrams': 2},
 'prepare': {'seed': 20170428, 'split': 0.2},
 'train': {'min_split': 0.01, 'n_est': 50, 'seed': 20170428}}


In [17]:
metrics = dvc.api.metrics_show(
    repo='https://github.com/iterative/example-get-started',
    rev='text-classification@v1.2.0'
)
pprint(metrics)

{'avg_prec': {'test': 0.9249974999612706, 'train': 0.9743681430252835},
 'roc_auc': {'test': 0.9460213440787918, 'train': 0.9866678562450621}}


## Log Model on MLMD

In [25]:
artifact = metadata_store_pb2.Artifact()
artifact.type_id = artifact_type_id
artifact.uri = resource_url

In [26]:
for k, v in params.items():
    if type(v) == list:
        artifact.custom_properties[k].struct_value.get_or_create_list(k).extend(v)
    else:
        artifact.custom_properties[k].struct_value.update(v)

In [27]:
for k, v in metrics.items():
    if type(v) == list:
        artifact.custom_properties[k].struct_value.get_or_create_list(k).extend(v)
    else:
        artifact.custom_properties[k].struct_value.update(v)

In [29]:
[model_artifact_id] = store.put_artifacts([artifact])

In [30]:
[my_mlmd_data] = store.get_artifacts_by_id([model_artifact_id])
pprint(my_mlmd_data)

id: 1
type_id: 11
uri: "https://remote.dvc.org/get-started/files/md5/66/4cf1568870c9984e0836092587aa87"
custom_properties {
  key: "avg_prec"
  value {
    struct_value {
      fields {
        key: "test"
        value {
          number_value: 0.9249974999612706
        }
      }
      fields {
        key: "train"
        value {
          number_value: 0.9743681430252835
        }
      }
    }
  }
}
custom_properties {
  key: "featurize"
  value {
    struct_value {
      fields {
        key: "max_features"
        value {
          number_value: 200.0
        }
      }
      fields {
        key: "ngrams"
        value {
          number_value: 2.0
        }
      }
    }
  }
}
custom_properties {
  key: "prepare"
  value {
    struct_value {
      fields {
        key: "seed"
        value {
          number_value: 20170428.0
        }
      }
      fields {
        key: "split"
        value {
          number_value: 0.2
        }
      }
    }
  }
}
custom_properties {
  key: