Skip to content

Commit

Permalink
Option to set root folder for chemprop web
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Sep 21, 2020
1 parent 6b5cc17 commit 0ff6fd4
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 72 deletions.
12 changes: 9 additions & 3 deletions chemprop/web/app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Runs the web interface version of chemprop, allowing for training and predicting in a web browser."""
import os

from flask import Flask

from chemprop.web.utils import set_root_folder


app = Flask(__name__)
app.config.from_object('chemprop.web.config')

os.makedirs(app.config['CHECKPOINT_FOLDER'], exist_ok=True)
os.makedirs(app.config['DATA_FOLDER'], exist_ok=True)
set_root_folder(
app=app,
root_folder=os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
create_folders=False
)

from chemprop.web.app import views
8 changes: 7 additions & 1 deletion chemprop/web/app/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@
from chemprop.web.app import app


DB_PATH = 'chemprop.sqlite3'


def init_app(app: Flask):
global DB_PATH

app.teardown_appcontext(close_db)
DB_PATH = app.config['DB_PATH']


def init_db():
Expand All @@ -38,7 +44,7 @@ def get_db():
"""
if 'db' not in g:
g.db = sqlite3.connect(
'chemprop.sqlite3',
DB_PATH,
detect_types=sqlite3.PARSE_DECLTYPES
)
g.db.row_factory = sqlite3.Row
Expand Down
14 changes: 1 addition & 13 deletions chemprop/web/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,11 @@
These are accessible in a dictionary, with each line defining a key.
"""

import os
from tempfile import TemporaryDirectory

import torch

_TEMP_FOLDER_OBJECT = TemporaryDirectory()

DEFAULT_USER_ID = 1
if os.access(os.path.dirname(os.path.realpath(__file__)), os.W_OK):
ROOT_FOLDER = os.path.dirname(os.path.realpath(__file__))
elif os.access(os.getcwd(), os.W_OK):
ROOT_FOLDER = os.path.join(os.getcwd(), "chemprop_web_app")
else :
raise ValueError("Failed to find a writable ROOT_FOLDER for web app data and checkpoints.")
DATA_FOLDER = os.path.join(ROOT_FOLDER, 'app/web_data')
CHECKPOINT_FOLDER = os.path.join(ROOT_FOLDER, 'app/web_checkpoints')
TEMP_FOLDER = os.path.join(ROOT_FOLDER, _TEMP_FOLDER_OBJECT.name)

SMILES_FILENAME = 'smiles.csv'
PREDICTIONS_FILENAME = 'predictions.csv'
DB_FILENAME = 'chemprop.sqlite3'
Expand Down
12 changes: 11 additions & 1 deletion chemprop/web/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tap import Tap # pip install typed-argument-parser (https://github.com/swansonk14/typed-argument-parser)

from chemprop.web.app import app, db
from chemprop.web.utils import clear_temp_folder, set_root_folder


class WebArgs(Tap):
Expand All @@ -16,14 +17,23 @@ class WebArgs(Tap):
debug: bool = False # Whether to run in debug mode
demo: bool = False # Display only demo features
initdb: bool = False # Initialize Database
root_folder: str = None # Root folder where web data and checkpoints will be saved (defaults to chemprop/web/app)


def run_web(args: WebArgs) -> None:
app.config['DEMO'] = args.demo

# Set up root folder and subfolders
set_root_folder(
app=app,
root_folder=args.root_folder,
create_folders=True
)
clear_temp_folder(app=app)

db.init_app(app)

if args.initdb or not os.path.isfile(app.config['DB_FILENAME']):
if args.initdb or not os.path.isfile(app.config['DB_PATH']):
with app.app_context():
db.init_db()
print("-- INITIALIZED DATABASE --")
Expand Down
34 changes: 34 additions & 0 deletions chemprop/web/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Contains utility functions for the Flask web app."""

import os
import shutil

from flask import Flask


def set_root_folder(app: Flask, root_folder: str = None, create_folders: bool = True) -> None:
"""
Sets the root folder for the config along with subfolders like the data and checkpoint folders.
:param app: Flask app.
:param root_folder: Path to the root folder. If None, the current root folders is unchanged.
:param create_folders: Whether to create the root folder and subfolders.
"""
# Set root folder and subfolders
if root_folder is not None:
app.config['ROOT_FOLDER'] = root_folder
app.config['DATA_FOLDER'] = os.path.join(app.config['ROOT_FOLDER'], 'app/web_data')
app.config['CHECKPOINT_FOLDER'] = os.path.join(app.config['ROOT_FOLDER'], 'app/web_checkpoints')
app.config['TEMP_FOLDER'] = os.path.join(app.config['ROOT_FOLDER'], 'app/temp')
app.config['DB_PATH'] = os.path.join(app.config['ROOT_FOLDER'], app.config['DB_FILENAME'])

# Create folders
if create_folders:
for folder_name in ['ROOT_FOLDER', 'DATA_FOLDER', 'CHECKPOINT_FOLDER', 'TEMP_FOLDER']:
os.makedirs(app.config[folder_name], exist_ok=True)


def clear_temp_folder(app: Flask) -> None:
"""Clears the temporary folder."""
shutil.rmtree(app.config['TEMP_FOLDER'])
os.makedirs(app.config['TEMP_FOLDER'], exist_ok=True)
9 changes: 9 additions & 0 deletions chemprop/web/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
Designed to be used for production only, along with Gunicorn.
"""
from chemprop.web.app import app, db
from chemprop.web.utils import clear_temp_folder, set_root_folder


def build_app(*args, **kwargs):
# Set up root folder and subfolders
set_root_folder(
app=app,
root_folder=kwargs.get('root_folder', None),
create_folders=True
)
clear_temp_folder(app=app)

db.init_app(app)
if 'init_db' in kwargs:
with app.app_context():
Expand Down
109 changes: 55 additions & 54 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,62 +458,63 @@ def test_interpret_single_task_regression(self,
self.fail(f'Interpretation failed with error: {e}')

def test_chemprop_web(self):
app = build_app(init_db=True)

app.config['TESTING'] = True

data_path = 'regression.csv'
test_path = 'regression_test_smiles.csv'
dataset_name = 'regression_data'
dataset_type = 'regression'
checkpoint_name = 'regression_ckpt'
ckpt_name = data_name = '1'
epochs = 3
ensemble_size = 1

with open(os.path.join(TEST_DATA_DIR, data_path)) as f:
train_data = BytesIO(f.read().encode('utf-8'))

with open(os.path.join(TEST_DATA_DIR, test_path)) as f:
test_smiles = f.read()

with app.test_client() as client:
response = client.get('/')
self.assertEqual(response.status_code, 200)

# Upload data
response = client.post(
url_for('upload_data', return_page='home'),
data={
'dataset': (train_data, data_path),
'datasetName': dataset_name
}
)
self.assertEqual(response.status_code, 302)
with TemporaryDirectory() as root_dir:
app = build_app(root_folder=root_dir, init_db=True)

# Train
response = client.post(
url_for('train'),
data={
'dataName': data_name,
'epochs': epochs,
'ensembleSize': ensemble_size,
'checkpointName': checkpoint_name,
'datasetType': dataset_type,
'useProgressBar': False
}
)
self.assertEqual(response.status_code, 200)
app.config['TESTING'] = True

# Predict
response = client.post(
url_for('predict'),
data={
'checkpointName': ckpt_name,
'textSmiles': test_smiles
}
)
self.assertEqual(response.status_code, 200)
data_path = 'regression.csv'
test_path = 'regression_test_smiles.csv'
dataset_name = 'regression_data'
dataset_type = 'regression'
checkpoint_name = 'regression_ckpt'
ckpt_name = data_name = '1'
epochs = 3
ensemble_size = 1

with open(os.path.join(TEST_DATA_DIR, data_path)) as f:
train_data = BytesIO(f.read().encode('utf-8'))

with open(os.path.join(TEST_DATA_DIR, test_path)) as f:
test_smiles = f.read()

with app.test_client() as client:
response = client.get('/')
self.assertEqual(response.status_code, 200)

# Upload data
response = client.post(
url_for('upload_data', return_page='home'),
data={
'dataset': (train_data, data_path),
'datasetName': dataset_name
}
)
self.assertEqual(response.status_code, 302)

# Train
response = client.post(
url_for('train'),
data={
'dataName': data_name,
'epochs': epochs,
'ensembleSize': ensemble_size,
'checkpointName': checkpoint_name,
'datasetType': dataset_type,
'useProgressBar': False
}
)
self.assertEqual(response.status_code, 200)

# Predict
response = client.post(
url_for('predict'),
data={
'checkpointName': ckpt_name,
'textSmiles': test_smiles
}
)
self.assertEqual(response.status_code, 200)


if __name__ == '__main__':
Expand Down

0 comments on commit 0ff6fd4

Please sign in to comment.