In [13]:
import os

from ml_metadata.proto import metadata_store_pb2
from ml_metadata.metadata_store import metadata_store

In [14]:
def _make_default_sqlite_uri():
    return '/'.join([
        os.environ['HOME'],
        'tfx_metadata_sqlite.db',
    ])

def get_metadata_store(
    filename_uri='',
    connection_mode=metadata_store_pb2.SqliteMetadataSourceConfig.UNKNOWN,
    reset=False):
    """Returns a metadata_store.MetadataStore handle to a SQLITE backend."""
    c = metadata_store_pb2.ConnectionConfig()
    c.sqlite.filename_uri = filename_uri or _make_default_sqlite_uri()
    c.sqlite.connection_mode = connection_mode
    return metadata_store.MetadataStore(c)

def delete_sqlite_db(filename_uri=''):
    os.remove(filename_uri or _make_default_sqlite_uri())
    
def update_airflow_db_airtifacts_uri(extracted_dir=''):
    """extracted_dir is the location where the airflow data dir is extracted
       e.g., /usr/local/google/home/huimiao/airflow/
                data/
                   taxi_data/
                   tfx/
       then extracted_dir = '/usr/local/google/home/huimiao/airflow/'
    """
    store = get_metadata_store(
        filename_uri = extracted_dir + "data/tfx/pipelines/chicago_taxi_pipeline_local/metadata.db",
        connection_mode=metadata_store_pb2.SqliteMetadataSourceConfig.READWRITE)
    for artifact in store.get_artifacts():
        tokens = artifact.uri.split("airflow")
        if len(tokens) > 1:
            new_uri = extracted_dir + tokens[1]
            artifact.uri = new_uri
            print new_uri
            store.put_artifacts([artifact])

In [15]:
# Constants for TFX Artifact types.
TFX_ARTIFACT_EXAMPLES = 'ExamplesPath'
TFX_ARTIFACT_SCHEMA = 'SchemaPath'
TFX_ARTIFACT_EXAMPLE_STATS = 'ExampleStatisticsPath'
TFX_ARTIFACT_EXAMPLE_VALIDATION = 'ExampleValidationPath'
TFX_ARTIFACT_TRANSFORMED_EXAMPLES = 'TransformPath'
TFX_ARTIFACT_MODEL = 'ModelExportPath'
TFX_ARTIFACT_MODEL_EVAL = 'ModelEvalPath'

# Constants for TFX Execution types.
TFX_EXECUTION_EXAMPLE_GEN = 'examples_gen'
TFX_EXECUTION_STATISTICS_GEN = 'statistics_gen'
TFX_EXECUTION_SCHEMA_GEN = 'schema_gen'
TFX_EXECUTION_EXAMPLE_VALIDATION = 'example_validation'
TFX_EXECUTION_TRANSFORM = 'transform'
TFX_EXECUTION_TRAINER = 'trainer'
TFX_EXECUTION_EVALUATOR = 'evaluator'

In [16]:
import time
import os
airflow_data_dir = os.path.join(os.environ['HOME'], 'tmp', str(int(time.time())))
print(airflow_data_dir)

/usr/local/google/home/huimiao/tmp/1550197884


In [17]:
%%bash -s $airflow_data_dir

tmp_data_dir=$1

echo $tmp_data_dir
# Download the default database, and rebase it 
rm -rf $tmp_data_dir
mkdir -p $tmp_data_dir
cp /google/src/files/234062393/depot/google3/experimental/users/vemmadi/tfx_demo/airflow_data.zip $tmp_data_dir
unzip $tmp_data_dir/airflow_data.zip -d $tmp_data_dir/
#chmod a+x -R ~/tmp/tmp_data_dir/data

/usr/local/google/home/huimiao/tmp/1550197884
Archive:  /usr/local/google/home/huimiao/tmp/1550197884/airflow_data.zip
   creating: /usr/local/google/home/huimiao/tmp/1550197884/data/
   creating: /usr/local/google/home/huimiao/tmp/1550197884/data/tfx/
   creating: /usr/local/google/home/huimiao/tmp/1550197884/data/tfx/pipelines/
   creating: /usr/local/google/home/huimiao/tmp/1550197884/data/tfx/pipelines/chicago_taxi_pipeline_local/
  inflating: /usr/local/google/home/huimiao/tmp/1550197884/data/tfx/pipelines/chicago_taxi_pipeline_local/metadata.db  
   creating: /usr/local/google/home/huimiao/tmp/1550197884/data/tfx/pipelines/chicago_taxi_pipeline_local/taxi-local/
   creating: /usr/local/google/home/huimiao/tmp/1550197884/data/tfx/pipelines/chicago_taxi_pipeline_local/taxi-local/evaluator/
   creating: /usr/local/google/home/huimiao/tmp/1550197884/data/tfx/pipelines/chicago_taxi_pipeline_local/taxi-local/evaluator/output/
   creating: /usr/local/google/home/huimiao/tmp/1550197884/d

In [20]:
update_airflow_db_airtifacts_uri(airflow_data_dir + "/")
default_metadata_db_path = airflow_data_dir + "/data/tfx/pipelines/chicago_taxi_pipeline_local/metadata.db"