Skip to content

Commit

Permalink
Update to project tool usage per database creation changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Kessler authored and Kessler committed Jun 6, 2019
1 parent 2f2616e commit d048fe0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 29 deletions.
14 changes: 8 additions & 6 deletions docs/usage/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,21 @@ create_db(

## ECNet .prj file usage

Once an ECNet project has been created, the resulting .prj file can be used to predict properties for new molecules. A text file containing names or SMILES strings of new molecules, one per line, is required in addition to the .prj file.
Once an ECNet project has been created, the resulting .prj file can be used to predict properties for new molecules. Just supply SMILES strings, a pre-existing ECNet .prj file, and optionally a path to save the results to:

```python
from ecnet.tools.project import predict

# From a names txt file
predict('molecules.txt', 'results.csv', 'my_project.prj', form='name')
smiles = ['CCC', 'CCCC']

# From a SMILES txt file
predict('smiles.txt', 'results.csv', 'my_project.prj', form='smiles')
# obtain results, do not save to file
results = predict(smiles, 'my_project.prj')

# obtain results, save to file
results = predict(smiles, 'my_project.prj', 'results.csv')
```

Both Open Babel and the Java JRE are required for conversions.
Java JRE 6.0+ is required for conversions.

## Constructing parity plots

Expand Down
40 changes: 22 additions & 18 deletions ecnet/tools/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#

# Stdlib imports
from datetime import datetime
from os import remove
from shutil import rmtree

Expand All @@ -19,29 +20,32 @@
from ecnet.tools.database import create_db


def predict(input_txt: str, results_file: str, prj_file: str, form: str='name',
temp_db: str='_new_mols.csv', clean_up: bool=True):
'''Predicts values for new data using pre-existing .prj file
def predict(smiles: list, prj_file: str, results_file: str=None,
backend: str='padel') -> list:
''' predict: predicts values for supplied molecules (SMILES strings) using
pre-existing ECNet project (.prj) file
Args:
input_txt (str): path to .txt file containing either molecule names or
SMILES strings
results_file (str): path to results file generated by this function
prj_file (str): path to pre-existing .prj file
form (str): `name` if supplying molecule names, `SMILES` if supplying
SMILES strings
temp_db (str): path to temporary database generated by this function
clean_up (bool): if True, cleans up all files generated during this
function (except for input/results files)
smiles (str): SMILES strings for molecules
prj_file (str): path to ECNet .prj file
results_file (str): if not none, saves results to this CSV file
backend (str): `padel` (default) or `alvadesc`, depending on the data
your project was trained with
Returns:
list: predicted values
'''

sv = Server(prj_file=prj_file)
create_db(input_txt, temp_db, form=form, clean_up=clean_up)
new_data = DataFrame(temp_db)

timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')[:-3]
create_db(smiles, '{}.csv'.format(timestamp), backend=backend)
new_data = DataFrame('{}.csv'.format(timestamp))
new_data.set_inputs(sv._df._input_names)
new_data.create_sets()
sv._df = new_data
sv.use(output_filename=results_file)
if clean_up:
remove(temp_db)
rmtree(prj_file.replace('.prj', ''))
sv._sets = sv._df.package_sets()
results = sv.use(output_filename=results_file)
remove('{}.csv'.format(timestamp))
rmtree(prj_file.replace('.prj', ''))
return results
10 changes: 5 additions & 5 deletions tests/tools/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def test_predict(self):
sv.train()
sv.save_project()

with open('smiles.smi', 'w') as smi_file:
smi_file.write('CCC')
smi_file.close()
predict('smiles.smi', 'results.csv', 'test_project', form='smiles')
results = predict(['CCC', 'CCCC'], 'test_project.prj', 'results.csv')
print(results)

self.assertEqual(len(results), 2)
with open('results.csv', 'r') as res_file:
self.assertGreater(len(res_file.read()), 0)
res_file.close()
remove('smiles.smi')

remove('results.csv')
remove('test_project.prj')
remove('config.yml')
Expand Down

0 comments on commit d048fe0

Please sign in to comment.