## Classification using SVMs

1. Read saved feature matrix and corresponding labels

In [101]:
import pickle
from os.path import join

region = 'curated'
pickle_path = '.'
with open(join(pickle_path, 'resnet50_feature_matrix_' + region + '.pkl'), 'rb') as f:
    resnet50_feature_matrix = pickle.load(f)
with open(join(pickle_path, 'labels_' + region + '.pkl'), 'rb') as f:
    labels = pickle.load(f)

2. Split data into train and validation set

In [102]:
from sklearn import svm
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.utils.multiclass import unique_labels

In [103]:
features_train, features_test, labels_train, labels_test = train_test_split(resnet50_feature_matrix,
                                                                            labels,
                                                                            test_size=0.33,
                                                                            random_state=43)

3. Train a classifier on the training set

In [104]:
clf = svm.SVC(gamma='scale', probability=True, C=100, decision_function_shape='ovr')
clf.fit(features_train, labels_train)

SVC(C=100, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
    max_iter=-1, probability=True, random_state=None, shrinking=True, tol=0.001,
    verbose=False)

4. Predict labels on the validation set according to the classifier

In [130]:
predicted_labels = clf.predict(features_test)

5. Calculate and plot (to do) the confusion matrix

In [132]:
cm = confusion_matrix(labels_test, predicted_labels)
print(cm)

[[17  0  0]
 [ 0  7  0]
 [ 2  0  6]]


In [75]:
pred_probas = clf.predict_proba(features_test)

for i in range(len(predictions)):
    print(pred_probas[i], labels_test[i])

[0.89986338 0.00837999 0.09175664] 0
[0.89033023 0.03628888 0.07338089] 0
[0.90051792 0.04183361 0.05764847] 0
[0.96626236 0.01001408 0.02372356] 0
[0.06135577 0.75839479 0.18024944] 1
[0.27708945 0.67796986 0.04494069] 1
[0.06874874 0.84765499 0.08359628] 1
[0.09206001 0.80910635 0.09883365] 1
[0.51292847 0.37310103 0.1139705 ] 0
[0.1786006  0.02395043 0.79744897] 2
[0.6964532  0.18571727 0.11782953] 0
[0.9219034  0.02168738 0.05640923] 0
[0.2721947  0.09265559 0.63514971] 2
[0.59463588 0.30523119 0.10013292] 0
[0.0592052  0.01061461 0.93018019] 2
[0.71731258 0.23392455 0.04876287] 0
[0.91635354 0.01524735 0.06839912] 0
[0.98610568 0.00670358 0.00719074] 0
[0.04432544 0.93136588 0.02430868] 1
[0.06431939 0.0258337  0.90984691] 2
[0.97024448 0.00759053 0.02216499] 0
[0.93516162 0.00911661 0.05572178] 0
[0.07389778 0.91993373 0.00616848] 1
[0.44354156 0.05972512 0.49673332] 2
[0.97099561 0.00886425 0.02014014] 0
[0.14034522 0.02270541 0.83694937] 2
[0.03357229 0.06528839 0.90113931] 2
[