Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Commit

Permalink
fix flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi Chen committed Jul 27, 2020
1 parent 9a3332e commit c5afd99
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 17 deletions.
11 changes: 5 additions & 6 deletions megnet/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ def __getitem__(self, index: int) -> tuple:
return inputs
else:
# get targets
it = itemgetter(*batch_index)
target_temp = itemgetter_list(self.targets, batch_index)
target_temp = np.atleast_2d(target_temp)

Expand Down Expand Up @@ -527,17 +526,17 @@ def process_bond_feature(self, x) -> np.ndarray:
return self.distance_converter.convert(x)


def itemgetter_list(l, indices: List) -> tuple:
def itemgetter_list(data_list: List, indices: List) -> tuple:
"""
Get indices of l and return a tuple
Get indices of data_list and return a tuple
Args:
l: (list)
data_list (list): data list
indices: (list) indices
Returns:
(tuple)
"""
it = itemgetter(*indices)
if np.size(indices) == 1:
return it(l),
return it(data_list),
else:
return it(l)
return it(data_list)
2 changes: 1 addition & 1 deletion megnet/data/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
except ImportError:
Chem = None

from typing import Sequence, Dict, Union, List
from typing import Dict, Union, List

__date__ = '12/01/2018'

Expand Down
30 changes: 23 additions & 7 deletions megnet/data/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,50 @@
"""
from monty.json import MSONable


ATOMNUM2TYPE = {"1": 1, "6": 2, "7": 4, "8": 6, "9": 8}


class AtomNumberToTypeConverter(MSONable):
"""
Convert atomic number Z into the atomic type in the QM9 dataset
Convert atomic number Z into the atomic type in the QM9 dataset.
This is specifically used for this problem, do not use it elsewhere.
The code is here for historical reasons.
"""
def __init__(self, mapping=ATOMNUM2TYPE):
"""
Atomic number to atomic type converter
Args:
mapping (dict): mapping dictionary
"""
self.mapping = mapping

def convert(self, l):
return [self.mapping[str(i)] for i in l]
def convert(self, z_list: list) -> list:
"""
Convert the atomic number list to atomic type list
Args:
z_list (list of integer): atomic number list
Returns: list of integer, atomic type list
"""
return [self.mapping[str(i)] for i in z_list]


def ring_to_vector(l):
def ring_to_vector(z_list: list) -> list:
"""
Convert the ring sizes vector to a fixed length vector
For example, l can be [3, 5, 5], meaning that the atom is involved
in 1 3-sized ring and 2 5-sized ring. This function will convert it into
[ 0, 0, 1, 0, 2, 0, 0, 0, 0, 0].
Args:
l: (list of integer) ring_sizes attributes
z_list: (list of integer) ring_sizes attributes
Returns:
(list of integer) fixed size list with the i-1 th element indicates number of
i-sized ring this atom is involved in.
"""
return_l = [0] * 9
if l:
for i in l:
if z_list:
for i in z_list:
return_l[i - 1] += 1
return return_l
3 changes: 2 additions & 1 deletion megnet/layers/graph/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def rho_e_v(self, e_p, inputs):
fr = tf.gather(atomwise1, index2, axis=1)

after_cfconv = atomwise1 + \
tf.transpose(a=tf.math.segment_sum(tf.transpose(a=fr * cfconv_out, perm=[1, 0, 2]), index1), perm=[1, 0, 2])
tf.transpose(a=tf.math.segment_sum(tf.transpose(
a=fr * cfconv_out, perm=[1, 0, 2]), index1), perm=[1, 0, 2])

atomwise2 = self.activation(self._mlp(after_cfconv, self.phi_v_weights[1], self.phi_v_biases[1]))
atomwise3 = self._mlp(atomwise2, self.phi_v_weights[2], self.phi_v_biases[2])
Expand Down
5 changes: 5 additions & 0 deletions megnet/layers/readout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@
"""
from .set2set import Set2Set
from .linear import LinearWithIndex

__all__ = [
"Set2Set",
"LinearWithIndex"
]
2 changes: 0 additions & 2 deletions megnet/utils/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def _repeat(x: tf.Tensor, n: tf.Tensor, axis: int = 1) -> tf.Tensor:
maxlen = tf.reduce_max(input_tensor=n)
x_shape = tf.shape(input=x)
x_dim = len(x.shape)
# get the length of x
xlen = tf.shape(input=n)[0]
# create a range with the length of x
shape = [1] * (x_dim + 1)
shape[axis + 1] = maxlen
Expand Down

0 comments on commit c5afd99

Please sign in to comment.