# Test perfect training accuracy on iris 

In [3]:
import math

from rerf.rerfClassifier import rerfClassifier

# Import scikit-learn dataset library
from sklearn import datasets

# Import scikit-learn metrics module for accuracy calculation
from sklearn import metrics

# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wilcoxon.html
from scipy.stats import wilcoxon

# testing sklearn random forest
from sklearn.ensemble import RandomForestClassifier

from tqdm import tqdm

In [4]:
# Load dataset
iris = datasets.load_iris()

In [5]:
def iris_pred_acc(cls, n_iter=10000):
    y_train_acc_list = []
    for i in tqdm(range(n_iter)):
        clf.fit(iris.data, iris.target)
        y_pred_train = clf.predict(iris.data)
        y_train_acc_list.append(metrics.accuracy_score(iris.target, y_pred_train))
        
    return y_train_acc_list 

In [6]:
def print_pred_summ(acc_list):
    print(sum([math.isclose(yt, 1) for yt in acc_list]))
    # print("avg acc", sum(y_train_acc_list)/len(y_train_acc_list))
    print(sorted(acc_list)[0:5])

In [8]:
clf = rerfClassifier(n_estimators=100, projection_matrix="RerF")
rerf_acc = iris_pred_acc(clf, 10000)
print("RerF")
print_pred_summ(rerf_acc)

clf = rerfClassifier(n_estimators=100, projection_matrix="Base")
rf_acc = iris_pred_acc(clf, 10000)
print("RF")
print_pred_summ(rf_acc)

clf = RandomForestClassifier(n_estimators=100)
sklearn_acc = iris_pred_acc(clf, 10000)
print("sklearn")
print_pred_summ(sklearn_acc)

100%|██████████| 10000/10000 [01:18<00:00, 127.03it/s]
  0%|          | 38/10000 [00:00<00:52, 188.92it/s]

RerF
10000
[1.0, 1.0, 1.0, 1.0, 1.0]


100%|██████████| 10000/10000 [00:53<00:00, 185.46it/s]
  0%|          | 2/10000 [00:00<10:49, 15.39it/s]

RF
9962
[0.9933333333333333, 0.9933333333333333, 0.9933333333333333, 0.9933333333333333, 0.9933333333333333]


100%|██████████| 10000/10000 [10:42<00:00, 15.64it/s]

sklearn
9982
[0.9933333333333333, 0.9933333333333333, 0.9933333333333333, 0.9933333333333333, 0.9933333333333333]





In [13]:
print(wilcoxon(rerf_acc, rf_acc))
print(wilcoxon(rerf_acc, sklearn_acc))
print(wilcoxon(rf_acc, sklearn_acc))
print(wilcoxon(sorted(rf_acc), sorted(sklearn_acc)))

WilcoxonResult(statistic=0.0, pvalue=7.074463098970675e-10)
WilcoxonResult(statistic=0.0, pvalue=2.209049699858536e-05)
WilcoxonResult(statistic=513.0, pvalue=0.007526315166457887)
WilcoxonResult(statistic=0.0, pvalue=7.74421643104407e-06)
