# Initialization

In [None]:
from adaboost import *
import scipy.io
import numpy as np
import matplotlib.pyplot as plt

def unwrap(data):
    """
    Simple "hack" for preparing data from *.mat files
    """
    try:
        while (len(data) == 1) and (len(data.shape) > 0):
            data = data[0]
        for key in list(data.dtype.names):
            data[key] = unwrap(data[key])
    except:
        pass
    return data

# Data preparation

In [None]:
data = scipy.io.loadmat('data_33rpz_cv07.mat')

trn_data = unwrap(data['trn_data'])
trn_images = trn_data[0]
trn_labels = np.squeeze(trn_data[1])

tst_data = unwrap(data['tst_data'])
tst_images = tst_data[0]
tst_labels = np.squeeze(tst_data[1])

print('training data shapes:')
print(trn_images.shape)
print(trn_labels.shape)

print('test data shapes:')
print(tst_images.shape)
print(tst_labels.shape)

In [None]:
digit = 6
N_trn = trn_labels.size
X_trn = np.reshape(trn_images, (-1, N_trn))
y_trn = -np.ones(N_trn)
y_trn[trn_labels == digit] = 1

N_tst = tst_labels.size
X_tst = np.reshape(tst_images, (-1, N_tst))
y_tst = -np.ones(N_tst)
y_tst[tst_labels == digit] = 1

# AdaBoost training

In [None]:
N_steps = 30
classifier, wc_error, upper_bound = adaboost(X_trn, y_trn, N_steps)

## Compute errors and visualize

In [None]:
trn_errors = compute_error(classifier, X_trn, y_trn)
tst_errors = compute_error(classifier, X_tst, y_tst)
min_iter = np.argmin(tst_errors)
min_err = tst_errors[min_iter]
print('minimal test error {}, achieved at iteration #{}'.format(min_err, min_iter))

In [None]:
plt.plot(upper_bound, label='Upper bound')
plt.plot(wc_error, label='WC error')
plt.plot(trn_errors, label='Training error')
plt.plot(tst_errors, label='Test error')
plt.xlabel('training step')
plt.ylabel('error')
plt.grid()
plt.legend()
plt.savefig('error_evolution.png')

## Classify images and visualize

In [None]:
classif = adaboost_classify(classifier, X_tst)
show_classification(tst_images, classif)
plt.savefig('classification.png')

In [None]:
show_classifiers(trn_images[..., y_trn==1], classifier)
plt.savefig('weak_classifiers.png')