# split data into train, validation and test sets

In [1]:
import mlrun
import os
import numpy as np
mlrun.mlconf.dbpath = 'http://mlrun-api:8080'

## parameters

**Please be sure to run the notebooks [1. remote archive to local parquet](1.%20remote%20archive%20to%20local%20parquet.ipynb) and [2. parquet to dask cluster](2.%20parquet%20to%20dask%20cluster.ipynb) before running this one.**

Since our data is already loaded into a Dask cluster we use that as our source

In [2]:
DESCRIPTION        = 'split data into train, validation and test splits'

IMAGE              = 'yjbds/mlrun-dask:dev'
JOB_KIND           = 'job'
TASK_NAME          = 'user-task-data-splits'

TARGET_PATH        = '/User/repos/demos/dask/artifacts'
DASK_CLIENT        = 'scheduler.json'
DASK_KEY           = 'airlines'
LABEL_COLUMN       = 'ArrDelay'
CATEGORIES         = ['UniqueCarrier', 'Origin', 'Dest']

# insert run id ... from db here
MLRUN_DB_UID       = '0bb6c8794695400ea79c28900737277b'

RNG                = 1

In [3]:
HANDLER = 'splitter_labelencode'

## split the data

In [4]:
# load function from a local Python file
splitter = mlrun.new_function(command='/User/repos/demos/dask/code/splitter-labelencode.py', 
                              image=IMAGE,
                              kind=JOB_KIND)

splitter.spec.build.image = IMAGE

# export or load function yaml
splitter.export('/User/repos/demos/dask/yaml/splitter-labelencode.yaml')
# splitter = mlrun.import_function('/User/repos/dask/yaml/train_valid_test_splitter-airlines.yaml')

splitter.apply(mlrun.mount_v3io())
splitter.deploy(skip_deployed=True, with_mlrun=False)

task_ = mlrun.NewTask(
    TASK_NAME,
    handler=HANDLER,
    params={
        'dask_client'   : DASK_CLIENT,
        'dask_key'      : DASK_KEY,
        'label_column'  : LABEL_COLUMN,
        'categories'    : CATEGORIES,
        'target_path'   : TARGET_PATH,
        'random_state'  : RNG,
    })

tsk2 = splitter.run(task_, handler='splitter_labelencode')

[mlrun] 2020-02-13 06:31:08,960 function spec saved to path: /User/repos/demos/dask/yaml/splitter-labelencode.yaml
[mlrun] 2020-02-13 06:31:08,970 starting run user-task-data-splits uid=bcdc2146794046688d13960dd21e6868  -> http://mlrun-api:8080
[mlrun] 2020-02-13 06:31:09,035 Job is running in the background, pod: user-task-data-splits-8s246
[mlrun] 2020-02-13 06:32:53,882 log artifact header at /User/repos/demos/dask/artifacts/header.pkl, size: None, db: Y
[mlrun] 2020-02-13 06:37:24,113 log artifact test_set at /User/repos/demos/dask/artifacts/test_set, size: None, db: Y

[mlrun] 2020-02-13 06:37:24,196 run executed, status=completed
Intel(R) Data Analytics Acceleration Library (Intel(R) DAAL) solvers for sklearn enabled: https://intelpython.github.io/daal4py/sklearn.html
We're assuming that the indexes of each dataframes are 
 aligned. This assumption is not generally safe.
  "Concatenating dataframes with unknown divisions.\n"
final state: succeeded


  pd.set_option('display.max_colwidth', -1)


uid,iter,start,state,name,labels,inputs,parameters,results,artifacts
...1e6868,0,Feb 13 06:31:20,completed,splitter-labelencode,host=user-task-data-splits-8s246kind=jobowner=admin,,"categories=['UniqueCarrier', 'Origin', 'Dest']dask_client=scheduler.jsondask_key=airlineslabel_column=ArrDelayrandom_state=1target_path=/User/repos/demos/dask/artifacts",,headertest_set


to track results use .show() or .logs() or in CLI: 
!mlrun get run bcdc2146794046688d13960dd21e6868  , !mlrun logs bcdc2146794046688d13960dd21e6868 
[mlrun] 2020-02-13 06:37:30,168 run executed, status=completed


## tests

In [5]:
import dask
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster

In [6]:
client = Client(scheduler_file='/User/repos/demos/dask/artifacts/scheduler.json')


lz4
+--------------------------+---------+
|                          | version |
+--------------------------+---------+
| client                   | 3.0.2   |
| scheduler                | 2.2.1   |
| tcp://10.233.64.52:41923 | 2.2.1   |
| tcp://10.233.64.53:33822 | 2.2.1   |
| tcp://10.233.64.54:38139 | 2.2.1   |
| tcp://10.233.64.55:34344 | 2.2.1   |
| tcp://10.233.64.56:38055 | 2.2.1   |
| tcp://10.233.64.57:36040 | 2.2.1   |
| tcp://10.233.64.58:35049 | 2.2.1   |
| tcp://10.233.64.59:43970 | 2.2.1   |
+--------------------------+---------+

msgpack
+--------------------------+---------+
|                          | version |
+--------------------------+---------+
| client                   | 0.6.1   |
| scheduler                | 0.6.2   |
| tcp://10.233.64.52:41923 | 0.6.2   |
| tcp://10.233.64.53:33822 | 0.6.2   |
| tcp://10.233.64.54:38139 | 0.6.2   |
| tcp://10.233.64.55:34344 | 0.6.2   |
| tcp://10.233.64.56:38055 | 0.6.2   |
| tcp://10.233.64.57:36040 | 0.6.2   |
| tcp://10.

In [7]:
assert  client.list_datasets() == ('airlines', 'xtrain', 'xvalid', 'ytrain', 'yvalid')

In [8]:
client.list_datasets()

('airlines', 'xtrain', 'xvalid', 'ytrain', 'yvalid')