Skip to content

Commit

Permalink
Backport PR scverse#2734 on branch 1.9.x (Make _validate_palette work…
Browse files Browse the repository at this point in the history
… with arrays) (scverse#2735)

Co-authored-by: Philipp A <flying-sheep@web.de>
  • Loading branch information
meeseeksmachine and flying-sheep committed Nov 7, 2023
1 parent 335596c commit d1fe8da
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.9.7.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

```{rubric} Bug fixes
```
- Fix handling of numpy array palettes (e.g. after write-read cycle) {pr}`2734` {smaller}`P Angerer`
7 changes: 4 additions & 3 deletions scanpy/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def default_palette(
return palette


def _validate_palette(adata, key):
def _validate_palette(adata: anndata.AnnData, key: str) -> None:
"""
checks if the list of colors in adata.uns[f'{key}_colors'] is valid
and updates the color list in adata.uns[f'{key}_colors'] if needed.
Expand Down Expand Up @@ -354,8 +354,9 @@ def _validate_palette(adata, key):
break
_palette.append(color)
# Don't modify if nothing changed
if _palette is not None and list(_palette) != list(adata.uns[color_key]):
adata.uns[color_key] = _palette
if _palette is None or np.equal(_palette, adata.uns[color_key]).all():
return
adata.uns[color_key] = _palette


def _set_colors_for_categorical_obs(
Expand Down
28 changes: 28 additions & 0 deletions scanpy/tests/test_plotting_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import cast
import numpy as np
import pytest

from anndata import AnnData
from matplotlib import colormaps
from matplotlib.colors import ListedColormap

from scanpy.plotting._utils import _validate_palette


viridis = cast(ListedColormap, colormaps["viridis"])


@pytest.mark.parametrize(
"palette",
[
pytest.param(viridis.colors, id="viridis"),
pytest.param(["b", "#cccccc", "r", "yellow", "lightblue"], id="named"),
pytest.param([(1, 0, 0, 1), (0, 0, 1, 1)], id="rgba"),
],
)
@pytest.mark.parametrize("typ", [np.asarray, list])
def test_validate_palette_no_mod(palette, typ):
palette = typ(palette)
adata = AnnData(uns=dict(test_colors=palette))
_validate_palette(adata, "test")
assert palette is adata.uns["test_colors"], "Palette should not be modified"

0 comments on commit d1fe8da

Please sign in to comment.