In [352]:
import numpy as np

import mapd
import torchvision
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds

In [353]:
MNIST_ROOT = "data"
torchvision.datasets.MNIST(root=MNIST_ROOT, download=True)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train

In [354]:
from torchvision.datasets import MNIST

In [355]:
from torch.utils.data import Dataset


class IDXDataset(Dataset):
    def __init__(self, dataset: Dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index], index

In [356]:
from lightning.pytorch.trainer.states import TrainerFn
from copy import deepcopy
from abc import ABCMeta, abstractmethod

import torch
from lightning import LightningModule
from typing import Any, Optional, List, Dict
from torch import Tensor
from torch.utils.data import DataLoader


class MAPDModule(LightningModule, metaclass=ABCMeta):
    mapd_current_indices_: torch.Tensor = torch.empty(0, dtype=torch.int64)
    mapd_indices_: List[Tensor] = []

    mapd_losses_: List[Tensor] = []
    mapd_stages_: List[str] = []
    mapd_proxy_metrics_: List[Tensor] = []

    as_proxies_: bool = False
    as_probes_: bool = False

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mapd_current_indices_ = torch.empty(0, dtype=torch.int64)
        self.mapd_losses_ = []
        self.mapd_stages_ = []
        self.mapd_proxy_metrics_ = []
        self.mapd_indices_ = []
        self.as_proxies_ = False
        self.as_probes_ = False
        self.is_val_probes_ = False

        self.probes_dataset = None

    @classmethod
    @abstractmethod
    def batch_loss(self, logits: Any, y: Any) -> Tensor:
        raise NotImplemented("batch_loss method not implemented")

    def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
        batch, indices = batch
        self.mapd_current_indices_ = indices
        return batch

    def batch_proxy_metric(self, logits: Any, y: Any) -> Tensor:
        softmax = torch.softmax(logits, dim=1)

        # softmax confidence on correct class
        return softmax[torch.arange(softmax.shape[0]), y]

    def _mapd_log_proxy(self, logits: Any, y: Any):
        proxy_metrics = self.batch_proxy_metric(logits, y).detach()
        self.mapd_proxy_metrics_.append(proxy_metrics)

    def _mapd_log_probes(self, logits: Any, y: Any):
        batch_losses = self.batch_loss(logits, y).detach()
        self.mapd_losses_.append(batch_losses)
        self.mapd_stages_ += ["train" if self.training else "val"] * batch_losses.shape[0]

    def mapd_log(self, logits: Any, y: Any) -> "MAPDModule":
        if not self.trainer.sanity_checking:
            self.mapd_indices_.append(self.mapd_current_indices_)

            if self.as_proxies_:
                self._mapd_log_proxy(logits, y)
                return self

            if (self.training or (not self.training and self.is_val_probes_)):
                self._mapd_log_probes(logits, y)

        return self

    def _reset_mapd_attrs(self) -> "MAPDModule":
        self.mapd_current_indices_ = torch.empty(0, dtype=torch.int64)
        self.mapd_losses_ = []
        self.mapd_proxy_metrics_ = []
        self.mapd_indices_ = []
        self.mapd_stages_ = []

        return self

    def as_proxies(self) -> "MAPDModule":
        self.as_proxies_ = True
        self.as_probes_ = False

        return self

    def as_probes(self, probes_dataset: Dataset) -> "MAPDModule":
        self.as_probes_ = True
        self.as_proxies_ = False

        self.probes_dataset = deepcopy(probes_dataset)
        self.probes_dataset.only_probes = True

        return self

    def _write_proxies(self):
        sample_indices = torch.cat(self.mapd_indices_).cpu().numpy()
        sample_proxy_metrics = torch.cat(self.mapd_proxy_metrics_).cpu().numpy()
        epochs = np.full(sample_indices.shape, self.current_epoch)

        table = pa.table(
            [
                pa.array(sample_indices),
                pa.array(sample_proxy_metrics),
                pa.array(epochs),
            ],
            names=["sample_index", "proxy_metric", "epoch"],
        )

        ds.write_dataset(table, "proxies",
                         partitioning=ds.partitioning(pa.schema([("epoch", pa.int64())]), flavor="filename"),
                         existing_data_behavior="overwrite_or_ignore", format="parquet")

    def _write_probes(self):
        sample_indices = torch.cat(self.mapd_indices_).cpu().numpy()
        sample_losses = torch.cat(self.mapd_losses_).cpu().numpy()
        epochs = np.full(sample_indices.shape, self.current_epoch)

        table = pa.table(
            [
                pa.array(sample_indices),
                pa.array(sample_losses),
                pa.array(epochs),
                pa.array(self.mapd_stages_)
            ],
            names=["sample_index", "loss", "epoch", "stage"],
        )

        ds.write_dataset(table, "probes",
                         partitioning=ds.partitioning(pa.schema([("epoch", pa.int64()), ("stage", pa.string())]), flavor="filename"),
                         existing_data_behavior="overwrite_or_ignore", format="parquet")

    def on_train_epoch_end(self) -> None:
        if self.as_proxies_:
            self._write_proxies()

        if self.as_probes_:
            self._write_probes()

        self._reset_mapd_attrs()


    def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
        if self.as_proxies_:
            return

        self.is_val_probes_ = dataloader_idx == 0


In [357]:
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


model = Net()

In [358]:
import lightning as L
from torch.nn import functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch


class ResNet18(MAPDModule):
    def __init__(
            self,
            max_epochs: int = 10,
            lr: float = 0.05,
            momentum: float = 0.9,
            weight_decay: float = 0.0005
    ):
        super().__init__()
        self.model = model

        self.max_epochs = max_epochs
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.save_hyperparameters(ignore=["model"])

    def mapd_settings(self):
        return {}

    def forward(self, x):
        return self.model(x)

    def batch_loss(self, logits, y) -> torch.Tensor:
        return F.cross_entropy(logits, y, reduction="none")

    def training_step(self, batch, batch_idx):
        x, y = batch

        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.mapd_log(logits, y)

        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx: int = 0):
        x, y = batch

        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)

        if dataloader_idx == 0:
            self.mapd_log(logits, y)

        return loss

    def configure_optimizers(self):
        optimizer = SGD(
            self.parameters(),
            lr=self.lr
        )

        return {"optimizer": optimizer}

In [359]:
from torch.utils.data import random_split, DataLoader
from torchvision import transforms


class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "data", batch_size: int = 32, num_workers=16):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        self.num_workers = num_workers

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
        self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
        mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        self.mnist_train = IDXDataset(self.mnist_train)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=8)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=8)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=8)

In [360]:
module = ResNet18()
dm = MNISTDataModule(batch_size=512, num_workers=16)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_test = MNIST(MNIST_ROOT, train=False, transform=transform)
mnist_predict = MNIST(MNIST_ROOT, train=False, transform=transform)
mnist_full = MNIST(MNIST_ROOT, train=True, transform=transform)
mnist_train, mnist_val = random_split(mnist_full, [55000, 5000])
mnist_train = IDXDataset(mnist_train)

dl = DataLoader(mnist_train, batch_size=512, shuffle=True, num_workers=16, prefetch_factor=8)

torch.set_float32_matmul_precision('medium')

trainer_proxy = L.Trainer(accelerator="gpu", max_epochs=100)
trainer_probes = L.Trainer(accelerator="gpu", max_epochs=100)

# Proxy
trainer_proxy.fit(module.as_proxies(), train_dataloaders=dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 21.8 K
-------------------------------
21.8 K    Trainable params
0         Non-trainable params
21.8 K    Total params
0.087     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [361]:
import os
from typing import Dict


class ProxyCalculator:
    def __init__(self, proxy_dataset_path: os.PathLike, proxy_name: str):
        self.proxy_df = None
        self.proxy_dataset_path = proxy_dataset_path
        self.proxy_dataset = ds.dataset(proxy_dataset_path, format="parquet", partitioning=ds.partitioning(pa.schema([("epoch", pa.int64())]), flavor="filename"))
        self.proxy_name = proxy_name

    def load(self, columns: list = None):
        self.proxy_df = self.proxy_dataset.to_table(columns=columns).to_pandas()
        self.proxy_df["epoch"] = self.proxy_df["epoch"].astype(int)

        return self.proxy_df

    def calculate_proxy_scores(self) -> Dict[int, float]:
        scores = self.proxy_df.groupby(["sample_index"]).agg({self.proxy_name: "sum"})

        scores[self.proxy_name] = (scores[self.proxy_name] - scores[self.proxy_name].min()) / (
            scores[self.proxy_name].max() - scores[self.proxy_name].min()
        )

        return scores[self.proxy_name].to_dict()


In [362]:
pc = ProxyCalculator("proxies", "proxy_metric")
pc.load()

pc.calculate_proxy_scores()

{0: 0.9850846529006958,
 1: 0.9776061177253723,
 2: 0.973325788974762,
 3: 0.8434043526649475,
 4: 0.8119612336158752,
 5: 0.9781427979469299,
 6: 0.9605704545974731,
 7: 0.9775375127792358,
 8: 0.8739024996757507,
 9: 0.9128909111022949,
 10: 0.9740433096885681,
 11: 0.9452816843986511,
 12: 0.9805165529251099,
 13: 0.8826597332954407,
 14: 0.9644473791122437,
 15: 0.9678345322608948,
 16: 0.961044430732727,
 17: 0.9716731309890747,
 18: 0.6625897288322449,
 19: 0.9695131778717041,
 20: 0.9626560211181641,
 21: 0.9147055149078369,
 22: 0.9414896368980408,
 23: 0.9732064604759216,
 24: 0.9828078150749207,
 25: 0.9794906973838806,
 26: 0.8485358357429504,
 27: 0.9805017709732056,
 28: 0.9790338277816772,
 29: 0.9798418879508972,
 30: 0.4323117733001709,
 31: 0.9609223008155823,
 32: 0.9870248436927795,
 33: 0.7986031770706177,
 34: 0.884515106678009,
 35: 0.9709193706512451,
 36: 0.8677396774291992,
 37: 0.956750214099884,
 38: 0.9426982998847961,
 39: 0.9036018252372742,
 40: 0.8129493

In [363]:
import os
import random
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch.utils.data import Dataset, Subset
from torchvision.transforms import transforms


class ProbeSuiteGenerator(Dataset):
    dataset: Dataset
    remaining_indices: list = []
    used_indices: list = []
    dataset_len: int
    label_count: int
    proxy_calculator: ProxyCalculator

    suites: Dict[int, Tuple[Tuple[torch.Tensor, int], int]] = {}
    index_to_suite: Dict[int, str] = {}

    only_probes: bool = False

    def __init__(
        self,
        dataset: Dataset,
        label_count: int,
        proxy_calculator: ProxyCalculator,
        num_probes: int = 500,
        corruption_module: Optional[Union[torch.nn.Module, transforms.Compose]] = None,
        only_probes: bool = False,
    ):
        self.dataset = dataset
        self.dataset_len = len(self.dataset)
        self.used_indices = []
        self.remaining_indices = list(range(self.dataset_len))
        self.label_count = label_count
        self.num_probes = num_probes

        self.corruption_module = corruption_module
        self.only_probes = only_probes

        self.proxy_calculator = proxy_calculator
        self.scores = proxy_calculator.calculate_proxy_scores()
        self.sorted_indices = list(
            dict(sorted(self.scores.items(), key=lambda x: x[1], reverse=True)).keys()
        )

        assert len(self.scores) == self.dataset_len

        self.suites = {}
        self.index_to_suite = {}

    def generate(self):
        self.generate_atypical()
        self.generate_typical()
        self.generate_random_outputs()
        self.generate_random_inputs_outputs()
        if self.corruption_module is not None:
            self.generate_corrupted()

        assert len(np.intersect1d(self.remaining_indices, self.used_indices)) == 0
        assert (
            len(np.unique(self.remaining_indices + self.used_indices))
            == self.dataset_len
        )
        assert (
            len(np.unique(self.remaining_indices)) + len(np.unique(self.used_indices))
            == self.dataset_len
        )

    def add_suite(
        self, name: str, suite: List[Tuple[torch.Tensor, int, int]]
    ) -> "ProbeSuiteGenerator":
        for (sample, target), idx in suite:
            self.index_to_suite[idx] = name
            self.suites[idx] = ((sample, target), idx)

        return self

    def generate_typical(self):
        subset = self.get_subset(indices=self.sorted_indices[: self.num_probes])
        suite = [((x, y), idx) for (x, y), idx in subset]

        self.add_suite("typical", suite)

    def generate_atypical(self):
        subset = self.get_subset(
            indices=self.sorted_indices[-self.num_probes :]  # noqa: E203
        )  # noqa: E203
        atypical = [((x, y), idx) for (x, y), idx in subset]

        self.add_suite("atypical", atypical)

    def generate_random_outputs(self):
        subset = self.get_subset()
        suite = [
            (
                (x, random.choice([i for i in range(self.label_count) if i != y])),
                idx,
            )
            for (x, y), idx in subset
        ]

        self.add_suite("random_outputs", suite)

    def generate_random_inputs_outputs(self):
        subset = self.get_subset()

        suite = [
            ((torch.rand_like(x), torch.randint(0, self.label_count, (1,)).item()), idx)
            for (x, y), idx in subset
        ]

        self.add_suite("random_inputs_outputs", suite)

    def generate_corrupted(self):
        subset = self.get_subset()

        suite = [
            ((self.corruption_module(x), y), idx)
            for (x, y), idx in subset
        ]

        self.add_suite("corrupted", suite)

    def get_subset(
        self,
        indices: Optional[list[int]] = None,
    ) -> Subset:
        if indices is None:
            subset_indices = np.random.choice(
                self.remaining_indices, self.num_probes, replace=False
            ).tolist()
        else:
            subset_indices = indices

        self.used_indices.extend(subset_indices)
        self.remaining_indices = [
            idx for idx in self.remaining_indices if idx not in subset_indices
        ]

        return Subset(self.dataset, subset_indices)

    def __getitem__(self, index) -> Tuple[Tuple[torch.Tensor, int], int]:
        if self.only_probes:
            keys = list(self.suites.keys())

            return self.suites[keys[index]]

        if index in self.used_indices:
            return self.suites[index]

        sample, target = self.dataset[index]

        return (sample, target), index

    def __len__(self):
        if self.only_probes:
            return len(self.index_to_suite)

        return self.dataset_len


def make_probe_suites(
    dataset: Dataset,
    label_count: int,
    proxy_calculator: ProxyCalculator,
    num_probes: int = 500
):
    probe_suite = ProbeSuiteGenerator(
        dataset,
        label_count,
        num_probes=num_probes,
        proxy_calculator=proxy_calculator,
    )
    probe_suite.generate()

    return probe_suite

In [364]:
dm.setup("train")
probe_suite_ds = make_probe_suites(dm.mnist_train, 10, pc, num_probes=500)

def make_dataloaders(validation_dataloaders: List[DataLoader], probe_suite_dataset: ProbeSuiteGenerator, dataloader_kwargs: Optional[dict] = None):
    default_dataloader_options = {
        "batch_size": 512,
        "num_workers": 1,
        "prefetch_factor": 1
    }

    if dataloader_kwargs is not None:
        default_dataloader_options.update(dataloader_kwargs)

    probe_suite_dataset = deepcopy(probe_suite_dataset)
    probe_suite_dataset.only_probes = True
    probe_suite_dataset.dataset = None

    return [DataLoader(probe_suite_dataset, **default_dataloader_options)] + validation_dataloaders

In [365]:
dl_probes = DataLoader(mnist_train, batch_size=512, shuffle=True, num_workers=16, prefetch_factor=8)

val_dataloader = DataLoader(IDXDataset(mnist_val), batch_size=512, shuffle=True, num_workers=16, prefetch_factor=8)
val_dataloaders = make_dataloaders([val_dataloader], probe_suite_ds)

trainer_probes.fit(module.as_probes(probe_suite_ds), train_dataloaders=dl_probes, val_dataloaders=val_dataloaders)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 21.8 K
-------------------------------
21.8 K    Trainable params
0         Non-trainable params
21.8 K    Total params
0.087     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [366]:
sum([len(l) for l in module.mapd_losses_])

0

In [367]:
probes = ds.dataset("probes", partitioning=ds.partitioning(pa.schema([("epoch", pa.int64()), ("stage", pa.string())]), flavor="filename"))

In [368]:
len(probes.to_table())

5700000