## split data into train, validation and test sets

In [1]:
import mlrun
import os
import numpy as np
mlrun.mlconf.dbpath = 'http://mlrun-api:8080'

## parameters

**Please be sure to run the notebook [arc_to_parquet](arc_to_parquet.ipynb) before running this one.**

In [3]:
CODE_BASE          = '/User/repos/functions'           
RNG                = 1
TARGET_DATA_PATH   = '/User/mlrun/models'
SRC_FILE           = 'higgs.pqt'
KEY                = 'higgs'

## split the data

In [53]:
testfn = mlrun.code_to_function(
    kind='job', 
    filename=os.path.join(CODE_BASE, 'datagen/splitters', 'train_valid_test.py'))
testfn.build_config(base_image='yjbds/mlrun-ds:latest', commands=[])
testfn.export(yaml_name)

[mlrun] 2020-01-26 19:13:33,873 function spec saved to path: /User/repos/functions/datagen/splitters/train_valid_test.yaml


In [54]:
splitter = mlrun.import_function(
    os.path.join(CODE_BASE, 'datagen/splitters', 'train_valid_test.yaml')
).apply(mlrun.mount_v3io())

In [55]:
splitter.deploy(skip_deployed=True, with_mlrun=False)

'ready'

In [56]:
task2 = mlrun.NewTask()
task2.with_params(
    src_file='higgs.pqt',
    target_path=TARGET_DATA_PATH,
    random_state=RNG)

tsk2 = splitter.run(task2, handler='train_valid_test_splitter')

[mlrun] 2020-01-26 19:13:38,909 starting run train_valid_test_splitter uid=2238f0c5856e4359a068cd881dc0c62b  -> http://mlrun-api:8080
[mlrun] 2020-01-26 19:13:39,003 Job is running in the background, pod: train-valid-test-splitter-zlfcj
[mlrun] 2020-01-26 19:14:07,630 log artifact header at /User/mlrun/models/header.pkl, size: None, db: Y
[mlrun] 2020-01-26 19:14:17,951 log artifact xtrain at /User/mlrun/models/xtrain.pqt, size: None, db: Y
[mlrun] 2020-01-26 19:14:21,585 log artifact xvalid at /User/mlrun/models/xvalid.pqt, size: None, db: Y
[mlrun] 2020-01-26 19:14:23,245 log artifact xtest at /User/mlrun/models/xtest.pqt, size: None, db: Y
[mlrun] 2020-01-26 19:14:24,139 log artifact ytrain at /User/mlrun/models/ytrain.pqt, size: None, db: Y
[mlrun] 2020-01-26 19:14:24,519 log artifact yvalid at /User/mlrun/models/yvalid.pqt, size: None, db: Y
[mlrun] 2020-01-26 19:14:24,755 log artifact ytest at /User/mlrun/models/ytest.pqt, size: None, db: Y

[mlrun] 2020-01-26 19:14:25,528 run ex

uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
...c0c62b,0,Jan 26 19:13:46,completed,train-valid-test,host=train-valid-test-splitter-zlfcjkind=jobowner=admin,,random_state=1src_file=higgs.pqttarget_path=/User/mlrun/models,,headerxtrainxvalidxtestytrainyvalidytest


to track results use .show() or .logs() or in CLI: 
!mlrun get run 2238f0c5856e4359a068cd881dc0c62b  , !mlrun logs 2238f0c5856e4359a068cd881dc0c62b 
[mlrun] 2020-01-26 19:14:28,394 run executed, status=completed


In [57]:
tsk2.outputs

{'header': '/User/mlrun/models/header.pkl',
 'xtrain': '/User/mlrun/models/xtrain.pqt',
 'xvalid': '/User/mlrun/models/xvalid.pqt',
 'xtest': '/User/mlrun/models/xtest.pqt',
 'ytrain': '/User/mlrun/models/ytrain.pqt',
 'yvalid': '/User/mlrun/models/yvalid.pqt',
 'ytest': '/User/mlrun/models/ytest.pqt'}

## tests

In [12]:
import pandas as pd

In [31]:
n_samples, n_features = pd.read_parquet(os.path.join(TARGET_DATA_PATH, SRC_FILE), engine='pyarrow').shape

In [33]:
xtrain_shape = pd.read_parquet(tsk2.outputs['xtrain'], engine='pyarrow').shape

In [42]:
rounding_err = -1
assert (int(n_samples*.75*.9)+rounding_err, M_FEATURES) == xtrain_shape, "xtrain doesn't have the expected shape"
assert ytrain_shape[0] == xtrain_shape[0], "ytrain and xtrain have different shapes"
assert ytrain_shape[1] == 1, "ytrain (labels) has more than 1 column"

In [44]:
xtest_shape = pd.read_parquet(tsk2.outputs['xtest'], engine='pyarrow').shape
ytest_shape = pd.read_parquet(tsk2.outputs['ytest'], engine='pyarrow').shape
assert (int(n_samples*.1), M_FEATURES) == xtest_shape,  "xtest doesn't have the expected shape"
assert ytest_shape[0] == xtest_shape[0], "ytest and xtest have different shapes"
assert ytest_shape[1] == 1, "ytest (test labels) has more than 1 column"

In [None]:
from cloudpickle import load

In [None]:
assert len(load(open(tsk2.outputs['header'], 'rb'))) == n_features

In [48]:
len(load(open(tsk2.outputs['header'], 'rb')))

0