In [30]:
import numpy as np

import mapd
import torchvision
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds

In [31]:
MNIST_ROOT = "data"
torchvision.datasets.MNIST(root=MNIST_ROOT, download=True)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train

In [32]:
from torchvision.datasets import MNIST

In [33]:
from torch.utils.data import Dataset


class IDXDataset(Dataset):
    def __init__(self, dataset: Dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index], index

In [34]:
from abc import ABCMeta, abstractmethod

import torch
from lightning import LightningModule
from typing import Any, Optional, List, Dict
from torch import Tensor
from torch.utils.data import DataLoader


class MAPDModule(LightningModule, metaclass=ABCMeta):
    mapd_current_indices_: torch.Tensor = torch.empty(0, dtype=torch.int64)
    mapd_indices_: List[Tensor] = []

    mapd_losses_: Tensor = torch.empty(0, dtype=torch.float32)
    mapd_proxy_metrics_: List[Tensor] = []

    as_proxies_: bool = False

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mapd_current_indices_ = torch.empty(0, dtype=torch.int64)
        self.mapd_losses_ = torch.empty(0, dtype=torch.float32)
        self.mapd_proxy_metrics_ = []
        self.mapd_indices_ = []
        self.as_proxies_ = False

    @classmethod
    @abstractmethod
    def batch_loss(self, logits: Any, y: Any) -> Tensor:
        raise NotImplemented("batch_loss method not implemented")

    def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
        batch, indices = batch
        self.mapd_current_indices_ = indices
        return batch

    def batch_proxy_metric(self, logits: Any, y: Any) -> Tensor:
        softmax = torch.softmax(logits, dim=1)

        # softmax confidence on correct class
        return softmax[torch.arange(softmax.shape[0]), y]

    def _mapd_log_proxy(self, logits: Any, y: Any):
        proxy_metrics = self.batch_proxy_metric(logits, y).detach()
        self.mapd_proxy_metrics_.append(proxy_metrics)

    def _mapd_log_probes(self, logits: Any, y: Any):
        pass

    def mapd_log(self, logits: Any, y: Any) -> "MAPDModule":
        self.mapd_indices_.append(self.mapd_current_indices_)

        if self.as_proxies_:
            self._mapd_log_proxy(logits, y)
            return self

        self._mapd_log_probes(logits, y)

        return self

    def _reset_mapd_attrs(self) -> "MAPDModule":
        self.mapd_current_indices_ = torch.empty(0, dtype=torch.int64)
        self.mapd_losses_ = torch.empty(0, dtype=torch.float32)
        self.mapd_proxy_metrics_ = []
        self.mapd_indices_ = []

        return self

    def as_proxies(self) -> "MAPDModule":
        self.as_proxies_ = True

        return self

    def _write_proxies(self):
        sample_indices = torch.cat(self.mapd_indices_).cpu().numpy()
        sample_proxy_metrics = torch.cat(self.mapd_proxy_metrics_).cpu().numpy()
        epochs = np.full(sample_indices.shape, self.current_epoch)

        table = pa.table(
            [
                pa.array(sample_indices),
                pa.array(sample_proxy_metrics),
                pa.array(epochs),
            ],
            names=["sample_index", "proxy_metric", "epoch"],
        )

        ds.write_dataset(table, "proxies",
                         partitioning=ds.partitioning(pa.schema([("epoch", pa.int64())]), flavor="filename"),
                         existing_data_behavior="overwrite_or_ignore", format="parquet")

    def on_train_epoch_end(self) -> None:
        self._write_proxies()

        self._reset_mapd_attrs()
        pass

    def on_validation_epoch_end(self):
        # Run loss logging for probes
        #dataloader = self.probe_suite_dataloader()
        pass

In [35]:
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


model = Net()

In [36]:
import lightning as L
from torch.nn import functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch


class ResNet18(MAPDModule):
    def __init__(
            self,
            max_epochs: int = 10,
            lr: float = 0.05,
            momentum: float = 0.9,
            weight_decay: float = 0.0005
    ):
        super().__init__()
        self.model = model

        self.max_epochs = max_epochs
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.save_hyperparameters(ignore=["model"])

    def mapd_settings(self):
        return {}

    def forward(self, x):
        return self.model(x)

    def batch_loss(self, logits, y) -> torch.Tensor:
        return F.cross_entropy(logits, y, reduction="none")

    def training_step(self, batch, batch_idx):
        x, y = batch

        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.mapd_log(logits, y)

        return loss

    def configure_optimizers(self):
        optimizer = SGD(
            self.parameters(),
            lr=self.lr
        )

        return {"optimizer": optimizer}

In [37]:
from torch.utils.data import random_split, DataLoader
from torchvision import transforms


class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "data", batch_size: int = 32, num_workers=16):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        self.num_workers = num_workers

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
        self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
        mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        self.mnist_train = IDXDataset(self.mnist_train)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=8)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=8)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=8)

In [38]:
module = ResNet18()
dm = MNISTDataModule(batch_size=512, num_workers=16)

torch.set_float32_matmul_precision('medium')

trainer_proxy = L.Trainer(accelerator="gpu", max_epochs=20)
trainer_probes = L.Trainer(accelerator="cpu", max_epochs=1)

# Proxy
trainer_proxy.fit(module.as_proxies(), datamodule=dm)

# Probes
#trainer_probes.fit(module.as_probes(), datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 21.8 K
-------------------------------
21.8 K    Trainable params
0         Non-trainable params
21.8 K    Total params
0.087     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.
