# Reframe the MNIST dataset as a binary classification problem
We want to train a binary classifier on training data containing only 0s and 1s.
Then introduce varying amounts of out-of-sample examples (2-9 digits).
Then we can see the effect on accuracy and distance.

## 1. Brute force the existing dataloader

In [6]:
import numpy as np
import pandas as pd

from xai.data_handlers.mnist import load_mnist

In [11]:
# Load corpus and test inputs
batch_size = 1024
# batch_size_test = 1000

corpus_loader = load_mnist(subset_size=8192, train=True, batch_size=batch_size) # MNIST train loader
test_loader = load_mnist(subset_size=1024, train=False, batch_size=batch_size) # MNIST test loader
corpus_inputs, corpus_labels = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, test_labels = next(iter(test_loader)) # A set of inputs to explain

If we use a large enough batch size, are the examples boradly uniformly distributed?

In [25]:
idx = 1
for corpus_inputs, corpus_labels in corpus_loader:
    digit_to_count = 1
    count_df = pd.DataFrame(corpus_labels).value_counts()
    print(f"{idx}: {count_df.loc[digit_to_count].iloc[0]}")
    idx += 1
    

1: 126
2: 128
3: 121
4: 110
5: 133
6: 104
7: 113
8: 105


In [26]:
idx = 1
for corpus_inputs, corpus_labels in corpus_loader:
    digit_to_count = 0
    count_df = pd.DataFrame(corpus_labels).value_counts()
    print(f"{idx}: {count_df.loc[digit_to_count].iloc[0]}")
    idx += 1
    

1: 89
2: 98
3: 105
4: 103
5: 111
6: 97
7: 84
8: 92


In a pinch, we could load much larger batches and just filter down and hope there are enough

Are the corpuses the same over reruns? No

In [30]:
corpus_loader_a = load_mnist(subset_size=8192, train=True, batch_size=batch_size) # MNIST train loader
corpus_inputs_a, corpus_labels_a = next(iter(corpus_loader_a)) # A tensor of corpus inputs

corpus_loader_b = load_mnist(subset_size=8192, train=True, batch_size=batch_size) # MNIST train loader
corpus_inputs_b, corpus_labels_b = next(iter(corpus_loader_b)) # A tensor of corpus inputs


In [36]:
torch.equal(corpus_inputs_a, corpus_inputs_b)

False

In [37]:
torch.equal(corpus_labels_a, corpus_labels_b)

False

## 2. Use the torch data set

In [27]:
import torch
from torch.utils.data import DataLoader
import torchvision

from xai.constants import DATA_DIR


DEFAULT_MNIST_NORMALIZATION = (0.1307,), (0.3081,)

In [28]:
torchvision.datasets.MNIST?

[0;31mInit signature:[0m
[0mtorchvision[0m[0;34m.[0m[0mdatasets[0m[0;34m.[0m[0mMNIST[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mroot[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtrain[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtransform[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mCallable[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtarget_transform[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mCallable[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdownload[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

Args:
    root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
        and  `

In [41]:
train = True
subset_size=None
shuffle = True
data_dir=DATA_DIR

In [92]:
    dataset = torchvision.datasets.MNIST(
        data_dir,
        train=train,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(*DEFAULT_MNIST_NORMALIZATION),
        ])
    )

In [93]:
len(dataset)

60000

In [94]:
# get 0s and 1s from dataset

In [110]:
mask_zeros = dataset.targets == 0
mask_ones = dataset.targets == 1
mask_zeros_ones = mask_zeros | mask_ones
idx_zeros_ones = mask_zeros_ones.nonzero().squeeze()

idx_zeros_ones

tensor([    1,     3,     6,  ..., 59984, 59987, 59994])

In [131]:
sum(mask_zeros)

tensor(5923)

In [128]:
import functools 
import operator

In [129]:
functools.reduce(operator.or_, [mask_zeros, mask_ones])

tensor([False,  True, False,  ..., False, False, False])

In [123]:
any(

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [112]:
ss = torch.utils.data.Subset(dataset, idx_zeros_ones)
ss

<torch.utils.data.dataset.Subset at 0x7f347467c310>

In [113]:
batch_size

1024

In [114]:
dl = DataLoader(ss, batch_size=64, shuffle=shuffle)
dl

<torch.utils.data.dataloader.DataLoader at 0x7f347467f410>

In [115]:
dl_inputs, dl_labels = next(iter(dl))

In [117]:
dl_inputs.shape

torch.Size([64, 1, 28, 28])

In [118]:
dataset.test_labels



tensor([5, 0, 4,  ..., 5, 6, 8])

In [119]:
dataset.train_labels



tensor([5, 0, 4,  ..., 5, 6, 8])

In [120]:
dataset.targets

tensor([5, 0, 4,  ..., 5, 6, 8])

In [58]:
corpus_loader.dataset.dataset.train_labels



tensor([5, 0, 4,  ..., 5, 6, 8])

In [61]:
len(corpus_loader.dataset.dataset)

60000

In [62]:
len(corpus_loader)

8

In [57]:
ss.dataset.tr

AttributeError: 'MNIST' object has no attribute 'tr'

In [48]:
dataset.train_labels.shape



torch.Size([60000])

In [45]:
dataset.test_labels.shape

torch.Size([60000])

In [None]:
labels_0_1

In [47]:
torch.utils.data.Subset?

[0;31mInit signature:[0m
[0mtorch[0m[0;34m.[0m[0mutils[0m[0;34m.[0m[0mdata[0m[0;34m.[0m[0mSubset[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mdataset[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mutils[0m[0;34m.[0m[0mdata[0m[0;34m.[0m[0mdataset[0m[0;34m.[0m[0mDataset[0m[0;34m[[0m[0;34m+[0m[0mT_co[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mindices[0m[0;34m:[0m [0mSequence[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Subset of a dataset at specified indices.

Args:
    dataset (Dataset): The whole Dataset
    indices (sequence): Indices in the whole set selected for subset
[0;31mFile:[0m           ~/anaconda3/envs/xai/lib/python3.11/site-packages/torch/utils/data/dataset.py
[0;31mType:[0m           type
[0;31mSubclasses:[0m     