In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
from typing import List, Set
import numpy as np
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score
from numpy_nn import *

In [None]:
Transform = transforms.ToTensor()
data_train = datasets.MNIST(root="mnist", train=True, download=True, transform=Transform)
data_test = datasets.MNIST(root="mnist", train=False, download=True, transform=Transform)

In [None]:
def make_reduced_dataset(dataset: Dataset, numbers: List[int]):
    num2idx = {num: idx for idx, num in enumerate(numbers)}
    num_set = set(numbers)
    for img, y_true in dataset:
        if y_true not in num_set:
            continue
        yield img, num2idx[y_true]
        

class ReducedMNIST(Dataset):
    def __init__(self):
        self.__imgs = []
        self.__y_true = []
    
    def add(self, img, y_true):
        self.__imgs.append(img)
        self.__y_true.append(y_true)

    def __len__(self):
        return len(self.__y_true)

    def __getitem__(self, idx):
        return self.__imgs[idx], self.__y_true[idx]

rdata_train = ReducedMNIST()
rdata_test = ReducedMNIST()
for img, y_true in make_reduced_dataset(data_train, [6, 9]):
    rdata_train.add(img, y_true)
for img, y_true in make_reduced_dataset(data_test, [6, 9]):
    rdata_test.add(img, y_true)

In [None]:
train_ldr = DataLoader(rdata_train, batch_size=100, shuffle=True)
test_ldr = DataLoader(rdata_test, batch_size=100, shuffle=True)

In [None]:
nn = Network()
nn.add_layer(Linear(in_features=28*28, out_features=128))
nn.add_layer(Sigmoid())
nn.add_layer(Linear(in_features=128, out_features=84))
nn.add_layer(Sigmoid())
nn.add_layer(Linear(in_features=84, out_features=1))
nn.add_layer(Sigmoid())

optimizer = SGD(lr=1e-2, momentum=0.9, parameters=nn.parameters)

In [None]:
for epoch in range(1, 25):
    for batch_img, batch_y_true in tqdm(train_ldr):
        batch_img = torch.flatten(batch_img, start_dim=1)
        batch_img = batch_img.numpy()
        optimizer.zero_grad()
        batch_predictions = nn.forward(batch_img).flatten()
        batch_y_true = batch_y_true.float().numpy()
        loss = LogLoss()
        loss_value = loss.loss(batch_predictions, batch_y_true)
        # grad error
        error = loss.backward()
        # !!! Discuss this
        # error = (batch_y_true - batch_predictions)[:, np.newaxis]
        nn.backward(error)
        optimizer.step()
    y_pred = []
    y_true = []
    for batch_img, batch_y_true in test_ldr:
        batch_img = torch.flatten(batch_img, start_dim=1)
        batch_img = batch_img.numpy()
        y_true.extend(batch_y_true.numpy().tolist())
        batch_predictions = nn.forward(batch_img).flatten()
        y_pred.extend(batch_predictions.tolist())
    print(f"ROC AUC = {roc_auc_score(y_true, y_pred)}")
    print(f"Accuracy = {accuracy_score(y_true, np.array(y_pred).round())}")
