In [12]:
import zarr
import numpy as np
import math, random
from collections import defaultdict

import torch
import torch.utils.data as torch_data
import numpy as np


import numcodecs  # for Pickle or JSON


In [2]:
# Create some fake data

In [13]:
import numpy as np
from collections import defaultdict

n_molecules = 100

unbatched_molecules = defaultdict(list)
for _ in range(n_molecules):
    n_atoms = np.random.randint(5, 15)
    n_edges = np.random.randint(n_atoms // 2, n_atoms * 2)
    
    x = np.random.randn(n_atoms, 3)              # positions
    a = np.random.randint(0, 5, size=n_atoms)    # atom types
    edge_idxs = np.random.randint(0, n_atoms, size=(n_edges, 2))
    e = np.random.randint(0, 3, size=n_edges)    # bond orders

    unbatched_molecules['x'].append(x)
    unbatched_molecules['a'].append(a)
    unbatched_molecules['edge_index'].append(edge_idxs)
    unbatched_molecules['e'].append(e)


In [14]:
# Number of graphs
n_graphs = len(unbatched_molecules['x'])

# Count the number of nodes/edges per graph
batch_num_nodes = [arr.shape[0] for arr in unbatched_molecules['x']]
batch_num_edges = [arr.shape[0] for arr in unbatched_molecules['edge_index']]

batch_num_nodes = np.array(batch_num_nodes, dtype=np.int64)
batch_num_edges = np.array(batch_num_edges, dtype=np.int64)

# Concatenate all node data
x = np.concatenate(unbatched_molecules['x'], axis=0)  # shape (total_nodes, 3)
a = np.concatenate(unbatched_molecules['a'], axis=0)  # shape (total_nodes,)

# Concatenate all edge data
edge_index = np.concatenate(unbatched_molecules['edge_index'], axis=0)  # shape (total_edges, 2)
e = np.concatenate(unbatched_molecules['e'], axis=0)                    # shape (total_edges,)

# Build lookup for each graph's node range
node_lookup = np.zeros((n_graphs, 2), dtype=np.int64)
node_lookup[1:, 0] = np.cumsum(batch_num_nodes[:-1])
node_lookup[:, 1] = np.cumsum(batch_num_nodes)

# Build lookup for each graph's edge range
edge_lookup = np.zeros((n_graphs, 2), dtype=np.int64)
edge_lookup[1:, 0] = np.cumsum(batch_num_edges[:-1])
edge_lookup[:, 1] = np.cumsum(batch_num_edges)


In [16]:
import zarr

store = zarr.storage.MemoryStore()
root = zarr.group(store=store)

node_data_group = root.create_group('node_data')
edge_data_group = root.create_group('edge_data')

# Let's pick chunk sizes based on ~10 graphs worth of nodes/edges
graphs_per_chunk = 10
mean_nodes_per_graph = int(batch_num_nodes.mean())
mean_edges_per_graph = int(batch_num_edges.mean())

nodes_per_chunk = graphs_per_chunk * mean_nodes_per_graph
edges_per_chunk = graphs_per_chunk * mean_edges_per_graph

# Create node-level arrays (empty initially)
ds_x = node_data_group.create_dataset(
    'x',
    shape=x.shape,                # e.g. (total_nodes, 3)
    dtype=x.dtype,
    chunks=(nodes_per_chunk, x.shape[1])  # e.g. (some_number, 3)
)
ds_a = node_data_group.create_dataset(
    'a',
    shape=a.shape,                # e.g. (total_nodes,)
    dtype=a.dtype,
    chunks=(nodes_per_chunk,)
)
ds_node_lookup = node_data_group.create_dataset(
    'node_lookup',
    shape=node_lookup.shape,      # e.g. (n_graphs, 2)
    dtype=node_lookup.dtype,
    chunks=node_lookup.shape      # small enough, can store in 1 chunk
)

# Create edge-level arrays (empty initially)
ds_edge_index = edge_data_group.create_dataset(
    'edge_index',
    shape=edge_index.shape,       # e.g. (total_edges, 2)
    dtype=edge_index.dtype,
    chunks=(edges_per_chunk, edge_index.shape[1])
)
ds_e = edge_data_group.create_dataset(
    'e',
    shape=e.shape,                # e.g. (total_edges,)
    dtype=e.dtype,
    chunks=(edges_per_chunk,)
)
ds_edge_lookup = edge_data_group.create_dataset(
    'edge_lookup',
    shape=edge_lookup.shape,      # e.g. (n_graphs, 2)
    dtype=edge_lookup.dtype,
    chunks=edge_lookup.shape
)

# Now write the data to each dataset
ds_x[...] = x
ds_a[...] = a
ds_node_lookup[...] = node_lookup

ds_edge_index[...] = edge_index
ds_e[...] = e
ds_edge_lookup[...] = edge_lookup


  ds_x = node_data_group.create_dataset(
  ds_a = node_data_group.create_dataset(
  ds_node_lookup = node_data_group.create_dataset(
  ds_edge_index = edge_data_group.create_dataset(
  ds_e = edge_data_group.create_dataset(
  ds_edge_lookup = edge_data_group.create_dataset(


In [17]:
import functools

class MyStorage:
    """
    Single storage class for chunked numeric arrays (x, a, e, edge_index).
    Provides chunk-level caching, so repeated reads in the same chunk
    do not cause repeated disk I/O.
    """
    def __init__(self, root_group):
        node_grp = root_group['node_data']
        edge_grp = root_group['edge_data']
        self.zarr_arrays = {
            'x': node_grp['x'],
            'a': node_grp['a'],
            'e': edge_grp['e'],
            'edge_index': edge_grp['edge_index'],
        }

        # store chunk sizes
        self.chunk_sizes = {
            name: arr.chunks[0] for name, arr in self.zarr_arrays.items()
        }

    def load_item(self, array_name, i):
        """
        Return the single row i of array array_name (axis=0).
        """
        chunk_size = self.chunk_sizes[array_name]
        chunk_num, chunk_idx = divmod(i, chunk_size)
        chunk_data = self._load_chunk(array_name, chunk_num)
        return chunk_data[chunk_idx]

    @functools.lru_cache(None)
    def _load_chunk(self, array_name, chunk_num):
        arr = self.zarr_arrays[array_name]
        chunk_size = self.chunk_sizes[array_name]
        start = chunk_num * chunk_size
        end = start + chunk_size
        return arr[start:end]


In [18]:
import torch
from torch.utils.data import Dataset

class ZarrDataset(Dataset):
    """
    Returns (x, a, e, edge_index) for graph `idx`.
    Uses a single MyStorage for chunk-cached reads.
    """
    def __init__(self, root_group):
        node_grp = root_group['node_data']
        edge_grp = root_group['edge_data']

        # lookups are small, read into memory
        self.node_lookup = node_grp['node_lookup'][:]
        self.edge_lookup = edge_grp['edge_lookup'][:]
        self.n_graphs = self.node_lookup.shape[0]

        self.storage = MyStorage(root_group)

    def __len__(self):
        return self.n_graphs

    def __getitem__(self, idx):
        # Node range
        node_start, node_end = self.node_lookup[idx]
        # Edge range
        edge_start, edge_end = self.edge_lookup[idx]

        # read node-level rows
        x_list = [self.storage.load_item('x', i) for i in range(node_start, node_end)]
        a_list = [self.storage.load_item('a', i) for i in range(node_start, node_end)]

        # read edge-level rows
        e_list = [self.storage.load_item('e', j) for j in range(edge_start, edge_end)]
        edge_idx_list = [self.storage.load_item('edge_index', j) for j in range(edge_start, edge_end)]

        # convert to arrays or torch tensors
        x = np.vstack(x_list).astype(np.float32)      # shape (num_nodes, 3)
        a = np.array(a_list, dtype=np.int64)          # shape (num_nodes,)
        e = np.array(e_list, dtype=np.int64)          # shape (num_edges,)
        edge_index = np.vstack(edge_idx_list).astype(np.int64)  # shape (num_edges, 2)

        return (x, a, e, edge_index)


In [22]:

import math, random
from torch.utils.data import Sampler

class TwoLevelSampler(Sampler[list[int]]):
    """
    Sampler that:
      1) Splits the dataset's [0..n_graphs-1] indices into chunks of size `chunk_size`.
      2) Randomly shuffles the chunk order (outer random).
      3) For each chunk, optionally shuffle the item order inside it (inner random).
      4) Yields *mini-batches* of size `mini_batch_size` from that chunk.
    """
    def __init__(
        self,
        data_source: Dataset,
        chunk_size: int,
        mini_batch_size: int,
        shuffle_chunks: bool = True,
        shuffle_within_chunk: bool = True
    ):
        super().__init__(data_source)
        self.n_graphs = len(data_source)
        self.chunk_size = chunk_size
        self.mini_batch_size = mini_batch_size
        self.shuffle_chunks = shuffle_chunks
        self.shuffle_within_chunk = shuffle_within_chunk

        self.num_chunks = math.ceil(self.n_graphs / self.chunk_size)

        # Pre-build a list of chunks, each chunk is a list of graph indices
        self.chunks = []
        start = 0
        for _ in range(self.num_chunks):
            end = min(start + self.chunk_size, self.n_graphs)
            self.chunks.append(list(range(start, end)))
            start = end

    def __iter__(self):
        # Shuffle the chunk order if requested
        chunk_indices = list(range(self.num_chunks))
        if self.shuffle_chunks:
            random.shuffle(chunk_indices)

        # For each chunk in random order
        for cidx in chunk_indices:
            indices_in_chunk = self.chunks[cidx]

            # Shuffle inside chunk if requested
            if self.shuffle_within_chunk:
                random.shuffle(indices_in_chunk)

            # Now subdivide this chunk's indices into mini-batches
            for start_i in range(0, len(indices_in_chunk), self.mini_batch_size):
                mini_batch = indices_in_chunk[start_i : start_i + self.mini_batch_size]
                yield mini_batch

    def __len__(self):
        """
        This is the total number of mini-batches (not the number of chunks).
        """
        total_mb = 0
        for chunk_list in self.chunks:
            chunk_len = len(chunk_list)
            total_mb += math.ceil(chunk_len / self.mini_batch_size)
        return total_mb



In [24]:
from torch.utils.data import DataLoader

def graph_collate(batch_list):
    """
    batch_list is a list of (x, a, e, edge_index).
    Return as-is or do more sophisticated collation.
    """
    return batch_list

dataset = ZarrDataset(root)

sampler = ChunkSampler(dataset, chunk_size=10, shuffle=True)

loader = DataLoader(
    dataset,
    batch_sampler=sampler,
    collate_fn=graph_collate,
    num_workers=0,  # or more if desired
    pin_memory=False
)

for batch_idx, batch_graphs in enumerate(loader):
    print(f"Batch {batch_idx}: {len(batch_graphs)} graphs")
    for (x, a, e, edge_index) in batch_graphs:
        # x.shape => (num_nodes, 3)
        # a.shape => (num_nodes,)
        # e.shape => (num_edges,)
        # edge_index.shape => (num_edges, 2)
        pass


Batch 0: 10 graphs
Batch 1: 10 graphs
Batch 2: 10 graphs
Batch 3: 10 graphs
Batch 4: 10 graphs
Batch 5: 10 graphs
Batch 6: 10 graphs
Batch 7: 10 graphs
Batch 8: 10 graphs
Batch 9: 10 graphs




In [None]:
TODO: Check if yeild actually reduces memory fetches
Does having seperate stores mess this idea up?