In [1]:
from dataclasses import dataclass
from PIL import Image
import numpy as np

In [10]:
def parse_image(filename):
    img_arr = np.array(Image.open(filename, 'r').convert('L').getdata())
    return np.where(img_arr == 255, -1, 1)

@dataclass
class ImageEntity:
    array: np.ndarray
    weights: np.array
    input_signal: int = 1


def new_entity_container(image, input_signal=1):
    x0 = 1
    return ImageEntity(
        array=np.concatenate(([x0], image)),
        weights=np.zeros(len(image) + 1),
        input_signal=input_signal)


def sum_output_signal(x, weights):
    return np.sum(x * weights) + weights[0]


def train(container, learning_rate=1):
    i = 0
    while True:
        i += 1
        container.weights = container.weights * learning_rate + container.array * container.input_signal
        S = sum_output_signal(container.array, container.weights)
        if S < 0 if container.input_signal == -1 else S > 0:
            return i


def identify(image, trained):
    image = np.concatenate(([1], image))
    result = dict()
    for letter in trained:
        s = sum_output_signal(image, trained[letter].weights)
        result[letter] = s
    return dict(sorted(result.items(), key=lambda item: item[1], reverse=True))

In [11]:
LETTERS_DIR = 'img/'

images = dict(
    a_valid=parse_image(LETTERS_DIR + 'a_valid.bmp'),
    a_invalid=parse_image(LETTERS_DIR + 'a_invalid.bmp'),
    d_valid=parse_image(LETTERS_DIR + 'd_valid.bmp'),
    d_invalid=parse_image(LETTERS_DIR + 'd_invalid.bmp'),
    x_valid=parse_image(LETTERS_DIR + 'x_valid.bmp'),
    x_invalid=parse_image(LETTERS_DIR + 'x_invalid.bmp'),
    o_valid=parse_image(LETTERS_DIR + 'o_valid.bmp'),
    o_invalid=parse_image(LETTERS_DIR + 'o_invalid.bmp'), )

TRAINED = dict(
    a=new_entity_container(image=images['a_valid']),
    d=new_entity_container(image=images['d_valid']),
    x=new_entity_container(image=images['x_valid']),
    o=new_entity_container(image=images['o_valid']), )

for letter in TRAINED:
    train(TRAINED[letter])

for key in images:
    letters = identify(image=images[key], trained=TRAINED)
    print(f'{key} (weights):')
    for letter in letters:
        print(f'\t{letter} = {letters[letter]}')
    print()

a_valid (weights):
	a = 2502.0
	d = 1082.0
	x = 756.0
	o = -284.0

a_invalid (weights):
	a = 1790.0
	d = 926.0
	x = 432.0
	o = -244.0

d_valid (weights):
	d = 2502.0
	a = 1082.0
	x = 812.0
	o = 160.0

d_invalid (weights):
	a = 1642.0
	d = 1314.0
	x = 668.0
	o = -32.0

x_valid (weights):
	x = 2502.0
	d = 812.0
	a = 756.0
	o = -254.0

x_invalid (weights):
	x = 1570.0
	d = 1112.0
	a = 932.0
	o = -10.0

o_valid (weights):
	o = 2502.0
	d = 160.0
	x = -254.0
	a = -284.0

o_invalid (weights):
	o = 1140.0
	d = 710.0
	a = 146.0
	x = 76.0

