In [None]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
import nvtabular as nvt
from glob import glob
import argparse
import yaml
from pprint import pprint
import numpy as np
from pathlib import Path

In [None]:
cur_dir = '/home/jovyan/work/projects/COSME'
config_subdir = 'configs/make_nvtab_data_config.yaml'

In [None]:
config_dir = f"{cur_dir}/{config_subdir}"

In [None]:
print(f"loading yaml file...")
config = open(config_dir, 'r').read()
pprint(config_yaml_data)

In [None]:
in_dir = config['in_dir']  
out_dir = config['out_dir']  
label_col_name = config['label_col_name']  
input_col_name = config['input_col_name']  
max_seq_len = config['max_seq_len']  
row_group_size = config['row_group_size']  

CUDA_VISIBLE_DEVICES = config['CUDA_VISIBLE_DEVICES']  
do_cuda_vis_dev = config['do_cuda_vis_dev']  

In [None]:
split_files = glob(f"{in_dir}/*.parquet")
split_names = [ x.split('/')[-1] for x in split_files]

In [None]:
print(f"starting Dask GPU cluster...")
if do_cuda_vis_dev:
    cluster = LocalCUDACluster(
        protocol="ucx",
        enable_tcp_over_ucx=True,
        CUDA_VISIBLE_DEVICES=CUDA_VISIBLE_DEVICES,
        local_directory='/tmp/dask',
    )
else:
    cluster = LocalCUDACluster(
        protocol="ucx",
        enable_tcp_over_ucx=True,
        local_directory='/tmp/dask',
    )
client = Client(cluster)

In [None]:
print(f"creating pipeline...")
# create the pipeline
# nvt.ColumnGroup(
cat_features = [input_col_name] >> nvt.ops.Categorify() >> nvt.ops.ListSlice(0, end=max_seq_len, pad=True,
                                                                             pad_value=0.0)
lab_features = [label_col_name] >> nvt.ops.Categorify()
# add label column
output = cat_features + lab_features
# create workflow
workflow = nvt.Workflow(output, client=client)

shuffle = nvt.io.Shuffle.PER_PARTITION

In [None]:
for i, in_file for enumerate(split_files):
    cur_split = split_names[i].split('.')[0]
    
    workflow_file = f"{out_dir}/workflow"
    cur_out_file = = f"{out_dir}/{cur_split}.parquet"
    
    if cur_split == 'train':
        print("fitting nvtab workflow on training data...")
        workflow.fit(nvt.Dataset(in_file, engine='parquet', row_group_size=row_group_size))
        workflow.save(workflow_file)
        
    workflow.transform(nvt.Dataset(in_file, engine='parquet', row_group_size=row_group_size)).to_parquet(
    output_path=cur_out_file,
    shuffle=shuffle,
    cats=[input_col_name],
    labels=[label_col_name],
    )
    
client.cancel(workflow)

In [None]:
print(f"shutting down Dask client")
client.shutdown()
print(f"finished")