Skip to content

Commit

Permalink
Merge pull request #251 from lilab-bcb/boli
Browse files Browse the repository at this point in the history
Added normalize, log1p, renamed arcsinh_transform to arcsinh, and ena…
  • Loading branch information
yihming committed Jul 17, 2022
2 parents ca08470 + a3e4c71 commit 7c57015
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 122 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ jobs:
- name: Install dependencies
run: |
sudo apt -q update
sudo apt install -y libfftw3-dev default-jdk
sudo apt install -y libfftw3-dev default-jdk git
python -m pip install --upgrade pip
python -m pip install flake8 pytest setuptools wheel cython
python -m pip install zarr
python -m pip install git+https://github.com/lilab-bcb/pegasusio@master
python -m pip install -e .[all]
- name: Lint with flake8
run: |
Expand Down
4 changes: 3 additions & 1 deletion pegasus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
filter_data,
identify_robust_genes,
log_norm,
arcsinh_transform,
normalize,
log1p,
arcsinh,
select_features,
pca,
pc_transform,
Expand Down
6 changes: 3 additions & 3 deletions pegasus/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def analyze_one_modality(unidata: UnimodalData, output_name: str, is_raw: bool,
append_df = append_df.copy()
append_df.index = append_df.index.map(lambda x: f"Ab-{x}")

rawX = hstack([unidata.get_matrix("raw.X"), Z], format = "csr")
rawX = hstack([unidata.get_matrix("counts"), Z], format = "csr")

Zt = Z.astype(np.float32)
if not kwargs["citeseq"]:
Expand All @@ -337,7 +337,7 @@ def analyze_one_modality(unidata: UnimodalData, output_name: str, is_raw: bool,
else:
Zt.data = np.arcsinh(Zt.data / 5.0, dtype = np.float32)

X = hstack([unidata.get_matrix("X"), Zt], format = "csr")
X = hstack([unidata.get_matrix(unidata.current_matrix()), Zt], format = "csr")

new_genome = unidata.get_genome()
if new_genome != append_data.get_genome():
Expand All @@ -346,7 +346,7 @@ def analyze_one_modality(unidata: UnimodalData, output_name: str, is_raw: bool,
feature_metadata = pd.concat([unidata.feature_metadata, append_df], axis = 0)
feature_metadata.reset_index(inplace = True)
_fillna(feature_metadata)
unidata = UnimodalData(unidata.barcode_metadata, feature_metadata, {"X": X, "raw.X": rawX}, unidata.uns.mapping, unidata.obsm.mapping, unidata.varm.mapping) # uns.mapping, obsm.mapping and varm.mapping are passed by reference
unidata = UnimodalData(unidata.barcode_metadata, feature_metadata, {unidata.current_matrix(): X, "counts": rawX}, unidata.uns.mapping, unidata.obsm.mapping, unidata.varm.mapping) # uns.mapping, obsm.mapping and varm.mapping are passed by reference
unidata.uns["genome"] = new_genome

if kwargs["citeseq"] and kwargs["citeseq_umap"]:
Expand Down
23 changes: 15 additions & 8 deletions pegasus/plotting/plot_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,7 @@ def _gen_ticklabels(ticks, max_value):
def ridgeplot(
data: Union[MultimodalData, UnimodalData],
features: Union[str, List[str]],
matrix_key: Optional[str] = None,
donor_attr: Optional[str] = None,
qc_attr: Optional[str] = None,
overlap: Optional[float] = 0.5,
Expand All @@ -1987,9 +1988,11 @@ def ridgeplot(
----------
data : `UnimodalData` or `MultimodalData` object
CITE-Seq or Cyto data.
Data matrix.
features : `str` or `List[str]`
One or more features to display.
matrix_key: `str`, optional, default None
Which matrix to search features for. If None, use the current matrix.
donor_attr: `str`, optional, default None
If not None, `features` must contain only one feature, plot this feature by donor indicated as `donor_attr`.
qc_attr: `str`, optional, default None
Expand All @@ -2013,8 +2016,8 @@ def ridgeplot(
Examples
--------
>>> fig = pg.ridgeplot(data, features = ['CD8', 'CD4', 'CD3'], show = False, dpi = 500)
>>> fig = pg.ridgeplot(data, features = 'CD3', donor_attr = 'assignment', show = False, dpi = 500)
>>> fig = pg.ridgeplot(data, features = ['CD8', 'CD4', 'CD3'], dpi = 500)
>>> fig = pg.ridgeplot(data, features = 'CD3', donor_attr = 'assignment', dpi = 500)
"""
sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

Expand All @@ -2026,6 +2029,9 @@ def ridgeplot(
logger.warning("At most 8 features are allowed to be plotted together!")
return None

if matrix_key == None:
matrix_key = data.current_matrix()

df = None
if donor_attr is None:
exprs = []
Expand All @@ -2034,7 +2040,7 @@ def ridgeplot(
size = idx.sum()
for feature in features:
fid = data.var_names.get_loc(feature)
exprs.append(slicing(data.get_matrix("arcsinh.transformed"), idx, fid))
exprs.append(slicing(data.get_matrix(matrix_key), idx, fid))
feats.append(np.repeat(feature, size))

df = pd.DataFrame({"expression": np.concatenate(exprs), "feature": np.concatenate(feats)})
Expand All @@ -2054,12 +2060,12 @@ def ridgeplot(

g = sns.FacetGrid(df, row="feature", hue="feature", aspect=8, height=1.0)
try:
g.map(sns.kdeplot, "expression", clip_on=False, shade=True, alpha=1, lw=1.5)
g.map(sns.kdeplot, "expression", clip_on=False, color="k", lw=1)
g.map(sns.kdeplot, "expression", clip_on=False, shade=True, alpha=1, lw=0)
g.map(sns.kdeplot, "expression", clip_on=False, color="k", lw=0.5)
except RuntimeError as re:
if str(re).startswith("Selected KDE bandwidth is 0. Cannot estimate density."):
g.map(sns.kdeplot, "expression", clip_on=False, shade=True, alpha=1, lw=1.5, bw=0.1)
g.map(sns.kdeplot, "expression", clip_on=False, color="k", lw=1, bw=0.1)
g.map(sns.kdeplot, "expression", clip_on=False, shade=True, alpha=1, lw=0, bw=0.1)
g.map(sns.kdeplot, "expression", clip_on=False, color="k", lw=0.5, bw=0.1)
else:
raise re
g.map(plt.axhline, y=0, lw=1, clip_on=False)
Expand All @@ -2074,6 +2080,7 @@ def _set_label(value, color, label):

g.set_titles("")
g.set_xlabels("")
g.set_ylabels("")
g.set(yticks=[])
g.despine(bottom=True, left=True)

Expand Down
5 changes: 3 additions & 2 deletions pegasus/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
calc_mean_and_var,
calc_expm1,
calc_stat_per_batch,
normalize_by_count,
calc_sig_background,
simulate_doublets,
check_batch_key,
Expand All @@ -25,7 +24,9 @@
identify_robust_genes,
_run_filter_data,
log_norm,
arcsinh_transform,
normalize,
log1p,
arcsinh,
select_features,
pca,
pc_transform,
Expand Down
22 changes: 14 additions & 8 deletions pegasus/tools/doublet_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def _calc_expected_doublet_rate(ncells):
@timer(logger=logger)
def _run_scrublet(
data: Union[MultimodalData, UnimodalData],
raw_mat_key: Optional[str] = 'counts',
name: Optional[str] = '',
expected_doublet_rate: Optional[float] = None,
sim_doublet_ratio: Optional[float] = 2.0,
Expand All @@ -236,7 +237,10 @@ def _run_scrublet(
Parameters
-----------
data: ``Union[MultimodalData, UnimodalData]`` object.
Annotated data matrix with rows for cells and columns for genes. Data must be low quality cell and gene filtered and log-transformed. Assume 'raw.X' stores the raw count matrix.
Annotated data matrix with rows for cells and columns for genes. Data must be low quality cell and gene filtered and log-transformed.
raw_mat_key: ``str``, optional, default: ``counts``
Matrix key for the raw count matrix.
name: ``str``, optional, default: ``''``
Name of the sample.
Expand Down Expand Up @@ -294,7 +298,7 @@ def _run_scrublet(
rho = expected_doublet_rate

# subset the raw count matrix
rawX = data.get_matrix("raw.X")
rawX = data.get_matrix(raw_mat_key)
obs_umis = rawX.sum(axis = 1, dtype = np.int32).A1
rawX = rawX[:, data.var["highly_variable_features"].values]
# Simulate synthetic doublets
Expand Down Expand Up @@ -467,6 +471,7 @@ def infer_doublets(
data: MultimodalData,
channel_attr: Optional[str] = None,
clust_attr: Optional[str] = None,
raw_mat_key: Optional[str] = 'counts',
min_cell: Optional[int] = 100,
expected_doublet_rate: Optional[float] = None,
sim_doublet_ratio: Optional[float] = 2.0,
Expand Down Expand Up @@ -538,9 +543,9 @@ def infer_doublets(
"""
assert data.get_modality() == "rna"
try:
rawX = data.get_matrix("raw.X")
rawX = data.get_matrix(raw_mat_key)
except ValueError:
raise ValueError("Cannot detect the raw count matrix raw.X; stop inferring doublets!")
raise ValueError(f"Cannot detect the raw count matrix {raw_mat_key}; stop inferring doublets!")

if_plot = plot_hist is not None

Expand All @@ -552,7 +557,7 @@ def infer_doublets(

if channel_attr is None:
if data.shape[0] >= min_cell:
fig = _run_scrublet(data, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
fig = _run_scrublet(data, raw_mat_key, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
n_prin_comps = n_prin_comps, k = k, n_jobs = n_jobs, random_state = random_state, plot_hist = if_plot, manual_correction = mancor.get('', None))
if if_plot:
fig.savefig(f"{plot_hist}.dbl.png")
Expand All @@ -578,14 +583,15 @@ def infer_doublets(
if idx.size >= min_cell:
unidata = UnimodalData({"barcodekey": data.obs_names[idx]},
{"featurekey": data.var_names},
{"X": rawX[idx]},
{"genome": genome, "modality": modality})
{"counts": rawX[idx]},
{"genome": genome, "modality": modality},
cur_matrix = "counts")
# Identify robust genes, count and log normalized and select top 2,000 highly variable features
identify_robust_genes(unidata)
log_norm(unidata)
highly_variable_features(unidata)
# Run _run_scrublet
fig = _run_scrublet(unidata, name = channel, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
fig = _run_scrublet(unidata, raw_mat_key, name = channel, expected_doublet_rate = expected_doublet_rate, sim_doublet_ratio = sim_doublet_ratio, \
n_prin_comps = n_prin_comps, k = k, n_jobs = n_jobs, random_state = random_state, plot_hist = if_plot, manual_correction = mancor.get(channel, None))
if if_plot:
fig.savefig(f"{plot_hist}.{channel}.dbl.png")
Expand Down

0 comments on commit 7c57015

Please sign in to comment.