In [None]:
%load_ext autoreload
%autoreload 2

from PIL import Image
from torch import tensor
import torch
import src.utils
from pathlib import Path
from IPython.display import display
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

# src.utils.download_mnist("data")

from torchvision import transforms
convert_tensor = transforms.ToTensor()
image_from_tensor = transforms.ToPILImage()


def convert(a):
    return 0 if a == 7 else 1

test_labels = src.utils.read_labels('data/t10k-labels-idx1-ubyte', 'test')
train_labels = src.utils.read_labels('data/train-labels-idx1-ubyte', 'train')
keys = list(train_labels)
test_keys = list(test_labels)

train_x = torch.stack([convert_tensor(Image.open(Path("data") / key)) for key in keys if train_labels[key] in [3, 7]]).squeeze()
train_x = train_x.view(-1, 28*28)
train_y = tensor([train_labels[key] for key in keys if train_labels[key] in [3,7]]).squeeze()
train_y = tensor([ convert(i) for i in train_y])


test_x = torch.stack([convert_tensor(Image.open(Path("data") / key)) for key in test_keys if test_labels[key] in [3, 7]]).squeeze()
test_x = test_x.view(-1, 28*28)
test_y = tensor([test_labels[key] for key in test_keys if test_labels[key] in [3,7]]).squeeze()
test_y = tensor([ convert(i) for i in test_y])

dset = list(zip(train_x,train_y))



# display(image_from_tensor(dset[7000][0].view(28,28)), dset[7000][1])

threes_mean = torch.stack([convert_tensor(Image.open(Path("data") / key)) for key in keys if train_labels[key] in [3]]).squeeze()
threes_mean = threes_mean.view(-1, 28*28)
threes_mean = threes_mean.mean(0)

sevens_mean = torch.stack([convert_tensor(Image.open(Path("data") / key)) for key in keys if train_labels[key] in [7]]).squeeze().view(-1, 28*28).mean(0)

# display(image_from_tensor(threes_mean.view(28,28)))
# display(image_from_tensor(sevens_mean.view(28,28)))

def mnist_distance(a, b): 
    return (a-b).abs().mean()

# print(mnist_distance(dset[7000][0], threes_mean), mnist_distance(dset[3000][0], threes_mean))
# print(mnist_distance(dset[3000][0], threes_mean))

def is_3(a):
    return mnist_distance(a, threes_mean) < mnist_distance(a, sevens_mean)

# print(is_3(dset[7000][0]), is_3(dset[3000][0]))

def mse(a, b):
    return ((a - b)**2).mean().sqrt()

weights = torch.randn(28*28).requires_grad_()
bias = torch.randn(1).requires_grad_()

def mnist_loss(preds, truths):
    s = preds.sigmoid()
    return torch.where(truths == 1, 1 - s, s).mean()

dataloader = DataLoader(dset, batch_size=10, shuffle=True)
lr = 0.01

def accuracy(weights, bias):
    preds = (test_x @ weights + bias).sigmoid()
    return ((preds > 0.5) == test_y).float().mean()
            
def train_epoch(dl, weights, bias, pbar):
    iteration = 0
    for _, batch in enumerate(dl):
        x, y = batch
        iteration += 1

        preds = x @ weights + bias
        loss = mnist_loss(preds, y)
        loss.backward()
        with torch.no_grad():
            weights -= lr * weights.grad
            bias -= lr * bias.grad
            weights.grad.zero_()
            bias.grad.zero_()
    #pbar.write(f'loss {loss}')

print(f'accuracy before {accuracy(weights, bias):.4f}')
    
with tqdm(total=20, unit='e') as pbar:
    for i in range(20):
        train_epoch(dataloader, weights, bias, pbar)
        pbar.update(1)
    
print(f'accuracy after {accuracy(weights, bias):.4f}')

def show_batch(dset):
    dataloader = DataLoader(dset, batch_size=10, shuffle=True)

    x,y = next(iter(dataloader))
    preds = x @ weights + bias
    x = x.view(-1, 28,28)
    display(*map(image_from_tensor, x))
    y, preds
