# split data into train, validation and test sets

In [1]:
import mlrun
import os
import numpy as np
mlrun.mlconf.dbpath = 'http://mlrun-api:8080'

## parameters

**Please be sure to run the notebooks [1. remote archive to local parquet](1.%20remote%20archive%20to%20local%20parquet.ipynb) and [2. parquet to dask cluster](2.%20parquet%20to%20dask%20cluster.ipynb) before running this one.**

Since our data is already loaded into a Dask cluster we use that as our source

In [2]:
FUNCTION           = 'train_valid_test'
DESCRIPTION        = 'split data into train, validation and test splits'

IMAGE              = 'yjbds/mlrun-dask:dev'
JOB_KIND           = 'job'
TASK_NAME          = 'user-task-data-splits'

TARGET_PATH        = '/User/repos/demos/dask/artifacts'
DASK_CLIENT        = 'scheduler.json'
DASK_KEY           = 'airlines'
LABEL_COLUMN       = 'ArrDelay'
CATEGORIES         = ['UniqueCarrier', 'Origin', 'Dest']

# insert run id ... from db here
MLRUN_DB_UID       = '338465d3d0d940e181da7268404db66b'

RNG                = 1

## split the data

In [None]:
# load function from a local Python file
splitter = mlrun.new_function(command='/User/repos/demos/dask/code/train-valid-test-splitter.py', 
                              image=IMAGE,
                              kind=JOB_KIND)

splitter.spec.build.image = IMAGE

# export or load function yaml
splitter.export('/User/repos/demos/dask/yaml/train-valid-test-splitter.yaml')
# splitter = mlrun.import_function('/User/repos/dask/yaml/train_valid_test_splitter-airlines.yaml')

splitter.apply(mlrun.mount_v3io())
splitter.deploy(skip_deployed=True, with_mlrun=False)

task_ = mlrun.NewTask(
    TASK_NAME,
    params={
        'dask_client'   : DASK_CLIENT,
        'dask_key'      : DASK_KEY,
        'label_column'  : LABEL_COLUMN,
        'categories'    : CATEGORIES,
        'target_path'   : TARGET_PATH,
        'random_state'  : RNG,
    })

tsk2 = splitter.run(task_, handler='train_valid_test_splitter')

## tests

In [5]:
import dask
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster

In [6]:
client = Client(scheduler_file='/User/repos/demos/dask/artifacts/scheduler.json')

In [7]:
df = client.get_dataset('ytrain')

In [8]:
published_datasets = client.list_datasets()
published_datasets

('airlines', 'xtrain', 'xvalid', 'ytrain', 'yvalid')

In [9]:
# df.shape[0].compute()

In [10]:
df.values

Unnamed: 0,Array,Chunk
Bytes,unknown,unknown
Shape,"(nan,)","(nan,)"
Count,123540 Tasks,12354 Chunks
Type,bool,numpy.ndarray
"Array Chunk Bytes unknown unknown Shape (nan,) (nan,) Count 123540 Tasks 12354 Chunks Type bool numpy.ndarray",,

Unnamed: 0,Array,Chunk
Bytes,unknown,unknown
Shape,"(nan,)","(nan,)"
Count,123540 Tasks,12354 Chunks
Type,bool,numpy.ndarray
