In [None]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
import dask_cudf
import argparse
import yaml
from pprint import pprint

In [None]:
cur_dir = '/home/jovyan/work/projects/COSME'
config_subdir = '/configs/convert_labels_to_species_config.yaml'

In [None]:
config_dir = f"{cur_dir}/{config_subdir}"

In [None]:
print(f"loading yaml file...")
config = open(config_dir, 'r').read()
pprint(config_yaml_data)

In [None]:
in_file = config['in_file']  
out_file = config['out_file']  
unq_label_out_file = config['unq_label_out_file']  
label_col_name = config['label_col_name']  
label_regex = config['label_regex']    
CUDA_VISIBLE_DEVICES = config['CUDA_VISIBLE_DEVICES']  
do_cuda_vis_dev = config['do_cuda_vis_dev']  
partition_size = config['partition_size']  

In [None]:
print(f"starting Dask GPU cluster...")
if do_cuda_vis_dev:
    cluster = LocalCUDACluster(
        protocol="ucx",
        enable_tcp_over_ucx=True,
        CUDA_VISIBLE_DEVICES=CUDA_VISIBLE_DEVICES,
        local_directory='/tmp',
    )
else:
    cluster = LocalCUDACluster(
        protocol="ucx",
        enable_tcp_over_ucx=True,
        local_directory='/tmp',
    )
client = Client(cluster)

In [None]:
# first we create the Dask dataframe
print(f"reading file {in_file}")
df = dask_cudf.read_parquet(in_file,  # location of raw file
                        partition_size=partition_size,
                        )

In [None]:
def extract_labels(df):
        df[label_col_name] = df[label_col_name].str.extract(label_regex).loc[:, 0]
        return df

In [None]:
print(f"extracting labels...")
df = df.map_partitions(extract_labels)

print(f"saving data to {out_file}")
# the final step is to save the cleaned data.
_ = df.to_parquet(out_file)

In [2]:
print(f"creating unique labels...")
del df
df = dask_cudf.read_parquet(out_file)
unq_labs_df = df[label_col_name].unique().to_frame().reset_index(True)
print(f"saving data to {unq_label_out_file}")
_ = unq_labs_df.to_parquet(unq_label_out_file)

In [None]:
print(f"shutting down Dask client")
client.shutdown()
print(f"finished")