diff --git a/src/mofdscribe/featurizers/text/mofdscriber.py b/src/mofdscribe/featurizers/text/mofdscriber.py index d76406c..f39c056 100644 --- a/src/mofdscribe/featurizers/text/mofdscriber.py +++ b/src/mofdscribe/featurizers/text/mofdscriber.py @@ -4,6 +4,7 @@ from collections import Counter from typing import Dict, Optional, Union +import numpy as np from moffragmentor import MOF as MOFFragmentorMOF # noqa: N811 from pymatgen.analysis.graphs import StructureGraph from pymatgen.core import IStructure, Structure @@ -39,6 +40,7 @@ def __init__( describer_kwargs: Optional[Dict] = None, incorporate_smiles: bool = True, describe_pores: bool = True, + describe_rcsr: bool = True, ) -> None: """Construct an instance of the MOFDescriber. @@ -50,12 +52,15 @@ def __init__( incorporate_smiles (bool): If True, describe building blocks. describe_pores (bool): If True, add description of the geometry of the MOF pores. + describe_rcsr (bool): If True, add RCSR code of the MOF + topology. """ describer_defaults = {"describe_oxidation_states": False, "describe_bond_lengths": True} self.condenser_kwargs = condenser_kwargs or {} self.describer_kwargs = {**describer_defaults, **(describer_kwargs or {})} self.incorporate_smiles = incorporate_smiles self.describe_pores = describe_pores + self.describe_rcsr = describe_rcsr def _get_bb_description(self, structure: Structure, structure_graph: StructureGraph) -> str: moffragmentor_mof = MOFFragmentorMOF(structure, structure_graph) @@ -65,7 +70,21 @@ def _get_bb_description(self, structure: Structure, structure_graph: StructureGr linker_smiles = " ,".join("{} {}".format(v, k) for k, v in linker_counter.items()) metal_smiles = " ,".join("{} {}".format(v, k) for k, v in metal_counter.items()) - return "Linkers: {}. Metal clusters: {}.".format(linker_smiles, metal_smiles) + bb_string = "Linkers: {}. Metal clusters: {}. ".format(linker_smiles, metal_smiles) + + rcsr_code = fragments.net_embedding.rcsr_code + if rcsr_code and len(rcsr_code) > 1: + rcsr_string = "RCSR code: {}. ".format(rcsr_code) + + output_string = "" + if self.incorporate_smiles: + output_string += bb_string + if self.describe_rcsr: + output_string += rcsr_string + + return output_string + + return output_string def _get_pore_description(self, structure): pore_featurizer = MOFMultipleFeaturizer( @@ -87,11 +106,15 @@ def _get_robocrys_description(self, structure): def _featurize(self, structure: Structure, structure_graph: StructureGraph): description = self._get_robocrys_description(structure) - if self.incorporate_smiles: - description += " " + self._get_bb_description(structure, structure_graph) + if self.incorporate_smiles or self.describe_rcsr: + if description[-1] != " ": + description += " " + description += self._get_bb_description(structure, structure_graph) if self.describe_pores: - description += " " + self._get_pore_description(structure) - return description + if description[-1] != " ": + description += " " + description += self._get_pore_description(structure) + return np.array([description]) def featurize(self, structure: Union[Structure, IStructure]): return self._featurize(structure, get_sg(structure))