Skip to content
This repository has been archived by the owner on Sep 14, 2022. It is now read-only.

Commit

Permalink
s3 integration, custom snippet func (#9)
Browse files Browse the repository at this point in the history
* Add s3 integration everywhere (spark may still need some work)
* Option for user supplied custom func
  • Loading branch information
birdsarah committed Apr 26, 2019
1 parent f31ba31 commit f67c034
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
83 changes: 53 additions & 30 deletions dye_score/dye_score.py
Expand Up @@ -14,6 +14,7 @@
read_csv as pd_read_csv,
)
from pprint import pprint
from s3fs import S3FileSystem, S3Map
from xarray import (
apply_ufunc,
DataArray,
Expand Down Expand Up @@ -105,6 +106,10 @@ def __init__(self, config_file_path, validate_config=True, print_config=True):
pprint(self.__conf)
if validate_config is True:
self.validate_config()
if use_aws is True:
self.s3 = S3FileSystem(**self.s3_storage_options)
else:
self.s3 = None

@property
def s3_storage_options(self):
Expand All @@ -130,6 +135,19 @@ def to_parquet_opts(self):
compression='snappy', engine='pyarrow', storage_options=self.s3_storage_options
)

@property
def from_parquet_opts(self):
"""Options used when saving to parquet."""
return dict(
engine='pyarrow', storage_options=self.s3_storage_options
)

def get_zarr_store(self, file_path):
if self.config('USE_AWS') is True:
return S3Map(root=file_path, s3=self.s3)
else:
return file_path

def config(self, option):
"""Method to retrieve config values
Expand Down Expand Up @@ -169,7 +187,7 @@ def dye_score_data_file(self, filename):
def validate_input_data(self):
"""Checks for expected columns and types in input data."""
in_file = self.config('INPUT_PARQUET_LOCATION')
df = read_parquet(in_file, engine='pyarrow')
df = read_parquet(in_file, **self.from_parquet_opts)
for column in self.dye_score_columns:
assert column in df.columns, f'{column} missing from df.columns ({df.columns})'
assert df[column].dtype == 'object', f'{column} does not have dtype `object`'
Expand All @@ -186,7 +204,7 @@ def get_input_df(self, columns=None):
if not columns:
columns = self.dye_score_columns
in_file = self.config('INPUT_PARQUET_LOCATION')
df = read_parquet(in_file, columns=columns, engine='pyarrow')
df = read_parquet(in_file, columns=columns, **self.from_parquet_opts)
return df

##
Expand All @@ -196,32 +214,42 @@ def get_input_df(self, columns=None):
# whenever we need to leverage its superior performance handling strings.
##

@staticmethod
def file_in_validation(inpath):
def file_in_validation(self, inpath):
"""Check path exists.
Raises ValueError if not. Used for input files, as these must exist to proceed.
Args:
inpath (str): Path of input file
"""
if not os.path.exists(inpath):
if self.config('USE_AWS') is True:
exists = self.s3.exists(inpath)
else:
exists = os.path.exists(inpath)
if not exists:
raise ValueError(f'File {inpath} does not exist. Cannot proceed.')

@staticmethod
def file_out_validation(outpath, override):
def file_out_validation(self, outpath, override):
"""Check path exists.
Raises ValueError if override is False. Otherwises removes the existing file.
Args:
outpath (str): Path of ourput file.
override (bool): Whether to raise an error or remove existing data.
"""
if os.path.exists(outpath) and override is False:
if self.config('USE_AWS') is True:
exists = self.s3.exists(outpath)
else:
exists = os.path.exists(outpath)

if exists and override is False:
raise ValueError(f'File {outpath} already exists. Use `override=True` to remove and replace.')
if os.path.exists(outpath) and override is True:
if exists and override is True:
print(f'Removing existing file {outpath}')
shutil.rmtree(outpath)
if self.config('USE_AWS') is True:
self.s3.rm(outpath, recursive=True)
else:
shutil.rmtree(outpath)

def build_raw_snippet_df(self, override=False):
def build_raw_snippet_df(self, override=False, snippet_func=None):
"""Builds raw_snippets from input data
Snippet function is ``script_url.netloc||script_url.path_end||func_name``
Expand All @@ -233,16 +261,16 @@ def build_raw_snippet_df(self, override=False):
Returns:
str. The file path where output is saved
"""
# TODO Add add an issue to supply user generated snippet code

# File setup
inpath = self.config('INPUT_PARQUET_LOCATION')
outpath = self.dye_score_data_file('raw_snippet_call_df')
self.file_in_validation(inpath)
self.file_out_validation(outpath, override)
# Process
df = read_parquet(inpath, columns=self.dye_score_columns, engine='pyarrow')
df['raw_snippet'] = df.apply(get_raw_snippet_from_row, axis=1, meta='O')
if not snippet_func:
snippet_func = get_raw_snippet_from_row
df = read_parquet(inpath, columns=self.dye_score_columns, **self.from_parquet_opts)
df['raw_snippet'] = df.apply(snippet_func, axis=1, meta='O')
df['called'] = 1
print(df.head())
df.to_parquet(outpath, **self.to_parquet_opts)
Expand All @@ -260,15 +288,13 @@ def build_snippet_map(self, override=False):
Returns:
str. The file path where output is saved
"""
# TODO File an issue - do we have a problem with duplicate snippets?

# File setup
inpath = self.dye_score_data_file('raw_snippet_call_df')
outpath = self.dye_score_data_file('raw_snippet_to_snippet_lookup')
self.file_in_validation(inpath)
self.file_out_validation(outpath, override)
# Process
df = read_parquet(inpath, columns=['raw_snippet'], engine='pyarrow')
df = read_parquet(inpath, columns=['raw_snippet'], **self.from_parquet_opts)
snippet_lookup = df.raw_snippet.unique().to_frame()
snippet_lookup['snippet'] = snippet_lookup.raw_snippet.apply(lambda x: hash(x), meta='int64')
print(snippet_lookup.head())
Expand Down Expand Up @@ -306,8 +332,6 @@ def build_snippets(self, spark, override=False):
Returns:
str. The file path where output is saved
"""
# TODO Get an issue to run everything on S3

spark.conf.set("spark.sql.execution.arrow.enabled", "true")

# File setup
Expand Down Expand Up @@ -345,7 +369,7 @@ def build_snippets(self, spark, override=False):
}
)
print(row_normalize_array)
row_normalize_array.to_dataset(name='data').to_zarr(store=outpath)
row_normalize_array.to_dataset(name='data').to_zarr(store=self.get_zarr_store(outpath))
# Cleanup
shutil.rmtree(tmp)
return outpath
Expand Down Expand Up @@ -413,7 +437,7 @@ def compute_distances_for_dye_snippets(self, dye_snippets, filename_suffix='dye_
self.file_out_validation(outpath, override)

# Process distances
df = open_zarr(store=snippet_file)['data']
df = open_zarr(store=self.get_zarr_store(snippet_file))['data']
df = df.chunk({'symbol': -1})
df_c = df.chunk({'snippet': 10_000})

Expand All @@ -429,7 +453,7 @@ def compute_distances_for_dye_snippets(self, dye_snippets, filename_suffix='dye_
input_core_dims=[['symbol'], ['symbol']],
)
print(distance_array)
distance_array.to_dataset(name='data').to_zarr(store=outpath)
distance_array.to_dataset(name='data').to_zarr(store=self.get_zarr_store(outpath))
return outpath

def compute_snippets_scores_for_thresholds(self, thresholds, filename_suffix='dye_snippets', override=False):
Expand All @@ -450,9 +474,8 @@ def compute_snippets_scores_for_thresholds(self, thresholds, filename_suffix='dy
file_name = f'snippets_dye_distances_from_{filename_suffix}'
inpath = os.path.join(resultsdir, file_name)
self.file_in_validation(inpath)
distance_array = open_zarr(store=inpath)['data']
distance_array = open_zarr(store=self.get_zarr_store(inpath))['data']

# TODO Make issue to not hard code this
LEAKY_THRESHOLD = 0.2
n_sites = distance_array.shape[0]
N_LEAKY_THRESHOLD = LEAKY_THRESHOLD * n_sites
Expand Down Expand Up @@ -488,7 +511,7 @@ def compute_dye_scores_for_thresholds(self, thresholds, filename_suffix='dye_sni
list. Paths results were written to
"""
snippet_dyeing_map_file = self.dye_score_data_file('snippet_dyeing_map')
snippet_data = read_parquet(snippet_dyeing_map_file, engine='pyarrow')
snippet_data = read_parquet(snippet_dyeing_map_file, **self.from_parquet_opts)
resultsdir = self.config('DYESCORE_RESULTS_DIR')

outpaths = []
Expand All @@ -498,11 +521,11 @@ def compute_dye_scores_for_thresholds(self, thresholds, filename_suffix='dye_sni
self.file_in_validation(inpath)
self.file_out_validation(outpath, override)

site_counts_df = read_parquet(inpath, engine='pyarrow')
site_counts_df = read_parquet(inpath, **self.from_parquet_opts)
script_to_dye = snippet_data.merge(site_counts_df, on='snippet')
script_to_dye_max = script_to_dye[['clean_script', 'dye_count']].groupby('clean_script').max()
script_to_dye_max = script_to_dye_max.rename(columns={'dye_count': 'dye_score'})
script_to_dye_max.compute().to_csv(outpath, compression='gzip')
script_to_dye_max.compute().to_csv(outpath, compression='gzip', storage_options=self.s3_storage_options)
outpaths.append(outpath)
return outpaths

Expand Down Expand Up @@ -547,8 +570,8 @@ def build_plot_data_for_thresholds(self, compare_list, thresholds, filename_suff
outpath = os.path.join(resultsdir, f'dye_score_plot_data_from_{filename_suffix}_{threshold}.csv.gz')
self.file_in_validation(inpath)
self.file_out_validation(outpath, override)
dye_score_df = pd_read_csv(inpath)
dye_score_df = pd_read_csv(inpath, storage_options=self.s3_storage_options)
plot_df = self._build_plot_data_for_score_df(dye_score_df, compare_list)
plot_df.to_csv(outpath, compression='gzip', index=False)
plot_df.to_csv(outpath, compression='gzip', index=False, storage_options=self.s3_storage_options)
outpaths.append(outpath)
return outpaths
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -16,6 +16,7 @@
'pyyaml>=4.2b1',
'xarray>=0.12.0',
'zarr>=2.2.0',
's3fs>=0.2.1',
]

setup_requirements = ['pytest-runner', ]
Expand Down

0 comments on commit f67c034

Please sign in to comment.