# Perceptron Algorithm Python Class

In [2]:
import numpy as np
from numpy.random import seed
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib import rcParams
# set up plot figure size
rcParams["figure.figsize"] = 10,5
%matplotlib inline

In [3]:
class Perceptron(object):
    """
    Perceptron Classifier 
    
    Parameters
    ----------
    eta : float
        Learning rate (between 0.0 and 1.0)
    n_iter : int
        Epochs over the training set
        
    Attributes
    ----------
    w_ : 1d-array
        Weights after fitting
    errors_ : list
        Number of misclassifications in every epoch
    """
    def __init__(self, eta=0.01, n_iter=10):
        self.eta = eta
        self.n_iterm = n_iter
        
    def fit(self, X, y):
        """
        Fit method for training data
        
        Parameters
        -----------
        X : {array-like, shape = [n_samples, n_features]]
            Training vectors, where 'n_samples' is the number
            of samples and 'n_features' is the number of
            features.
        y : {array-like}, shape = [n_samples]
            Target values
        Returns
        -------
        self: object
        """
        self.w_ = np.zeros(1 + X.shape[1])
        self.errors_ = []
        
        for _ in range(self.n_iter):
            errors = 0
            for xi, target in zip(X,y):
                update = self.eta * (target - self.predict(xi))
                self.w_[1:] += update * xi
                self.w_[0] += update
                errors += int(update != 0.0)
            self.errors_.append(errors)
        return self
    
    def net_input(self, X):
        """
        Calculate the net input
        """
        return np.dot(X, self.w_[1:] + self.w_[0])
    
    def predict(self, X):
        """
        Return class label after unit step.
        """
        return np.where(self.net_input(x) >= 0.0, 1, -1)

## Iris dataset

In [5]:
# load the iris data set
csv_name = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
header = None
# load the data into a pandas dataframe
df = pd.read_csv(csv_name, header=header)

URLError: <urlopen error EOF occurred in violation of protocol (_ssl.c:841)>