In [12]:
from daggerml import Dml, Error, Resource
from dml_util import funkify, S3Store, funk
import os
import pandas as pd

In [2]:
os.environ["DML_S3_BUCKET"] ="dml-examples"
os.environ["DML_S3_PREFIX"] ="clustering"
os.environ["DML_DEBUG"] = "1"

In [3]:
DOCKER_CONTEXT_DIR ="./dkr-context"
AWS_CREDS = os.path.expanduser("~/.aws/credentials")
DOCKER_FLAGS = ["-v", f"{AWS_CREDS}:/root/.aws/credentials:ro", "-e","AWS_SHARED_CREDENTIALS_FILE=/root/.aws/credentials"]

In [4]:
dml = Dml(repo="tutorial", branch="main")
dag = dml.new("ml-example-2")
s3 = S3Store()


In [5]:
excludes = [
    "tests/*.py",
    ".pytest_cache",
    ".ruff_cache",
    "__pycache__",
    "examples",
    ".venv",
    "**/.venv",
]

dag.tar = s3.tar(dml, DOCKER_CONTEXT_DIR, excludes=excludes)
dag.dkr = funk.dkr_build
dag.img = dag.dkr(
    dag.tar,
    ["--platform", "linux/amd64"],
    timeout=60_000,
)


In [6]:
@funkify(uri="docker",data={"image":dag.img.value(), "flags": DOCKER_FLAGS})
@funkify
def load_data(dag):
    from tempfile import NamedTemporaryFile
    from dml_util import S3Store
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    import pandas as pd
    s3 = S3Store()
    params = dag.argv[1].value()
    X, y = load_iris(as_frame=True, return_X_y=True)
    splits = train_test_split(X, y, random_state=params["random_state"])
    out = {}
    for name, spl in zip(["X_train", "X_test", "y_train", "y_test"], splits):
        with NamedTemporaryFile() as temp:
            if isinstance(spl,pd.Series):
                spl = spl.to_frame("class")
            spl.to_parquet(temp.name)
            temp.seek(0)
            out[name] = s3.put(filepath=temp.name, suffix=".parquet")

    return out 

dag.load_data = load_data
params = {"random_state": 2}
iris_data = dag.load_data(params, name="iris_data")
iris_data

DictNode(node/e7fc3349e17724d182f1d623f348a94a)

In [7]:
@funkify
def fit_model(dag):
    import pandas as pd
    from sklearn.cluster import KMeans
    from dml_util import S3Store
    import pickle
    from time import time
    t_0 = time()
    train = dag.argv[1].value()
    params = dag.argv[2].value()
    clusterer = KMeans(**params)
    iris_train = pd.read_parquet(train.uri,engine="fastparquet")
    fitted = clusterer.fit(iris_train)
    s3 = S3Store()
    t_n = time()
    dag.elapsed = t_n - t_0
    return s3.put(pickle.dumps(fitted), suffix=".pkl")
    
print(type(fit_model))
print(fit_model.uri)
print(fit_model.data["script"])

<class 'daggerml.core.Resource'>
script
#!/usr/bin/env python3
from dml_util import aws_fndag

def fit_model(dag):
    import pandas as pd
    from sklearn.cluster import KMeans
    from dml_util import S3Store
    import pickle
    from time import time
    t_0 = time()
    train = dag.argv[1].value()
    params = dag.argv[2].value()
    clusterer = KMeans(**params)
    iris_train = pd.read_parquet(train.uri,engine="fastparquet")
    fitted = clusterer.fit(iris_train)
    s3 = S3Store()
    t_n = time()
    dag.elapsed = t_n - t_0
    return s3.put(pickle.dumps(fitted), suffix=".pkl")

if __name__ == "__main__":
    with aws_fndag() as dag:
        res = fit_model(dag)
        if dag._ref is None:
            dag.result = res


In [8]:
dag.fit_model = fit_model
fitted = dag.fit_model(iris_data["X_train"], {"n_clusters": 3})

[23f18e0f] INFO dml_util.adapters.base: CloudWatch logging not enabled due to AWS access error: An error occurred (ResourceNotFoundException) when calling the DescribeLogStreams operation: The specified log group does not exist.
[23f18e0f] DEBUG dml_util.adapters.base: reading data from <_io.TextIOWrapper name='<stdin>' mode='r' encoding='utf-8'>
[23f18e0f] INFO dml_util.runners.base: getting info from 'LocalState'
[23f18e0f] DEBUG dml_util.runners.local: Submitting script to local runner
[23f18e0f] DEBUG dml_util.runners.local: Environment for script: {"DML_S3_BUCKET": "dml-examples", "DML_S3_PREFIX": "clustering", "DML_LOG_GROUP": "dml", "DML_RUN_ID": "23f18e0f", "DML_DEBUG": "1", "DML_INPUT_LOC": "/tmp/dml.1ufbyu0c/input.dump", "DML_OUTPUT_LOC": "/tmp/dml.1ufbyu0c/output.dump", "DML_LOG_STDOUT": "/run/23edab3846e6ad3c0339bb74dde2de12/stdout", "DML_LOG_STDERR": "/run/23edab3846e6ad3c0339bb74dde2de12/stderr"}
[23f18e0f] INFO dml_util.runners.local: Process 756820 started in /tmp/dml.1

In [9]:
@funkify
def predict(dag):
    import pandas as pd
    from dml_util import S3Store
    import pickle
    from tempfile import NamedTemporaryFile
    s3 = S3Store()

    model = pickle.loads(s3.get(dag.argv[1]))
    X_test = pd.read_parquet(dag.argv[2].value().uri, engine="fastparquet")
    predictions = model.transform(X_test)
    preds_df = pd.DataFrame(predictions,index=X_test.index,columns=[f"c{i}" for i in range(predictions.shape[1])])

    with NamedTemporaryFile() as temp:
        preds_df.to_parquet(temp.name)
        temp.seek(0)
        return s3.put(filepath=temp.name, suffix=".parquet")
    
dag.predict = predict

In [10]:
predictions = dag.predict(fitted, iris_data["X_test"])

[24681c21] INFO dml_util.adapters.base: CloudWatch logging not enabled due to AWS access error: An error occurred (ResourceNotFoundException) when calling the DescribeLogStreams operation: The specified log group does not exist.
[24681c21] DEBUG dml_util.adapters.base: reading data from <_io.TextIOWrapper name='<stdin>' mode='r' encoding='utf-8'>
[24681c21] INFO dml_util.runners.base: getting info from 'LocalState'
[24681c21] DEBUG dml_util.runners.local: Submitting script to local runner
[24681c21] DEBUG dml_util.runners.local: Environment for script: {"DML_S3_BUCKET": "dml-examples", "DML_S3_PREFIX": "clustering", "DML_LOG_GROUP": "dml", "DML_RUN_ID": "24681c21", "DML_DEBUG": "1", "DML_INPUT_LOC": "/tmp/dml.mbuh232y/input.dump", "DML_OUTPUT_LOC": "/tmp/dml.mbuh232y/output.dump", "DML_LOG_STDOUT": "/run/7c485a82b92b2922242b7e97fbe9386b/stdout", "DML_LOG_STDERR": "/run/7c485a82b92b2922242b7e97fbe9386b/stderr"}
[24681c21] INFO dml_util.runners.local: Process 757006 started in /tmp/dml.m

In [13]:
pd.read_parquet(predictions.value().uri)

Unnamed: 0,c0,c1,c2
6,3.503996,0.465283,5.225368
3,3.407649,0.600014,5.167937
113,0.898882,4.137848,1.49712
12,3.45118,0.570876,5.205401
24,3.049662,0.526939,4.7655
129,1.954124,5.065398,0.601138
25,3.178985,0.518499,4.922011
108,1.689291,5.008289,0.664355
128,1.485037,4.796079,0.555526
45,3.379836,0.556263,5.132418
