Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Integration of torch models in main #34

Open
wants to merge 104 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
ffb9af2
feat: Add Torch graph models class implementations
jpitoskas May 15, 2024
69d0c8f
feat: Move torch-related code in jaqpotpy_torch/
jpitoskas May 15, 2024
2972b63
feat: Add Smiles Graph Featurizer
jpitoskas May 15, 2024
2098f95
feat: Add Smiles Graph Dataset
jpitoskas May 15, 2024
35ed377
refactor: Rename subdirectories in jaqpotpy_torch
jpitoskas May 16, 2024
05a8791
feat: Add trainers package for torch models
jpitoskas May 16, 2024
6fd3be3
feat: Extended trainers package for torch models
jpitoskas May 20, 2024
6aa65ba
feat: Implement Binary & Regression Graph Trainers
jpitoskas May 20, 2024
72dc45a
feat: Implement SmilesGraphDatasetWithExternal
jpitoskas May 23, 2024
9bbd19e
refactor: Add abstract class Featurizer
jpitoskas May 23, 2024
b26555c
feat: Implement Fully Connected Network model
jpitoskas May 23, 2024
abf26a8
feat: Implement graph models with external feats
jpitoskas May 23, 2024
59895d8
fix: Circular import error and add super()
jpitoskas May 23, 2024
611430c
fix: Change "mu" to "mean" for consistency
jpitoskas May 23, 2024
4a4e5f8
feat: Deployment for torch models
jpitoskas Jun 3, 2024
044c87d
feat: Deploy for external and changed logic
jpitoskas Jun 4, 2024
686106e
fix: Fix log_filepath var name & rm unused libs
jpitoskas Jun 5, 2024
6b683aa
Merge remote-tracking branch 'origin/main' into feat/JAQPOT-62/torch-…
jpitoskas Jun 6, 2024
c3227d2
Merge branch 'main' into feat/JAQPOT-62/torch-graph-training
jpitoskas Jun 14, 2024
9d5a50b
Merge branch 'feat/JAQPOT-62/torch-graph-training' of https://github.…
jpitoskas Jun 14, 2024
dedceae
refactor: Change dir structure of trainers
jpitoskas Jun 14, 2024
e863a12
feat: Fully functional torch model upload
jpitoskas Jun 14, 2024
26fb527
feat: Method to get all installed packages in env
jpitoskas Jun 14, 2024
3a81ef6
feat: Implement FC Trainer for solely external
jpitoskas Jun 14, 2024
927d84d
feat: Add bond featurs as edge_attr
jpitoskas Jun 19, 2024
c61006c
feat: Make categorical values as str
jpitoskas Jun 19, 2024
7eee23f
feat: Implement Multiclass Trainer
jpitoskas Jun 19, 2024
5dbe355
fix: Model type of multiclass fc model trainer
jpitoskas Jun 19, 2024
0e68b16
fix: multiclass_fc_model_trainer.py rename typo
jpitoskas Jun 19, 2024
f026247
fix: Add zero_division in both precision & recall
jpitoskas Jun 19, 2024
4c7725f
fix: Fix args types of all networks/models
jpitoskas Jun 19, 2024
31cfe29
fix: Confusion Matrix return as matrix not vector
jpitoskas Jun 21, 2024
8c5333f
fix: Two minor fixes in metrics
jpitoskas Jun 21, 2024
5600dea
feat: Remove labels from conf_mat of binary
jpitoskas Jun 21, 2024
9177278
refactor: Add headers to all branch files
jpitoskas Jun 22, 2024
374e11e
refactor: Add documentation to models
jpitoskas Jun 22, 2024
e23581a
refactor: Fix in docs of models
jpitoskas Jun 22, 2024
250575a
refactor: Add documentation to trainers
jpitoskas Jun 22, 2024
d92c875
feat: Add scheduler to trainer
jpitoskas Jun 26, 2024
5cb1b00
refactor: Add docstring to smiles graph featurizer
jpitoskas Jun 26, 2024
1927cfb
fix: Forgot scheduler to regression_model_trainer
jpitoskas Jun 26, 2024
2482033
fix: Edge dim to transformer graph network
jpitoskas Jun 26, 2024
f5762ce
refactor: Add docstrings to Datasets
jpitoskas Jun 26, 2024
3f3259a
refactor: SmilesGraphFeaturizer code in docs
jpitoskas Jun 26, 2024
4d04f16
refactor: Refactor docs for featurizer and dataset
jpitoskas Jun 26, 2024
498bd7f
refactor: Add docs for models etc
jpitoskas Jun 28, 2024
53bdb32
fix: Move model to device in Trainer constructor
jpitoskas Jun 28, 2024
eb2a165
refactor: More fixes in docs
jpitoskas Jun 28, 2024
0599703
feat: Add Torch graph models class implementations
jpitoskas May 15, 2024
52a0a8b
feat: Move torch-related code in jaqpotpy_torch/
jpitoskas May 15, 2024
7e94fd1
feat: Add Smiles Graph Featurizer
jpitoskas May 15, 2024
6929bdc
feat: Add Smiles Graph Dataset
jpitoskas May 15, 2024
fee07df
refactor: Rename subdirectories in jaqpotpy_torch
jpitoskas May 16, 2024
db0764c
feat: Add trainers package for torch models
jpitoskas May 16, 2024
34e6f8f
feat: Extended trainers package for torch models
jpitoskas May 20, 2024
fee812e
feat: Implement Binary & Regression Graph Trainers
jpitoskas May 20, 2024
1c8fe26
feat: Implement SmilesGraphDatasetWithExternal
jpitoskas May 23, 2024
f08cec6
refactor: Add abstract class Featurizer
jpitoskas May 23, 2024
3be83e0
feat: Implement Fully Connected Network model
jpitoskas May 23, 2024
1f459d7
feat: Implement graph models with external feats
jpitoskas May 23, 2024
dc403a9
fix: Circular import error and add super()
jpitoskas May 23, 2024
83f5031
fix: Change "mu" to "mean" for consistency
jpitoskas May 23, 2024
7d5a3da
feat: Deployment for torch models
jpitoskas Jun 3, 2024
73c13d0
feat: Deploy for external and changed logic
jpitoskas Jun 4, 2024
e55c4a7
fix: Fix log_filepath var name & rm unused libs
jpitoskas Jun 5, 2024
2c54bd6
refactor: Change dir structure of trainers
jpitoskas Jun 14, 2024
faefa2a
feat: Fully functional torch model upload
jpitoskas Jun 14, 2024
89a071d
feat: Method to get all installed packages in env
jpitoskas Jun 14, 2024
126a53f
feat: Implement FC Trainer for solely external
jpitoskas Jun 14, 2024
35e45a1
feat: Add bond featurs as edge_attr
jpitoskas Jun 19, 2024
d30a668
feat: Make categorical values as str
jpitoskas Jun 19, 2024
a31a98c
feat: Implement Multiclass Trainer
jpitoskas Jun 19, 2024
b192d9b
fix: Model type of multiclass fc model trainer
jpitoskas Jun 19, 2024
b6c4c5c
fix: multiclass_fc_model_trainer.py rename typo
jpitoskas Jun 19, 2024
4510307
fix: Add zero_division in both precision & recall
jpitoskas Jun 19, 2024
3dbabfe
fix: Fix args types of all networks/models
jpitoskas Jun 19, 2024
08ff203
fix: Confusion Matrix return as matrix not vector
jpitoskas Jun 21, 2024
ddcd72b
fix: Two minor fixes in metrics
jpitoskas Jun 21, 2024
f217b38
feat: Remove labels from conf_mat of binary
jpitoskas Jun 21, 2024
a933dc1
refactor: Add headers to all branch files
jpitoskas Jun 22, 2024
2aabc6d
refactor: Add documentation to models
jpitoskas Jun 22, 2024
f344713
refactor: Fix in docs of models
jpitoskas Jun 22, 2024
4ee8c11
refactor: Add documentation to trainers
jpitoskas Jun 22, 2024
9fd1b39
feat: Add scheduler to trainer
jpitoskas Jun 26, 2024
16b9046
refactor: Add docstring to smiles graph featurizer
jpitoskas Jun 26, 2024
78b6d3e
fix: Forgot scheduler to regression_model_trainer
jpitoskas Jun 26, 2024
9828db8
fix: Edge dim to transformer graph network
jpitoskas Jun 26, 2024
d7c30f1
refactor: Add docstrings to Datasets
jpitoskas Jun 26, 2024
d2431fd
refactor: SmilesGraphFeaturizer code in docs
jpitoskas Jun 26, 2024
a1c4922
refactor: Refactor docs for featurizer and dataset
jpitoskas Jun 26, 2024
1f89d87
refactor: Add docs for models etc
jpitoskas Jun 28, 2024
b7a4462
fix: Move model to device in Trainer constructor
jpitoskas Jun 28, 2024
56a6c11
refactor: More fixes in docs
jpitoskas Jun 28, 2024
1fe1096
Merge branch 'feat/JAQPOT-62/torch-graph-training' of https://github.…
jpitoskas Jun 28, 2024
3bb3fcb
refactor: Change all 'Argument:' to 'Args' in docs
jpitoskas Jun 28, 2024
d5ce368
refactor: Fix some docs and add model_type names
jpitoskas Jun 28, 2024
0f3cebf
refactor: Change 'NUMERICAL' to 'FLOAT'
jpitoskas Jun 28, 2024
38d4226
chore: Add port 8002, to match jaqpotpy-inference
jpitoskas Jun 28, 2024
a67ec47
fix: __len__() in SmilesGraphDatasetWithExternal
jpitoskas Jun 28, 2024
ee19ef3
fix: Add model back to cpu before torch.jit.script
jpitoskas Jun 30, 2024
951d0ff
chore: Change format of creator in schemas
jpitoskas Jun 30, 2024
dd6a050
fix: Forgot to change a 'NUMERICAL' to 'FLOAT'
jpitoskas Jun 30, 2024
0db9db0
Merge branch 'main' into feat/JAQPOT-62/torch-graph-training
jpitoskas Jun 30, 2024
658536d
chore: Temporarily add docs_jaqpotpy_torch
jpitoskas Jun 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,9 @@ jaqpotpy/models/.DS_Store
jaqpotpy/models/.DS_Store
.DS_Store
jaqpotpy/models/tests/.DS_Store

*.pkl
*.json
venv-jaqpotpy/
*.txt
*.pt
8 changes: 8 additions & 0 deletions docs_jaqpotpy_torch/jaqpotpy_torch/datasets/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Module jaqpotpy_torch.datasets
==============================
Author: Ioannis Pitoskas (jpitoskas@gmail.com)

Sub-modules
-----------
* jaqpotpy_torch.datasets.smiles_graph_dataset
* jaqpotpy_torch.datasets.tabular_dataset
111 changes: 111 additions & 0 deletions docs_jaqpotpy_torch/jaqpotpy_torch/datasets/smiles_graph_dataset.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
Module jaqpotpy_torch.datasets.smiles_graph_dataset
===================================================
Author: Ioannis Pitoskas (jpitoskas@gmail.com)

Classes
-------

`SmilesGraphDataset(smiles, y=None, featurizer=None)`
: A PyTorch Dataset class for handling SMILES strings as graphs.
This class overrides `__getitem__` and `__len__` (check source code for methods' docstrings).

Attributes:
smiles (list): A list of SMILES strings.
y (list, optional): A list of target values.
featurizer (SmilesGraphFeaturizer): The object to transform SMILES strings into graph representations.
precomputed_features (list, optional): A list of precomputed features. If precompute_featurization() is not called, this attribute remains None.

The SmilesGraphDataset constructor.

Args:
smiles (list): A list of SMILES strings.
y (list, optional): A list of target values. Default is None.
featurizer (SmilesGraphFeaturizer, optional): A featurizer object for to create graph representations from SMILES strings.

Example:
```
>>> from jaqpotpy.jaqpotpy_torch.featurizers import SmilesGraphFeaturizer
>>> from rdkit import Chem
>>>
>>> smiles = ['C1=CN=CN1', 'CCCCCCC']
>>> y = [0, 1]
>>> featurizer = SmilesGraphFeaturizer()
>>> featurizer.add_atom_characteristic('symbol', ['C', 'O', 'Na', 'Cl', 'UNK'])
>>> featurizer.add_atom_characteristic('is_in_ring')
>>> featurizer.add_bond_characteristic('bond_type', [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE])
>>>
>>> dataset = SmilesGraphDataset(smiles, y=y, featurizer=featurizer)
>>> dataset[0]
Data(x=[5, 6], edge_index=[2, 10], edge_attr=[10, 3], y=0, smiles='C1=CN=CN1')
```

### Ancestors (in MRO)

* torch.utils.data.dataset.Dataset
* typing.Generic

### Descendants

* jaqpotpy_torch.datasets.smiles_graph_dataset.SmilesGraphDatasetWithExternal

### Methods

`get_atom_feature_labels(self)`
: Returns the atom feature labels.

Returns:
list: A list of atom feature labels.

`get_bond_feature_labels(self)`
: Returns the bond feature labels.

Returns:
list: A list of bond feature labels.

`get_num_edge_features(self)`
: Returns the number of edge features.

Returns:
int: Number of edge features.

`get_num_node_features(self)`
: Returns the number of node features.

Returns:
int: Number of node features.

`precompute_featurization(self)`
: Precomputes the featurization of the dataset.

`SmilesGraphDatasetWithExternal(smiles, external, y=None, featurizer=None)`
: A PyTorch Dataset class for handling SMILES strings as graphs, along with additional external features.
This class inherits from SmilesGraphDataset and overrides `__getitem__` and `__len__` (check source code for methods' docstrings).

Attributes:
smiles (list): A list of SMILES strings.
y (list, optional): A list of target values.
featurizer (SmilesGraphFeaturizer): The object to transform SMILES strings into graph representations.
precomputed_features (list, optional): A list of precomputed features. If precompute_featurization() is not called, this attribute remains None.
external (torch.Tensor): A 2D tensor containing the external features.

The SmilesGraphDatasetWithExternal constructor.

Args:
smiles (list): A list of SMILES strings.
external (numpy.ndarray or pandas.DataFrame): External feature data 2D matrix.
y (list, optional): A list of target values. Default is None.
featurizer (SmilesGraphFeaturizer, optional): A featurizer object for to create graph representations from SMILES strings. Default is None.

### Ancestors (in MRO)

* jaqpotpy_torch.datasets.smiles_graph_dataset.SmilesGraphDataset
* torch.utils.data.dataset.Dataset
* typing.Generic

### Methods

`get_num_external_features(self)`
: Returns the number of external features.

Returns:
int: Number of external features.
42 changes: 42 additions & 0 deletions docs_jaqpotpy_torch/jaqpotpy_torch/datasets/tabular_dataset.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
Module jaqpotpy_torch.datasets.tabular_dataset
==============================================
Author: Ioannis Pitoskas (jpitoskas@gmail.com)

Classes
-------

`TabularDataset(X, y=None)`
: A PyTorch Dataset class for handling tabular data.

Attributes:
X (torch.tensor): A 2D tensor containing the feature data.
y (torch.tensor): A 1D tensor containing the target data.

The TabularDataset constructor.

Args:
X (numpy.ndarray or pandas.DataFrame): Feature data matrix of shape (n_samples, n_features).
y (numpy.ndarray or pandas.DataFrame, optional): Target data of shape (n_samples,).

Example:
```
>>> import numpy as np
>>> X = np.random.rand(3, 2)
>>> y = np.random.rand(3, 2)
>>> dataset = TabularDataset(X, y=y)
>>> dataset[0]
(tensor([0.7778, 0.3400]), tensor(0.4730))
```

### Ancestors (in MRO)

* torch.utils.data.dataset.Dataset
* typing.Generic

### Methods

`get_num_features(self)`
: Returns the number of features in the dataset.

Returns:
int: Number of features.
25 changes: 25 additions & 0 deletions docs_jaqpotpy_torch/jaqpotpy_torch/featurizers/featurizer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Module jaqpotpy_torch.featurizers.featurizer
============================================
Author: Ioannis Pitoskas (jpitoskas@gmail.com)

Classes
-------

`Featurizer()`
: Abstract base class for featurizers.

### Ancestors (in MRO)

* abc.ABC

### Descendants

* jaqpotpy_torch.featurizers.smiles_graph_featurizer.SmilesGraphFeaturizer

### Methods

`featurize(self, *args, **kwargs)`
: Abstract method to featurize the input data.

Returns:
The featurized data.
8 changes: 8 additions & 0 deletions docs_jaqpotpy_torch/jaqpotpy_torch/featurizers/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Module jaqpotpy_torch.featurizers
=================================
Author: Ioannis Pitoskas (jpitoskas@gmail.com)

Sub-modules
-----------
* jaqpotpy_torch.featurizers.featurizer
* jaqpotpy_torch.featurizers.smiles_graph_featurizer
Loading
Loading