diff --git a/tdc/test/test_oracles.py b/tdc/test/test_oracles.py index bb2c3a50..8d5ed926 100644 --- a/tdc/test/test_oracles.py +++ b/tdc/test/test_oracles.py @@ -38,6 +38,26 @@ def test_jnk3(self): x = oracle('C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O') assert abs(x - 0.01) < 0.0001 + def test_list_single(self): + from tdc import Oracle + + oracle = Oracle(name='GSK3B') + + x = oracle(['CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1']) + assert abs(x[0] - 0.03) < 0.0001 + + def test_list_multi(self): + from tdc import Oracle + + oracle = Oracle(name='JNK3') + + x = oracle(['CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1', \ + 'CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1', \ + 'C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O']) + assert abs(x[0] - 0.01) < 0.0001 + assert abs(x[1] - 0.0) < 0.0001 + assert abs(x[2] - 0.01) < 0.0001 + # def tearDown(self): # print(os.getcwd()) # shutil.rmtree(os.path.join(os.getcwd(), "data"))