Skip to content

Commit

Permalink
Update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tjkessler committed Apr 7, 2023
1 parent c47cd93 commit 0fbe482
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
18 changes: 9 additions & 9 deletions graphchem/nn/gcn.py
Expand Up @@ -21,26 +21,26 @@ def __init__(self, atom_vocab_size: int, bond_vocab_size: int,
Molecule graphs are first embedded (torch.nn.Embedding), then each
message passing operation consists of:
bond_embedding -> EdgeConv -> updated bond_embedding
atom_embedding + bond_embedding -> GeneralConv -> updated
bond_embedding > EdgeConv > updated bond_embedding
atom_embedding + bond_embedding > GeneralConv > updated
atom_embedding
The sum of all atom states is then passed through a series of fully-
connected readout layers to regress on a variable:
atom_embedding -> fully-connected readout layers -> target variable
atom_embedding > fully-connected readout layers > target variable
Args:
atom_vocab_size (int): num features (MoleculeEncoder.vocab_sizes)
bond_vocab_size (int): num features (MoleculeEncoder.vocab_sizes)
output_dim (int): number of target values per compound
embedding_dim (int, default=64): number of embedded features for
atoms and bonds
n_messages (int, default=2): number of message passes between atoms
n_readout (int, default=2): number of feed-forward post-readout
embedding_dim (int): number of embedded features for atoms and
bonds
n_messages (int): number of message passes between atoms
n_readout (int): number of feed-forward post-readout
layers (think standard NN/MLP)
readout_dim (int, default=64): number of neurons in readout layers
dropout (float, default=0.0): random neuron dropout during training
readout_dim (int): number of neurons in readout layers
dropout (float): random neuron dropout during training
"""

super(MoleculeGCN, self).__init__()
Expand Down
11 changes: 5 additions & 6 deletions graphchem/preprocessing/features.py
Expand Up @@ -13,7 +13,7 @@ def get_ring_size(obj: Union['rdkit.Chem.Atom', 'rdkit.Chem.Bond'],
Args:
obj (Union[rdkit.Chem.Atom, rdkit.Chem.Bond]): atom or bond
max_size (int, default=12): maximum ring size to consider
max_size (int): maximum ring size to consider
"""

if not obj.IsInRing():
Expand Down Expand Up @@ -172,9 +172,8 @@ def encode_many(self, smiles: List[str]) -> List[Tuple['torch.tensor']]:
smiles (List[str]): list of SMILES strings
Returns:
List[Tuple[torch.tensor, torch.tensor, torch.tensor]]: List of:
(atom encoding, bond encoding, connectivity matrix) for each
compound
List[Tuple[torch.tensor]]: List of: (atom encoding, bond encoding,
connectivity matrix) for each compound
"""

encoded_compounds = []
Expand All @@ -189,8 +188,8 @@ def encode(self, smiles: str) -> Tuple['torch.tensor']:
smiles (str): molecule's SMILES string
Returns:
Tuple[torch.tensor, torch.tensor, torch.tensor]: (encoded atom
features, encoded bond features, molecule connectivity matrix)
Tuple[torch.tensor]: (encoded atom features, encoded bond features,
molecule connectivity matrix)
"""

mol = rdkit.Chem.MolFromSmiles(smiles)
Expand Down

0 comments on commit 0fbe482

Please sign in to comment.