Skip to content

Commit

Permalink
Backport PR scverse#1922: Fix paga exact reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup authored and meeseeksmachine committed Nov 3, 2021
1 parent 78eaf31 commit c7f9743
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.8.2.rst
Expand Up @@ -7,6 +7,7 @@
- Fix ``use_raw=None`` using :attr:`anndata.AnnData.var_names` if :attr:`anndata.AnnData.raw`
is present in :func:`scanpy.tl.score_genes` :pr:`1999` :smaller:`M Klein`
- Fix compatibility with UMAP 0.5.2 :pr:`2028` :smaller:`L Mcinnes`
- Fixed non-determinism in :func:`scanpy.pl.paga` node positions :pr:`1922` :smaller:`I Virshup`

.. rubric:: Performance enhancements

Expand Down
15 changes: 11 additions & 4 deletions scanpy/plotting/_tools/paga.py
Expand Up @@ -14,6 +14,7 @@
from matplotlib.axes import Axes
from matplotlib.colors import is_color_like, Colormap
from scipy.sparse import issparse
from sklearn.utils import check_random_state

from .. import _utils
from .._utils import matrix, _IGraphLayout, _FontWeight, _FontSize
Expand Down Expand Up @@ -174,8 +175,11 @@ def _compute_pos(
root=0,
layout_kwds: Mapping[str, Any] = MappingProxyType({}),
):
import random
import networkx as nx

random_state = check_random_state(random_state)

nx_g_solid = nx.Graph(adjacency_solid)
if layout is None:
layout = 'fr'
Expand All @@ -190,9 +194,9 @@ def _compute_pos(
)
layout = 'fr'
if layout == 'fa':
np.random.seed(random_state)
# np.random.seed(random_state)
if init_pos is None:
init_coords = np.random.random((adjacency_solid.shape[0], 2))
init_coords = random_state.random_sample((adjacency_solid.shape[0], 2))
else:
init_coords = init_pos.copy()
forceatlas2 = ForceAtlas2(
Expand Down Expand Up @@ -233,6 +237,7 @@ def _compute_pos(
)
else:
# igraph layouts
random.seed(random_state.bytes(8))
g = _sc_utils.get_igraph_from_adjacency(adjacency_solid)
if 'rt' in layout:
g_tree = _sc_utils.get_igraph_from_adjacency(adj_tree)
Expand All @@ -243,9 +248,11 @@ def _compute_pos(
pos_list = g.layout(layout).coords
else:
# I don't know why this is necessary
np.random.seed(random_state)
# np.random.seed(random_state)
if init_pos is None:
init_coords = np.random.random((adjacency_solid.shape[0], 2)).tolist()
init_coords = random_state.random_sample(
(adjacency_solid.shape[0], 2)
).tolist()
else:
init_pos = init_pos.copy()
# this is a super-weird hack that is necessary as igraph’s
Expand Down
Binary file modified scanpy/tests/_images/master_paga.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_paga_continuous.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_paga_continuous_multiple.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_paga_continuous_obs.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified scanpy/tests/_images/master_paga_pie.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions scanpy/tests/test_paga.py
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

from matplotlib import cm
import numpy as np

import scanpy as sc

Expand Down Expand Up @@ -86,3 +87,21 @@ def test_paga_compare(image_comparer):
sc.pl.paga_compare(pbmc, basis="umap", show=False)

save_and_compare_images('master_paga_compare_pbmc3k')


def test_paga_positions_reproducible():
"""Check exact reproducibility and effect of random_state on paga positions"""
# https://github.com/theislab/scanpy/issues/1859
pbmc = sc.datasets.pbmc68k_reduced()
sc.tl.paga(pbmc, "bulk_labels")

a = pbmc.copy()
b = pbmc.copy()
c = pbmc.copy()

sc.pl.paga(a, show=False, random_state=42)
sc.pl.paga(b, show=False, random_state=42)
sc.pl.paga(c, show=False, random_state=13)

np.testing.assert_array_equal(a.uns["paga"]["pos"], b.uns["paga"]["pos"])
assert a.uns["paga"]["pos"].tolist() != c.uns["paga"]["pos"].tolist()

0 comments on commit c7f9743

Please sign in to comment.