In [None]:
import numpy
import pandas
import os
import hsml
from hsml.client.exceptions import RestAPIError

In [None]:
def setup_env():    

    connection = hsml.connection()
    mr = connection.get_model_registry()
        
    return mr

In [None]:
mr = setup_env()

In [None]:
# Test metrics, description, model_schema and input examples

In [None]:
exported_tf_model = mr.get_model("model_tf")

In [None]:
assert isinstance(exported_tf_model, hsml.tensorflow.model.Model)

assert 'accuracy' in exported_tf_model.training_metrics and 'loss' in exported_tf_model.training_metrics

assert exported_tf_model.description == "A test desc for this model"
print(exported_tf_model.model_schema)
# Check input_example and model schema
assert len(exported_tf_model.model_schema['input_schema']['columnar_schema']) == 3, "schema len incorrect"
assert exported_tf_model.model_schema['output_schema']['tensor_schema']['type'] == "float64", "schema type incorrect"
assert exported_tf_model.model_schema['output_schema']['tensor_schema']['shape'] == '(8,)', "schema shape incorrect"
assert len(exported_tf_model.input_example) == 3, "input example len incorrect"

In [None]:
exported_tf_model_v3 = mr.get_model("model_tf", version=3)
assert exported_tf_model_v3.version == 3, "Model version should be 3"

In [None]:
try:
    skl_model = mr.get_model("not_found")
    assert False, "should return RestAPIError"
except RestAPIError:
    pass

In [None]:
skl_model = mr.get_model("model_sklearn")

In [None]:
skl_model.delete()

In [None]:
tf_models = mr.get_models("model_tf")

In [None]:
assert len(tf_models) == 3

In [None]:
best_tf_model = mr.get_best_model("model_tf", "accuracy", "max")

In [None]:
assert best_tf_model.version == 2, "Highest accuracy should be version 2"

In [None]:
model_dir = best_tf_model.download()

In [None]:
assert 'saved_model.pb' in os.listdir(model_dir), "Model file should be in the downloaded model directory"