In [None]:
import weka.core.jvm as jvm
jvm.start()

In [None]:
import numpy as np

from weka.classifiers import Classifier, Evaluation
from weka.core.converters import Loader
from weka.core.classes import Random

from weka.filters import Filter
from weka.classifiers import FilteredClassifier

import weka.plot.classifiers as plcls
import weka.plot.graph as graph  # NB: pygraphviz and PIL are required

import seaborn as sns
import matplotlib.pyplot as plt

## Load dataset

In [None]:
loader = Loader(classname="weka.core.converters.ArffLoader")
data = loader.load_file('datasets-UCI/UCI/breast-w.arff')
data.class_is_last()

## Data splitting

In [None]:
# 70% for training/validation set
remove_train_val = Filter(classname="weka.filters.unsupervised.instance.Resample", 
                          options=['-S', '1', '-Z', '70', '-no-replacement'])
remove_train_val.inputformat(data)
train_val_set = remove_train_val.filter(data)
# 30% for test set
remove_test = Filter(classname="weka.filters.unsupervised.instance.Resample", 
                     options=['-S', '1', '-Z', '70', '-no-replacement', '-V'])
remove_test.inputformat(data)
test_set = remove_test.filter(data)

In [None]:
print(data.num_instances)
print(train_val_set.num_instances)
print(test_set.num_instances)

### WEKA Decision Tree (J48 model)

In [None]:
top_mean = -999

for m in range(1,20):
    for c in np.arange(0.05, 1.0, 0.025):
        cls = Classifier(classname="weka.classifiers.trees.J48")
        cls.options = ['-C', str(c), '-M', str(m)]
        cls.build_classifier(train_val_set)
        evl = Evaluation(train_val_set)
        evl.crossvalidate_model(cls, train_val_set, 10, Random(1))
        acc = evl.percent_correct
        if acc > top_mean:
            top_mean = acc
            best_params = [c, m]
            best_cls = cls

# Print best results
print(top_mean)
# Print chosen parameters
print(best_params)

In [None]:
evl = Evaluation(test_set)
evl.test_model(best_cls, test_set)
evl.percent_correct

In [None]:
print(best_cls)
print(evl.summary())
print(evl.class_details())

In [None]:
def plot_confusion(best_cls, data, norm=True):
    evl = Evaluation(data)
    evl.crossvalidate_model(best_cls, data, 10, Random(1))
    cnf_matrix = evl.confusion_matrix
    if norm:
        cnf_matrix = cnf_matrix.astype('float')/cnf_matrix.sum(axis=1)[:,np.newaxis]
        sns.heatmap(cnf_matrix, annot=True, fmt=".2f", cmap="YlGnBu")
    else:
        sns.heatmap(cnf_matrix, annot=True, fmt="d", cmap="YlGnBu")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
plot_confusion(best_cls, test_set, norm=True)

In [None]:
plcls.plot_roc(evl, class_index=[0, 1], wait=True)

In [None]:
graph.plot_dot_graph(best_cls.graph)

In [None]:
jvm.stop()