In [5]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import PIL.Image as Image

Get data

In [6]:
IMAGE_SIZE = (28, 28)
EPOCH = 3
LEARNING_RATE = 0.001

# Import data
data = pd.read_csv(r"../digit-recognizer/train.csv")
data_test = pd.read_csv(r"../digit-recognizer/test.csv")    

# Get labels and image array from data
labels : np.ndarray = data.values[:, 0]
images : np.ndarray = data.values[:, 1:].astype('uint8')

images_test : np.ndarray = data_test.values.astype('uint8')

Process data

In [7]:
class Model(nn.Module):
    def __init__(self, input, hidden, output):
        super().__init__()
        self.layer1 = nn.Linear(input, hidden)
        self.layer2 = nn.Linear(hidden, output)

    def forward(self, x):
        x = self.layer1(x)
        x = F.sigmoid(self.layer2(x))
        # torch.sigmoid(x)

        return x

In [8]:
model = Model(IMAGE_SIZE[0] * IMAGE_SIZE[1], 1024, 10)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
for i in range(EPOCH):
    wrong = 0
    for j in range(len(images)):
        optimizer.zero_grad()
        res = model(torch.tensor(images[j], dtype=torch.float32))
        # print(res)
        if (int(torch.argmax(res)) != labels[j]):
            wrong += 1

        # Create label for calculating loss
        label_tensor = np.zeros(10)
        label_tensor[labels[j]] = 1
        label_tensor = torch.tensor(label_tensor, dtype=torch.float32)
        # print(label_tensor)
        loss = criterion(res, label_tensor)
        loss.backward()
        optimizer.step()
        
        if ((j+1) % 2000 == 0):
            print(f"Data {j+1}: Wrong = {wrong}, Accuracy: {100-wrong/j*100}%")
        
    print(f"Epoch: {i} --> Wrong: {wrong}, Accuracy: {100-wrong / labels.size * 100}%\n")    

# Test model
wrong = 0
for i in range(100):
    res = torch.argmax(model(torch.tensor(images_test[i], dtype=torch.float32)))
    img = Image.fromarray(images_test[i].reshape(IMAGE_SIZE))
    print(f"Image {i+1}: {res}")
    img.save(f"./test/{str(i+1)}_ans={res}.jpg")




Data 2000: Wrong = 1511, Accuracy: 24.41220610305153%
Data 4000: Wrong = 2818, Accuracy: 29.532383095773937%
Data 6000: Wrong = 4142, Accuracy: 30.955159193198867%
Data 8000: Wrong = 5429, Accuracy: 32.12901612701587%
Data 10000: Wrong = 6667, Accuracy: 33.323332333233324%
Data 12000: Wrong = 7850, Accuracy: 34.57788149012417%
Data 14000: Wrong = 9074, Accuracy: 35.18108436316881%
Data 16000: Wrong = 10295, Accuracy: 35.65222826426651%
Data 18000: Wrong = 11482, Accuracy: 36.20756708706039%
Data 20000: Wrong = 12592, Accuracy: 37.036851842592135%
Data 22000: Wrong = 13755, Accuracy: 37.4744306559389%
Data 24000: Wrong = 14896, Accuracy: 37.9307471144631%
Data 26000: Wrong = 16035, Accuracy: 38.32455094426709%


KeyboardInterrupt: 