In [16]:
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch, Dataset
from torch_geometric.datasets import MNISTSuperpixels
import torch
from typing import Any, Iterator, Tuple
from torch_geometric.utils import to_dense_batch, to_dense_adj
from dataclasses import dataclass


@dataclass(frozen=True)
class DenseData:
    x: torch.Tensor
    adj: torch.Tensor
    mask: torch.Tensor

    def __repr__(self):
        return (
            f"DenseData("
            f"x={tuple(self.x.shape)}, "
            f"adj={tuple(self.adj.shape)}, "
            f"mask={tuple(self.mask.shape)})"
        )


@dataclass(frozen=True)
class DenseElemet:
    data: DenseData
    y: torch.Tensor

    def __repr__(self):
        return (
            f"DenseElement("
            f"{self.data.__repr__()}, "
            f"y={tuple(self.y.shape)})"
        )


def densify(data: Batch) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    x, mask = to_dense_batch(
        torch.cat((data.x, data.pos), dim=1),
        data.batch
    )
    adj = to_dense_adj(data.edge_index, data.batch)
    return DenseElemet(DenseData(x, adj, mask), data.y)


class DenseMNISTDataLoader(DataLoader):
    def __init__(
        self,
        dataset: Dataset,
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Any = None,
        exclude_keys: Any = None,
        **kwargs,
    ):
        kwargs.pop('collate_fn', None)

        super().__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            follow_batch=follow_batch,
            exclude_keys=exclude_keys,
            **kwargs,
        )

    def __iter__(self) -> Iterator[DenseElemet]:
        for batch in super().__iter__():
            yield densify(batch)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MNIST_PATH = "../datasets/MNISTSuperpixel"
data_module = MNISTSuperpixels(MNIST_PATH)
data_module.to(device)
loader = DenseMNISTDataLoader(data_module, batch_size=5)
batch = next(iter(loader))
print(batch)

DenseElement(DenseData(x=(5, 75, 3), adj=(5, 75, 75), mask=(5, 75)), y=(5,))


In [3]:
from DataModules import MNISTSuperpixelDataModule

MNIST_PATH = "../datasets/MNISTSuperpixel"
data_module = MNISTSuperpixelDataModule(MNIST_PATH)
data_module.setup("fit")
loader = data_module.train_dataloader()
a = next(iter(loader))

In [11]:
x, adj, mask = a.data.x, a.data.adj, a.data.mask

In [19]:
x.shape

torch.Size([32, 75, 3])