In [111]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris, load_digits

# Replace softmax/onehot implementation with this for performance increase
from sklearn.preprocessing import OneHotEncoder
from scipy.special import softmax
onehot_encoder = OneHotEncoder(sparse=False)

In [84]:
def softmax_row(Z):
    return [np.exp(row)/np.sum(np.exp(row)) for row in Z]

In [144]:
def onehot_encode(Y):
    N = Y.shape[0]
    Y_oh = np.zeros((N, np.max(Y)+1))
    for i in range(N):
        Y_oh[i][Y[i]] = 1
    return Y_oh

In [103]:
def loss(X, Y, W):
    """
    Y: onehot encoded
    """
    Z = - X @ W
    N = X.shape[0]
    loss = 1/N * (np.trace(X @ W @ Y.T) + np.sum(np.log(np.sum(np.exp(Z), axis=1))))
    return loss

In [155]:
def gradient(X, Y, W, mu):
    """
    Y: onehot encoded 
    """
    Z = - X @ W
    P = softmax(Z, axis=1)
    #P = softmax_row(Z)
    N = X.shape[0]
    print((Y - P).shape)
    print((X.T @ (Y - P)).shape)
    print(W.shape)
    print()
    gd = 1/N * (X.T @ (Y - P)) + 2 * mu * W
    return gd

In [145]:
def gradient_descent(X, Y, max_iter=1000, eta=0.1, mu=0.01):
    """
    Very basic gradient descent algorithm with fixed eta and mu
    """
    #Y_onehot = onehot_encoder.fit_transform(Y.reshape(-1,1))
    Y_onehot = onehot_encode(Y)
    W = np.zeros((X.shape[1], Y_onehot.shape[1]))
    step = 0
    step_lst = [] 
    loss_lst = []
    W_lst = []
 
    while step < max_iter:
        step += 1
        W -= eta * gradient(X, Y_onehot, W, mu)
        step_lst.append(step)
        W_lst.append(W)
        loss_lst.append(loss(X, Y_onehot, W))

    df = pd.DataFrame({
        'step': step_lst, 
        'loss': loss_lst
    })
    return df, W

In [141]:
class Multiclass:
    def fit(self, X, Y):
        self.loss_steps, self.W = gradient_descent(X, Y)

    def loss_plot(self):
        return self.loss_steps.plot(
            x='step', 
            y='loss',
            xlabel='step',
            ylabel='loss'
        )

    def predict(self, H):
        Z = - H @ self.W
        P = softmax(Z, axis=1)
        #P = softmax_row(Z)
        return np.argmax(P, axis=1)


In [142]:
X = load_iris().data
Y = load_iris().target
print(X.shape)
print(Y.shape)

(150, 4)
(150,)


In [156]:
# fit model
model = Multiclass()
model.fit(X, Y)

[[1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0.

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)


(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)
(4, 3)

(150, 3)
(4, 3)


In [135]:
# predict 
model.predict(X)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [158]:
model.predict(X).shape

(150,)

In [147]:
# check the predicted value and the actual value
model.predict(X) == Y

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True, False,  True,
       False,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True, False, False,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,

In [148]:
X2 = load_digits().data
Y2 = load_digits().target
print(X2.shape)
print(Y2.shape)

(1797, 64)
(1797,)


In [149]:
# fit model
model2 = Multiclass()
model2.fit(X2, Y2)

[[1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 1. 0.]]


In [150]:
# predict 
model2.predict(X2)

array([0, 1, 2, ..., 8, 9, 8])

In [151]:
# check the predicted value and the actual value
for r in model2.predict(X2) == Y2:
    print(r)

True
True
True
True
True
False
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True