In [7]:
from src import Value
from src import MLP
from src import training_data_points, training_targets, testing_data_points, testing_targets

In [3]:
# Define a Model
number_of_inputs = len(training_data_points[0])
number_of_outputs_for_each_layer = [4, 4, 1]

mlp = MLP(len(training_data_points[0]), number_of_outputs_for_each_layer)

In [4]:
# Train a Model
number_of_epochs = 100
for epoch in range(number_of_epochs):
    # Forward pass
    actual_ys = [mlp(x)[0] for x in training_data_points]
    loss: Value = sum((expected_y - actual_y)**2 for expected_y,
                      actual_y in zip(training_targets, actual_ys))
    print("loss: ", loss)

    # zero grad
    for p in mlp.parameters():
        p.grad = 0.0
    # Backward pass
    loss.backward()

    for p in mlp.parameters():
        p.data += -0.001 * p.grad

loss:  Value(data=21.991445789392838)
loss:  Value(data=19.602365518489552)
loss:  Value(data=18.29554598198281)
loss:  Value(data=17.141987944638704)
loss:  Value(data=16.0847879478177)
loss:  Value(data=15.113943099348274)
loss:  Value(data=14.223179327269511)
loss:  Value(data=13.405912527006189)
loss:  Value(data=12.655447357386983)
loss:  Value(data=11.965423291282784)


In [6]:
# Test a Model
def test(mlp: MLP, testing_data_points, testing_targets):
    """Test if model guess iris specie correctly
    """
    results = [mlp(x)[0] for x in testing_data_points]
    
    actual = [0.0 if abs(result.data) < 0.33 else 0.5 if abs(result.data) < 0.66 else 1.0 for result in results]
    
    accuracy = (sum([1 if actual[i] == testing_targets[i] else 0 for i in range(len(testing_targets)) ]) / len(testing_targets)) * 100 

    for i, result in enumerate(results):
        print(f"#{i} | Target: {testing_targets[i]} | Actual: {actual[i]}")
    print(f"Total accuracy: {accuracy}")
        
test(mlp, testing_data_points, testing_targets)

#0 | Target: 1.0 | Actual: 0.5
#1 | Target: 0.0 | Actual: 0.0
#2 | Target: 0.5 | Actual: 0.5
#3 | Target: 0.5 | Actual: 0.5
#4 | Target: 0.0 | Actual: 0.0
#5 | Target: 1.0 | Actual: 1.0
#6 | Target: 0.5 | Actual: 0.5
#7 | Target: 0.0 | Actual: 0.5
#8 | Target: 0.5 | Actual: 0.5
#9 | Target: 1.0 | Actual: 0.5
#10 | Target: 0.5 | Actual: 0.5
#11 | Target: 1.0 | Actual: 0.5
#12 | Target: 0.5 | Actual: 0.5
#13 | Target: 1.0 | Actual: 0.5
#14 | Target: 1.0 | Actual: 0.5
#15 | Target: 0.5 | Actual: 0.5
#16 | Target: 0.5 | Actual: 0.5
#17 | Target: 0.0 | Actual: 0.5
#18 | Target: 0.0 | Actual: 0.5
#19 | Target: 0.0 | Actual: 0.5
#20 | Target: 1.0 | Actual: 1.0
#21 | Target: 0.0 | Actual: 0.0
#22 | Target: 1.0 | Actual: 0.5
#23 | Target: 0.0 | Actual: 0.5
#24 | Target: 0.0 | Actual: 0.5
#25 | Target: 0.5 | Actual: 0.5
#26 | Target: 0.5 | Actual: 0.5
#27 | Target: 0.5 | Actual: 0.5
#28 | Target: 0.0 | Actual: 0.5
#29 | Target: 0.0 | Actual: 0.0
Total accuracy: 56.666666666666664
