Skip to content

Commit

Permalink
Parameterize standardization
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Dec 11, 2023
1 parent db96ad0 commit 87b82c6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
3 changes: 2 additions & 1 deletion inferelator_velocity/programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def program_select(
n_programs=2,
n_comps=None,
layer="X",
count_layer=None,
normalize=True,
mcv_loss_arr=None,
n_jobs=-1,
Expand Down Expand Up @@ -109,7 +110,7 @@ def program_select(

# CREATE A NEW DATA OBJECT FOR THIS ANALYSIS #

d = copy_count_layer(data, layer)
d = copy_count_layer(data, layer, counts_layer=count_layer)

# PREPROCESSING / NORMALIZATION #

Expand Down
6 changes: 4 additions & 2 deletions inferelator_velocity/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def program_times(
program_var_key=PROGRAM_KEY,
programs=None,
n_comps=None,
verbose=False
verbose=False,
standardize_count_data=True
):
"""
Calcuate times for each cell based on known cluster time values
Expand Down Expand Up @@ -148,7 +149,8 @@ def program_times(
return_components=True,
verbose=verbose,
n_comps=n_comps if n_comps is None else n_comps[prog],
wrap_time=wrap_time[prog] if wrap_time is not None else None
wrap_time=wrap_time[prog] if wrap_time is not None else None,
normalize_data=standardize_count_data
)

# Add keys to the .uns object
Expand Down
25 changes: 21 additions & 4 deletions inferelator_velocity/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_bins(data, n_bins=None, centers=None, width=None):
return centers, half_width


def copy_count_layer(data, layer):
def copy_count_layer(data, layer, counts_layer=None):

lref = data.X if layer == 'X' else data.layers[layer]

Expand All @@ -177,9 +177,26 @@ def copy_count_layer(data, layer):
"these results will be nonsense."
)

d = ad.AnnData(lref.astype(float))
d.layers['counts'] = lref.copy()
d.var = data.var.copy()
d = ad.AnnData(
lref,
var=data.var
)

if counts_layer is None:
d.layers['counts'] = lref.copy()
else:
d.layers['counts'] = data.layers[counts_layer]

if not pat.is_integer_dtype(d.X.dtype):
warnings.warn(
"Count data is expected, "
f"but {d.layers['counts'].dtype} data has been passed. "
"This data will be normalized and processed "
"as count data. If it is not count data, "
"these results will be nonsense."
)

d.X = d.X.astype(float)

return d

Expand Down

0 comments on commit 87b82c6

Please sign in to comment.