# Scikit-Learn course alapján
#### https://youtu.be/pqNCD_5r0IU?t=9151

In [2]:
import mnist
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import confusion_matrix
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import os

In [3]:
xtrain = mnist.train_images()
ytrain = mnist.train_labels()

xtest = mnist.test_images()
ytest = mnist.test_labels()

In [4]:
xtrain = xtrain.reshape((-1, 28*28))
xtest = xtest.reshape((-1,28*28))

In [5]:
xtrain = np.array(xtrain/256)
xtest = np.array(xtest/256)

In [6]:
clf = MLPClassifier(solver='adam', activation='relu', hidden_layer_sizes=(64,64))
clf.fit(xtrain,ytrain)

In [7]:
prediction = clf.predict(xtest)

In [8]:
acc = confusion_matrix(ytest, prediction)

In [9]:
def accuracy(cm):
    diag = cm.trace()
    elements = cm.sum()
    return diag/elements

print(accuracy(acc))

In [10]:
def ImagetoBin(img):
    data = list(img.getdata())
    for i in range(len(data)):
        data[i] = 255 - data[i]
    data = np.array(data)/256.0
    return data

In [13]:
image_folder = "data"
images = []
predictions = []

# Képek beolvasása és osztályozás végrehajtása
for filename in os.listdir(image_folder):
    if filename.endswith(".png"):
        file_path = os.path.join(image_folder, filename)
        image = Image.open(file_path)
        binimg = ImagetoBin(image)
        p = clf.predict([binimg])
        images.append(image)
        predictions.append(p)

fig, axes = plt.subplots(2, 5, figsize=(10, 4))

for i, ax in enumerate(axes.flat):
    ax.imshow(images[i], cmap='gray')
    ax.set_title("Prediction: {}".format(predictions[i]))

plt.tight_layout()
plt.show()

In [12]:
zero = [100, 100, 100, 100, 101, 102, 101, 101, 101, 103, 102, 103, 101, 101, 101, 100, 99, 98, 99, 99, 100, 99, 99, 96, 99, 97, 99, 100, 100, 101, 100, 100, 101, 102, 102, 102, 101, 102, 101, 99, 100, 101, 101, 100, 98, 97, 98, 100, 98, 100, 98, 97, 97, 100, 100, 99, 100, 100, 100, 101, 101, 102, 103, 102, 102, 111, 155, 194, 166, 197, 161, 118, 100, 97, 97, 100, 100, 99, 97, 97, 96, 100, 100, 99, 99, 100, 101, 101, 101, 102, 103, 102, 120, 191, 121, 95, 99, 91, 113, 175, 127, 101, 99, 98, 99, 97, 99, 98, 97, 100, 100, 100, 99, 101, 102, 101, 101, 101, 102, 110, 186, 104, 100, 102, 102, 101, 101, 113, 185, 107, 100, 99, 100, 97, 97, 97, 96, 99, 101, 98, 98, 100, 101, 101, 101, 102, 100, 152, 125, 100, 100, 100, 99, 101, 102, 151, 209, 148, 105, 99, 99, 98, 98, 97, 97, 98, 97, 98, 99, 98, 99, 100, 101, 100, 106, 180, 102, 99, 99, 99, 100, 101, 100, 99, 116, 182, 170, 107, 99, 98, 97, 96, 97, 97, 99, 101, 97, 98, 99, 100, 101, 101, 135, 130, 100, 100, 98, 99, 99, 101, 101, 102, 99, 113, 110, 171, 108, 99, 97, 96, 97, 98, 99, 100, 97, 98, 99, 100, 101, 100, 165, 108, 103, 99, 99, 100, 100, 101, 102, 102, 103, 104, 102, 122, 141, 99, 98, 97, 98, 98, 99, 99, 98, 99, 100, 100, 102, 103, 167, 101, 100, 100, 98, 98, 100, 99, 99, 101, 101, 102, 103, 100, 163, 105, 100, 97, 99, 99, 99, 99, 98, 98, 99, 100, 102, 117, 148, 101, 99, 100, 98, 99, 101, 99, 100, 100, 102, 101, 103, 102, 143, 118, 100, 99, 98, 98, 100, 100, 97, 99, 100, 101, 102, 128, 136, 101, 98, 98, 99, 100, 98, 99, 99, 102, 102, 102, 104, 104, 127, 129, 100, 99, 99, 98, 99, 101, 96, 100, 101, 102, 102, 130, 136, 100, 99, 100, 98, 99, 99, 99, 100, 102, 102, 103, 103, 103, 126, 134, 100, 99, 98, 98, 99, 100, 98, 100, 101, 102, 103, 130, 135, 100, 100, 99, 99, 99, 99, 100, 101, 101, 102, 103, 103, 103, 143, 123, 100, 97, 98, 98, 99, 100, 100, 100, 101, 101, 103, 125, 141, 99, 98, 98, 98, 99, 100, 101, 101, 101, 102, 103, 103, 101, 168, 104, 99, 98, 98, 99, 99, 101, 100, 101, 99, 102, 103, 114, 158, 100, 97, 100, 99, 100, 100, 101, 101, 100, 102, 102, 103, 111, 164, 101, 98, 100, 99, 99, 101, 103, 100, 99, 100, 103, 103, 104, 177, 103, 100, 99, 99, 99, 101, 101, 101, 100, 102, 103, 104, 141, 129, 98, 98, 99, 99, 101, 101, 101, 99, 100, 99, 102, 104, 104, 161, 120, 101, 97, 100, 98, 101, 102, 99, 101, 103, 104, 108, 183, 101, 98, 98, 97, 99, 102, 101, 101, 100, 99, 100, 102, 105, 103, 118, 167, 102, 100, 99, 98, 100, 101, 101, 102, 102, 103, 167, 123, 98, 99, 98, 98, 99, 99, 101, 101, 100, 100, 100, 103, 103, 104, 99, 147, 148, 102, 99, 100, 100, 103, 101, 101, 100, 149, 143, 98, 98, 98, 99, 99, 99, 100, 101, 103, 101, 102, 102, 103, 102, 103, 103, 100, 156, 150, 102, 101, 101, 101, 97, 109, 170, 147, 99, 98, 99, 99, 99, 100, 101, 102, 101, 103, 100, 100, 101, 101, 102, 100, 102, 102, 101, 127, 189, 157, 137, 132, 166, 182, 116, 96, 98, 99, 100, 100, 100, 101, 101, 102, 103, 104, 100, 99, 101, 100, 101, 100, 101, 101, 102, 101, 98, 115, 129, 127, 109, 98, 99, 99, 101, 101, 101, 101, 100, 102, 102, 102, 103, 104, 99, 98, 100, 100, 99, 100, 101, 100, 102, 102, 101, 103, 102, 102, 102, 101, 100, 99, 101, 102, 101, 101, 103, 103, 102, 103, 103, 104, 100, 99, 100, 101, 100, 101, 101, 101, 101, 101, 101, 101, 101, 101, 102, 100, 101, 103, 102, 102, 103, 104, 102, 102, 103, 104, 103, 103, 101, 100, 102, 102, 101, 100, 102, 102, 102, 103, 102, 100, 102, 101, 102, 101, 103, 103, 103, 103, 104, 103, 103, 103, 104, 105, 104, 103, 100, 101, 101, 102, 103, 102, 102, 103, 102, 101, 102, 102, 102, 102, 101, 102, 103, 103, 103, 104, 103, 105, 104, 104, 104, 105, 105, 105, 100, 101, 102, 103, 102, 102, 103, 103, 102, 101, 103, 102, 103, 102, 102, 103, 102, 104, 104, 104, 105, 106, 105, 104, 104, 105, 105, 106]
zero = np.array(zero)/256
pr = clf.predict([zero])
print(pr)