In [21]:
import pandas as pd
import numpy as np
from sklearn import model_selection
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

In [28]:
class Trainer:

    def __init__(self, datapath:str, file_type:str ,feature_columns:str, labels_columns:str, na_strategy:str, na_columns:str, num_attri:str, cat_attri:str, model:str):
        self.datapath = datapath

        self.feature_columns = feature_columns

        self.labels_columns = labels_columns

        if file_type=="csv":
            self.file_type = file_type
        else:
            raise ValueError("file_type parameter should be csv")

        self.na_strategy = na_strategy
        self.na_columns = na_columns

        self.num_attri = num_attri
        self.cat_attri = cat_attri
        self.model = model

    def read_data(self):
        print("read_data")
        if self.file_type == "csv":
            df = pd.read_csv(self.datapath)
            return df
        raise ValueError("this code is only supports csv file as of now")

    def separate_data(self, df):
        print("read_data2")
        return df[self.feature_columns], df[self.labels_columns]

    def split_train_test(features, labels, test_size=0.2, random_state=62):
        print("read_data3")
        X_train, X_test, y_train, y_test = model_selection.train_test_split(
            features, labels, test_size=test_size, random_state=random_state
        )

        return X_train, X_test, y_train, y_test


    def data_cleaning(self,data:pd.DataFrame):
        print("read_data4")
        if self.na_strategy == "drop":
            if self.na_columns:
                data.dropna(subset=self.na_columns.split(","))
        elif self.na_strategy == "dropna":
            data.drop(self.na_columns, axis=1)

        return data

    def pipeline(self, data:pd.DataFrame, num_attribs:list, cat_attribs:list):
        print("read_data5")
        int_pipeline = []
        if self.na_strategy not in ['dropna', 'drop']:
            int_pipeline.append(('imputer', SimpleImputer(strategy=self.na_strategy)))

        int_pipeline.append(('std_scaler', StandardScaler()))

        num_pipeline = Pipeline(int_pipeline)


        full_pipeline = ColumnTransformer([
            ("num", num_pipeline, num_attribs),
            ("cat", OneHotEncoder(), cat_attribs),
        ])

        return full_pipeline.fit_transform(data)

    def prepare_data(self):
        print("read_data6")
        df = self.read_data()
        df = self.data_cleaning(df)

        data = self.pipeline(df, self.num_attri, self.cat_attri)

        x,y = self.separate_data(data)

        X_train, X_test, y_train, y_test = self.split_train_test(x.y)

        return X_train, X_test, y_train, y_test


    def model(self,):
        print("read_data7")
        if self.model == "linear_regression":
            return LinearRegression()

    def eval(self, model, x,y):
        print("read_data8")
        pred = model.predict(x)
        mse = mean_squared_error(y, pred)
        return np.sqrt(mse)

    def train(self):
        print("read_data9")
        X_train, X_test, y_train, y_test = self.prepare_data()
        model = self.model()
        model.fit(X_train, y_train)
        score = self.eval(model, X_test, y_test)

        print(score)








In [29]:
t = Trainer("/content/sample_data/california_housing_train.csv",\
             'csv', \
             ["housing_median_age","total_rooms","total_bedrooms","population","households","median_income"],\
             ["median_house_value"], \
             "median",\
             "",\
            ["housing_median_age,total_rooms,total_bedrooms,population,households,median_income", "median_house_value"],\
             [],\
             "linear_regression")



In [30]:
t.train()

read_data9
read_data6
read_data
read_data4
read_data5


ValueError: A given column is not a column of the dataframe