In [11]:
# generate data
# list of points 
import numpy as np 
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
np.random.seed(2)

means = [[2, 2], [4, 2]]
cov = [[.3, .2], [.2, .3]]
N = 10
X0 = np.random.multivariate_normal(means[0], cov, N).T
X1 = np.random.multivariate_normal(means[1], cov, N).T
print(X0)
print(X1)

[[2.22096057 2.70132234 3.08493823 2.02701417 2.73223639 1.21171968
  2.22920603 1.8637762  1.74682699 2.37191737]
 [2.19579728 3.43487375 2.70849736 1.47010441 2.32571583 2.23682627
  1.72925457 1.59716548 2.27230351 2.37595358]]
[[4.47403369 4.09281249 4.22222334 4.58438569 4.74493118 3.6355797
  5.19217738 3.51075436 3.93784332 3.8787214 ]
 [2.4040742  1.65061706 2.11659863 2.05326933 2.67628604 2.63347726
  3.2425902  2.11880111 1.56029947 2.12126884]]


In [12]:
def h(w, x):    
    return np.sign(np.dot(w.T, x))

In [13]:
def has_converged(X, y, w):    
    return np.array_equal(h(w, X), y) 

In [14]:
def perceptron(X, y, w_init):
    w = [w_init]
    N = X.shape[1]
    d = X.shape[0]
    mis_points = []
    while True:
        # mix data 
        mix_id = np.random.permutation(N)
        for i in range(N):
            xi = X[:, mix_id[i]].reshape(d, 1)
            yi = y[0, mix_id[i]]
            if h(w[-1], xi)[0] != yi: # misclassified point
                mis_points.append(mix_id[i])
                w_new = w[-1] + yi*xi 
                w.append(w_new)
                
        if has_converged(X, y, w[-1]):
            break
    return (w, mis_points)

In [15]:
d = X.shape[0]
w_init = np.random.randn(d, 1)
(w, m) = perceptron(X, y, w_init)
print(w)

[array([[-0.3135082 ],
       [ 0.77101174],
       [-1.86809065]]), array([[ 0.6864918 ],
       [ 2.63478794],
       [-0.27092518]]), array([[-0.3135082 ],
       [-1.83924575],
       [-2.67499938]]), array([[ 0.6864918 ],
       [ 0.38171482],
       [-0.4792021 ]]), array([[-0.3135082 ],
       [-3.5561285 ],
       [-2.03950157]]), array([[ 0.6864918 ],
       [-1.52911433],
       [-0.56939716]]), array([[1.6864918 ],
       [0.7000917 ],
       [1.15985741]]), array([[ 0.6864918 ],
       [-4.04483948],
       [-1.51642863]]), array([[ 1.6864918 ],
       [-1.67292211],
       [ 0.85952495]]), array([[2.6864918 ],
       [0.19085409],
       [2.45669043]]), array([[ 1.6864918 ],
       [-3.31990027],
       [ 0.33788932]]), array([[ 2.6864918],
       [-1.0989397],
       [ 2.5336866]]), array([[ 1.6864918 ],
       [-4.9776611 ],
       [ 0.41241776]]), array([[ 2.6864918 ],
       [-2.95064692],
       [ 1.88252217]]), array([[3.6864918 ],
       [0.1342913 ],
       [4.5910