In [1]:
import torch
from torch.utils.data import DataLoader
from torch_geometric.data import Dataset as pygDataset
from torch.utils.data import Dataset
from torch_geometric.data import Dataset as pygDataset
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj
import torch as th
from dataclasses import dataclass
from typing import Dict, Tuple, List, Sequence
from dataclasses import dataclass
from torch_geometric.datasets import MNISTSuperpixels
from torch.nn.functional import pad

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MNIST_PATH = "../datasets/MNISTSuperpixel/"



In [2]:
dataset = MNISTSuperpixels(MNIST_PATH).to(device)[:10]
dataset

MNISTSuperpixels(10)

In [3]:
@dataclass
class DenseData:
    x: th.tensor
    adj: th.tensor
    mask: th.tensor

    def __repr__(self):
        return (
            f"DenseData("
            f"x={tuple(self.x.shape)}, "
            f"adj={tuple(self.adj.shape)}, "
            f"mask={tuple(self.mask.shape)})"
        )

    def __getitem__(self, index):
        return DenseData(
            self.x[index],
            self.adj[index],
            self.mask[index]
        )

    def __add__(self, other):
        if not isinstance(other, DenseData):
            raise ValueError("Both objects need to be of type DenseData")
        x_concat = torch.cat((self.x, other.x), dim=0)
        adj_concat = torch.cat((self.adj, other.adj), dim=0)
        mask_concat = torch.cat((self.mask, other.mask), dim=0)
        return DenseData(x_concat, adj_concat, mask_concat)


class DenseGraphDataset(Dataset):
    def __init__(self, pyg_dataset: pygDataset):
        self._pyg_dataset = pyg_dataset

        self.data = []
        self.targets = []
        for el in self._pyg_dataset:
            el_dict = el.to_dict()
            x = el_dict.pop('x')
            adj = to_dense_adj(
                el_dict.pop('edge_index'),
                edge_attr=el_dict.pop('edge_attr', None)
            ).squeeze(0)
            mask = th.ones(
                x.shape[0],
                device=x.device,
                dtype=th.bool
            )

            y = el_dict.pop('y')

            remaining_keys = list(sorted(el_dict.keys()))
            for k in remaining_keys:
                # Concatenate remaning attributes on x
                x = th.cat((x, el_dict[k]), dim=1)

            self.data.append(
                DenseData(
                    x,
                    adj,
                    mask)
            )
            self.targets.append(y)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

In [54]:
dense_dataset = DenseGraphDataset(dataset)
loader = next(iter(dense_dataset))
print(loader)

(DenseData(x=(75, 3), adj=(75, 75), mask=(75,)), tensor([5], device='cuda:0'))


In [5]:
def dense_collate_fn(batch: List[Tuple[DenseData, th.Tensor]]) -> Tuple[DenseData, th.Tensor]:
    max_num_nodes = max([el[0].x.shape[0] for el in batch])
    x_list = []
    adj_list = []
    mask_list = []
    y_list = []

    for data, y in batch:
        n_nodes = data.x.shape[0]
        x = pad(data.x, (data.x.ndim-1)*(0, 0) + (0, max_num_nodes - n_nodes)).unsqueeze(0)
        adj = pad(data.adj, (data.adj.ndim-2)*(0, 0) + 2*(0, max_num_nodes - n_nodes)).unsqueeze(0)
        mask = pad(data.mask, (0, max_num_nodes - n_nodes)).unsqueeze(0)

        x_list.append(x)
        adj_list.append(adj)
        mask_list.append(mask)
        y_list.append(y)

    x_stacked = th.cat(x_list, dim=0)
    adj_stacked = th.cat(adj_list, dim=0)
    mask_stacked = th.cat(mask_list, dim=0)
    y_stacked = th.cat(y_list, dim=0)

    return DenseData(x_stacked, adj_stacked, mask_stacked), y_stacked


In [6]:
loader = DataLoader(dense_dataset, batch_size=5, collate_fn=dense_collate_fn)
a, y = next(iter(loader))
print(f"{a},y={tuple(y.shape)}")

DenseData(x=(5, 75, 3), adj=(5, 75, 75), mask=(5, 75)),y=(5,)


In [7]:
# from DataModules import MNISTSuperpixelDataModule

# data_module = MNISTSuperpixelDataModule(MNIST_PATH,batch_size=20)
# data_module.setup("fit")
# a = data_module.train_dataloader()

In [8]:
mnist_full = DenseGraphDataset(MNISTSuperpixels(MNIST_PATH, train=True)[:10])
to_take = (mnist_full[1] == 0) | (mnist_full[1] == 1)
mnist_full = mnist_full[to_take]
split = int(len(mnist_full) * 0.9)
mnist_train = mnist_full[:split]
mnist_val = mnist_full[split:]


train = DataLoader(mnist_train, batch_size=5, drop_last=True, shuffle=True, pin_memory=True,
                           collate_fn=dense_collate_fn)



In [9]:
from typing import Tuple, List, Sequence
import torch as th
from torch_geometric.data import Dataset as pygDataset
from torch.utils.data import DataLoader, Dataset
from torch_geometric.utils import to_dense_adj
from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import pad

# TODO: consider store dense repr ons disk.
# TODO: it works only with a single feature channel (both on nodes and edges) -> should be true on most of the dataset
# TODO: find a wat to access pygDataset attribute without breaking multiprocessing dataloader (__get_attribute__ fails)
#  Probably the best practice is to write them explicitly


def dense_collate_fn_old(batch: List[Tuple[Sequence[th.Tensor], th.Tensor]]) -> Tuple[Sequence[th.Tensor], th.Tensor]:
    num_attr = len(batch[0][0])
    max_num_nodes = max([el[0][0].shape[0] for el in batch])
    zipped_padded_batch = [[] for _ in range(num_attr+1)] # +1 for y
    for x, y in batch:
        n_nodes = x[0].shape[0]
        for i, attr in enumerate(x):
            n_dim = attr.ndim
            if i == 0:
                # this is the node feature x, padding on first dimension
                padded_attr = pad(attr, (n_dim-1)*(0, 0) + (0, max_num_nodes - n_nodes)).unsqueeze(0)
            elif i == 1:
                # this is the dense adj, padding on the first two dimension
                padded_attr = pad(attr, (n_dim-2)*(0, 0) + 2*(0, max_num_nodes - n_nodes)).unsqueeze(0)
            elif i == 2:
                # this is the node_mask, padding on the first (and only) dimension
                padded_attr = pad(attr, (0, max_num_nodes - n_nodes)).unsqueeze(0)
            else:
                # these are other attributes: we always assume are nodes features or scalars
                if attr.ndim == 0 or attr.shape[0] == 1:
                    # this is a scalar, no need of padding
                    padded_attr = attr
                elif attr.ndim < 3:
                    # node attributes
                    padded_attr = pad(attr, (n_dim-1)*(0, 0) + (0, max_num_nodes - n_nodes)).unsqueeze(0)
                else:
                    raise ValueError(f'Unsupported attribute with shape {attr.shape}')
            zipped_padded_batch[i].append(padded_attr)
        zipped_padded_batch[-1].append(y)  # for y

    stacked_batch = [th.cat(zipped_padded_batch[i], dim=0) for i in range(num_attr+1)]
    return tuple(stacked_batch[:-1]), stacked_batch[-1]


class DenseGraphDataset_old(Dataset):

    """
    This class is wrapper of a pygDataset.
    """

    def __init__(self, pyg_dataset: pygDataset):
        self._pyg_dataset = pyg_dataset

        self.data = []
        self.targets = []
        for el in self._pyg_dataset:
            el_dict = el.to_dict()
            adj = to_dense_adj(el_dict.pop('edge_index'), edge_attr=el_dict.pop('edge_attr', None)).squeeze(0)
            x_tuple = [el_dict.pop('x'), adj, th.ones((el.x.shape[0]), device=el.x.device, dtype=th.bool)]
            y = el_dict.pop('y')

            remaining_keys = list(sorted(el_dict.keys()))
            for k in remaining_keys:
                x_tuple.append(el_dict[k])

            self.data.append(tuple(x_tuple))
            self.targets.append(y)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

In [10]:
mnist_full = MNISTSuperpixels(MNIST_PATH, train=True).to(device)[:20]
data = DenseGraphDataset_old(mnist_full)

loader = DataLoader(data, batch_size=5, collate_fn=dense_collate_fn_old)