In [47]:
from torch.utils.data import DataLoader

import mapd
import torchvision
from torchvision import transforms

from mapd.probes.make_probe_suites import make_probe_suites
from mapd.probes.utils.idx_dataset import IDXDataset
from mapd.utils.make_dataloaders import make_dataloaders
from mapd.classifiers.make_mapd_classifier import make_mapd_classifier
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [61]:
MNIST_ROOT = "data-emnist"
torchvision.datasets.EMNIST(root=MNIST_ROOT, download=True, split="letters")

Downloading https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip to data-emnist/EMNIST/raw/gzip.zip


100%|██████████| 561753746/561753746 [00:45<00:00, 12385864.15it/s]


Extracting data-emnist/EMNIST/raw/gzip.zip to data-emnist/EMNIST/raw


Dataset EMNIST
    Number of datapoints: 124800
    Root location: data-emnist
    Split: Train

In [62]:
from torchvision.datasets import MNIST

In [78]:
from torch import nn


class Net(nn.Module):
    def __init__(self, num_labels: int = 10):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, num_labels)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


model = Net(num_labels=26)

In [88]:
import numpy as np
import lightning as L
from torch.nn import functional as F
from torch.optim import SGD
import torch


class MNISTModule(mapd.MAPDModule):
    def __init__(
            self,
            max_epochs: int = 10,
            lr: float = 0.05,
            momentum: float = 0.9,
            weight_decay: float = 0.0005
    ):
        super().__init__()
        self.model = model

        self.max_epochs = max_epochs
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.save_hyperparameters(ignore=["model"])

    def mapd_settings(self):
        return {
            "proxy_metric": "loss",
            "proxy_metric_direction": "minimize",
        }

    def forward(self, x):
        return self.model(x)

    def batch_loss(self, logits, y) -> torch.Tensor:
        return F.cross_entropy(logits, y, reduction="none")

    def batch_proxy_metric(self, logits, y) -> torch.Tensor:
        return -self.batch_loss(logits, y)

    def training_step(self, batch, batch_idx):
        x, y = batch

        logits = self.forward(x)
        loss = self.batch_loss(logits, y).mean()
        self.mapd_log(logits, y)

        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx: int = 0):
        x, y = batch

        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.mapd_log(logits, y)

        return loss

    def configure_optimizers(self):
        optimizer = SGD(
            self.parameters(),
            lr=self.lr
        )

        return {"optimizer": optimizer}

In [None]:
from torch.utils.data import random_split

module = MNISTModule()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
target_transform = transforms.Compose([transforms.Lambda(lambda x: x - 1)])
mnist_test = torchvision.datasets.EMNIST(MNIST_ROOT, train=False, split="letters", transform=transform, target_transform=target_transform)
mnist_predict = torchvision.datasets.EMNIST(MNIST_ROOT, train=False, split="letters", transform=transform, target_transform=target_transform)
mnist_full = torchvision.datasets.EMNIST(MNIST_ROOT, train=True, split="letters", transform=transform, target_transform=target_transform)
mnist_train, mnist_val = random_split(mnist_full, [0.8, 0.2])
mnist_train = IDXDataset(mnist_train)

dl = DataLoader(mnist_train, batch_size=512, shuffle=True, num_workers=4, prefetch_factor=2)

torch.set_float32_matmul_precision('medium')

trainer_proxy = L.Trainer(accelerator="cpu", max_epochs=50)
trainer_probes = L.Trainer(accelerator="cpu", max_epochs=50)

# Proxy
trainer_proxy.fit(module.as_proxies(), train_dataloaders=dl)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 22.7 K
-------------------------------
22.7 K    Trainable params
0         Non-trainable params
22.7 K    Total params
0.091     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [None]:
probe_suite_ds = make_probe_suites(mnist_train, 10, "proxies", num_probes=500)

In [None]:
dl_probes = DataLoader(mnist_train, batch_size=512, shuffle=True, num_workers=4, prefetch_factor=2)

val_dataloader = DataLoader(IDXDataset(mnist_val), batch_size=512, shuffle=True, num_workers=4, prefetch_factor=2)
val_dataloaders = make_dataloaders([val_dataloader], probe_suite_ds)

trainer_probes.fit(module.as_probes(probe_suite_ds), train_dataloaders=dl_probes, val_dataloaders=val_dataloaders)

In [None]:
clf, label_encoder = make_mapd_classifier("probes", probe_suite_ds, clf="xgboost_rf")

In [None]:
from mapd.classifiers.make_predictions import make_predictions

preds = make_predictions("probes", clf, label_encoder)

In [None]:
from mapd.visualization.surface_predictions import make_surface_predictions
import matplotlib.pyplot as plt

fig = make_surface_predictions(preds, mnist_train, probe_suite="typical")
plt.show(fig)
fig = make_surface_predictions(preds, mnist_train, probe_suite="atypical")
plt.show(fig)

In [None]:
# Print how many of each probe suite is predicted
from collections import Counter

counts = Counter([ps for ps, _ in preds.values()])

In [None]:
counts

In [60]:
len(preds)

55000