Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V2: improve CLI for multitask #647

Merged
merged 34 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7299ffc
Initializing tensor for loc and scale required for loading model
hwpang Feb 13, 2024
4f12a18
dimension should be 1 x number of tasks
hwpang Feb 13, 2024
89d15d0
Add test data for regression mol multitask
hwpang Feb 13, 2024
f3dc05b
Add example model for cli test
hwpang Feb 13, 2024
87eb0ad
Add cli test for regression mol multitask
hwpang Feb 13, 2024
70415f8
Add target columns as predict input to use as column header
hwpang Feb 14, 2024
0cc1235
Merge branch 'v2/dev' into v2/cli/multitask
hwpang Feb 22, 2024
5c9efb6
Also test case without target columns
hwpang Feb 22, 2024
0ad745c
Add multitask mve cli test
hwpang Feb 22, 2024
19dcd32
Scale all regression task (including regression-mve, etc)
hwpang Feb 22, 2024
c4e8079
Add test to train mve
hwpang Feb 22, 2024
5905b5b
Add test for training regression-mve model
hwpang Feb 22, 2024
3f4957c
Scale all regression task, use scaler=None if not
hwpang Feb 22, 2024
660cdfc
Evaluate on the mean, not var
hwpang Feb 22, 2024
61a38fb
Merge branch 'v2/mve' into v2/cli/multitask
hwpang Feb 22, 2024
63d3741
Merge branch 'v2/dev' into v2/cli/multitask
hwpang Feb 23, 2024
d5d4f32
Remove unnecessary to list
hwpang Feb 27, 2024
647cbb5
Remove mve for now as it is not in the goal of v2.0
hwpang Feb 27, 2024
9d2ea14
Remove mve
hwpang Feb 27, 2024
ebb5e29
Formatting
hwpang Feb 27, 2024
6035d6d
Fix dimension
hwpang Feb 27, 2024
61f9c14
Remove changes related to MVE as it's for v2.1
hwpang Feb 28, 2024
867b0fb
Change type check
hwpang Feb 28, 2024
0477271
Merge branch 'v2/dev' into v2/cli/multitask
hwpang Feb 28, 2024
95154aa
Remove mve related changes
hwpang Feb 28, 2024
bed4d46
Type check integer
hwpang Feb 28, 2024
d669910
Merge branch 'v2/dev' into v2/cli/multitask
hwpang Feb 29, 2024
0a74476
Formatting
hwpang Feb 29, 2024
8bc682b
Remove merge artifects
hwpang Feb 29, 2024
7f91eac
Update example model file
hwpang Feb 29, 2024
b345795
Use float in type check
hwpang Feb 29, 2024
c85a814
loc and scale need . for python to recognize them as float
hwpang Feb 29, 2024
d819f36
Remove unused MveFFN
hwpang Feb 29, 2024
1027006
Merge branch 'v2/dev' into v2/cli/multitask
hwpang Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 16 additions & 1 deletion chemprop/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def add_predict_args(parser: ArgumentParser) -> ArgumentParser:
required=True,
help="Path to a pretrained model checkpoint (.ckpt) or a pretrained model file (.pt).",
)
parser.add_argument(
"--target-columns",
nargs="+",
help="Column names to save the predictions to. If not provided, the predictions will be saved to columns named 'pred_0', 'pred_1', etc.",
)

# TODO: add uncertainty and calibration
# unc_args = parser.add_argument_group("uncertainty and calibration args")
Expand Down Expand Up @@ -264,7 +269,17 @@ def main(args):
preds = torch.concat(predss, 0)
if isinstance(model.predictor, MulticlassClassificationFFN):
preds = torch.argmax(preds, dim=-1)
target_columns = [f"pred_{i}" for i in range(preds.shape[1])] # TODO: need to improve this

if args.target_columns is not None:
assert (
len(args.target_columns) == model.n_tasks
), "Number of target columns must match the number of tasks."
target_columns = args.target_columns
else:
target_columns = [
f"pred_{i}" for i in range(preds.shape[1])
] # TODO: need to improve this for cases like multi-task MVE and multi-task multiclass

df_test[target_columns] = preds
if args.output.suffix == ".pkl":
df_test = df_test.reset_index(drop=True)
Expand Down
2 changes: 1 addition & 1 deletion chemprop/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def main(args):
if args.save_smiles_splits:
save_smiles_splits(args, output_dir, train_dset, val_dset, test_dset)

if args.task_type == "regression":
if "regression" in args.task_type:
scaler = train_dset.normalize_targets()
val_dset.normalize_targets(scaler)
logger.info(f"Train data: mean = {scaler.mean_} | std = {scaler.scale_}")
Expand Down
2 changes: 1 addition & 1 deletion chemprop/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from chemprop.data import TrainingBatch, BatchMolGraph
from chemprop.nn.metrics import Metric
from chemprop.nn import MessagePassing, Aggregation, Predictor, LossFunction
from chemprop.nn import MessagePassing, Aggregation, Predictor, LossFunction, MveFFN
hwpang marked this conversation as resolved.
Show resolved Hide resolved
from chemprop.schedulers import NoamLR


Expand Down
17 changes: 13 additions & 4 deletions chemprop/nn/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,22 @@ def __init__(
dropout: float = 0,
activation: str = "relu",
criterion: LossFunction | None = None,
loc: float | Tensor = 0,
scale: float | Tensor = 1,
loc: float | Tensor = 0.,
scale: float | Tensor = 1.,
):
super().__init__(n_tasks, input_dim, hidden_dim, n_layers, dropout, activation, criterion)

self.register_buffer("loc", torch.tensor(loc).view(1, -1))
self.register_buffer("scale", torch.tensor(scale).view(1, -1))
if isinstance(loc, float):
loc = torch.ones(1, self.n_tasks) * loc
else:
loc = torch.tensor(loc).view(1, -1)
self.register_buffer("loc", loc)

if isinstance(scale, float):
scale = torch.ones(1, self.n_tasks) * scale
else:
scale = torch.tensor(scale).view(1, -1)
self.register_buffer("scale", scale)

def forward(self, Z: Tensor) -> Tensor:
Y = super().forward(Z)
Expand Down
60 changes: 60 additions & 0 deletions tests/cli/test_cli_regression_mol_multitask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""This tests the CLI functionality of training and predicting a regression model on a single molecule.
"""

import pytest

from chemprop.cli.main import main

pytestmark = pytest.mark.CLI


@pytest.fixture
def data_path(data_dir):
return str(data_dir / "regression" / "mol_multitask.csv")


@pytest.fixture
def model_path(data_dir):
return str(data_dir / "example_model_v2_regression_mol_multitask.pt")


def test_train_quick(monkeypatch, data_path):
args = ["chemprop", "train", "-i", data_path, "--epochs", "1", "--num-workers", "0"]

with monkeypatch.context() as m:
m.setattr("sys.argv", args)
main()


def test_predict_quick(monkeypatch, data_path, model_path):
args = [
"chemprop",
"predict",
"-i",
data_path,
"--model-path",
model_path,
"--target-columns",
"mu",
"alpha",
"homo",
"lumo",
"gap",
"r2",
"zpve",
"cv",
"u0",
"u298",
"h298",
"g298",
]

with monkeypatch.context() as m:
m.setattr("sys.argv", args)
main()

args = ["chemprop", "predict", "-i", data_path, "--model-path", model_path]

with monkeypatch.context() as m:
m.setattr("sys.argv", args)
main()
Binary file not shown.