Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return svg text in MolVisualizer #388

Merged
merged 2 commits into from
Sep 10, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 32 additions & 14 deletions chainer_chemistry/saliency/visualizer/mol_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,24 @@ def __init__(self, logger=None):

def visualize(self, saliency, mol, save_filepath=None,
visualize_ratio=1.0, color_fn=red_blue_cmap,
scaler=abs_max_scaler, legend=''):
scaler=abs_max_scaler, legend='', raise_import_error=True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think raise_import_error should be False by default if we follow the original behavior of saving png files.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay i will fix it to set False as default.

):
"""Visualize or save `saliency` with molecule

returned value can be used for visualization.

.. admonition:: Example

>>> svg = visualizer.visualize(saliency, mol)
>>>
>>> # For a Jupyter user, it will show figure on notebook.
>>> from IPython.core.display import SVG
>>> SVG(svg.replace('svg:', ''))
>>>
>>> # For a user who want to save a file as png
>>> import cairosvg
>>> cairosvg.svg2png(bytestring=svg, write_to="foo.png")

Args:
saliency (numpy.ndarray): 1-dim saliency array (num_node,)
mol (Chem.Mol): mol instance of this saliency
Expand All @@ -59,6 +74,10 @@ def visualize(self, saliency, mol, save_filepath=None,
scaler (callable): function which takes `x` as input and outputs
scaled `x`, for plotting.
legend (str): legend for the plot
raise_import_error (bool): raise error when `ImportError` is raised

Returns:
svg (str): drawed svg text.
"""
rdDepictor.Compute2DCoords(mol)
Chem.SanitizeMol(mol)
Expand Down Expand Up @@ -115,35 +134,30 @@ def color_bond(bond):
try:
import cairosvg
cairosvg.svg2png(bytestring=svg, write_to=save_filepath)
except ImportError:
except ImportError as e:
self.logger.error(
'cairosvg is not installed! '
'Please install cairosvg to save by png format.\n'
'pip install cairosvg')
return None
if raise_import_error:
raise e
else:
raise ValueError(
'Unsupported extention {} for save_filepath {}'
.format(extention, save_filepath))
else:
try:
from IPython.core.display import SVG
return SVG(svg.replace('svg:', ''))
except ImportError:
self.logger.error(
'IPython module failed to import, '
'please install by "pip install ipython"')
return None
return svg


class SmilesVisualizer(MolVisualizer):

def visualize(self, saliency, smiles, save_filepath=None,
visualize_ratio=1.0, color_fn=red_blue_cmap,
scaler=abs_max_scaler, legend='', add_Hs=False,
use_canonical_smiles=True):
use_canonical_smiles=True, raise_import_error=True):
"""Visualize or save `saliency` with molecule

See parent `MolVisualizer` class for further usage.

Args:
saliency (numpy.ndarray): 1-dim saliency array (num_node,)
smiles (str): smiles of the molecule.
Expand All @@ -157,6 +171,10 @@ def visualize(self, saliency, smiles, save_filepath=None,
add_Hs (bool): Add explicit H or not
use_canonical_smiles (bool): If `True`, smiles are converted to
canonical smiles before constructing `mol`
raise_import_error (bool): raise error when `ImportError` is raised

Returns:
svg (str): drawed svg text.
"""
mol = Chem.MolFromSmiles(smiles)
if use_canonical_smiles:
Expand All @@ -167,4 +185,4 @@ def visualize(self, saliency, smiles, save_filepath=None,
return super(SmilesVisualizer, self).visualize(
saliency, mol, save_filepath=save_filepath,
visualize_ratio=visualize_ratio, color_fn=color_fn, scaler=scaler,
legend=legend)
legend=legend, raise_import_error=raise_import_error)
14 changes: 9 additions & 5 deletions tests/saliency_tests/visualizer_tests/test_mol_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def test_mol_visualizer(tmpdir):

# 1. test with setting save_filepath
save_filepath = os.path.join(str(tmpdir), 'tmp.svg')
visualizer.visualize(saliency, mol, save_filepath=save_filepath)
svg = visualizer.visualize(saliency, mol, save_filepath=save_filepath)
assert isinstance(svg, str)
assert os.path.exists(save_filepath)

# 2. test with `save_filepath=None` runs without error
visualizer.visualize(
svg = visualizer.visualize(
saliency, mol, save_filepath=None, visualize_ratio=0.5,)
assert isinstance(svg, str)


def test_smiles_visualizer(tmpdir):
Expand All @@ -33,9 +35,10 @@ def test_smiles_visualizer(tmpdir):

# 1. test with setting save_filepath
save_filepath = os.path.join(str(tmpdir), 'tmp.svg')
visualizer.visualize(saliency, smiles, save_filepath=save_filepath,
add_Hs=False)
svg = visualizer.visualize(saliency, smiles, save_filepath=save_filepath,
add_Hs=False)
assert os.path.exists(save_filepath)
assert isinstance(svg, str)
save_filepath = os.path.join(str(tmpdir), 'tmp.png')
visualizer.visualize(saliency, smiles, save_filepath=save_filepath,
add_Hs=False)
Expand All @@ -44,9 +47,10 @@ def test_smiles_visualizer(tmpdir):
# assert os.path.exists(save_filepath)

# 2. test with `save_filepath=None` runs without error
visualizer.visualize(
svg = visualizer.visualize(
saliency, smiles, save_filepath=None, visualize_ratio=0.5,
add_Hs=False, use_canonical_smiles=True)
assert isinstance(svg, str)


def test_mol_visualizer_assert_raises(tmpdir):
Expand Down