Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JAQPOT-127): jaqpotpy-datasets-refactoring #28

Merged
merged 52 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
e77df82
fix: Deleted datasets
periklis91 Jun 10, 2024
a9a302f
fix: super minor changes
periklis91 Jun 12, 2024
198218d
fix: Old abstract dataset classes
periklis91 Jun 12, 2024
0c29099
feat: New abtract classes
periklis91 Jun 12, 2024
cc92c70
feat: old molecular datasets
periklis91 Jun 12, 2024
286eede
Merge branch 'main' into feat/JAQPOT-127/jaqpotpy-dataset-tests
periklis91 Jun 12, 2024
eedb2e8
feat: deleted Molecular dataset
periklis91 Jun 12, 2024
5c1d272
feat: Added error handling in csv reading
periklis91 Jun 12, 2024
a8c73d2
feat: added dataframe setter, getter, deleter
periklis91 Jun 12, 2024
259b96d
fix: Deleted redundant comments
periklis91 Jun 12, 2024
c3563e5
feat: Deleted SmilesDataset class
periklis91 Jun 12, 2024
0a55f99
feat: error handling for y_cols
periklis91 Jun 13, 2024
fc9ea78
feat: x_cols error handling
periklis91 Jun 13, 2024
ec768a7
feat: smiles_cols error handling
periklis91 Jun 13, 2024
75b55bd
feat: Handling more that one smiles column
periklis91 Jun 13, 2024
b5d67c6
feat: Finalized Jaqpotpy dataset class
periklis91 Jun 13, 2024
4d9bcdf
fix: Smiles parsing
periklis91 Jun 13, 2024
ddcfd0a
feat: Making sure that the user provided y_cols correctly
periklis91 Jun 13, 2024
801d42b
fix: delete redundant code
periklis91 Jun 13, 2024
4592233
fix: deleted unessecary code
periklis91 Jun 14, 2024
8470faa
feat: added extra features to test datasets
periklis91 Jun 14, 2024
1f32f0b
feat: Implemented __get__X and __get__Y methods
periklis91 Jun 14, 2024
3b92bb5
chore: removed module loading in dataset inits
periklis91 Jun 14, 2024
cc874be
Merge branch 'main' into feat/JAQPOT-127/jaqpotpy-dataset-tests
periklis91 Jun 14, 2024
4055ceb
chore: Minor changes in test datasets
periklis91 Jun 14, 2024
fca39fa
feat: Make sure there is no overlap between x_cols, y_cols and smiles…
periklis91 Jun 14, 2024
497309b
feat: Check that all user-specified cols exist
periklis91 Jun 14, 2024
34d928b
test: early-dataset-testing
periklis91 Jun 14, 2024
ac9583a
refactor: more robust error handling
periklis91 Jun 14, 2024
dab8e91
test: Added further dataset tests
periklis91 Jun 14, 2024
f185192
test: df and path test
periklis91 Jun 17, 2024
3339eef
chore: Introduce JapoqtpyDataset
periklis91 Jun 22, 2024
f378af5
fix: module loading issue
periklis91 Jun 22, 2024
49cab03
chore: changed MolecularDataset with JaqpotpyDataset
periklis91 Jun 23, 2024
411da7d
chore: minor dependency fix
periklis91 Jun 23, 2024
4f382a6
chore: minor dependency fix
periklis91 Jun 23, 2024
0fc6879
chore: minor dependency fix
periklis91 Jun 23, 2024
ee36678
chore: Restored dataset.__init__.py
periklis91 Jun 23, 2024
b9c61de
chore: removed wildcard import
periklis91 Jun 23, 2024
c505cc4
chore: renamed doa_m to doa_fitted
periklis91 Jun 23, 2024
3f9b5fc
fix: Compatibility between MolecularSKLearn and JaqpotpyDataset
periklis91 Jun 24, 2024
b0c5f56
chore: rename Preprocesses class to Preprocess
periklis91 Jun 25, 2024
809614b
chore: Rename Preprocesses to Preprocess in all files
periklis91 Jun 25, 2024
d915a6c
feat: Add a .copy() method in JaqpotpyDataset
periklis91 Jun 28, 2024
f6dda97
chore: Return pandas series instead of numpy with __get__X__ and y
periklis91 Jun 28, 2024
99c30e4
fix: Now the .copy() method works
periklis91 Jun 28, 2024
62f609f
chore: comment with tests to be added
periklis91 Jun 28, 2024
cf23a76
refactor: Preprocessing made an object from class, attribute
periklis91 Jun 28, 2024
e4e9fa4
Refactor: Major model refactor
periklis91 Jun 28, 2024
b4679a6
chore: removed references to Molecular SKLearn
periklis91 Jun 28, 2024
efdae02
fix: Fixed error in len(dataset)
periklis91 Jun 28, 2024
7e67761
Merge branch 'main' into feat/JAQPOT-127/jaqpotpy-dataset-tests
periklis91 Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions jaqpotpy/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jaqpotpy.datasets import *
from jaqpotpy.datasets.molecular_datasets import *
from jaqpotpy.datasets.image_datasets import TorchImageDataset
from jaqpotpy.datasets.molecular_datasets import MolecularDataset, MolecularTabularDataset, TorchGraphDataset
from jaqpotpy.datasets.material_datasets import CompositionDataset, StructureDataset
#from jaqpotpy.datasets import *
#from jaqpotpy.datasets.molecular_datasets import *
from .image_datasets import TorchImageDataset
from .molecular_datasets import JaqpotpyDataset, TorchGraphDataset
from .material_datasets import CompositionDataset, StructureDataset
264 changes: 154 additions & 110 deletions jaqpotpy/datasets/dataset_base.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,111 @@
"""
Dataset base classes
Dataset abstract classes
"""
from typing import Any
import inspect
from typing import Iterable
from abc import ABC, abstractmethod
import os
import pickle
from typing import Iterable, Optional
import pandas as pd


class BaseDataset(object):
class BaseDataset(ABC):
"""
Astract class for datasets
Abstract class for datasets. This class defines the common interface and basic functionality
for dataset manipulation and handling.

Attributes:
_df (pd.DataFrame): The underlying DataFrame holding the dataset.
x_cols (Optional[Iterable[str]]): The columns to be used as features.
y_cols (Optional[Iterable[str]]): The columns to be used as labels.
_task (str): The task type, either 'regression' or 'classification'.
_dataset_name (str): The name of the dataset.
_y (Iterable[str]): The labels of the dataset.
_x (Iterable[str]): The features of the dataset.
"""
def __init__(self, path=None, x_cols=None, y_cols=None) -> None:
self._Y = None
self._X = None

def __init__(self, df: pd.DataFrame = None, path: Optional[str] = None,
y_cols: Iterable[str] = None,
x_cols: Optional[Iterable[str]] = None,
task: str = None) -> None:

if df is None and path is None:
raise TypeError("Either a DataFrame or a path to a file must be provided.")
elif (df is not None) and (path is not None):
raise TypeError("Either a DataFrame or a path to a file must be provided.")

if df is not None:
if not isinstance(df, pd.DataFrame):
raise TypeError("Provided 'df' must be a pandas DataFrame.")
else:
self._df = df
self.path = None
elif path is not None:
self.path = path
extension = os.path.splitext(self.path)[1]
if extension == '.csv':
self._df = pd.read_csv(path)
else:
raise ValueError("The provided file is not a valid CSV file.")

if not(isinstance(y_cols, str) or
(isinstance(y_cols, list) and all(isinstance(item, str) for item in y_cols))
):
raise TypeError("y_cols must be provided and should be either"
"a string or a list of strings")

if not(isinstance(x_cols, str) or
(isinstance(x_cols, list) and all(isinstance(item, str) for item in x_cols)) or
(isinstance(x_cols, list) and len(x_cols) == 0) or
(x_cols is None)):
raise TypeError("x_cols should be either a string, an empty list"
"a list of strings, or None")

#Find the length of each provided column name vector and put everything in lists
if isinstance(y_cols, str):
self.y_cols = [y_cols]
self.y_cols_len = 1
elif isinstance(y_cols, list) :
self.y_cols = y_cols
self.x_cols_len = len(y_cols)

if isinstance(x_cols, str):
self.x_cols = [x_cols]
self.y_cols_len = 1
elif isinstance(x_cols, list) :
self.x_cols = x_cols
self.x_cols_len = len(x_cols)
elif x_cols is None:
self.x_cols= []
self.x_cols_len = 0

self.task = task
self._dataset_name = None
self._df = None
self._x_cols_all = None
self.path = path
self.x_cols = x_cols
self.y_cols = y_cols
self._task = "regression"
self.featurizer = None
self._featurizer_name = None
self._external = None
self._y = None
self._x = None

@property
def df(self) -> pd.DataFrame:
return self._df

@df.setter
def df(self, value: pd.DataFrame):
if not isinstance(value, pd.DataFrame):
raise ValueError("The value must be a pandas DataFrame.")
self._df = value

@df.deleter
def df(self):
del self._df

@property
def task(self):
return self._task

@task.setter
def task(self, value):
if value is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggesting to reduce to a single if statement

if value is None or value.lower() not in ['regression', 'classification']:

raise ValueError("Task must be either 'regression' or 'classification'")
elif value.lower() not in ['regression', 'classification']:
raise ValueError("Task must be either 'regression' or 'classification'")
self._task = value

@property
Expand All @@ -41,18 +116,6 @@ def dataset_name(self):
def dataset_name(self, value):
self._dataset_name = value

@property
def featurizer_name(self) -> Iterable[Any]:
return self.featurizer.__name__

@property
def x_colls_all(self) -> Iterable[str]:
return self._x_cols_all

@x_colls_all.setter
def x_colls_all(self, value):
self._x_cols_all = value

@property
def X(self) -> Iterable[str]:
return self._x
Expand All @@ -69,57 +132,6 @@ def y(self) -> Iterable[str]:
def y(self, value):
self._y = value

@property
def external(self) -> Iterable[str]:
return self._external

@external.setter
def external(self, value):
self._external = value

@property
def df(self) -> Any:
return self._df

@df.setter
def df(self, value):
self._df = value

@featurizer_name.setter
def featurizer_name(self, value):
self._featurizer_name = value

def create(self):
raise NotImplementedError("Need implementation")

def __repr__(self) -> str:
# args_spec = inspect.getfullargspec(self.__init__) # type: ignore
# args_names = [arg for arg in args_spec.args if arg != 'self']
# args_info = ''
# for arg_name in args_names:
# value = self.__dict__[arg_name]
# # for str
# if isinstance(value, str):
# value = "'" + value + "'"
# # for list
return self.__class__.__name__


class MolecularDataset(BaseDataset):
def __init__(self, path=None, smiles_col=None, x_cols=None, y_cols=None, smiles=None) -> None:
self.smiles = smiles
self._smiles_strings = None
self.smiles_col = smiles_col
super().__init__(path, x_cols, y_cols)

@property
def smiles_strings(self) -> Iterable[str]:
return self._smiles_strings

@smiles_strings.setter
def smiles_strings(self, value):
self._smiles_strings = value

def save(self):
if self._dataset_name:
with open(self._dataset_name + ".jdata", 'wb') as f:
Expand All @@ -133,14 +145,65 @@ def load(cls, filename):
with open(filename, 'rb') as f:
return pickle.load(f)

@abstractmethod
def create(self):
"""
Creates the dataset.
"""
raise NotImplementedError

@abstractmethod
def __get_X__(self):
"""
Returns the design matrix X.
"""
raise NotImplementedError

@abstractmethod
def __get_Y__(self):
"""
Returns the response Y.
"""
raise NotImplementedError

@abstractmethod
def __repr__(self) -> str:
"""
Returns a string representation of the dataset.
"""
raise NotImplementedError

@abstractmethod
def __len__(self):
"""
Returns the number of samples in the dataset.
"""
raise NotImplementedError

@abstractmethod
def __get__(self, instance, owner):
"""
Gets an attribute of the dataset.
"""
raise NotImplementedError

@abstractmethod
def __getitem__(self, idx):
"""
Gets a sample by index.
"""
raise NotImplementedError


class MaterialDataset(BaseDataset):
def __init__(self, path=None, materials_col=None, x_cols=None, y_cols=None, materials=None) -> None:
def __init__(self, df: pd.DataFrame = None, path: Optional[str] = None,
y_cols: Iterable[str] = None,
x_cols: Optional[Iterable[str]] =None, materials_col=None,
materials=None) -> None:
super().__init__(df = df, path = path, y_cols = y_cols, x_cols = x_cols)
self.materials = materials
self._materials_strings = None
self.materials_col = materials_col
super().__init__(path, x_cols, y_cols)


@property
def materials_strings(self) -> Iterable[str]:
Expand All @@ -150,32 +213,13 @@ def materials_strings(self) -> Iterable[str]:
def materials_strings(self, value):
self._materials_strings = value

def save(self):
if self._dataset_name:
with open(self._dataset_name + ".jdata", 'wb') as f:
pickle.dump(self, f)
else:
with open("jaqpot_dataset" + ".jdata", 'wb') as f:
pickle.dump(self, f)

@classmethod
def load(cls, filename):
with open(filename, 'rb') as f:
return pickle.load(f)

class ImageDataset(BaseDataset):
def __init__(self, path=None, x_cols=None, y_cols=None) -> None:
super().__init__(path=path, x_cols=x_cols, y_cols=y_cols)
def __init__(self, df: pd.DataFrame = None, path: Optional[str] = None,
y_cols: Iterable[str] = None,
x_cols: Optional[Iterable[str]] =None) -> None:
super().__init__(df = df, path = path, y_cols = y_cols, x_cols = x_cols)

def save(self):
if self._dataset_name:
with open(self._dataset_name + ".jdata", 'wb') as f:
pickle.dump(self, f)
else:
with open("jaqpot_dataset" + ".jdata", 'wb') as f:
pickle.dump(self, f)

@classmethod
def load(cls, filename):
with open(filename, 'rb') as f:
return pickle.load(f)
if __name__ == '__main__':
...
Loading
Loading