Skip to content

Commit

Permalink
Code cleanup & testing
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Apr 10, 2023
1 parent 62cf925 commit dae432d
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 47 deletions.
108 changes: 78 additions & 30 deletions inferelator_velocity/plotting/program_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@
from matplotlib.ticker import FormatStrFormatter
import warnings

from inferelator_velocity.utils.keys import (
OBSM_PCA_KEY,
OBS_GROUP_KEY_KEY,
OBS_TIME_KEY_KEY,
OBSM_KEY_KEY,
N_COMP_SUBKEY,
CLUSTER_ORDER_SUBKEY,
SHORTEST_PATH_SUBKEY,
ASSIGNMENT_NAME_SUBKEY,
ASSIGNMENT_PATH_SUBKEY,
CENTROID_SUBKEY,
CLOSEST_ASSIGNMENT_SUBKEY,
ASSIGNMENT_CENTROID_SUBKEY
)

from inferelator_velocity.times import _wrap_time

DEFAULT_CMAP = 'plasma'


Expand Down Expand Up @@ -84,19 +101,22 @@ def program_time_summary(
if ax_key_pref is None:
ax_key_pref = ''

uns_key = f"program_{program}_pca"
uns_key = OBSM_PCA_KEY.format(prog=program)
if uns_key not in adata.uns:
raise RuntimeError(
f"Unable to find program {program} in .uns[{uns_key}]. "
"Run program_times() before calling plotter."
)

obs_group_key = adata.uns[uns_key]['obs_group_key']
obs_time_key = adata.uns[uns_key]['obs_time_key']
obsm_key = adata.uns[uns_key]['obsm_key']
uns = adata.uns[uns_key]

obs_group_key = uns[OBS_GROUP_KEY_KEY]
obs_time_key = uns[OBS_TIME_KEY_KEY]
obsm_key = uns[OBSM_KEY_KEY]

# GET CLUSTER-CLUSTER PATH LABELS ####
_panels = adata.uns[uns_key]['assignment_names']
_panels = uns[ASSIGNMENT_NAME_SUBKEY]
n_comps = uns[N_COMP_SUBKEY]
n = len(_panels)

if panel_tags is None:
Expand All @@ -108,7 +128,7 @@ def program_time_summary(

# SET UP COLORMAPS IF NOT PROVIDED ####
if cluster_order is None:
cluster_order = adata.uns[uns_key]['cluster_order']
cluster_order = uns[CLUSTER_ORDER_SUBKEY]

if cbar_cmap is None:
cbar_cmap = DEFAULT_CMAP
Expand All @@ -128,7 +148,12 @@ def program_time_summary(
# SET UP FIGURE IF NOT PROVIDED ####
if ax is None:

_layout = [['pca1'], ['pca2'], ['hist'], ['cbar']]
if n_comps >= 3:
_layout = [['pca1'], ['pca2'], ['hist'], ['cbar']]
elif n_comps == 2:
_layout = [['pca1'], ['.'], ['hist'], ['cbar']]
else:
_layout = [['.'], ['.'], ['hist'], ['cbar']]

# IF THERE ARE 4 OR FEWER CLUSTER-CLUSTER PAIRS, DRAW 4x2 ####
if n < 5:
Expand Down Expand Up @@ -193,15 +218,18 @@ def program_time_summary(
refs = {}

# IDENTIFY CLUSTER CENTROIDS ####
_centroids = [adata.uns[uns_key]['centroids'][x] for x in cluster_order]
_centroids = [
uns[CENTROID_SUBKEY][x]
for x in cluster_order
]

if wrap_time is not None:
_centroids = _centroids + [
adata.uns[uns_key]['centroids'][cluster_order[0]]
uns[CENTROID_SUBKEY][cluster_order[0]]
]

_var_exp = adata.uns[uns_key]['variance_ratio'].copy()
_var_exp /= np.sum(adata.uns[uns_key]['variance_ratio'])
_var_exp = uns['variance_ratio'].copy()
_var_exp /= np.sum(uns['variance_ratio'])
_var_exp *= 100

# BUILD PC1/PC2 PLOT ####
Expand All @@ -211,12 +239,12 @@ def program_time_summary(
ax[ax_key_pref + 'pca1'],
_color_vector,
centroid_indices=_centroids,
shortest_path=adata.uns[uns_key]['shortest_path'],
shortest_path=uns[SHORTEST_PATH_SUBKEY],
alpha=alpha
)

_n_comps = len(adata.uns[uns_key]['variance'])
_total_var = np.sum(adata.uns[uns_key]['variance_ratio']) * 100
_n_comps = len(uns['variance'])
_total_var = np.sum(uns['variance_ratio']) * 100
ax[ax_key_pref + 'pca1'].annotate(
f"{_n_comps} PCS ({_total_var:.1f}%)",
xy=(0, 0),
Expand All @@ -235,7 +263,7 @@ def program_time_summary(
ax[ax_key_pref + 'pca2'],
_color_vector,
centroid_indices=_centroids,
shortest_path=adata.uns[uns_key]['shortest_path'],
shortest_path=uns[SHORTEST_PATH_SUBKEY],
alpha=alpha
)

Expand All @@ -246,34 +274,49 @@ def program_time_summary(
for i, _pname in enumerate(_panels):

if ax_key_pref + _pname in ax:
_idx = adata.uns[uns_key]['closest_path_assignment'] == i
_idx = uns[CLOSEST_ASSIGNMENT_SUBKEY] == i

# REMOVE PADDING ON PATH ####
_path = adata.uns[uns_key]['assignment_path'][i]
_path = uns[ASSIGNMENT_PATH_SUBKEY][i]
_path = _path[_path != -1]

refs[ax_key_pref + 'group' + _pname] = _plot_pca(
adata.obsm[obsm_key][:, 0:2],
ax[ax_key_pref + _pname],
_color_vector,
bool_idx=_idx,
centroid_indices=adata.uns[uns_key]['assignment_centroids'][i],
shortest_path=_path,
alpha=alpha
)
if n_comps > 1:
refs[ax_key_pref + 'group' + _pname] = _plot_pca(
adata.obsm[obsm_key][:, 0:2],
ax[ax_key_pref + _pname],
_color_vector,
bool_idx=_idx,
centroid_indices=uns[ASSIGNMENT_CENTROID_SUBKEY][i],
shortest_path=_path,
alpha=alpha
)

ax[ax_key_pref + _pname].set_ylabel("PC2")

else:
refs[ax_key_pref + 'group' + _pname] = _plot_time_histogram(
adata.obsm[obsm_key][:, 0],
adata.obs[obs_group_key].values,
ax[ax_key_pref + _pname],
group_order=cluster_order,
group_colors=cluster_colors,
)

ax[ax_key_pref + _pname].set_ylabel("# Cells")

ax[ax_key_pref + _pname].set_title(_pname)
ax[ax_key_pref + _pname].set_xlabel("PC1")
ax[ax_key_pref + _pname].set_ylabel("PC2")

# BUILD TIME HISTOGRAM PLOT ####
if ax_key_pref + 'hist' in ax:

_times = adata.obs[obs_time_key].values

if wrap_time is not None:
_times[_times > wrap_time] = _times[_times > wrap_time] - wrap_time
_times[_times < 0] = _times[_times < 0] + wrap_time
_times = _wrap_time(
_times[_times > wrap_time],
wrap_time
)

# Just mask out unwanted times
# Moderately easier than sorting it out after binning
Expand Down Expand Up @@ -435,7 +478,12 @@ def _get_colors(values, color_dict):
return c


def _get_time_hist_data(time_data, group_data, bins, group_order=None):
def _get_time_hist_data(
time_data,
group_data,
bins,
group_order=None
):

if group_order is None:
group_order = np.unique(group_data)
Expand Down
25 changes: 17 additions & 8 deletions inferelator_velocity/program_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,22 @@

from inferelator_velocity.utils.misc import vprint

from .utils import copy_count_layer, is_iterable_arg
from .utils.keys import OBS_TIME_KEY, N_BINS, PROGRAM_KEY
from .metrics import mutual_information, make_array_discrete
from .utils import (
copy_count_layer,
is_iterable_arg
)

from .metrics import (
mutual_information,
make_array_discrete
)

from .utils.keys import (
OBS_TIME_KEY,
N_BINS,
PROGRAM_KEY,
get_program_ids
)


def assign_genes_to_programs(
Expand Down Expand Up @@ -59,11 +72,7 @@ def assign_genes_to_programs(
"""

if programs is None:
programs = [
p
for p in data.uns[PROGRAM_KEY]['program_names']
if p != '-1'
]
programs = get_program_ids(data)
elif is_iterable_arg(programs):
pass
else:
Expand Down
105 changes: 105 additions & 0 deletions inferelator_velocity/tests/test_times.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import unittest

import numpy as np
import anndata as ad

from inferelator_velocity.times import (
program_times,
calculate_times
)

from inferelator_velocity.plotting.program_times import (
program_time_summary
)

from inferelator_velocity.utils.keys import (
PROG_NAMES_SUBKEY,
PROGRAM_KEY
)

N = 10

DIST = np.tile(np.arange(N), (N, 1)).astype(float)
Expand Down Expand Up @@ -44,6 +54,101 @@ def test_times(self):
)


class TestProgramTimes(unittest.TestCase):

def setUp(self) -> None:

self.adata = ad.AnnData(
EXPR,
dtype=EXPR.dtype
)

self.adata.obs['group'] = LAB
self.adata.var[PROGRAM_KEY] = '0'
self.adata.uns[PROGRAM_KEY] = {
PROG_NAMES_SUBKEY: ['0']
}

return super().setUp()

def test_program_times(self):

program_times(
self.adata,
{'0': 'group'},
{'0': COL}
)

times = self.adata.obs['program_0_time']

self.assertListEqual(
[0, 0.5, 1.],
[times[v] for k, v in {'a': 2, 'b': 5, 'c': 9}.items()]
)

def test_program_times_exceptions(self):

with self.assertRaises(ValueError):

program_times(
self.adata,
{'0': 'group'},
{'0': COL},
program_var_key='abiubiosu'
)

with self.assertRaises(ValueError):

program_times(
self.adata,
{'1': 'group'},
{'0': COL}
)

with self.assertRaises(ValueError):

program_times(
self.adata,
{'0': 'group'},
{'1': COL}
)

del self.adata.uns[PROGRAM_KEY]

with self.assertRaises(RuntimeError):

program_times(
self.adata,
{'0': 'group'},
{'0': COL}
)

def test_program_time_plots(self):

with self.assertRaises(RuntimeError):

f, a = program_time_summary(
self.adata,
'0'
)

program_times(
self.adata,
{'0': 'group'},
{'0': COL}
)

f, a = program_time_summary(
self.adata,
'0'
)

self.assertEqual(
len(a),
4
)


class TestTimeFunctions(unittest.TestCase):

def test_wrap_time(self):
Expand Down

0 comments on commit dae432d

Please sign in to comment.