In [None]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
import dask_cudf
import argparse
import yaml
from pprint import pprint

In [None]:
cur_dir = '/home/jovyan/work/projects/COSME'
config_subdir = 'configs/make_kmers_config.yaml'

In [None]:
config_dir = f"{cur_dir}/{config_subdir}"

In [None]:
print(f"loading yaml file...")
config = open(config_dir, 'r').read()
pprint(config_yaml_data)

In [None]:
in_dir = config['in_dir']  
in_data_splits_config_file = config['in_data_splits_config_file']  
out_dir = config['out_dir']  
unique_classes_data_path = config['unique_classes_data_path']  
number_of_classes = config['number_of_classes']  
size_per_class = config['size_per_class']  
rand_seed = config['rand_seed']  
do_rand_seed = config['do_rand_seed']  
tgt_col = config['tgt_col']  
inp_col = config['inp_col']  
do_unknown_class = config['do_unknown_class']  
name_for_unknown_class = config['name_for_unknown_class']  

CUDA_VISIBLE_DEVICES = config['CUDA_VISIBLE_DEVICES']  
do_cuda_vis_dev = config['do_cuda_vis_dev']  
partition_size = config['partition_size']  

In [None]:
original_size_per_class = config['size_per_class']
split_files = glob(f"{in_dir}/*.parquet")
split_names = [ x.split('/')[-1] for x in split_files]

In [None]:
dsplit_config = open(in_data_splits_config_file, 'r').read()
splits = dsplit_config['splits'] 
# split_names = []
# split_perc = []
# for key, val in splits.items():
#     split_names.append(key)
#     split_perc.append(val)

In [None]:
print(f"train size per class: {size_per_class}")
# turn off random seed if needed
if not do_rand_seed:
    rand_seed = None

In [None]:
print(f"starting Dask GPU cluster...")
if do_cuda_vis_dev:
    cluster = LocalCUDACluster(
        protocol="ucx",
        enable_tcp_over_ucx=True,
        CUDA_VISIBLE_DEVICES=CUDA_VISIBLE_DEVICES,
        local_directory='/tmp/dask',
    )
else:
    cluster = LocalCUDACluster(
        protocol="ucx",
        enable_tcp_over_ucx=True,
        local_directory='/tmp/dask',
    )
client = Client(cluster)

In [None]:
print(f"loading classes file...")
# make all varaiables needed: number of samples from included df | outpath
# select classes
selected_classes = dask_cudf.read_parquet(unique_classes_data_path).compute().to_pandas().values()
total_classes = len(selected_classes)
print(f"randomly selecting {number_of_classes} of {total_classes} classes")
if do_rand_seed:
    np.random.seed(rand_seed)
selected_classes = np.random.choice(selected_classes, number_of_classes, replace=False)
# calc number of samples to take
# num_incl_samples = selected_classes.shape[0] * size_per_class
# print(f"number of samples to be taken: {num_incl_samples}")
original_selected_classes = selected_classes.copy()

In [None]:
def add_unknown_class(df):
    bool_mask = df[tgt_col].isin(selected_classes)
    df.loc[~bool_mask, tgt_col] = name_for_unknown_class
    return df

In [None]:
def random_select(df, size_per_class):
    temp_row_cnt = df.shape[0]
    cur_sample_amt = min([size_per_class, temp_row_cnt])
    keep_frac = float(cur_sample_amt / temp_row_cnt)
    df = df.sample(frac=keep_frac, replace=False, random_state=rand_seed)
    return df

In [None]:
for i, in_file for enumerate(split_files):
    cur_split = split_names[i].split('.')[0]
    cur_perc = splits[cur_split]
    
    selected_classes = original_selected_classes.copy()
    size_per_class = original_size_per_class.copy()
    if cur_split != 'train':
        size_per_class /= splits['train']
        size_per_class *= cur_perc
        size_per_class = int(round(size_per_class))
                                 
    df = dask_cudf.read_parquet(in_file, partition_size=partition_size)

    if do_unknown_class:
        print(f"adding unknown class name to list")
        selected_classes = np.sort(np.append(name_for_unknown_class, selected_classes))
        df = df.map_partitions(add_unknown_class)
    else:
        print(f"sorting selected class names")
        selected_classes = np.sort(selected_classes)
        
    df = df.shuffle(tgt_col, ignore_index=True)
    
    _ = df.to_parquet('/tmp/df.parquet')
    
    df = df.read_parquet('/tmp/df.parquet', partition_size=partition_size)
    
    df = df.map_partitions(random_select, size_per_class).reset_index(True)
    
    _ = df.to_parquet(f"{out_dir}/{cur_split}.parquet")
    
    client.cancel(df)

In [None]:
print(f"shutting down Dask client")
client.shutdown()
print(f"finished")