In [1]:
import pandas as pd

df = pd.read_csv("perceptron_toydata-truncated.txt", sep="\t")
df

Unnamed: 0,x1,x2,label
0,0.77,-1.14,0
1,-0.33,1.44,0
2,0.91,-3.07,0
3,-0.37,-1.91,0
4,-0.63,-1.53,0
5,0.39,-1.99,0
6,-0.49,-2.74,0
7,-0.68,-1.52,0
8,-0.1,-3.43,0
9,-0.05,-1.95,0


In [2]:
X_train = df[["x1", "x2"]].values
y_train = df["label"].values

In [3]:
import torch

X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)

In [4]:
class Perceptron:
    def __init__(self, num_features):
        self.weights = torch.zeros(num_features)
        self.bias = torch.tensor(0.)
        
    def forward(self, x):
        weighted_sum = torch.dot(x, self.weights) + self.bias
        
        if weighted_sum > 0.0:
            return torch.tensor(1.0)
        else:
            return torch.tensor(0.0)
    
    def update(self, x, y):
        y_pred= self.forward(x)
        error = y - y_pred
        
        self.bias += error
        
        self.weights += error * x
        
        return error
    
    
        
        
    
    

In [5]:
def compute_accuracy(mode1, all_X, all_y):
    
    num_correct = 0.0
    
    for x, y in zip(all_X, all_y):
        y_pred = mode1.forward(x)
        if y_pred == y:
            num_correct += 1
    return num_correct / len(all_y)



In [6]:
ppn = Perceptron(num_features=2)


In [10]:
x= torch.tensor([1.1, 2.1])
ppn.forward(x)

tensor(1.)

In [12]:
ppn.update(x, y=torch.tensor(1.0))

tensor(0.)

In [13]:
def train(model, all_X, all_y, num_epochs):
    for epoch in range(num_epochs):
        error_count = 0.0
        
        for x, y in zip(all_X, all_y):
            error= model.update(x, y)
            error_count += abs(error)
        print(f"Epoch: {epoch+1}, Error: {error_count}")
        

In [14]:
ppn = Perceptron(num_features=2)

In [15]:
train(ppn, X_train, y_train, num_epochs=10)

Epoch: 1, Error: 1.0
Epoch: 2, Error: 3.0
Epoch: 3, Error: 1.0
Epoch: 4, Error: 0.0
Epoch: 5, Error: 0.0
Epoch: 6, Error: 0.0
Epoch: 7, Error: 0.0
Epoch: 8, Error: 0.0
Epoch: 9, Error: 0.0
Epoch: 10, Error: 0.0
