In [1]:
import os
from mobilePriceRangePrediction.constants import *
from mobilePriceRangePrediction.utils.common import read_yaml, create_directories, save_object
from typing import List

In [2]:
os.chdir("../")

In [3]:
%pwd

'e:\\DataScienceProjects\\mobile-price-range-prediction'

In [4]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelTrainerConfig:
    root_dir: Path
    data_path: Path

In [5]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_model_trainer_config(self) -> ModelTrainerConfig:
        config = self.config.model_trainer

        create_directories([config.root_dir])

        model_trainer_config = ModelTrainerConfig(
            root_dir=config.root_dir,
            data_path=config.data_path
        )

        return model_trainer_config

In [28]:
from mobilePriceRangePrediction.utils.common import save_object ,load_object
from sklearn.metrics import f1_score

from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier, ExtraTreesClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from xgboost import XGBClassifier
from catboost import CatBoostClassifier

In [29]:
class ModelTrainer:
    def __init__(self, config: ModelTrainerConfig):
        self.config = config

    def train(self):
        ### Get processed data
        train_arr = load_object(self.config.data_path+'/train_arr.pkl')
        test_arr = load_object(self.config.data_path+'/test_arr.pkl')

        X_train, y_train = train_arr[:,:-1], train_arr[:,-1]
        X_test, y_test = test_arr[:,:-1], test_arr[:,-1]
        models = {
            'Logistic Regression': LogisticRegression(),
            'Decision Trees': DecisionTreeClassifier(),
            'Random Forest': RandomForestClassifier(),
            'Gradient Boosting Machines (GBM)': GradientBoostingClassifier(),
            'Support Vector Machines (SVM)': SVC(),
            'K-Nearest Neighbors (KNN)': KNeighborsClassifier(),
            'Naive Bayes': GaussianNB(),
            'AdaBoost': AdaBoostClassifier(),
            'Extra Trees Classifier': ExtraTreesClassifier(),
            'XGBoost': XGBClassifier(),
            'CatBoost': CatBoostClassifier(verbose=False)
        }

        model_report:dict = self.evaluate_models(
            X_train=X_train, y_train=y_train, 
            X_test=X_test, y_test=y_test,
            models=models
        )

        best_model_score = max(sorted(model_report.values()))
        best_model_name = list(model_report.keys())[
            list(model_report.values()).index(best_model_score)
        ]

        best_model = models[best_model_name]

        save_object(
            file_path=os.path.join(self.config.root_dir,'model.pkl'),
            obj = best_model
        )

    def evaluate_models(self, X_train, y_train, X_test, y_test, models):
        report = {}
            
        for name, model in models.items():
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            test_model_score = f1_score(y_test, y_pred, average='micro')
            report[name] = test_model_score

        return report


In [30]:
try:
    config = ConfigurationManager()
    model_trainer_config = config.get_model_trainer_config()
    model_trainer_config = ModelTrainer(config=model_trainer_config)
    model_trainer_config.train()
except Exception as e:
    raise e

[2024-03-25 19:33:17,072: INFO: common: yaml file: config\config.yaml loaded successfully.]
[2024-03-25 19:33:17,073: INFO: common: yaml file: params.yaml loaded successfully.]
[2024-03-25 19:33:17,074: INFO: common: Created directory at: artifacts]
[2024-03-25 19:33:17,075: INFO: common: Created directory at: artifacts/model]
[1. 2. 0. ... 2. 3. 1.]
[1. 2. 0. ... 2. 3. 1.]
