In [109]:
import h5py
import numpy as np
from numpy.random import default_rng
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm

from smml.kernel import GaussianKernel

In [90]:
PATH = '../../datasets/usps/usps.h5'
with h5py.File(PATH, 'r') as hf:
        train = hf.get('train')
        X_tr = train.get('data')[:]
        y_tr = train.get('target')[:]
        test = hf.get('test')
        X_te = test.get('data')[:]
        y_te = test.get('target')[:]

In [104]:
class Pegasos:
    def __init__(self, l=0.5, T=1000, K=GaussianKernel(), seed=42):
        self.l = l
        self.T = T
        self.K = K
        self.seed = seed

    def fit(self, X : np.ndarray, y : np.ndarray):
        if (self.l <= 0): 
            raise ValueError('Parameter lambda is not strictly positive')

        self.X_train = X
        self.y_train = y
        self.alphas = np.zeros(X.shape[0])
        rng = default_rng(self.seed)

        for t in tqdm(range(1, self.T + 1)):
            i = rng.integers(X.shape[0])
            s = np.sum(
                [self.alphas[j] * y[j] * self.K(X[j], X[i]) 
                 for j in range(X.shape[0])])
            if (y[i] / (self.l * t)) * s < 1:
                self.alphas[t] += 1

    def predict(self, X):
        #TODO: check if we can remove eta, sign should not change

        return np.array(
            [np.sign((1 / (self.l * self.T)) 
                     * np.sum([self.alphas[j] * self.y_train[j] * self.K(self.X_train[j], x) 
                               for j in range(self.X_train.shape[0])])) for x in X])

In [None]:
class MulticlassPegasos:
    pass
    # should train multiple pegasos etc

In [91]:
y_tr

array([6, 5, 4, ..., 3, 0, 1])

In [97]:
y_tr_0 = np.where(y_tr == 0, 1, -1)

In [105]:
pegasos = Pegasos()

In [106]:
pegasos.fit(X_tr, y_tr_0)

100%|██████████| 1000/1000 [01:17<00:00, 12.86it/s]


In [107]:
y_pred = pegasos.predict(X_te)

In [108]:
y_pred

array([-1., -1., -1., ..., -1.,  1., -1.])

In [110]:
y_te_0 = np.where(y_te == 0, 1, -1)

In [111]:
accuracy_score(y_pred, y_te_0)

0.9825610363726955