Skip to content

Commit

Permalink
Merge pull request #388 from corochann/fix_mol_visualizer
Browse files Browse the repository at this point in the history
return svg text in MolVisualizer
  • Loading branch information
corochann committed Sep 10, 2019
2 parents ed79f52 + a1ec394 commit f707429
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
46 changes: 32 additions & 14 deletions chainer_chemistry/saliency/visualizer/mol_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,24 @@ def __init__(self, logger=None):

def visualize(self, saliency, mol, save_filepath=None,
visualize_ratio=1.0, color_fn=red_blue_cmap,
scaler=abs_max_scaler, legend=''):
scaler=abs_max_scaler, legend='', raise_import_error=False
):
"""Visualize or save `saliency` with molecule
returned value can be used for visualization.
.. admonition:: Example
>>> svg = visualizer.visualize(saliency, mol)
>>>
>>> # For a Jupyter user, it will show figure on notebook.
>>> from IPython.core.display import SVG
>>> SVG(svg.replace('svg:', ''))
>>>
>>> # For a user who want to save a file as png
>>> import cairosvg
>>> cairosvg.svg2png(bytestring=svg, write_to="foo.png")
Args:
saliency (numpy.ndarray): 1-dim saliency array (num_node,)
mol (Chem.Mol): mol instance of this saliency
Expand All @@ -59,6 +74,10 @@ def visualize(self, saliency, mol, save_filepath=None,
scaler (callable): function which takes `x` as input and outputs
scaled `x`, for plotting.
legend (str): legend for the plot
raise_import_error (bool): raise error when `ImportError` is raised
Returns:
svg (str): drawed svg text.
"""
rdDepictor.Compute2DCoords(mol)
Chem.SanitizeMol(mol)
Expand Down Expand Up @@ -115,35 +134,30 @@ def color_bond(bond):
try:
import cairosvg
cairosvg.svg2png(bytestring=svg, write_to=save_filepath)
except ImportError:
except ImportError as e:
self.logger.error(
'cairosvg is not installed! '
'Please install cairosvg to save by png format.\n'
'pip install cairosvg')
return None
if raise_import_error:
raise e
else:
raise ValueError(
'Unsupported extention {} for save_filepath {}'
.format(extention, save_filepath))
else:
try:
from IPython.core.display import SVG
return SVG(svg.replace('svg:', ''))
except ImportError:
self.logger.error(
'IPython module failed to import, '
'please install by "pip install ipython"')
return None
return svg


class SmilesVisualizer(MolVisualizer):

def visualize(self, saliency, smiles, save_filepath=None,
visualize_ratio=1.0, color_fn=red_blue_cmap,
scaler=abs_max_scaler, legend='', add_Hs=False,
use_canonical_smiles=True):
use_canonical_smiles=True, raise_import_error=False):
"""Visualize or save `saliency` with molecule
See parent `MolVisualizer` class for further usage.
Args:
saliency (numpy.ndarray): 1-dim saliency array (num_node,)
smiles (str): smiles of the molecule.
Expand All @@ -157,6 +171,10 @@ def visualize(self, saliency, smiles, save_filepath=None,
add_Hs (bool): Add explicit H or not
use_canonical_smiles (bool): If `True`, smiles are converted to
canonical smiles before constructing `mol`
raise_import_error (bool): raise error when `ImportError` is raised
Returns:
svg (str): drawed svg text.
"""
mol = Chem.MolFromSmiles(smiles)
if use_canonical_smiles:
Expand All @@ -167,4 +185,4 @@ def visualize(self, saliency, smiles, save_filepath=None,
return super(SmilesVisualizer, self).visualize(
saliency, mol, save_filepath=save_filepath,
visualize_ratio=visualize_ratio, color_fn=color_fn, scaler=scaler,
legend=legend)
legend=legend, raise_import_error=raise_import_error)
14 changes: 9 additions & 5 deletions tests/saliency_tests/visualizer_tests/test_mol_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def test_mol_visualizer(tmpdir):

# 1. test with setting save_filepath
save_filepath = os.path.join(str(tmpdir), 'tmp.svg')
visualizer.visualize(saliency, mol, save_filepath=save_filepath)
svg = visualizer.visualize(saliency, mol, save_filepath=save_filepath)
assert isinstance(svg, str)
assert os.path.exists(save_filepath)

# 2. test with `save_filepath=None` runs without error
visualizer.visualize(
svg = visualizer.visualize(
saliency, mol, save_filepath=None, visualize_ratio=0.5,)
assert isinstance(svg, str)


def test_smiles_visualizer(tmpdir):
Expand All @@ -33,9 +35,10 @@ def test_smiles_visualizer(tmpdir):

# 1. test with setting save_filepath
save_filepath = os.path.join(str(tmpdir), 'tmp.svg')
visualizer.visualize(saliency, smiles, save_filepath=save_filepath,
add_Hs=False)
svg = visualizer.visualize(saliency, smiles, save_filepath=save_filepath,
add_Hs=False)
assert os.path.exists(save_filepath)
assert isinstance(svg, str)
save_filepath = os.path.join(str(tmpdir), 'tmp.png')
visualizer.visualize(saliency, smiles, save_filepath=save_filepath,
add_Hs=False)
Expand All @@ -44,9 +47,10 @@ def test_smiles_visualizer(tmpdir):
# assert os.path.exists(save_filepath)

# 2. test with `save_filepath=None` runs without error
visualizer.visualize(
svg = visualizer.visualize(
saliency, smiles, save_filepath=None, visualize_ratio=0.5,
add_Hs=False, use_canonical_smiles=True)
assert isinstance(svg, str)


def test_mol_visualizer_assert_raises(tmpdir):
Expand Down

0 comments on commit f707429

Please sign in to comment.