In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.base import clone

In [2]:
class OneR(object):
    
    def __init__(self):
        self.ideal_feature = None
        self.ideal_feature_index = 0
        self.max_accuracy = 0
        self.result = dict()
    
    def fit(self, X, y):
        response = list()
        
        dfx = pd.DataFrame(X)
        
        for feature_index, feature in enumerate(dfx):
            self.result[str(feature)] = dict()
            options_values = set(dfx[feature])
            join_data = pd.DataFrame({"variable":dfx[feature], "label":y})
            cross_table = pd.crosstab(join_data.variable, join_data.label)   

            summary = cross_table.idxmax(axis=1)
            self.result[str(feature)] = dict(summary)
    
            correct_answers = 0
            for idx, row in join_data.iterrows():
                if row['label'] == self.result[str(feature)][row['variable']]:
                    correct_answers += 1

            accuracy = (correct_answers/len(y))
            
            if accuracy > self.max_accuracy:
                self.max_accuracy = accuracy
                self.ideal_feature = feature
                self.ideal_feature_index = feature_index

            result_feature = {"feature": str(feature), "accuracy":accuracy, "rules": self.result[str(feature)] } 
            print(result_feature)
            response.append(result_feature)
            
        return response

    def predict(self, X):
        predict_result = []

        print(self.result)

        for item in X:
            value = self.result[str(self.ideal_feature)][item[self.ideal_feature_index]]
            predict_result.append(value)

        return predict_result

    def get_params(self, deep = False):
        return {}
           
    def __repr__(self):
        if self.ideal_feature != None:
            message = "Most accurate feature is: " + str(self.ideal_feature)
        else:
            message = "Cannot choose most accurate feature"
        return message

In [3]:
data = pd.read_csv('data/mushrooms.csv')
y_mush = data['class']

x_mush = data.drop("class", axis=1)

In [4]:
def cross_validation(df, clf):
    clone_classifier = clone(clf)
    df_train, df_test = train_test_split(data, test_size=0.3, random_state=77)

    y_train = df_train["class"].to_numpy()
    X_train = df_train.drop("class", axis=1).to_numpy()
    
    y_test = df_test["class"].to_numpy()
    X_test = df_test.drop("class", axis=1).to_numpy()
    clone_classifier.fit(X_train, y_train)
    labels_predict = clone_classifier.predict(X_test)
    n_correct = sum(labels_predict == y_test)
    return n_correct / len(labels_predict)

In [5]:
clf_mushrooms = OneR()
#results = clf_mushrooms.fit(x_mush, y_mush)

#print(results)
#print(clf_mushrooms)

#predicted_data = clf_mushrooms.predict(
#    [["x" ,"s", "g", "f", "n", "f", "w", "b", "k", "t", "e", "s", "s", "w", "w", "p", "w", "o", "e", "n", "a", "g"],
#    ["x", "y", "w", "t", "p", "f", "c", "n", "p", "e", "e", "s","s","w","w","p","w","o","p","k","v","g"]])
#print(f"predict = {predicted_data}")

cross_val = cross_validation(data, clf_mushrooms)
print(f"Accuracy = {cross_val}")


{'feature': '0', 'accuracy': 0.5589166373549068, 'rules': {'b': 'e', 'c': 'p', 'f': 'e', 'k': 'p', 's': 'e', 'x': 'e'}}
{'feature': '1', 'accuracy': 0.5798452339078438, 'rules': {'f': 'e', 'g': 'p', 's': 'p', 'y': 'p'}}
{'feature': '2', 'accuracy': 0.5946183608863876, 'rules': {'b': 'p', 'c': 'e', 'e': 'p', 'g': 'e', 'n': 'e', 'p': 'p', 'r': 'e', 'u': 'e', 'w': 'e', 'y': 'p'}}
{'feature': '3', 'accuracy': 0.7449876890608512, 'rules': {'f': 'p', 't': 'e'}}
{'feature': '4', 'accuracy': 0.9843475202251143, 'rules': {'a': 'e', 'c': 'p', 'f': 'p', 'l': 'e', 'm': 'p', 'n': 'e', 'p': 'p', 's': 'p', 'y': 'p'}}
{'feature': '5', 'accuracy': 0.5114315863524446, 'rules': {'a': 'e', 'f': 'e'}}
{'feature': '6', 'accuracy': 0.6178332747098135, 'rules': {'c': 'p', 'w': 'e'}}
{'feature': '7', 'accuracy': 0.7511431586352445, 'rules': {'b': 'e', 'n': 'p'}}
{'feature': '8', 'accuracy': 0.803904326415758, 'rules': {'b': 'p', 'e': 'e', 'g': 'p', 'h': 'p', 'k': 'e', 'n': 'e', 'o': 'e', 'p': 'e', 'r': 'p', 'u