Skip to content

Commit

Permalink
Merge pull request #294 from lilab-bcb/yiming
Browse files Browse the repository at this point in the history
Improve Scatter and Spatial plots
  • Loading branch information
yihming committed Apr 25, 2024
2 parents d9d30d7 + eb95d41 commit 8eef3c6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 12 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci-test.yml
Expand Up @@ -38,9 +38,7 @@ jobs:
- name: Fetch test data
run: |
sudo apt install -y git
eval `ssh-agent -s`
ssh-add - <<< '${{ secrets.PEGASUS_ACTIONS }}'
git clone git@github.com:lilab-bcb/pegasus-test-data.git ./tests/data
git clone https://github.com/lilab-bcb/pegasus-test-data.git ./tests/data
- name: Pipeline test
run: |
bash tests/run_pipeline.sh
Expand Down
63 changes: 54 additions & 9 deletions pegasus/plotting/plot_library.py
Expand Up @@ -62,6 +62,7 @@ def scatter(
hspace: Optional[float] = 0.15,
marker_size: Optional[float] = None,
scale_factor: Optional[float] = None,
aspect: Optional[str] = "auto",
return_fig: Optional[bool] = False,
dpi: Optional[float] = 300.0,
show_neg_for_sig: Optional[bool] = False,
Expand Down Expand Up @@ -121,6 +122,9 @@ def scatter(
Manually set the marker size in the plot. If ``None``, automatically adjust the marker size to the plot size.
scale_factor: ``float``, optional (default: ``None``)
Manually set the scale factor in the plot if it's not ``None``. This is used by generating the spatial plots for 10x Visium data.
aspect: ``str``, optional (default: ``auto``)
Set the aspect of the axis scaling, i.e. the ratio of y-unit to x-unit. Set ``auto`` to fill the position rectangle with data; ``equal`` for the same scaling for x and y.
It applies to all subplots.
return_fig: ``bool``, optional, default: ``False``
Return a ``Figure`` object if ``True``; return ``None`` otherwise.
dpi: ``float``, optional, default: 300.0
Expand Down Expand Up @@ -203,6 +207,7 @@ def scatter(
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect(aspect)
if i * ncols + j >= nfigs:
ax.set_frame_on(False)

Expand Down Expand Up @@ -633,12 +638,18 @@ def spatial(
basis: str = 'spatial',
resolution: str = 'hires',
cmaps: Optional[Union[str, List[str]]] = 'viridis',
restrictions: Optional[Union[str, List[str]]] = None,
show_background: Optional[bool] = False,
palettes: Optional[Union[str, List[str]]] = None,
vmin: Optional[Union[float, List[float]]] = None,
vmax: Optional[Union[float, List[float]]] = None,
alpha: Union[float, List[float]] = 1.0,
alpha_img: float = 1.0,
nrows: Optional[int] = None,
ncols: Optional[int] = None,
y_flip: bool = True,
margin_percent: float = 0.05,
aspect: Optional[str] = "equal",
dpi: float = 300.0,
return_fig: bool = False,
**kwargs,
Expand All @@ -658,8 +669,15 @@ def spatial(
resolution: ``str``, optional, default: ``hires``
Use the spatial image whose value is specified in ``data.img['image_id']`` to show in background.
For 10X Visium data, user can either specify ``hires`` or ``lowres`` to use High or Low resolution spatial images, respectively.
Alternatively, if ``data.img`` does not exist, then no spatial image will be shown.
cmaps: ``str`` or ``List[str]``, optional, default: ``viridis``
The colormap(s) for plotting numeric attributes. The default ``viridis`` colormap theme follows the spatial plot function in SCANPY (``scanpy.pl.spatial``).
restrictions: ``str`` or ``List[str]``, optional, default: None
A list of restrictions to subset data for plotting. There are two types of restrictions: global restriction and attribute-specific restriction. Global restriction appiles to all attributes in ``attrs`` and takes the format of 'key:value,value...', or 'key:~value,value...'. This restriction selects cells with the ``data.obs[key]`` values belong to 'value,value...' (or not belong to if '~' shows). Attribute-specific restriction takes the format of 'attr:key:value,value...', or 'attr:key:~value,value...'. It only applies to one attribute 'attr'. If 'attr' and 'key' are the same, one can use '.' to replace 'key' (e.g. ``cluster_labels:.:value1,value2``).
show_background: ``bool``, optional, default: False
Only applicable if `restrictions` is set. By default, only data points selected are shown. If show_background is True, data points that are not selected will also be shown.
palettes: ``str`` or ``List[str]``, optional, default: None
Used for setting colors for every categories in categorical attributes. Each string in ``palettes`` takes the format of 'attr:color1,color2,...,colorn'. 'attr' is the categorical attribute and 'color1' - 'colorn' are the colors for each category in 'attr' (e.g. 'cluster_labels:black,blue,red,...,yellow'). If there is only one categorical attribute in 'attrs', ``palletes`` can be set as a single string and the 'attr' keyword can be omitted (e.g. "blue,yellow,red").
vmin: ``float``, optional, default: ``None``
Minimum value to show on a numeric scatter plot (feature plot).
vmax: ``float``, optional, default: ``None``
Expand All @@ -672,6 +690,14 @@ def spatial(
Number of rows in the figure. If not set, pegasus will figure it out automatically.
ncols: ``int``, optional, default: ``None``
Number of columns in the figure. If not set, pegasus will figure it out automatically.
y_flip: ``bool``, optional, default: ``False``
Set to ``True`` if flipping the y axis is needed. This is for the case when y-coordinate origin starts from the top.
For 10x Visium data, if ``resolution`` is specified, this parameter is then ignored.
margin_percent: ``float``, optional, default: ``0.05``
The margin is set to ``margin_percent``*100% of the smaller edge of the image size in each of the 4 sides.
aspect:``str``, optional (default: ``equal``)
Set the aspect of the axis scaling, i.e. the ratio of y-unit to x-unit. Set ``auto`` to fill the position rectangle with data; ``equal`` for the same scaling for x and y.
It applies to all subplots.
dpi: ``float``, optional, default: ``300.0``
The resolution of the figure in dots-per-inch.
return_fig: ``bool``, optional, default: ``False``
Expand All @@ -688,8 +714,12 @@ def spatial(
>>> pg.spatial(data, attrs=['CD14', 'TRAC'], resolution='lowres')
"""
assert f"X_{basis}" in data.obsm.keys(), f"'X_{basis}' coordinates do not exist!"
assert hasattr(data, 'img'), "The spatial image data are missing!"
assert resolution in data.img['image_id'].values, f"'{resolution}' image does not exist!"

if data.img is None:
resolution = None
#assert data.img, "The spatial image data are missing!"
elif resolution:
assert resolution in data.img['image_id'].values, f"'{resolution}' image does not exist!"

if attrs is not None:
if not is_list_like(attrs):
Expand All @@ -699,10 +729,14 @@ def spatial(

nattrs = len(attrs) if attrs is not None else 1

image_item = data.img.loc[data.img['image_id']==resolution]
image_obj = image_item['data'].iat[0]
scale_factor = image_item['scale_factor'].iat[0]
spot_radius = image_item['spot_diameter'].iat[0] * 0.5
if resolution:
image_item = data.img.loc[data.img['image_id']==resolution]
image_obj = image_item['data'].iat[0]
scale_factor = image_item['scale_factor'].iat[0]
spot_radius = image_item['spot_diameter'].iat[0] * 0.5
else:
scale_factor = None
spot_radius = None

fig = scatter(
data=data,
Expand All @@ -711,27 +745,38 @@ def spatial(
marker_size=spot_radius,
scale_factor=scale_factor,
cmaps=cmaps,
restrictions=restrictions,
show_background=show_background,
palettes=palettes,
vmin=vmin,
vmax=vmax,
nrows=nrows,
ncols=ncols,
dpi=dpi,
alpha=alpha,
aspect=aspect,
return_fig=True,
)

if scale_factor is None:
scale_factor = 1.0

coord_x = (data.obsm[f"X_{basis}"][:, 0].min() * scale_factor,
data.obsm[f"X_{basis}"][:, 0].max() * scale_factor)
coord_y = (data.obsm[f"X_{basis}"][:, 1].min() * scale_factor,
data.obsm[f"X_{basis}"][:, 1].max() * scale_factor)

margin_offset = 50
margin_offset = min(np.abs(coord_x[1] - coord_x[0]), np.abs(coord_y[1] - coord_y[0])) * margin_percent

for i in range(nattrs):
ax = fig.axes[i]
ax.imshow(image_obj, alpha=alpha_img)
if resolution:
ax.imshow(image_obj, alpha=alpha_img)
ax.set_xlim(coord_x[0]-margin_offset, coord_x[1]+margin_offset)
ax.set_ylim(coord_y[1]+margin_offset, coord_y[0]-margin_offset)
if resolution or y_flip:
ax.set_ylim(coord_y[1]+margin_offset, coord_y[0]-margin_offset)
else:
ax.set_ylim(coord_y[0]-margin_offset, coord_y[1]+margin_offset)

return fig if return_fig else None

Expand Down

0 comments on commit 8eef3c6

Please sign in to comment.