Skip to content

Commit

Permalink
Simplify count functions with a decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
tavinathanson committed Nov 1, 2016
1 parent da968fb commit 2d2be79
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 110 deletions.
127 changes: 65 additions & 62 deletions cohorts/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,97 +17,100 @@
from .variant_filters import no_filter, effect_expressed_filter
from .utils import first_not_none_param

from functools import wraps
import numpy as np
from varcode.effects import Substitution
from varcode.common import memoize

def snv_count(row, cohort, filter_fn=None,
normalized_per_mb=None, **kwargs):
filter_fn = first_not_none_param([filter_fn, cohort.filter_fn], no_filter)
normalized_per_mb = first_not_none_param([normalized_per_mb, cohort.normalized_per_mb], False)
def cohort_defaults(func):
@wraps(func)
def wrapper(row, cohort,
filter_fn=None, normalized_per_mb=None,
**kwargs):
filter_fn = first_not_none_param([filter_fn, cohort.filter_fn], no_filter)
normalized_per_mb = first_not_none_param([normalized_per_mb, cohort.normalized_per_mb], False)
return func(row=row,
cohort=cohort,
filter_fn=filter_fn,
normalized_per_mb=normalized_per_mb,
**kwargs)
return wrapper

def cohort_counter(func):
@wraps(func)
def wrapper(row, cohort, filter_fn=None, normalized_per_mb=None, **kwargs):
filter_fn = first_not_none_param([filter_fn, cohort.filter_fn], no_filter)
normalized_per_mb = first_not_none_param([normalized_per_mb, cohort.normalized_per_mb], False)
per_patient_data = func(row=row,
cohort=cohort,
filter_fn=filter_fn,
normalized_per_mb=normalized_per_mb,
**kwargs)
patient_id = row["patient_id"]
if patient_id in per_patient_data:
count = len(per_patient_data[patient_id])
if normalized_per_mb:
count /= float(get_patient_to_mb(cohort)[patient_id])
return count
return np.nan
return wrapper

@memoize
def get_patient_to_mb(cohort):
patient_to_mb = dict(cohort.as_dataframe(join_with="ensembl_coverage")[["patient_id", "MB"]].to_dict("split")["data"])
return patient_to_mb

@cohort_counter
def snv_count(row, cohort, filter_fn, normalized_per_mb, **kwargs):
patient_id = row["patient_id"]
patient_variants = cohort.load_variants(
return cohort.load_variants(
patients=[cohort.patient_from_id(patient_id)],
filter_fn=filter_fn,
**kwargs)
if patient_id in patient_variants:
count = len(patient_variants[patient_id])
if normalized_per_mb:
count /= float(get_patient_to_mb(cohort)[patient_id])
return count
return np.nan

def nonsynonymous_snv_count(row, cohort, filter_fn=None,
normalized_per_mb=None, **kwargs):
filter_fn = first_not_none_param([filter_fn, cohort.filter_fn], no_filter)
normalized_per_mb = first_not_none_param([normalized_per_mb, cohort.normalized_per_mb], False)
patient_id = row["patient_id"]
@cohort_counter
def nonsynonymous_snv_count(row, cohort, filter_fn, normalized_per_mb, **kwargs):
# This only loads one effect per variant.
patient_nonsynonymous_effects = cohort.load_effects(
patient_id = row["patient_id"]
return cohort.load_effects(
only_nonsynonymous=True,
patients=[cohort.patient_from_id(patient_id)],
filter_fn=filter_fn,
**kwargs)
if patient_id in patient_nonsynonymous_effects:
count = len(patient_nonsynonymous_effects[patient_id])
if normalized_per_mb:
count /= float(get_patient_to_mb(cohort)[patient_id])
return count
return np.nan

def missense_snv_count(row, cohort, filter_fn=None,
normalized_per_mb=None, **kwargs):
filter_fn = first_not_none_param([filter_fn, cohort.filter_fn], no_filter)
normalized_per_mb = first_not_none_param([normalized_per_mb, cohort.normalized_per_mb], False)
patient_id = row["patient_id"]
@cohort_counter
def missense_snv_count(row, cohort, filter_fn, normalized_per_mb, **kwargs):
def missense_filter_fn(filterable_effect):
assert filter_fn is not None, "filter_fn should never be None, but it is."
return (type(filterable_effect.effect) == Substitution and
filter_fn(filterable_effect))
# This only loads one effect per variant.
patient_missense_effects = cohort.load_effects(
patient_id = row["patient_id"]
return cohort.load_effects(
only_nonsynonymous=True,
patients=[cohort.patient_from_id(patient_id)],
filter_fn=missense_filter_fn,
**kwargs)
if patient_id in patient_missense_effects:
count = len(patient_missense_effects[patient_id])
if normalized_per_mb:
count /= float(get_patient_to_mb(cohort)[patient_id])
return count
return np.nan

def neoantigen_count(row, cohort, filter_fn=None,
normalized_per_mb=None, **kwargs):
filter_fn = first_not_none_param([filter_fn, cohort.filter_fn], no_filter)
normalized_per_mb = first_not_none_param([normalized_per_mb, cohort.normalized_per_mb], False)
patient_id = row["patient_id"]
@cohort_counter
def neoantigen_count(row, cohort, filter_fn, normalized_per_mb, **kwargs):
patient = cohort.patient_from_id(row["patient_id"])
patient_neoantigens = cohort.load_neoantigens(patients=[patient],
filter_fn=filter_fn,
**kwargs)
if patient_id in patient_neoantigens:
patient_neoantigens_df = patient_neoantigens[patient_id]
count = len(patient_neoantigens_df)
if normalized_per_mb:
count /= float(get_patient_to_mb(cohort)[patient_id])
return count
return np.nan
return cohort.load_neoantigens(patients=[patient],
filter_fn=filter_fn,
**kwargs)

def expressed_missense_snv_count(row, cohort, filter_fn=None,
normalized_per_mb=None):
filter_fn = first_not_none_param([filter_fn, cohort.filter_fn], no_filter)
normalized_per_mb = first_not_none_param([normalized_per_mb, cohort.normalized_per_mb], False)
@cohort_defaults
def expressed_missense_snv_count(row, cohort, filter_fn, normalized_per_mb, **kwargs):
def expressed_filter_fn(filterable_effect):
assert filter_fn is not None, "filter_fn should never be None, but it is."
return filter_fn(filterable_effect) and effect_expressed_filter(filterable_effect)
return missense_snv_count(row, cohort, filter_fn=expressed_filter_fn,
return missense_snv_count(row=row,
cohort=cohort,
filter_fn=expressed_filter_fn,
normalized_per_mb=normalized_per_mb)

def expressed_neoantigen_count(row, cohort, **kwargs):
return neoantigen_count(row, cohort, only_expressed=True, **kwargs)

@memoize
def get_patient_to_mb(cohort):
patient_to_mb = dict(cohort.as_dataframe(join_with="ensembl_coverage")[["patient_id", "MB"]].to_dict("split")["data"])
return patient_to_mb
@cohort_defaults
def expressed_neoantigen_count(row, cohort, filter_fn, normalized_per_mb, **kwargs):
return neoantigen_count(row=row,
cohort=cohort,
only_expressed=True, **kwargs)
40 changes: 0 additions & 40 deletions test/functions.py

This file was deleted.

8 changes: 5 additions & 3 deletions test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from . import data_path, generated_data_path, DATA_DIR
from .data_generate import generate_vcfs
from .functions import *

from cohorts import Cohort, Patient
from cohorts.utils import InvalidDataError
Expand All @@ -36,7 +35,8 @@ def make_simple_clinical_dataframe(
"deceased": [True, False, False] if deceased_list is None else deceased_list,
"progressed_or_deceased": [True, True, False] if progressed_or_deceased_list is None else progressed_or_deceased_list})

def make_simple_cohort(merge_type="union", **kwargs):
def make_simple_cohort(merge_type="union",
**kwargs):
clinical_dataframe = make_simple_clinical_dataframe(**kwargs)
patients = []
for i, row in clinical_dataframe.iterrows():
Expand All @@ -49,11 +49,13 @@ def make_simple_cohort(merge_type="union", **kwargs):
)
patients.append(patient)

return Cohort(
Cohort.normalized_per_mb = False
cohort = Cohort(
patients=patients,
responder_pfs_equals_os=True,
merge_type=merge_type,
cache_dir=generated_data_path("cache"))
return cohort

def test_pfs_equal_to_os():
# Should not error
Expand Down

0 comments on commit 2d2be79

Please sign in to comment.