From ae559ac93d614e9731171f1d831bd7f43879fff0 Mon Sep 17 00:00:00 2001 From: corochann Date: Tue, 10 Sep 2019 20:52:01 +0900 Subject: [PATCH 1/2] return svg text --- .../saliency/visualizer/mol_visualizer.py | 46 +++++++++++++------ .../visualizer_tests/test_mol_visualizer.py | 14 ++++-- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/chainer_chemistry/saliency/visualizer/mol_visualizer.py b/chainer_chemistry/saliency/visualizer/mol_visualizer.py index d84ac048..9199e865 100644 --- a/chainer_chemistry/saliency/visualizer/mol_visualizer.py +++ b/chainer_chemistry/saliency/visualizer/mol_visualizer.py @@ -46,9 +46,24 @@ def __init__(self, logger=None): def visualize(self, saliency, mol, save_filepath=None, visualize_ratio=1.0, color_fn=red_blue_cmap, - scaler=abs_max_scaler, legend=''): + scaler=abs_max_scaler, legend='', raise_import_error=True + ): """Visualize or save `saliency` with molecule + returned value can be used for visualization. + + .. admonition:: Example + + >>> svg = visualizer.visualize(saliency, mol) + >>> + >>> # For a Jupyter user, it will show figure on notebook. + >>> from IPython.core.display import SVG + >>> SVG(svg.replace('svg:', '')) + >>> + >>> # For a user who want to save a file as png + >>> import cairosvg + >>> cairosvg.svg2png(bytestring=svg, write_to="foo.png") + Args: saliency (numpy.ndarray): 1-dim saliency array (num_node,) mol (Chem.Mol): mol instance of this saliency @@ -59,6 +74,10 @@ def visualize(self, saliency, mol, save_filepath=None, scaler (callable): function which takes `x` as input and outputs scaled `x`, for plotting. legend (str): legend for the plot + raise_import_error (bool): raise error when `ImportError` is raised + + Returns: + svg (str): drawed svg text. """ rdDepictor.Compute2DCoords(mol) Chem.SanitizeMol(mol) @@ -115,25 +134,18 @@ def color_bond(bond): try: import cairosvg cairosvg.svg2png(bytestring=svg, write_to=save_filepath) - except ImportError: + except ImportError as e: self.logger.error( 'cairosvg is not installed! ' 'Please install cairosvg to save by png format.\n' 'pip install cairosvg') - return None + if raise_import_error: + raise e else: raise ValueError( 'Unsupported extention {} for save_filepath {}' .format(extention, save_filepath)) - else: - try: - from IPython.core.display import SVG - return SVG(svg.replace('svg:', '')) - except ImportError: - self.logger.error( - 'IPython module failed to import, ' - 'please install by "pip install ipython"') - return None + return svg class SmilesVisualizer(MolVisualizer): @@ -141,9 +153,11 @@ class SmilesVisualizer(MolVisualizer): def visualize(self, saliency, smiles, save_filepath=None, visualize_ratio=1.0, color_fn=red_blue_cmap, scaler=abs_max_scaler, legend='', add_Hs=False, - use_canonical_smiles=True): + use_canonical_smiles=True, raise_import_error=True): """Visualize or save `saliency` with molecule + See parent `MolVisualizer` class for further usage. + Args: saliency (numpy.ndarray): 1-dim saliency array (num_node,) smiles (str): smiles of the molecule. @@ -157,6 +171,10 @@ def visualize(self, saliency, smiles, save_filepath=None, add_Hs (bool): Add explicit H or not use_canonical_smiles (bool): If `True`, smiles are converted to canonical smiles before constructing `mol` + raise_import_error (bool): raise error when `ImportError` is raised + + Returns: + svg (str): drawed svg text. """ mol = Chem.MolFromSmiles(smiles) if use_canonical_smiles: @@ -167,4 +185,4 @@ def visualize(self, saliency, smiles, save_filepath=None, return super(SmilesVisualizer, self).visualize( saliency, mol, save_filepath=save_filepath, visualize_ratio=visualize_ratio, color_fn=color_fn, scaler=scaler, - legend=legend) + legend=legend, raise_import_error=raise_import_error) diff --git a/tests/saliency_tests/visualizer_tests/test_mol_visualizer.py b/tests/saliency_tests/visualizer_tests/test_mol_visualizer.py index 75401f38..f21175f8 100644 --- a/tests/saliency_tests/visualizer_tests/test_mol_visualizer.py +++ b/tests/saliency_tests/visualizer_tests/test_mol_visualizer.py @@ -17,12 +17,14 @@ def test_mol_visualizer(tmpdir): # 1. test with setting save_filepath save_filepath = os.path.join(str(tmpdir), 'tmp.svg') - visualizer.visualize(saliency, mol, save_filepath=save_filepath) + svg = visualizer.visualize(saliency, mol, save_filepath=save_filepath) + assert isinstance(svg, str) assert os.path.exists(save_filepath) # 2. test with `save_filepath=None` runs without error - visualizer.visualize( + svg = visualizer.visualize( saliency, mol, save_filepath=None, visualize_ratio=0.5,) + assert isinstance(svg, str) def test_smiles_visualizer(tmpdir): @@ -33,9 +35,10 @@ def test_smiles_visualizer(tmpdir): # 1. test with setting save_filepath save_filepath = os.path.join(str(tmpdir), 'tmp.svg') - visualizer.visualize(saliency, smiles, save_filepath=save_filepath, - add_Hs=False) + svg = visualizer.visualize(saliency, smiles, save_filepath=save_filepath, + add_Hs=False) assert os.path.exists(save_filepath) + assert isinstance(svg, str) save_filepath = os.path.join(str(tmpdir), 'tmp.png') visualizer.visualize(saliency, smiles, save_filepath=save_filepath, add_Hs=False) @@ -44,9 +47,10 @@ def test_smiles_visualizer(tmpdir): # assert os.path.exists(save_filepath) # 2. test with `save_filepath=None` runs without error - visualizer.visualize( + svg = visualizer.visualize( saliency, smiles, save_filepath=None, visualize_ratio=0.5, add_Hs=False, use_canonical_smiles=True) + assert isinstance(svg, str) def test_mol_visualizer_assert_raises(tmpdir): From a1ec3946420843078534f5cb38d66f8a7cffc9f6 Mon Sep 17 00:00:00 2001 From: corochann Date: Tue, 10 Sep 2019 21:12:14 +0900 Subject: [PATCH 2/2] change defalt value of raise_import_error --- chainer_chemistry/saliency/visualizer/mol_visualizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chainer_chemistry/saliency/visualizer/mol_visualizer.py b/chainer_chemistry/saliency/visualizer/mol_visualizer.py index 9199e865..512d5759 100644 --- a/chainer_chemistry/saliency/visualizer/mol_visualizer.py +++ b/chainer_chemistry/saliency/visualizer/mol_visualizer.py @@ -46,7 +46,7 @@ def __init__(self, logger=None): def visualize(self, saliency, mol, save_filepath=None, visualize_ratio=1.0, color_fn=red_blue_cmap, - scaler=abs_max_scaler, legend='', raise_import_error=True + scaler=abs_max_scaler, legend='', raise_import_error=False ): """Visualize or save `saliency` with molecule @@ -153,7 +153,7 @@ class SmilesVisualizer(MolVisualizer): def visualize(self, saliency, smiles, save_filepath=None, visualize_ratio=1.0, color_fn=red_blue_cmap, scaler=abs_max_scaler, legend='', add_Hs=False, - use_canonical_smiles=True, raise_import_error=True): + use_canonical_smiles=True, raise_import_error=False): """Visualize or save `saliency` with molecule See parent `MolVisualizer` class for further usage.