In [1]:
import numpy as np

In [78]:
class Perceptron():
    def __init__(self, thresholds=0.0, eta=0.01, n_iter=10):
        self.thresholds = thresholds
        self.eta = eta
        self.n_iter = n_iter
    
    def fit(self, X, y):
        self.w_ = np.zeros(1+X.shape[1])
        self.errors_ = []
        
        for _ in range(self.n_iter):
            errors = 0
            for xi, target in zip(X,y):
                update = self.eta * (target-self.predict(xi))
                self.w_[1:] += update * xi
                self.w_[0] += update
                errors += int(update!=0.0)
            self.errors_.append(errors)
            print(self.w_)
        
        return self
    
    def net_input(self, X):
        return np.dot(X, self.w_[1:]) + self.w_[0]
    
    def predict(self, X):
        return np.where(self.net_input(X) > self.thresholds, 1, -1)

In [79]:
X = np.array([[0,0],[0,1],[1,0],[1,1]])
y = np.array([-1,-1,-1,1])

ppn = Perceptron(0,0.01,10)
ppn.fit(X,y)
print(ppn.errors_)

[0.02 0.02 0.02]
[0.   0.04 0.02]
[-0.02  0.04  0.02]
[-0.02  0.04  0.04]
[-0.04  0.04  0.02]
[-0.04  0.04  0.02]
[-0.04  0.04  0.02]
[-0.04  0.04  0.02]
[-0.04  0.04  0.02]
[-0.04  0.04  0.02]
[1, 3, 3, 2, 1, 0, 0, 0, 0, 0]


In [80]:
print(ppn.w_)
print('+'*20)
print(X)
print('+'*20)
print(np.dot(X, ppn.w_[1:]))
print('+'*20)
print(np.dot(X, ppn.w_[1:]) + ppn.w_[0])
print('+'*20)
print(ppn.net_input(X))
print('+'*20)

[-0.04  0.04  0.02]
++++++++++++++++++++
[[0 0]
 [0 1]
 [1 0]
 [1 1]]
++++++++++++++++++++
[0.   0.02 0.04 0.06]
++++++++++++++++++++
[-0.04 -0.02  0.    0.02]
++++++++++++++++++++
[-0.04 -0.02  0.    0.02]
++++++++++++++++++++


In [81]:
print(ppn.net_input(X) > 0)
print('+'*20)
print(np.where(ppn.net_input(X) > 0,1,-1))
print('+'*20)

[False False False  True]
++++++++++++++++++++
[-1 -1 -1  1]
++++++++++++++++++++


In [82]:
errors = 0
xi, target = X[0], y[0]
print(xi, target)
print('+'*20)
update = ppn.eta * (target-ppn.predict(xi))
print(update)
print('+'*20)
ppn.w_[1:] += update*xi
ppn.w_[0] + update
print(ppn.w_[0], ppn.w_[1:])
print('+'*20)
errors += int(update!=0.0)
print(errors)
print('+'*20)

[0 0] -1
++++++++++++++++++++
0.0
++++++++++++++++++++
-0.04 [0.04 0.02]
++++++++++++++++++++
0
++++++++++++++++++++
