-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
936 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
input_data_path: "../../data/raw/heart_cleveland_upload.csv" | ||
input_model_path: "../../models/model.pkl" | ||
output_predictions_path: "../../data/predictions/predictions.csv" | ||
feature_params: | ||
categorical_features: | ||
- 'sex' | ||
- 'cp' | ||
- 'fbs' | ||
- 'restecg' | ||
- 'exang' | ||
- 'slope' | ||
- 'ca' | ||
- 'thal' | ||
numerical_features: | ||
- 'age' | ||
- 'trestbps' | ||
- 'chol' | ||
- 'thalach' | ||
- 'oldpeak' | ||
features_to_drop: | ||
target: 'condition' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
input_data_path: "../../data/raw/heart_cleveland_upload.csv" | ||
output_model_path: "../../models/model.pkl" | ||
splitting_params: | ||
val_size: 0.2 | ||
random_state: 42 | ||
model_params: | ||
model: "RF" | ||
n_estimators: 100 | ||
random_state: 42 | ||
feature_params: | ||
categorical_features: | ||
- 'sex' | ||
- 'cp' | ||
- 'fbs' | ||
- 'restecg' | ||
- 'exang' | ||
- 'slope' | ||
- 'ca' | ||
- 'thal' | ||
numerical_features: | ||
- 'age' | ||
- 'trestbps' | ||
- 'chol' | ||
- 'thalach' | ||
- 'oldpeak' | ||
features_to_drop: | ||
target: 'condition' |
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
antlr4-python3-runtime==4.9.3 | ||
click==8.1.3 | ||
hydra-core==1.2.0 | ||
joblib==1.2.0 | ||
numpy==1.23.4 | ||
omegaconf==2.2.3 | ||
packaging==21.3 | ||
pandas==1.5.1 | ||
pyparsing==3.0.9 | ||
python-dateutil==2.8.2 | ||
pytz==2022.5 | ||
PyYAML==6.0 | ||
scikit-learn==1.1.3 | ||
scipy==1.9.3 | ||
six==1.16.0 | ||
sklearn==0.0 | ||
threadpoolctl==3.1.0 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Scripts to download or generate data""" | ||
|
||
import pandas as pd | ||
from sklearn.model_selection import train_test_split | ||
from src.enities.splitting_params import SplittingParams | ||
|
||
|
||
def read_data(data_path: str) -> pd.DataFrame: | ||
"""Reads data from data_path and return pd.DataFrame""" | ||
data = pd.read_csv(data_path) | ||
return data | ||
|
||
|
||
def train_val_split( | ||
data: pd.DataFrame, split_params: SplittingParams | ||
) -> [pd.DataFrame, pd.DataFrame]: | ||
"""Function to split data according to splitting parameters""" | ||
train, val = train_test_split(data, test_size=split_params.val_size, | ||
random_state=split_params.random_state) | ||
|
||
return train, val |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
|
||
@dataclass() | ||
class FeatureParams: | ||
"""Dataclass to care feature parameters from configuration file""" | ||
categorical_features: List[str] | ||
features_to_drop: List[str] | ||
numerical_features: List[str] | ||
target: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from dataclasses import dataclass, field | ||
|
||
|
||
@dataclass() | ||
class ModelParams: | ||
"""Dataclass of model parameters from configuration file""" | ||
model: str | ||
random_state: int = field(default=42) | ||
n_estimators: int = field(default=100) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
"""Main dataclass to read predict configuration file""" | ||
|
||
from dataclasses import dataclass | ||
|
||
# import project modules | ||
from src.enities.feature_params import FeatureParams | ||
|
||
|
||
@dataclass() | ||
class PredictParams: | ||
"""Dataclass to care config params""" | ||
input_data_path: str | ||
input_model_path: str | ||
output_predictions_path: str | ||
feature_params: FeatureParams |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from dataclasses import dataclass, field | ||
|
||
|
||
@dataclass() | ||
class SplittingParams: | ||
"""Dataclass of splitting params from configuration file""" | ||
val_size: float = field(default=0.2) | ||
random_state: int = field(default=42) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Main dataclass to read train configuration file""" | ||
|
||
from dataclasses import dataclass | ||
|
||
# import project modules | ||
from src.enities.splitting_params import SplittingParams | ||
from src.enities.model_params import ModelParams | ||
from src.enities.feature_params import FeatureParams | ||
|
||
|
||
@dataclass() | ||
class TrainingParams: | ||
"""Dataclass to care config params""" | ||
input_data_path: str | ||
output_model_path: str | ||
splitting_params: SplittingParams | ||
model_params: ModelParams | ||
feature_params: FeatureParams |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Scripts to turn raw data into features for modeling""" | ||
import pandas as pd | ||
import numpy as np | ||
from sklearn.preprocessing import OneHotEncoder | ||
from sklearn.impute import SimpleImputer | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
|
||
|
||
# import project modules | ||
from src.enities.feature_params import FeatureParams | ||
|
||
|
||
def transform_numerical_feature(numerical_data) -> pd.DataFrame: | ||
numerical_pipeline = Pipeline( | ||
[ | ||
("impute", SimpleImputer(missing_values=np.nan, strategy="most_frequent")), | ||
] | ||
) | ||
numerical_pipeline.fit(numerical_data) | ||
numerical_pipeline.transform(numerical_data) | ||
return numerical_data | ||
|
||
|
||
def transform_categorical_feature(categorical_data) -> pd.DataFrame: | ||
categorical_pipeline = Pipeline( | ||
[ | ||
("impute", SimpleImputer(missing_values=np.nan, strategy="most_frequent")), | ||
("ohe", OneHotEncoder()), | ||
] | ||
) | ||
categorical_pipeline.fit(categorical_data) | ||
categorical_pipeline.transform(categorical_data) | ||
return pd.DataFrame(categorical_data) | ||
|
||
|
||
class Transformer(BaseEstimator, TransformerMixin): | ||
|
||
def __init__(self, feature_params: FeatureParams): | ||
self.numerical_features = list(feature_params.numerical_features) | ||
self.categorical_features = list(feature_params.categorical_features) | ||
self.data = None | ||
|
||
def fit(self, data: pd.DataFrame): | ||
self.data = data | ||
return self | ||
|
||
def transform(self) -> pd.DataFrame: | ||
self.data[self.numerical_features] = \ | ||
transform_numerical_feature(self.data[self.numerical_features]) | ||
self.data[self.categorical_features] = \ | ||
transform_numerical_feature(self.data[self.categorical_features]) | ||
return self.data | ||
|
||
|
||
def extract_target(data: pd.DataFrame, features_params: FeatureParams) -> pd.DataFrame: | ||
"""Extracts target, returns pd.Series of target feature""" | ||
target = data[features_params.target] | ||
return target |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Scripts to use trained models to make predictions""" | ||
import pickle | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import sklearn | ||
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier | ||
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score | ||
from typing import Dict, Union | ||
|
||
# import project modules | ||
from src.enities.training_params import TrainingParams | ||
|
||
# define | ||
SklearnClassifierModel = Union[RandomForestClassifier, GradientBoostingClassifier] | ||
|
||
|
||
def train_model(data: pd.DataFrame, target: pd.DataFrame, train_params: TrainingParams) -> SklearnClassifierModel: | ||
if train_params.model_params.model == "RF": | ||
model = sklearn.ensemble.RandomForestClassifier( | ||
n_estimators=train_params.model_params.n_estimators, | ||
random_state=train_params.model_params.random_state | ||
) | ||
elif train_params.model_params.model == "GB": | ||
model = sklearn.ensemble.GradientBoostingClassifier( | ||
n_estimators=train_params.model_params.n_estimators, | ||
random_state=train_params.model_params.random_state | ||
) | ||
else: | ||
raise NotImplementedError() | ||
model.fit(data, target) | ||
return model | ||
|
||
|
||
def predict_model(model: SklearnClassifierModel, data: pd.DataFrame) -> np.ndarray: | ||
predicted = model.predict(data) | ||
return predicted | ||
|
||
|
||
def evaluate(target_predicted: np.ndarray, target: pd.DataFrame) -> Dict[str, float]: | ||
return { | ||
"f1": f1_score(target, target_predicted), | ||
"precision": precision_score(target, target_predicted), | ||
"recall": recall_score(target, target_predicted), | ||
"accuracy": accuracy_score(target, target_predicted) | ||
} | ||
|
||
|
||
def serialize_model(model: SklearnClassifierModel, output_path: str) -> None: | ||
with open(output_path, "wb") as f: | ||
pickle.dump(model, f) | ||
|
||
|
||
def load_model(model_path: str) -> SklearnClassifierModel: | ||
model = pickle.load(open(model_path, 'rb')) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Scripts to predict""" | ||
|
||
import logging | ||
import sys | ||
import hydra | ||
from hydra.core.config_store import ConfigStore | ||
import pandas as pd | ||
|
||
# Import project modules | ||
from src.enities.predict_params import PredictParams | ||
from src.data.make_dataset import read_data | ||
from src.features.build_features import Transformer | ||
from src.model.predict_model import predict_model, load_model | ||
|
||
|
||
# ConfigStore for caring parameters from configuration | ||
# file in TrainingParams dataclass | ||
cs = ConfigStore.instance() | ||
cs.store(name="predict", node=PredictParams) | ||
|
||
|
||
# setting logger | ||
logger = logging.getLogger(__name__) | ||
handler = logging.StreamHandler(sys.stdout) | ||
logger.setLevel(logging.INFO) | ||
logger.addHandler(handler) | ||
logger.propagate = False | ||
|
||
|
||
def run_predict_pipeline(predict_params: PredictParams): | ||
# read data | ||
logger.info(f"Reading data from {predict_params.input_data_path}...") | ||
data = read_data(predict_params.input_data_path) | ||
|
||
# load model | ||
logger.info(f"Loading model from {predict_params.input_model_path}...") | ||
model = load_model(predict_params.input_model_path) | ||
|
||
# transform data | ||
logger.info(f"Transforming data...") | ||
transformer = Transformer(predict_params.feature_params) | ||
transformer.fit(data) | ||
data = transformer.transform() | ||
|
||
# predict | ||
logger.info(f"Getting predictions...") | ||
predictions = predict_model(model, data) | ||
|
||
# save prediction | ||
logger.info(f"Setting prediction to {predict_params.output_predictions_path}...") | ||
pd.DataFrame(predictions).to_csv(predict_params.output_predictions_path) | ||
|
||
|
||
@hydra.main(version_base=None, config_path="../../configs/.", config_name="predict_config") | ||
def predict_pipeline(config_params: PredictParams) -> None: | ||
"""Function to read terminal arguments""" | ||
run_predict_pipeline(config_params) | ||
|
||
|
||
if __name__ == "__main__": | ||
predict_pipeline() |
Oops, something went wrong.