/
train.py
55 lines (44 loc) · 1.47 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from pathlib import Path
import hydra
import joblib
import pandas as pd
from helpers import load_data
from omegaconf import DictConfig
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
def create_pipeline() -> Pipeline:
return Pipeline([("scaler", StandardScaler()), ("svm", SVC())])
def train_model(
X_train: pd.DataFrame,
y_train: pd.Series,
pipeline: Pipeline,
hyperparameters: dict,
grid_params: dict,
) -> GridSearchCV:
"""Train model using GridSearchCV"""
grid_search = GridSearchCV(pipeline, dict(hyperparameters), **grid_params)
grid_search.fit(X_train, y_train)
return grid_search
def save_model(model, path: str):
"""Save model to path"""
Path(path).parent.mkdir(exist_ok=True)
joblib.dump(model, path)
@hydra.main(config_path="../config", config_name="main", version_base="1.2")
def train(config: DictConfig) -> None:
"""Train model and save it"""
print("Training the model...")
X_train = load_data(f"{config.data.processed.dir}/X_train.pkl")
y_train = load_data(f"{config.data.processed.dir}/y_train.pkl")
pipeline = create_pipeline()
grid_search = train_model(
X_train,
y_train,
pipeline,
config.train.hyperparameters,
config.train.grid_search,
)
save_model(grid_search.best_estimator_, config.model_path)
if __name__ == "__main__":
train()