In [20]:
from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784', as_frame=False)
oldX, oldY = mnist.data, mnist.target

In [21]:
import torch
import numpy as np

mask = (oldY == '0') | (oldY == '1')
X, Y = oldX[mask], oldY[mask].astype(np.int64)

split = int(len(X) * 6 / 7)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

trainXhalf = torch.from_numpy(X[:split]).float().to(device)
trainYhalf = torch.from_numpy(Y[:split]).float().view(-1, 1).to(device)
testXhalf = torch.from_numpy(X[split:]).float().to(device)
testYhalf = torch.from_numpy(Y[split:]).float().view(-1, 1).to(device)

trainPerm = torch.randperm(len(trainXhalf) * 2)
testPerm = torch.randperm(len(testXhalf) * 2)
trainX = torch.vstack((trainXhalf, 1 - trainXhalf))[trainPerm]
testX = torch.vstack((testXhalf, 1 - testXhalf))[testPerm]
trainY = torch.vstack((trainYhalf, trainYhalf))[trainPerm]
testY = torch.vstack((testYhalf, testYhalf))[testPerm]

trainX /= 255
testX /= 255

trainX.shape, trainY.shape, testX.shape, testY.shape

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
weight = torch.randn(len(trainX[0]), 1, requires_grad=True, device=device)
bias = torch.randn(1, requires_grad=True, device=device)
criterion = torch.nn.BCEWithLogitsLoss()
eta, epoch = 0.1, 15000
for i in range(epoch):
    loss = criterion(trainX @ weight + bias, trainY)
    loss.backward()
    with torch.no_grad():
        weight -= eta * weight.grad
        bias -= eta * bias.grad
    weight.grad.zero_()
    bias.grad.zero_()

ValueError: Target size (torch.Size([12668, 1])) must be the same as input size (torch.Size([25336, 1]))

In [None]:
correct = 0

for i in range(len(testX)):
    correct += torch.round(torch.sigmoid(testX[i] @ weight + bias)).item() == testY[i].item()

print(correct / len(testX))

0.9990530303030303


In [None]:
torch.save({'weight': weight, 'bias': bias}, 'saved_params/bin_ocr.pth')