diff --git a/cohorts/cohort.py b/cohorts/cohort.py index 8de412c..8ccd09e 100644 --- a/cohorts/cohort.py +++ b/cohorts/cohort.py @@ -48,7 +48,7 @@ from tqdm import tqdm from .dataframe_loader import DataFrameLoader -from .utils import DataFrameHolder, first_not_none_param, filter_not_null, InvalidDataError, strip_column_names as _strip_column_names, get_logger +from .utils import DataFrameHolder, first_not_none_param, filter_not_null, InvalidDataError, strip_column_names as _strip_column_names, get_logger, get_cache_dir from .provenance import compare_provenance from .survival import plot_kmf from .plot import mann_whitney_plot, fishers_exact_plot, roc_curve_plot, stripboxplot, CorrelationResults @@ -72,6 +72,10 @@ class Cohort(Collection): A list of `Patient`s for this cohort. cache_dir : str Path to store cached results, e.g. cached variant effects. + cache_root_dir : str + (optional) directory in which cache_dir should be created + cache_dir_kwargs : dict + (optional) dictionary of name=value data to use when formatting cache_dir str show_progress : bool Whether or not to show DataFrame application progress as an increasing percentage. kallisto_ensembl_version : int @@ -118,6 +122,8 @@ class Cohort(Collection): def __init__(self, patients, cache_dir, + cache_root_dir=None, + cache_dir_kwargs=dict(), show_progress=True, kallisto_ensembl_version=None, cache_results=True, @@ -146,7 +152,8 @@ def __init__(self, # this when patient-specific functions all live in Patient. for patient in patients: patient.cohort = self - self.cache_dir = cache_dir + self.cache_dir = get_cache_dir(cache_dir=cache_dir, cache_root_dir=cache_root_dir, **cache_dir_kwargs) + self.cache_root_dir = cache_root_dir self.show_progress = show_progress self.cache_results = cache_results self.kallisto_ensembl_version = kallisto_ensembl_version diff --git a/cohorts/utils.py b/cohorts/utils.py index 21da6eb..8e404a2 100644 --- a/cohorts/utils.py +++ b/cohorts/utils.py @@ -17,6 +17,29 @@ from collections import namedtuple import sys import logging +from os import path + +logger = logging.getLogger(__name__) + +def get_cache_dir(cache_dir, cache_root_dir=None, *args, **kwargs): + """ + Return full cache_dir, according to following logic: + - if cache_dir is a full path (per path.isabs), return that value + - if not and if cache_root_dir is not None, join two paths + - otherwise, log warnings and return None + Separately, if args or kwargs are given, format cache_dir using kwargs + """ + cache_dir = cache_dir.format(*args, **kwargs) + if path.isabs(cache_dir): + if cache_root_dir is not None: + logger.warning('cache_dir ({}) is a full path; ignoring cache_root_dir'.format(cache_dir)) + return cache_dir + if cache_root_dir is not None: + return path.join(cache_root_dir, cache_dir) + else: + logger.warning("cache dir is not full path & cache_root_dir not given. Caching may not work as expected!") + return None + class DataFrameHolder(namedtuple("DataFrameHolder", ["cols", "df"])): """Holds a DataFrame along with associated columns of interest."""