Skip to content

Commit

Permalink
test: move loading graphs to setUpClass (#1484)
Browse files Browse the repository at this point in the history
* test: move loading graphs to setUpClass

Loading graphs is expensive.

* small fix

* fix a typo

* fix typo
  • Loading branch information
njzjz committed Feb 20, 2022
1 parent 17d1443 commit 45c2feb
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 48 deletions.
52 changes: 36 additions & 16 deletions source/tests/test_deepdipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
default_places = 10

class TestDeepDipolePBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deepdipole.pbtxt")), "deepdipole.pb")
self.dp = DeepDipole("deepdipole.pb")
cls.dp = DeepDipole("deepdipole.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -27,8 +30,10 @@ def setUp(self):
self.box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.])
self.expected_d = np.array([-9.274180565967479195e-01,2.698028341272042496e+00,2.521268387140979117e-01,2.927260638453461628e+00,-8.571926301526779923e-01,1.667785136187720063e+00])

def tearDown(self):
os.remove("deepdipole.pb")
@classmethod
def tearDownClass(cls):
os.remove("deepdipole.pb")
cls.dp = None

def test_attrs(self):
self.assertEqual(self.dp.get_ntypes(), 2)
Expand Down Expand Up @@ -61,9 +66,12 @@ def test_2frame_atm(self):


class TestDeepDipoleNoPBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deepdipole.pbtxt")), "deepdipole.pb")
self.dp = DeepDipole("deepdipole.pb")
cls.dp = DeepDipole("deepdipole.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -74,8 +82,10 @@ def setUp(self):
self.box = np.array([20., 0., 0., 0., 20., 0., 0., 0., 20.])
self.expected_d = np.array([-1.982092647058316e+00, 8.303361089028074e-01, 1.499962003179265e+00, 2.927112547154802e+00, -8.572096473802318e-01, 1.667798310054391e+00])

def tearDown(self):
os.remove("deepdipole.pb")
@classmethod
def tearDownClass(cls):
os.remove("deepdipole.pb")
cls.dp = None

def test_1frame_atm(self):
dd = self.dp.eval(self.coords, None, self.atype)
Expand All @@ -101,9 +111,12 @@ def test_1frame_atm_large_box(self):
@unittest.skipIf(parse_version(tf.__version__) < parse_version("1.15"),
f"The current tf version {tf.__version__} is too low to run the new testing model.")
class TestDeepDipoleNewPBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deepdipole_new.pbtxt")), "deepdipole_new.pb")
self.dp = DeepDipole("deepdipole_new.pb")
cls.dp = DeepDipole("deepdipole_new.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -119,8 +132,10 @@ def setUp(self):
self.expected_gt = self.expected_t.reshape(-1, self.nout).sum(0).reshape(-1)
self.expected_gv = self.expected_v.reshape(1, self.nout, 6, 9).sum(-2).reshape(-1)

def tearDown(self):
os.remove("deepdipole_new.pb")
@classmethod
def tearDownClass(cls):
os.remove("deepdipole_new.pb")
cls.dp = None

def test_attrs(self):
self.assertEqual(self.dp.get_ntypes(), 2)
Expand Down Expand Up @@ -260,9 +275,12 @@ def test_2frame_full_atm(self):
@unittest.skipIf(parse_version(tf.__version__) < parse_version("1.15"),
f"The current tf version {tf.__version__} is too low to run the new testing model.")
class TestDeepDipoleFakePBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deepdipole_fake.pbtxt")), "deepdipole_fake.pb")
self.dp = DeepDipole("deepdipole_fake.pb")
cls.dp = DeepDipole("deepdipole_fake.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -286,8 +304,10 @@ def setUp(self):
fake_target = fake_target - 13 * np.rint(fake_target / 13)
self.target_t = fake_target.reshape(-1)

def tearDown(self):
os.remove("deepdipole_fake.pb")
@classmethod
def tearDownClass(cls):
os.remove("deepdipole_fake.pb")
cls.dp = None

def test_attrs(self):
self.assertEqual(self.dp.get_ntypes(), 2)
Expand Down
37 changes: 26 additions & 11 deletions source/tests/test_deeppolar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
default_places = 10

class TestDeepPolarPBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deeppolar.pbtxt")), "deeppolar.pb")
self.dp = DeepPolar("deeppolar.pb")
cls.dp = DeepPolar("deeppolar.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -27,8 +30,10 @@ def setUp(self):
self.box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.])
self.expected_d = np.array([1.061407927405987051e-01,-3.569013342133873778e-01,-2.862108976089940138e-02,-3.569013342133875444e-01,1.304367268874677244e+00,1.037647501453442256e-01,-2.862108976089940138e-02,1.037647501453441284e-01,8.100521520762453409e-03,1.236797829492216616e+00,-3.717307430531632262e-01,7.371515676976750919e-01,-3.717307430531630041e-01,1.127222682121889058e-01,-2.239181552775717510e-01,7.371515676976746478e-01,-2.239181552775717787e-01,4.448255365635306879e-01])

def tearDown(self):
os.remove("deeppolar.pb")
@classmethod
def tearDownClass(cls):
os.remove("deeppolar.pb")
cls.dp = None

def test_attrs(self):
self.assertEqual(self.dp.get_ntypes(), 2)
Expand Down Expand Up @@ -62,9 +67,12 @@ def test_2frame_atm(self):


class TestDeepPolarNoPBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deeppolar.pbtxt")), "deeppolar.pb")
self.dp = DeepPolar("deeppolar.pb")
cls.dp = DeepPolar("deeppolar.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -75,8 +83,10 @@ def setUp(self):
self.box = np.array([20., 0., 0., 0., 20., 0., 0., 0., 20.])
self.expected_d = np.array([5.601785462021734e-01, -2.346693909765864e-01, -4.239188998286720e-01, -2.346693909765862e-01, 9.830744757127260e-02, 1.775876472255247e-01, -4.239188998286717e-01, 1.775876472255248e-01, 3.208034917622381e-01, 1.302526099276315e+00, -3.784198124746947e-01, 7.548241853986054e-01, -3.784198124746949e-01, 1.098824690874320e-01, -2.194150345809899e-01, 7.548241853986057e-01, -2.194150345809898e-01, 4.382376148484938e-01])

def tearDown(self):
os.remove("deeppolar.pb")
@classmethod
def tearDownClass(cls):
os.remove("deeppolar.pb")
cls.dp = None

def test_1frame_atm(self):
dd = self.dp.eval(self.coords, None, self.atype)
Expand All @@ -102,9 +112,12 @@ def test_1frame_atm_large_box(self):
@unittest.skipIf(parse_version(tf.__version__) < parse_version("1.15"),
f"The current tf version {tf.__version__} is too low to run the new testing model.")
class TestDeepPolarNewPBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deeppolar_new.pbtxt")), "deeppolar_new.pb")
self.dp = DeepPolar("deeppolar_new.pb")
cls.dp = DeepPolar("deeppolar_new.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -120,8 +133,10 @@ def setUp(self):
self.expected_gt = self.expected_t.reshape(-1, self.nout).sum(0).reshape(-1)
self.expected_gv = self.expected_v.reshape(1, self.nout, 6, 9).sum(-2).reshape(-1)

def tearDown(self):
@classmethod
def tearDownClass(cls):
os.remove("deeppolar_new.pb")
cls.dp = None

def test_attrs(self):
self.assertEqual(self.dp.get_ntypes(), 2)
Expand Down
33 changes: 24 additions & 9 deletions source/tests/test_deeppot_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ def test(self):


class TestDeepPotAPBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deeppot.pbtxt")), "deeppot.pb")
self.dp = DeepPot("deeppot.pb")
cls.dp = DeepPot("deeppot.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -89,8 +92,10 @@ def setUp(self):
self.expected_f = np.array([-3.034045420701179663e-01,8.405844663871177014e-01,7.696947487118485642e-02,7.662001266663505117e-01,-1.880601391333554251e-01,-6.183333871091722944e-01,-5.036172391059643427e-01,-6.529525836149027151e-01,5.432962643022043459e-01,6.382357912332115024e-01,-1.748518296794561167e-01,3.457363524891907125e-01,1.286482986991941552e-03,3.757251165286925043e-01,-5.972588700887541124e-01,-5.987006197104716154e-01,-2.004450304880958100e-01,2.495901655353461868e-01])
self.expected_v = np.array([-2.912234126853306959e-01,-3.800610846612756388e-02,2.776624987489437202e-01,-5.053761003913598976e-02,-3.152373041953385746e-01,1.060894290092162379e-01,2.826389131596073745e-01,1.039129970665329250e-01,-2.584378792325942586e-01,-3.121722367954994914e-01,8.483275876786681990e-02,2.524662342344257682e-01,4.142176771106586414e-02,-3.820285230785245428e-02,-2.727311173065460545e-02,2.668859789777112135e-01,-6.448243569420382404e-02,-2.121731470426218846e-01,-8.624335220278558922e-02,-1.809695356746038597e-01,1.529875294531883312e-01,-1.283658185172031341e-01,-1.992682279795223999e-01,1.409924999632362341e-01,1.398322735274434292e-01,1.804318474574856390e-01,-1.470309318999652726e-01,-2.593983661598450730e-01,-4.236536279233147489e-02,3.386387920184946720e-02,-4.174017537818433543e-02,-1.003500282164128260e-01,1.525690815194478966e-01,3.398976109910181037e-02,1.522253908435125536e-01,-2.349125581341701963e-01,9.515545977581392825e-04,-1.643218849228543846e-02,1.993234765412972564e-02,6.027265332209678569e-04,-9.563256398907417355e-02,1.510815124001868293e-01,-7.738094816888557714e-03,1.502832772532304295e-01,-2.380965783745832010e-01,-2.309456719810296654e-01,-6.666961081213038098e-02,7.955566551234216632e-02,-8.099093777937517447e-02,-3.386641099800401927e-02,4.447884755740908608e-02,1.008593228579038742e-01,4.556718179228393811e-02,-6.078081273849572641e-02])

def tearDown(self):
@classmethod
def tearDownClass(cls):
os.remove("deeppot.pb")
cls.dp = None

def test_attrs(self):
self.assertEqual(self.dp.get_ntypes(), 2)
Expand Down Expand Up @@ -160,9 +165,12 @@ def test_2frame_atm(self):


class TestDeepPotANoPBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deeppot.pbtxt")), "deeppot.pb")
self.dp = DeepPot("deeppot.pb")
cls.dp = DeepPot("deeppot.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -175,8 +183,10 @@ def setUp(self):
self.expected_f = np.array([-2.161037360255332107e+00,9.052994347015581589e-01,1.635379623977007979e+00,2.161037360255332107e+00,-9.052994347015581589e-01,-1.635379623977007979e+00,-1.167128117249453811e-02,1.371975700096064992e-03,-1.575265180249604477e-03,6.226508593971802341e-01,-1.816734122009256991e-01,3.561766019664774907e-01,-1.406075393906316626e-02,3.789140061530929526e-01,-6.018777878642909140e-01,-5.969188242856223736e-01,-1.986125696522633155e-01,2.472764510780630642e-01])
self.expected_v = np.array([-7.042445481792056761e-01,2.950213647777754078e-01,5.329418202437231633e-01,2.950213647777752968e-01,-1.235900311906896754e-01,-2.232594111831812944e-01,5.329418202437232743e-01,-2.232594111831813499e-01,-4.033073234276823849e-01,-8.949230984097404917e-01,3.749002169013777030e-01,6.772391014992630298e-01,3.749002169013777586e-01,-1.570527935667933583e-01,-2.837082722496912512e-01,6.772391014992631408e-01,-2.837082722496912512e-01,-5.125052659994422388e-01,4.858210330291591605e-02,-6.902596153269104431e-03,6.682612642430500391e-03,-5.612247004554610057e-03,9.767795567660207592e-04,-9.773758942738038254e-04,5.638322117219018645e-03,-9.483806049779926932e-04,8.493873281881353637e-04,-2.941738570564985666e-01,-4.482529909499673171e-02,4.091569840186781021e-02,-4.509020615859140463e-02,-1.013919988807244071e-01,1.551440772665269030e-01,4.181857726606644232e-02,1.547200233064863484e-01,-2.398213304685777592e-01,-3.218625798524068354e-02,-1.012438450438508421e-02,1.271639330380921855e-02,3.072814938490859779e-03,-9.556241797915024372e-02,1.512251983492413077e-01,-8.277872384009607454e-03,1.505412040827929787e-01,-2.386150620881526407e-01,-2.312295470054945568e-01,-6.631490213524345034e-02,7.932427266386249398e-02,-8.053754366323923053e-02,-3.294595881137418747e-02,4.342495071150231922e-02,1.004599500126941436e-01,4.450400364869536163e-02,-5.951077548033092968e-02])

def tearDown(self):
@classmethod
def tearDownClass(cls):
os.remove("deeppot.pb")
cls.dp = None

def test_1frame(self):
ee, ff, vv = self.dp.eval(self.coords, self.box, self.atype, atomic = False)
Expand Down Expand Up @@ -238,9 +248,12 @@ def test_2frame_atm(self):


class TestDeepPotALargeBoxNoPBC(unittest.TestCase) :
def setUp(self):
@classmethod
def setUpClass(cls):
convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deeppot.pbtxt")), "deeppot.pb")
self.dp = DeepPot("deeppot.pb")
cls.dp = DeepPot("deeppot.pb")

def setUp(self):
self.coords = np.array([12.83, 2.56, 2.18,
12.09, 2.87, 2.74,
00.25, 3.32, 1.68,
Expand All @@ -253,8 +266,10 @@ def setUp(self):
self.expected_f = np.array([-2.161037360255332107e+00,9.052994347015581589e-01,1.635379623977007979e+00,2.161037360255332107e+00,-9.052994347015581589e-01,-1.635379623977007979e+00,-1.167128117249453811e-02,1.371975700096064992e-03,-1.575265180249604477e-03,6.226508593971802341e-01,-1.816734122009256991e-01,3.561766019664774907e-01,-1.406075393906316626e-02,3.789140061530929526e-01,-6.018777878642909140e-01,-5.969188242856223736e-01,-1.986125696522633155e-01,2.472764510780630642e-01])
self.expected_v = np.array([-7.042445481792056761e-01,2.950213647777754078e-01,5.329418202437231633e-01,2.950213647777752968e-01,-1.235900311906896754e-01,-2.232594111831812944e-01,5.329418202437232743e-01,-2.232594111831813499e-01,-4.033073234276823849e-01,-8.949230984097404917e-01,3.749002169013777030e-01,6.772391014992630298e-01,3.749002169013777586e-01,-1.570527935667933583e-01,-2.837082722496912512e-01,6.772391014992631408e-01,-2.837082722496912512e-01,-5.125052659994422388e-01,4.858210330291591605e-02,-6.902596153269104431e-03,6.682612642430500391e-03,-5.612247004554610057e-03,9.767795567660207592e-04,-9.773758942738038254e-04,5.638322117219018645e-03,-9.483806049779926932e-04,8.493873281881353637e-04,-2.941738570564985666e-01,-4.482529909499673171e-02,4.091569840186781021e-02,-4.509020615859140463e-02,-1.013919988807244071e-01,1.551440772665269030e-01,4.181857726606644232e-02,1.547200233064863484e-01,-2.398213304685777592e-01,-3.218625798524068354e-02,-1.012438450438508421e-02,1.271639330380921855e-02,3.072814938490859779e-03,-9.556241797915024372e-02,1.512251983492413077e-01,-8.277872384009607454e-03,1.505412040827929787e-01,-2.386150620881526407e-01,-2.312295470054945568e-01,-6.631490213524345034e-02,7.932427266386249398e-02,-8.053754366323923053e-02,-3.294595881137418747e-02,4.342495071150231922e-02,1.004599500126941436e-01,4.450400364869536163e-02,-5.951077548033092968e-02])

def tearDown(self):
@classmethod
def tearDownClass(cls):
os.remove("deeppot.pb")
cls.dp = None

def test_1frame(self):
ee, ff, vv = self.dp.eval(self.coords, self.box, self.atype, atomic = False)
Expand Down
Loading

0 comments on commit 45c2feb

Please sign in to comment.