Skip to content

Commit

Permalink
Merge pull request #223 from datamol-io/fix/parse-args-in-lasso
Browse files Browse the repository at this point in the history
Allow additional args for colors in lasso
  • Loading branch information
maclandrol committed Jan 19, 2024
2 parents 5d1cde1 + e56d083 commit 4fbb047
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 5 deletions.
57 changes: 52 additions & 5 deletions datamol/viz/_lasso_highlight.py
Expand Up @@ -6,7 +6,7 @@
# - possibility to do this for multiple target molecules at once
# - have the option to write to a file like to_image

from typing import List, Iterator, Tuple, Union, Optional, Any, cast
from typing import List, Dict, Iterator, Tuple, Union, Optional, Any, cast

from collections import defaultdict
from collections import namedtuple
Expand Down Expand Up @@ -400,6 +400,10 @@ def lasso_highlight_image(
line_width: int = 2,
scale_padding: float = 1.0,
verbose: bool = False,
highlight_atoms: Optional[List[List[int]]] = None,
highlight_bonds: Optional[List[List[int]]] = None,
highlight_atom_colors: Optional[List[Dict[int, DatamolColor]]] = None,
highlight_bond_colors: Optional[List[Dict[int, DatamolColor]]] = None,
**kwargs: Any,
):
"""Create an image of a list of molecules with substructure matches using lasso-based highlighting.
Expand All @@ -408,7 +412,7 @@ def lasso_highlight_image(
Args:
target_molecules: One or a list of molecules to be highlighted.
search_molecules: The substructure to be highlighted.
atom_indices: Atom indices to be highlighted substructure.
atom_indices: Atom indices to be highlighted as substructure using the lasso visualization.
legends: A string or a list of string as legend for every molecules.
n_cols: Number of molecules per column.
mol_size: The size of the image to be returned
Expand All @@ -421,6 +425,10 @@ def lasso_highlight_image(
line_width: width of drawn lines.
scale_padding: Padding around the molecule when drawing to scale.
verbose: Whether to print the verbose information.
highlight_atoms: The atoms to highlight, a list for each molecule. It's the `highlightAtoms` argument of the RDKit drawer object.
highlight_bonds: The bonds to highlight, a list for each molecule. It's the `highlightBonds` argument of the RDKit drawer object.
highlight_atom_colors: The colors to use for highlighting atoms, a list of dict mapping atom index to color for each molecule.
highlight_bond_colors: The colors to use for highlighting bonds, a list of dict mapping bond index to color for each molecule.
**kwargs: Additional arguments to pass to the drawing function. See RDKit
documentation related to `MolDrawOptions` for more details at
https://www.rdkit.org/docs/source/rdkit.Chem.Draw.rdMolDraw2D.html.
Expand Down Expand Up @@ -551,9 +559,38 @@ def lasso_highlight_image(
# EN: the following is edge-case free after trying 6 different logics, but may break if RDKit changes the way it draws molecules
scaling_val = Point2D(scale_padding, scale_padding)

if isinstance(highlight_atoms, list) and isinstance(highlight_atoms[0], int):
highlight_atoms = [highlight_atoms] * len(target_molecules)
if isinstance(highlight_bonds, list) and isinstance(highlight_bonds[0], int):
highlight_bonds = [highlight_bonds] * len(target_molecules)
if isinstance(highlight_atom_colors, dict):
highlight_atom_colors = [highlight_atom_colors] * len(target_molecules)
if isinstance(highlight_bond_colors, dict):
highlight_bond_colors = [highlight_bond_colors] * len(target_molecules)

# make sure we are using rdkit colors
if highlight_atom_colors is not None:
highlight_atom_colors = [
{k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_atom_colors
]
if highlight_bond_colors is not None:
highlight_bond_colors = [
{k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_bond_colors
]

kwargs["highlightAtoms"] = highlight_atoms
kwargs["highlightBonds"] = highlight_bonds
kwargs["highlightAtomColors"] = highlight_atom_colors
kwargs["highlightBondColors"] = highlight_bond_colors

try:
drawer.DrawMolecules(mols_to_draw, legends=legends, **kwargs)
except Exception:
drawer.DrawMolecules(
mols_to_draw,
legends=legends,
**kwargs,
)
except Exception as e:
logger.error(e)
raise ValueError(
"Failed to draw molecules. Some arguments neither match expected MolDrawOptions, nor DrawMolecule inputs. Please check the input arguments."
)
Expand All @@ -567,8 +604,18 @@ def lasso_highlight_image(
h_pos, w_pos = np.unravel_index(ind, (n_rows, n_cols))
offset_x = int(w_pos * mol_size[0])
offset_y = int(h_pos * mol_size[1])

ind_kwargs = kwargs.copy()
if isinstance(ind_kwargs["highlightAtoms"], list):
ind_kwargs["highlightAtoms"] = ind_kwargs["highlightAtoms"][ind]
if isinstance(ind_kwargs["highlightAtomColors"], list):
ind_kwargs["highlightAtomColors"] = ind_kwargs["highlightAtomColors"][ind]
if isinstance(ind_kwargs["highlightBonds"], list):
ind_kwargs["highlightBonds"] = ind_kwargs["highlightBonds"][ind]
if isinstance(ind_kwargs["highlightBondColors"], list):
ind_kwargs["highlightBondColors"] = ind_kwargs["highlightBondColors"][ind]
drawer.SetOffset(offset_x, offset_y)
drawer.DrawMolecule(mol, legend=legends[ind], **kwargs)
drawer.DrawMolecule(mol, legend=legends[ind], **ind_kwargs)
offset = None
if draw_mols_same_scale:
offset = drawer.Offset()
Expand Down
6 changes: 6 additions & 0 deletions datamol/viz/utils.py
Expand Up @@ -141,6 +141,12 @@ def to_rdkit_color(color: Optional[DatamolColor]) -> Optional[RDKitColor]:
Args:
color: A datamol color: hex, rgb, rgba or None.
"""
if color is None:
return None

if isinstance(color, str):
return mcolors.to_rgba(color) # type: ignore
if isinstance(color, (tuple, list)) and len(color) in [3, 4] and any(x > 1 for x in color):
return tuple(x / 255 if i < 3 else x for i, x in enumerate(color))

return color
24 changes: 24 additions & 0 deletions tests/test_viz_lasso_highlight.py
Expand Up @@ -17,6 +17,30 @@ def test_from_mol():
assert dm.lasso_highlight_image(mol, smarts_list)


def test_with_highlight():
smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]"
mol = dm.to_mol(smi)
smarts_list = "CONN"
highlight_atoms = [4, 5, 6]
highlight_bonds = [1, 2, 3, 4]
highlight_atom_colors = {4: (230, 230, 250), 5: (230, 230, 250), 6: (230, 230, 250)}
highlight_bond_colors = {
1: (230, 230, 250),
2: (230, 230, 250),
3: (230, 230, 250),
4: (230, 230, 250),
}
assert dm.lasso_highlight_image(
mol,
smarts_list,
highlight_atoms=highlight_atoms,
highlight_bonds=highlight_bonds,
highlight_atom_colors=highlight_atom_colors,
highlight_bond_colors=highlight_bond_colors,
continuousHighlight=False,
)


def test_original_working_solution_list_single_str():
smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]"
smarts_list = ["CONN"]
Expand Down

0 comments on commit 4fbb047

Please sign in to comment.