In [1]:
import numpy as np
import numpy.linalg as lia
import pandas as pd
import matplotlib as plt

In [2]:
from sklearn import datasets
digits = datasets.load_digits()

In [3]:
print(len(digits.data))
print(digits.target.size)

1797
1797


In [4]:
from sklearn.datasets import fetch_openml
wine = fetch_openml(name='wine', version=1)

In [5]:
print(len(wine.data))
print(wine.target.size)

178
178


In [6]:
# normalization of wine data via division of each feature by its max value

wine_data_norm = []
for col in wine.data.T:
    col_norm = col/np.amax(col)
    wine_data_norm.append(col_norm)
    
wine.data = np.asarray(wine_data_norm).T

In [7]:
# 5-fold cross validation for digits dataset

digitsTrainingSetSize = int(np.ceil(0.8 * len(digits.data)))
digitsValidationSetSize = int(len(digits.data) - digitsTrainingSetSize)

xDigitsTrainingSets = []
yDigitsTrainingSets = []
xDigitsValidationSets = []
yDigitsValidationSets = []

for foldIndex in range(5):

    xValidationSet = []
    yValidationSet = []

    for index, data in enumerate(digits.data[foldIndex*digitsValidationSetSize:((foldIndex*digitsValidationSetSize)+digitsValidationSetSize)]):
        xValidationSet.append(data.tolist())
        yValidationSet.append(digits.target[index+(foldIndex*digitsValidationSetSize)])
    
    xTrainingSet = []
    yTrainingSet = []

    for index, data in enumerate(digits.data.tolist()):
        if data not in xValidationSet:
            xTrainingSet.append(data)
            yTrainingSet.append(digits.target[index])
            
    xDigitsTrainingSets.append(xTrainingSet)
    yDigitsTrainingSets.append(yTrainingSet)
    xDigitsValidationSets.append(xValidationSet)
    yDigitsValidationSets.append(yValidationSet)
    
# 5-fold cross validation for wine dataset

wineTrainingSetSize = int(np.ceil(0.8 * len(wine.data)))
wineValidationSetSize = int(len(wine.data) - wineTrainingSetSize)

xWineTrainingSets = []
yWineTrainingSets = []
xWineValidationSets = []
yWineValidationSets = []

for foldIndex in range(5):

    xValidationSet = []
    yValidationSet = []
    for index, data in enumerate(wine.data[foldIndex*wineValidationSetSize:((foldIndex*wineValidationSetSize)+wineValidationSetSize)]):
        xValidationSet.append(data.tolist())
        yValidationSet.append(wine.target[index+(foldIndex*wineValidationSetSize)])
    
    xTrainingSet = []
    yTrainingSet = []
    
    for index, data in enumerate(wine.data.tolist()):
        if data not in xValidationSet:
            xTrainingSet.append(data)
            yTrainingSet.append(wine.target[index])
            
    xWineTrainingSets.append(xTrainingSet)
    yWineTrainingSets.append(yTrainingSet)
    xWineValidationSets.append(xValidationSet)
    yWineValidationSets.append(yValidationSet)

In [8]:
# one-hot encoding of y for digits dataset

numberOfDigitsTargets = 10
numberOfWineTargets = 3

for index, fold in enumerate(yDigitsTrainingSets):
    encodedFold = []
    for i, y in enumerate(fold):
        encoding = np.zeros(numberOfDigitsTargets)
        encoding[y] = 1
        encodedFold.append(encoding.tolist())
    yDigitsTrainingSets[index] = encodedFold
    
for index, fold in enumerate(yDigitsValidationSets):
    encodedFold = []
    for i, y in enumerate(fold):
        encoding = np.zeros(numberOfDigitsTargets)
        encoding[y] = 1
        encodedFold.append(encoding.tolist())
    yDigitsValidationSets[index] = encodedFold

# one-hot encoding of y for wine dataset

for index, fold in enumerate(yWineTrainingSets):
    encodedFold = []
    for i, y in enumerate(fold):
        encoding = np.zeros(numberOfWineTargets)
        encoding[int(y)-1] = 1
        encodedFold.append(encoding.tolist())
    yWineTrainingSets[index] = encodedFold
    
for index, fold in enumerate(yWineValidationSets):
    encodedFold = []
    for i, y in enumerate(fold):
        encoding = np.zeros(numberOfWineTargets)
        encoding[int(y)-1] = 1
        encodedFold.append(encoding.tolist())
    yWineValidationSets[index] = encodedFold


In [9]:
def getRandomIndices(arr, batch_size):
    indices = []
    
    if batch_size > len(arr):
        print("Error: batch size larger than size of dataset.")
        return
    
    while batch_size > 0:
        x = np.floor(np.random.random() * len(arr))
        if x not in indices:
            indices.append(int(x))
            batch_size -= 1
    
    return indices

In [10]:
# gradient descent class
 
class GradientDescent:
    
    def __init__(self, batch_size, learning_rate=0.01, momentum=0.9, max_iters=20, epsilon=1e-8):
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.batch_size = batch_size
        self.max_iters = max_iters
        self.epsilon = epsilon
        self.deltas = [] # delta w{t}
        
    def run(self, gradient_fn, x, y, w):
        # TODO: use epsilon
        grad = np.inf
        t = 1
        while np.linalg.norm(grad) > self.epsilon and t < self.max_iters:
            print(f'Grad descent step {t}')            
            gradients = gradient_fn(x, y, w, self.batch_size)
            for c in range(len(y[0])):
                wc = []
                if(c==0):
                    w[c] = w[c-1] - self.learning_rate * gradients[c]
                else:
                    delta_w = (self.momentum)*(self.deltas[-1]) + (1-self.momentum)*gradients[c]
                    w[c] = w[c-1] - self.learning_rate * delta_w
                if(c != 0):
                    w_prev = self.deltas[-1]
                self.deltas.append(w[c])
                print(f'=============\n{c} : {w[c]}\n')
            t += 1
            
        return w


In [11]:
# logistic regression

class LogisticRegression:
    def __init__(self, add_bias=True):
        self.add_bias = add_bias
        pass
            
    def fit(self, x, y, optimizer):
        # TODO: add bias
        
        def gradient(x, y, w, batch_size):
            gradients = np.zeros(len(w)).tolist()

            indices = getRandomIndices(x, batch_size)

            for index in indices:
                a = np.asarray(x[index])
                b = np.asarray(y[index])
                
#                 print("x:", a.astype(int))
#                 print("y:", b)

                for c in range(len(b)):
                    w_x =  w[c] @ a
                    num = np.exp(w_x)

                    den = 0
                    for i in range(len(b)):
                        w_x =  w[i] @ a
                        den += np.exp(w_x)

                    yh_c = num/den

                    y_c = b[c]
                    
                    # TODO: may change, see slide 27 of logistic slideshow
                    cost_c = np.dot(yh_c - y_c, a)
                
                    gradients[c] += cost_c

            return gradients
        
        w0 = []
        for c in range(len(y[0])):
            w0.append(np.zeros(len(x[0])))
            
        self.w = optimizer.run(gradient, x, y, w0)
        return self
    
    def predict(self, x):
        a = np.asarray(x)
        b = np.asarray(self.w)
        
#         if self.add_bias:
#             x = np.column_stack([x,np.ones(N)])

        yh=[]
        for i,x_c in enumerate(a):
            yh_x=[]
            for c in range(len(b)):
                w_x =  b[c] @ x_c
                num = np.exp(w_x)

                den = 0
                for i in range(len(b)):
                    w_x =  b[i] @ x_c
                    den += np.exp(w_x)

                yh_c = num/den
                yh_x.append(yh_c)
            yh.append(yh_x)
        return yh

In [12]:
def cost(yh, y):
    return y * np.log1p(np.exp(-yh)) + (1-yh) * np.log1p(np.exp(yh))

# TODO: grid-search to find lowest cost combination of model hyper-parameters

batch_size = 10
learning_rate = 0.01
momentum = 0.9

digits_cost = 0

for fold_index, fold in enumerate(xDigitsTrainingSets):
    gradientDescentModel = GradientDescent(batch_size, learning_rate, momentum)
    logisticRegressionModel = LogisticRegression(False)
    
    logisticRegressionModel.fit(fold, yDigitsTrainingSets[fold_index], gradientDescentModel)
    yh = logisticRegressionModel.predict(xDigitsValidationSets[fold_index])
    
    for sample_index, yh_x in enumerate(yh):
        c = np.argmax(yDigitsValidationSets[fold_index][sample_index])
        cst = cost(yh_x[c], yDigitsValidationSets[fold_index][sample_index][c])
        digits_cost += cst
        print(yh_x, yDigitsValidationSets[fold_index][sample_index])
        print(cst)

wine_cost = 0

for fold_index, fold in enumerate(xWineTrainingSets):
    gradientDescentModel = GradientDescent(batch_size, learning_rate, momentum)
    logisticRegressionModel = LogisticRegression(False)
    
    logisticRegressionModel.fit(fold, yWineTrainingSets[fold_index], gradientDescentModel)
    yh = logisticRegressionModel.predict(xWineValidationSets[fold_index])
    
    for sample_index, yh_x in enumerate(yh):
        c = np.argmax(yWineValidationSets[fold_index][sample_index])
        cst = cost(yh_x[c], yWineValidationSets[fold_index][sample_index][c])
        wine_cost += cst
        print(yh_x, yWineValidationSets[fold_index][sample_index])
        print(cst)

print("Model hyper-parameters:")
print("\tMini-batch size:", batch_size)
print("\tLearning rate:", learning_rate)
print("\tMomentum:", momentum)
print("Digits total cost:", digits_cost)
print("Wine total cost:", wine_cost)



#xWineTrainingSets = []
#yWineTrainingSets = []
#xWineValidationSets = []
#yWineValidationSets = []

Grad descent step 1
0 : [ 0.00000000e+00 -7.00000000e-03 -8.60000000e-02 -1.96000000e-01
 -2.78000000e-01 -2.09000000e-01 -8.90000000e-02 -1.60000000e-02
  0.00000000e+00 -5.20000000e-02 -1.38000000e-01 -2.44000000e-01
 -1.17000000e-01 -1.17000000e-01 -8.90000000e-02 -1.00000000e-02
  0.00000000e+00 -4.20000000e-02 -1.82000000e-01 -2.09000000e-01
 -9.20000000e-02 -5.00000000e-02 -2.00000000e-03 -4.00000000e-03
  0.00000000e+00 -8.00000000e-03 -1.44000000e-01 -2.15000000e-01
 -2.37000000e-01 -6.10000000e-02 -5.00000000e-03  0.00000000e+00
  0.00000000e+00 -2.20000000e-02 -8.30000000e-02 -2.82000000e-01
 -3.76000000e-01 -1.49000000e-01 -5.30000000e-02  0.00000000e+00
  0.00000000e+00 -1.70000000e-02 -4.70000000e-02 -2.63000000e-01
 -2.51000000e-01 -4.40000000e-02 -4.60000000e-02  0.00000000e+00
  0.00000000e+00  6.00000000e-03  3.99680289e-17 -1.81000000e-01
 -1.35000000e-01 -9.30000000e-02 -9.60000000e-02 -1.10000000e-02
  0.00000000e+00 -4.00000000e-03 -9.80000000e-02 -2.12000000e-01
 


1 : [ 0.00000000e+00 -1.02889670e-01 -8.17593895e-01 -4.78487591e-01
 -7.76460455e-01 -1.12609717e+00 -2.37655126e-01 -8.35835323e-03
  2.59150918e-05 -5.74561165e-01 -3.36724295e-01  1.27161893e-02
 -5.46431007e-02  1.36821363e-01 -3.94922003e-01 -7.91169823e-02
  4.40658014e-05  3.60648480e-03  5.84437294e-01 -3.30655021e-01
 -1.62687462e+00  6.63937519e-01  1.45997394e-01 -7.55726100e-02
  0.00000000e+00  4.31070830e-01  6.76680145e-01 -1.81366013e+00
 -2.52116736e+00 -1.39141229e-02  5.08381257e-01 -8.18168949e-03
  0.00000000e+00  6.76839949e-01  6.04946944e-01 -1.87899534e+00
 -2.33805397e+00 -5.68805362e-01  6.33184825e-01  0.00000000e+00
 -2.01141508e-05  3.24103017e-01  1.49195464e+00 -1.16254797e+00
 -1.09589998e+00  2.82858427e-01  2.75576053e-01 -2.05454387e-05
  4.78637099e-05 -6.53741823e-02  6.14623764e-01 -2.27042034e-01
  2.34932740e-03  4.91588785e-01 -4.05507170e-01 -3.70769370e-04
  0.00000000e+00 -9.37338064e-02 -1.02854526e+00 -2.92890041e-01
 -2.13741226e-02 -2.

0 : [ 0.00000000e+00 -6.76425438e-02 -4.78980530e-02  1.02183546e+00
  5.53723002e-01 -5.06167107e-01 -1.57586308e-01 -4.22593399e-03
  1.36389903e-05 -2.71077049e-01  1.11357365e+00  1.18227519e+00
  1.00846662e+00  1.31665635e+00 -9.36938685e-02 -4.17802996e-02
  2.31916229e-05  3.29546415e-01  1.69796311e+00  1.12508687e-01
 -1.02456592e+00  1.55940296e+00  5.76647414e-01 -3.97923399e-02
  0.00000000e+00  7.84452156e-01  1.66553324e+00 -1.09951415e+00
 -1.78231708e+00  6.86539472e-01  1.04256666e+00 -4.30598449e-03
  0.00000000e+00  1.06180463e+00  1.61192366e+00 -1.10004154e+00
 -1.59191707e+00  3.21444084e-01  1.08484684e+00  0.00000000e+00
 -8.56012074e-05  6.29650414e-01  2.26607454e+00 -6.55605179e-01
 -5.94554055e-01  1.27556595e+00  6.57540075e-01 -4.94977917e-03
  2.51904442e-05  9.38774071e-02  1.63696967e+00  9.73816538e-01
  1.12985474e+00  1.45906608e+00 -2.35215016e-01 -6.38049152e-03
  0.00000000e+00 -5.88377864e-02 -1.96400121e-01  1.12921541e+00
  1.00714163e+00 -3.1

[1.4758378666602187e-17, 2.1441650758165603e-12, 2.5586165172647567e-09, 1.324074286644029e-05, 0.0010898320785223683, 0.8527308905025275, 0.14616603411532297, 2.205393763949952e-31, 2.3156636272760005e-29, 3.647404480755481e-26] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
1.3862943611198906
[1.051905993804937e-17, 1.0511124740729565e-12, 3.289235162520803e-09, 1.2330616395291812e-05, 0.0006535136961073738, 0.8332503022259443, 0.16608385017126667, 1.2981205818487265e-30, 8.14181288015284e-29, 1.3745203313914371e-25] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
1.3862943611198906
[1.8320490104171723e-16, 2.0555931245321515e-11, 1.5514833399034207e-08, 1.1127456962363415e-05, 0.0011436911997888843, 0.9800407252074637, 0.018804440600395555, 5.06820214863796e-40, 1.3166244953983573e-38, 4.5505858469286863e-35] [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198904
[9.196370915343514e-24, 1.8816836971089212e-16, 2.898980024323401e-12, 1.4580477646631753e-07, 0.00

3 : [ 0.00000000e+00  2.38333783e-03  4.66437834e-01  9.52206652e-01
  8.06474451e-01  1.16033618e-01 -2.33809026e-02 -2.43862227e-03
  0.00000000e+00  2.25352537e-01  1.11065373e+00  8.52426951e-01
  6.98351289e-01  8.75736069e-01  3.39473575e-02 -9.15729532e-04
  0.00000000e+00  4.15801304e-01  1.06972744e+00  7.07065766e-02
  2.64451597e-02  8.99878701e-01  2.64572920e-01 -1.72218063e-05
 -6.36127898e-04  5.48759668e-01  8.75466266e-01 -1.00661625e-01
 -1.35378192e-01  5.80879549e-01  4.76523739e-01  0.00000000e+00
  0.00000000e+00  5.56423306e-01  8.74899662e-01 -9.06180262e-02
 -1.23441709e-01  5.51041863e-01  5.63651714e-01  0.00000000e+00
  0.00000000e+00  3.33781570e-01  1.01501754e+00 -6.69092265e-02
  1.43707170e-02  7.78964561e-01  4.95591741e-01 -1.34871272e-03
  0.00000000e+00  3.32393406e-02  1.07235137e+00  5.87270312e-01
  6.67860378e-01  1.00327608e+00  2.46466150e-01 -1.20269429e-03
  0.00000000e+00  2.58408907e-03  4.05283745e-01  1.04988965e+00
  9.36178644e-01  4.1

0 : [ 0.00000000e+00 -8.83021194e-02 -2.82732851e-01  1.12950898e-01
 -5.32889414e-02 -1.05603549e+00 -4.25099865e-01 -1.19988985e-01
  7.51683551e-05  1.78536402e-02  7.05531043e-01  4.28893433e-01
 -2.36301363e-01  3.39853577e-01 -5.21521836e-01 -9.82512178e-02
  3.75841776e-05  3.20862191e-01  8.75488527e-01 -5.22754719e-01
 -1.34444032e+00  3.00030818e-01 -5.68778751e-02 -1.39807483e-02
 -3.40899687e-04  6.62032489e-01  8.36511715e-01 -1.61928748e+00
 -2.37000166e+00 -1.48799302e-01  7.01270549e-01  0.00000000e+00
  0.00000000e+00  9.55367322e-01  8.28570939e-01 -1.72409746e+00
 -2.01528906e+00 -1.33613752e-01  8.30015550e-01  0.00000000e+00
  0.00000000e+00  4.26369933e-01  1.25569079e+00 -8.16347060e-01
 -1.42554382e+00  4.44654293e-01  6.49100151e-01 -4.28896971e-04
  0.00000000e+00 -3.02839542e-02  1.14839705e+00  4.73842951e-01
 -6.16100783e-02  4.25903206e-01 -1.12782348e-01 -2.86378283e-02
  0.00000000e+00 -7.48391413e-02 -3.25390780e-01  2.82801408e-01
 -5.55581135e-02 -4.9


3 : [ 0.00000000e+00 -8.97133090e-02 -1.01747396e+00 -1.49793551e+00
 -1.93029270e+00 -1.84385069e+00 -4.62777647e-01 -1.68744546e-01
  3.17406259e-05 -2.71534976e-01 -8.97164703e-01 -1.53082351e+00
 -1.08218349e+00 -5.66515323e-01 -5.67597449e-01 -8.76752513e-02
  2.13633840e-05  1.23792647e-01 -4.87694561e-01 -1.58155965e+00
 -2.14812128e+00 -1.78976779e-01 -1.26990034e-01 -8.24336012e-03
 -1.80143141e-04  5.15136556e-01 -6.60992914e-01 -2.55287313e+00
 -3.69385826e+00 -7.15643433e-01  5.22002573e-01 -6.53367332e-06
  0.00000000e+00  3.92506072e-01 -5.48173730e-01 -3.29471940e+00
 -4.02069390e+00 -1.53765616e+00  6.63218914e-01  0.00000000e+00
 -6.40082559e-06  1.14388949e-01 -3.01737623e-01 -2.40965965e+00
 -2.43597063e+00 -6.53215790e-01 -1.70492844e-01 -1.66008978e-04
  0.00000000e+00 -1.09717934e-01 -2.34112112e-01 -1.16963014e+00
 -1.19425422e+00 -2.92533681e-01 -1.16120291e+00 -3.45629271e-02
  0.00000000e+00 -6.42215680e-02 -1.15903527e+00 -1.18709088e+00
 -1.62784946e+00 -1.

1.386273988111109
[5.279167970174091e-40, 6.502093593799277e-37, 6.638259708070351e-33, 9.947543911556432e-27, 9.592598645590579e-23, 2.3766145334188087e-20, 5.432556629490722e-13, 7.183663062507746e-08, 0.00036643946575201116, 0.9996334886970741] [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[1.7288113516394162e-43, 2.738992226591962e-40, 7.758970966044482e-36, 1.7110200349088728e-29, 2.8634948545715e-24, 1.0522337505989027e-21, 3.8934190109760656e-14, 3.65979442519315e-08, 0.00040305940029051974, 0.9995969040017263] [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[1.0549716321973736e-42, 1.599685479580703e-39, 1.317319068707428e-35, 1.6587675997357108e-29, 5.432000463733935e-25, 2.5923271116856274e-22, 1.1669190295510024e-14, 1.3562800817762309e-08, 0.00020474633437192886, 0.9997952401028156] [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[1.7312858621888395e-34, 5.191036431112822e-32, 1.1462583235973902e-28, 1.3815


7 : [ 0.00000000e+00 -4.40831941e-02 -4.93622164e-01 -1.19678143e+00
 -1.59765365e+00 -6.38395350e-01 -2.93361780e-01 -4.46692881e-02
  1.51197107e-05 -2.50904949e-01 -9.96100097e-01 -1.73856609e+00
 -1.33062095e+00 -4.01137531e-01 -2.29588206e-01 -2.33074963e-03
  0.00000000e+00 -3.29791488e-01 -1.14689012e+00 -1.54407001e+00
 -1.38553162e+00 -2.64484463e-01 -8.48736239e-02  9.99999949e-04
  0.00000000e+00  1.11689118e-01 -1.00431244e+00 -2.00535838e+00
 -2.04443215e+00 -5.15878197e-01  1.22538808e-01  0.00000000e+00
  0.00000000e+00  1.29698715e-02 -7.76012809e-01 -2.19285913e+00
 -2.33116247e+00 -1.22416907e+00 -1.50042363e-01  0.00000000e+00
 -6.65543077e-04 -5.28963961e-02 -5.21126771e-01 -1.83356871e+00
 -1.33765330e+00 -1.00219251e+00 -4.11028312e-01 -7.30094685e-03
  0.00000000e+00 -5.80016192e-02 -2.21177085e-01 -1.45471718e+00
 -1.11914184e+00 -1.32587596e+00 -6.18996262e-01 -2.38522629e-02
  0.00000000e+00 -8.57514418e-02 -6.59547960e-01 -1.25306234e+00
 -1.79537729e+00 -1.


8 : [ 0.00000000e+00 -1.95192634e-02  1.23714504e-02  3.30207856e-01
  5.84450679e-03 -1.88485628e-01 -2.16334065e-01 -4.80026495e-02
  9.19590503e-06 -1.53807378e-02  4.60065994e-01  9.32125200e-02
  1.23619302e-01  8.01356797e-01 -1.13380832e-01 -1.43684282e-02
 -4.22701197e-05  2.09639522e-01  5.07778876e-01 -5.17155270e-01
 -8.33338697e-01  9.83932060e-01  2.27202848e-01 -8.50728001e-04
 -2.11350598e-05  5.67628011e-01  4.26765579e-01 -1.20115400e+00
 -1.42580869e+00  5.11471668e-01  6.45238511e-01  0.00000000e+00
  0.00000000e+00  5.22561137e-01  5.57613313e-01 -1.41123022e+00
 -1.58997095e+00  4.87498811e-02  5.82917690e-01  0.00000000e+00
 -4.04787569e-04  2.16011573e-01  8.95689178e-01 -9.45685397e-01
 -6.87604773e-01  2.89836405e-01  3.48668293e-01 -4.54088914e-03
  0.00000000e+00 -6.68847837e-03  9.17229614e-01 -3.00256439e-02
  8.83713613e-02  2.50009030e-01 -3.17938721e-02 -1.87770492e-02
  0.00000000e+00 -6.01813883e-02 -1.18578522e-01  3.20930662e-01
 -1.93278848e-02 -3.


9 : [ 0.00000000e+00 -2.69621070e-02  2.18701811e-01  1.40464236e+00
  1.04913034e+00 -7.26442184e-02 -1.87446694e-01 -1.82659245e-02
  5.15594605e-06  1.47720369e-02  1.35528953e+00  1.16056727e+00
  9.23367976e-01  1.59328207e+00 -6.19992923e-02 -5.41534141e-03
 -2.36999464e-05  5.28865187e-01  1.65913768e+00  1.96576356e-01
 -7.75718390e-01  1.86595123e+00  6.35263650e-01  1.03659261e-04
 -3.59008400e-05  1.04245012e+00  1.49973188e+00 -6.63901678e-01
 -1.34412898e+00  1.21991710e+00  1.18623702e+00  0.00000000e+00
  0.00000000e+00  1.09705425e+00  1.60001418e+00 -1.08919429e+00
 -1.43003970e+00  7.40058720e-01  1.11046388e+00  0.00000000e+00
 -2.26922545e-04  6.44333650e-01  2.04626901e+00 -8.14807840e-01
 -5.10045654e-01  1.36354825e+00  5.88820038e-01 -2.60642718e-03
 -4.72644347e-05  8.42216335e-02  1.93872355e+00  6.95036652e-01
  1.06958807e+00  1.47158172e+00  4.06964210e-04 -1.18140138e-02
 -1.57548116e-05 -4.60257546e-02  1.60276373e-01  1.37392653e+00
  1.10243022e+00 -2.

[3.0984403179179613e-44, 5.312390058054854e-36, 5.2661612523891545e-31, 1.9909441036707534e-27, 2.4674674397554927e-22, 1.2327261065556817e-17, 2.211450109958169e-15, 1.705489178350313e-07, 0.00020430948047879218, 0.9997955199706011] [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[9.251909383804542e-46, 3.174053591179286e-36, 1.2392548273397165e-31, 5.053238628235219e-28, 2.6183878183221294e-23, 1.3342764225049343e-18, 5.598092772611007e-16, 5.4608221938305356e-08, 0.00013599879883736536, 0.9998639465929401] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
0.31347693992954767
[5.921035248554067e-54, 2.9584076805373e-44, 1.4079623308864462e-38, 3.774051110692953e-34, 3.543862297715717e-28, 8.478318630823597e-22, 1.1862381366473398e-18, 3.818829565456092e-09, 3.962675915100696e-05, 0.9999603694220194] [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[7.713643225060354e-60, 7.288493289145445e-50, 1.2956996877730865e-42, 1.9745079356596735e-37,


9 : [ 0.00000000e+00 -2.94742079e-02 -4.46164757e-01 -8.04172015e-01
 -7.30171778e-01 -5.34645080e-01 -1.54164952e-01 -2.89549749e-03
  0.00000000e+00 -2.41238832e-01 -6.78545121e-01 -6.75033177e-01
 -3.10686368e-01 -3.62750827e-02 -3.94585487e-02 -4.29050369e-03
  0.00000000e+00  5.01729403e-02 -6.79731653e-01 -7.15232872e-01
 -5.72036977e-01  2.90778110e-01  2.52630116e-01 -9.02993702e-05
  0.00000000e+00  1.44599920e-01 -5.30787282e-01 -1.40646233e+00
 -1.58622795e+00 -1.98044332e-01  3.36327164e-01  0.00000000e+00
  0.00000000e+00  2.38366626e-01 -8.67746354e-02 -1.29018138e+00
 -1.60976674e+00 -6.76751826e-01  1.92700508e-01  0.00000000e+00
  0.00000000e+00  1.67152154e-01  2.04208094e-01 -1.00285744e+00
 -1.16748175e+00 -3.81762006e-01 -3.21067609e-01  0.00000000e+00
  0.00000000e+00  6.99983335e-04  1.70511848e-01 -5.62764115e-01
 -4.07211421e-01 -2.10704312e-01 -7.83396615e-01 -2.87788577e-02
  0.00000000e+00 -3.71873745e-02 -5.23964771e-01 -7.74250001e-01
 -8.55289431e-01 -9.


3 : [ 0.00000000e+00 -3.14532033e-02 -3.53344366e-02  2.68826847e-01
  1.24848928e-01 -2.34568510e-01 -1.43241873e-01 -2.23296842e-03
  0.00000000e+00 -8.95541016e-02  3.02397166e-01  3.06242589e-01
  4.70073400e-01  6.48303184e-01 -2.02937196e-02 -4.08400007e-03
  0.00000000e+00  2.80181460e-01  4.12656461e-01 -2.58645155e-01
 -3.18995026e-01  9.74552571e-01  3.65729184e-01  9.72660439e-05
  0.00000000e+00  4.24666356e-01  4.72539276e-01 -9.17144073e-01
 -1.17854929e+00  4.07989242e-01  6.58539346e-01  0.00000000e+00
  0.00000000e+00  5.62343899e-01  6.48706985e-01 -9.58898694e-01
 -1.18949670e+00  3.57671800e-03  6.37897879e-01  0.00000000e+00
 -6.59922340e-06  3.80120965e-01  1.02808828e+00 -6.35932670e-01
 -8.37063185e-01  3.66639391e-01  2.80196464e-01 -7.64368609e-06
 -7.12662559e-05  2.17059563e-02  9.99309891e-01  2.71293433e-01
  2.70716415e-01  8.36345624e-01 -2.12000208e-01 -1.76574503e-02
 -1.96697578e-05 -3.81711643e-02 -1.09446019e-01  2.91461865e-01
  3.61003684e-01 -1.


1 : [ 0.00000000e+00 -3.25465576e-02 -7.46951938e-02  3.85042107e-01
 -1.09479720e-03 -6.13913044e-01 -2.13351466e-01 -6.05081820e-03
  0.00000000e+00 -1.51424462e-01  6.14762800e-01  4.58106980e-01
  1.06444725e+00  9.24905994e-01 -7.43596547e-02 -1.20942635e-03
  0.00000000e+00  4.54874946e-01  1.04995445e+00 -3.99534689e-01
 -4.19572436e-01  1.23215977e+00  5.09723619e-01 -5.45334092e-05
  0.00000000e+00  7.26216966e-01  8.29631699e-01 -1.59811950e+00
 -1.86098369e+00  7.72559354e-01  1.13370250e+00 -8.60218905e-06
  0.00000000e+00  8.12095504e-01  8.06080553e-01 -1.70595700e+00
 -2.28812083e+00 -3.50486669e-03  1.08059306e+00  0.00000000e+00
 -3.80176925e-06  6.79076544e-01  1.39996507e+00 -1.23778969e+00
 -1.46639880e+00  5.80340085e-01  2.97187168e-01 -4.40347735e-06
 -4.10560219e-05  1.22330851e-01  1.52748594e+00 -2.29986139e-02
  5.22356169e-01  1.01270932e+00 -5.55004357e-01 -1.73895386e-02
 -1.13316183e-05 -2.44604550e-02 -6.14826325e-02  3.36734411e-01
  3.47733581e-01 -4.

[0.0002865739304312251, 0.00027438471624090916, 0.9994390413533274, 1.6827604260201073e-26, 8.833531810414397e-24, 1.552916038692224e-20, 3.697426181223886e-18, 3.3416692650609774e-17, 6.540909995679948e-17, 4.047310514760065e-16] [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3860957026738938
[0.0005328038905313298, 0.00033932204163680823, 0.9991278740678302, 2.6443991831327246e-24, 4.157383786279701e-22, 2.540153384353392e-19, 3.281866133282175e-17, 2.057378937663265e-16, 3.43672183677134e-16, 1.0308672648512151e-15] [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3859249786164738
[1.2448956437646213e-05, 6.641267082319983e-05, 0.9999211377567049, 4.174208521264503e-25, 7.846287437813656e-22, 4.2170514469753426e-17, 1.0506130623441227e-14, 2.7361031691389114e-12, 2.6517496440549442e-11, 5.867702690863609e-10] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
1.386294361117994
[2.018451366443605e-06, 1.0533454937850249e-05, 0.9999874480920181, 2.3306174855694322e-28, 1.104

3 : [ 0.00000000e+00 -5.68136343e-02 -7.90013081e-01 -1.85051677e+00
 -2.23126224e+00 -1.39074118e+00 -1.81863490e-01 -3.17995169e-03
  0.00000000e+00 -3.34863398e-01 -1.65404604e+00 -1.91680375e+00
 -1.75094802e+00 -1.58268037e+00 -5.94991938e-01 -5.84702202e-02
  0.00000000e+00 -3.83882574e-01 -1.69804846e+00 -1.35721589e+00
 -1.67368858e+00 -1.48413250e+00 -5.45116672e-01 -7.71239523e-02
  0.00000000e+00 -3.44674893e-01 -1.51181703e+00 -2.21341641e+00
 -2.42857151e+00 -1.50354391e+00 -4.38461719e-01 -8.44262642e-03
  0.00000000e+00 -1.69158263e-01 -1.22782398e+00 -2.03272932e+00
 -2.01530529e+00 -1.78915980e+00 -3.69362215e-01  0.00000000e+00
  0.00000000e+00 -6.06745217e-02 -6.73792192e-01 -1.14813044e+00
 -1.56555643e+00 -1.28545176e+00 -6.09882258e-01 -2.38507990e-03
  0.00000000e+00 -1.67152465e-01 -1.13532351e+00 -1.59656651e+00
 -1.51136280e+00 -1.26628030e+00 -6.39971719e-01 -8.74529295e-03
  0.00000000e+00 -6.18765305e-02 -9.81369708e-01 -1.97646344e+00
 -2.14394296e+00 -1.2


1 : [ 0.00000000e+00 -1.90188905e-02  6.64149286e-01  1.48866347e+00
  1.15369747e+00  1.37266199e-04 -1.08977929e-01 -2.44907931e-03
  0.00000000e+00  2.02982942e-02  1.94071415e+00  1.86284690e+00
  1.40369338e+00  1.55275725e+00 -1.43619985e-01 -3.23166088e-02
  1.90012935e-03  6.51245761e-01  2.13863058e+00  3.91861354e-01
 -6.27311375e-01  1.47148860e+00  5.16602087e-01 -4.28023504e-02
  9.50064673e-04  9.74484787e-01  1.87576789e+00 -9.10253093e-01
 -1.39672341e+00  8.61792502e-01  1.21946824e+00 -4.67668866e-03
  0.00000000e+00  1.18256899e+00  1.66632218e+00 -9.90267457e-01
 -1.18821416e+00  8.04679235e-01  1.33685868e+00  0.00000000e+00
  0.00000000e+00  7.79717856e-01  2.36009935e+00 -3.98387967e-01
 -5.97543306e-01  1.58953206e+00  9.20244300e-01 -1.51099157e-03
  0.00000000e+00  5.85304126e-02  2.18992533e+00  1.17751799e+00
  1.29651858e+00  2.05256761e+00  2.30229822e-01 -5.90076903e-03
  0.00000000e+00 -3.76509060e-02  4.14620376e-01  1.83138968e+00
  1.68227732e+00  4.


1 : [ 0.00000000e+00 -2.75168198e-02  4.05508333e-01  8.92198903e-01
  6.25189326e-01 -1.72156625e-01 -1.91248436e-01 -2.29529333e-02
  7.25014027e-05 -5.95927505e-03  1.22843457e+00  1.21514583e+00
  9.17866375e-01  1.10946668e+00 -1.35782347e-01 -4.12325352e-02
 -3.94793204e-05  4.65729032e-01  1.36627993e+00  5.78086533e-02
 -6.61576640e-01  1.14518253e+00  4.43003102e-01 -3.21100028e-02
 -4.03833316e-05  7.38013516e-01  1.02454091e+00 -1.15011436e+00
 -1.42206607e+00  5.83403185e-01  9.54270083e-01 -2.43916726e-03
  0.00000000e+00  9.20472669e-01  1.06450341e+00 -1.02158364e+00
 -1.18144749e+00  3.78330395e-01  1.16042490e+00  0.00000000e+00
  0.00000000e+00  6.45348554e-01  1.72380523e+00 -5.86109027e-01
 -7.38239277e-01  9.08209593e-01  7.74351828e-01 -7.91705951e-04
  0.00000000e+00  5.25095394e-02  1.58165827e+00  5.97695084e-01
  5.50243490e-01  1.40908033e+00  1.03181291e-01 -7.69067440e-03
  0.00000000e+00 -3.73939095e-02  2.11686210e-01  1.18934410e+00
  1.09100832e+00  2.

[1.520863469348181e-45, 3.332811775321965e-38, 1.0781328477495317e-31, 2.296358433997226e-25, 2.1036815399460593e-18, 3.4347677670251105e-13, 7.741503669600027e-07, 0.031210648453203736, 0.9687885773960858, 8.854589586824533e-32] [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[3.354059975847252e-48, 3.0114751657808244e-40, 5.274437800995714e-32, 1.446387592172649e-25, 5.941833290528377e-20, 2.00380595437287e-14, 9.160832094368626e-07, 0.037782237829157454, 0.962216846087613, 4.0739784145854074e-35] [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[3.951176121123587e-43, 2.7954731882801663e-36, 9.021637271106986e-30, 1.975419869837265e-23, 7.619517380740939e-18, 9.224651189951112e-13, 1.9339000428234166e-06, 0.05180758758007098, 0.9481904785189638, 1.8621633054667806e-31] [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[3.744714331172867e-46, 1.863099711524777e-38, 3.0765378288988045e-32, 6.020074187421028e-27, 1.64046710

[3.839588208817798e-53, 1.1231806366229816e-43, 7.393116257979214e-36, 3.1264926740663445e-29, 4.659728546127715e-22, 6.923582846549794e-16, 7.263312665478556e-08, 0.025745932763788056, 0.9742539946030846, 1.0744146667152735e-38] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
0.35357961501876956
[1.3949900886440454e-51, 1.3249257593987595e-42, 6.061456486572294e-35, 3.461823937580664e-28, 5.1945999445046236e-21, 2.7028680333870015e-15, 7.127512826448036e-08, 0.022782477689298464, 0.9772174510355706, 1.5781161782664707e-36] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
0.3489812257503927
[9.734687100359703e-48, 9.238982621659906e-40, 1.1381878480161468e-33, 2.1327159942862606e-28, 4.675480487129277e-20, 4.129566014329075e-15, 8.404426516203409e-07, 0.03845557616216051, 0.9615435833951836, 5.591608181984875e-35] [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1.3862943611198906
[1.4979915679053867e-45, 1.6894591582171927e-38, 7.554408749548525e-32, 3.573158788978129e-25, 7.45

0 : [-0.25667805 -0.21972047 -0.24311186 -0.34849525 -0.17450867 -0.06076851
  0.05852717 -0.31964899 -0.07499467 -0.10651703 -0.11222047 -0.09664
  0.10334853]

1 : [-0.2506738  -0.2166387  -0.23811489 -0.34208844 -0.17031477 -0.05755096
  0.06008119 -0.31429313 -0.07155673 -0.10609133 -0.10778085 -0.09142874
  0.10273357]

2 : [-0.25009129 -0.2141045  -0.23696247 -0.33969247 -0.17008857 -0.05913948
  0.05698827 -0.31164983 -0.07288288 -0.1040966  -0.10914491 -0.09368187
  0.10102895]

Grad descent step 18
0 : [-0.24302881 -0.22112356 -0.23176773 -0.34751416 -0.16201754 -0.04548036
  0.07351258 -0.31679474 -0.0659174  -0.10073502 -0.10017498 -0.07661466
  0.12467854]

1 : [-0.23981002 -0.21963992 -0.22861448 -0.34253136 -0.15999859 -0.04382609
  0.07426884 -0.31269139 -0.06399074 -0.10069013 -0.09742934 -0.07410358
  0.12267165]

2 : [-0.2393895  -0.21645478 -0.22814377 -0.34052158 -0.1599265  -0.04604252
  0.07053012 -0.31061487 -0.06544478 -0.09925835 -0.09929354 -0.07696492
  0.120

1.12264682959737
[0.34146711435441307, 0.3334188082465569, 0.3251140773990299] [0.0, 1.0, 0.0]
1.1226543533688156
[0.34088470337938076, 0.33330483654298426, 0.32581046007763503] [0.0, 1.0, 0.0]
1.122757236789477
[0.3414596851111982, 0.33334118534028784, 0.32519912954851393] [0.0, 1.0, 0.0]
1.1227244254309312
[0.3410598737358026, 0.3332603766697205, 0.3256797495944768] [0.0, 1.0, 0.0]
1.1227973684950108
[0.34118742610464875, 0.33337331066284326, 0.325439263232508] [0.0, 1.0, 0.0]
1.1226954256858825
[0.34044274992551254, 0.3332624393508873, 0.32629481072360017] [0.0, 1.0, 0.0]
1.1227955066490347
[0.3405641305896382, 0.3332431819587262, 0.32619268745163565] [0.0, 1.0, 0.0]
1.1228128888985134
[0.3402489603416236, 0.3332088366910516, 0.3265422029673248] [0.0, 1.0, 0.0]
1.1228438891782138
[0.3415176748438973, 0.3334373798780375, 0.32504494527806516] [0.0, 1.0, 0.0]
1.1226375876321906
[0.3415690787104311, 0.3334274315235876, 0.3250034897659813] [0.0, 1.0, 0.0]
1.1226465686469216
[0.3431157077


1 : [0.21721417 0.07704717 0.1740811  0.09205708 0.15742341 0.19677516
 0.17872062 0.07064402 0.13695577 0.12854142 0.12838604 0.2007821
 0.22115315]

2 : [0.20834906 0.07379116 0.16670191 0.08622518 0.15095504 0.1896095
 0.1727356  0.06614449 0.13160329 0.12492999 0.12211521 0.19268003
 0.21515602]

Grad descent step 7
0 : [0.23047637 0.07566386 0.18122581 0.08768491 0.16330687 0.21588012
 0.20092434 0.06558555 0.14872922 0.13492114 0.13725476 0.21874554
 0.24708629]

1 : [0.22934252 0.07461712 0.18069569 0.08844223 0.16274035 0.21443756
 0.19941341 0.06593964 0.14781709 0.13282886 0.13780368 0.21821039
 0.24404378]

2 : [0.22412527 0.07412406 0.17651613 0.08595379 0.15913726 0.2093802
 0.19450242 0.06445772 0.14434771 0.13151227 0.13326528 0.21220638
 0.23947309]

Grad descent step 8
0 : [0.22933974 0.07174996 0.17837029 0.08095327 0.16334923 0.21685536
 0.20947341 0.05369319 0.15090681 0.14114982 0.1365675  0.2225318
 0.26379287]

1 : [0.23342841 0.07323778 0.18201561 0.08552027 0.