In [7]:
import zarr
import numpy as np
from collections import defaultdict
from pathlib import Path

import torch
import torch.utils.data as torch_data

from omtra.utils.zarr_utils import list_zarr_arrays
from omtra.utils.graph import build_lookup_table
from omtra.utils import omtra_root

# create a single array, in memory

we're createing an in-memory zarr store but we could just as easily create a directory store and have the zarr array stored on disk

In [2]:
# basic example of creating and writing data to a zarr array
n_rows, n_feats = 1000, 100

z = zarr.create(shape=(n_rows, n_feats), 
                chunks=(100, n_feats), 
                dtype="f4", 
                store=zarr.storage.MemoryStore())

# Assign data to the array
z[:, :] = np.random.random((n_rows, n_feats))
z.info

Type               : Array
Zarr format        : 3
Data type          : DataType.float32
Shape              : (1000, 100)
Chunk shape        : (100, 100)
Order              : C
Read-only          : False
Store type         : MemoryStore
Filters            : ()
Serializer         : BytesCodec(endian=<Endian.little: 'little'>)
Compressors        : (ZstdCodec(level=0, checksum=False),)
No. bytes          : 400000 (390.6K)

# writing a single molecule to a zarr store

In [9]:
# generate fake data for a single molecule
n_atoms = 5
n_bonds = 3
positions = np.random.randn(n_atoms,3)
atom_types = np.zeros(n_atoms, dtype=int)
bond_orders = np.zeros(n_bonds, dtype=int)
edge_idxs = np.random.randint(0, n_atoms, size=(n_bonds, 2))

# Create a Zarr MemoryStore
# in practice we'll use a DirectoryStore but for this educational example we'll use a MemoryStore
store = zarr.storage.MemoryStore()

# Create a root group
root = zarr.group(store=store)

# Store tensors under different keys with specified chunk sizes
root.create_array('positions', shape=positions.shape, chunks=(1, 3), dtype=positions.dtype)
root.create_array('atom_types', shape=atom_types.shape, chunks=(1,), dtype=atom_types.dtype)
root.create_array('bond_orders', shape=bond_orders.shape, chunks=(1,), dtype=bond_orders.dtype)
root.create_array('edge_idxs', shape=edge_idxs.shape, chunks=(1, 2), dtype=edge_idxs.dtype)

# write data to the arrays
root['positions'][:] = positions
root['atom_types'][:] = atom_types
root['bond_orders'][:] = bond_orders
root['edge_idxs'][:] = edge_idxs

In [8]:
# Access data (same as before)
print("Positions:", root['positions'][:])
print("Atom Types:", root['atom_types'][:])
print("Bond Orders:", root['bond_orders'][:])
print("Edge Indices:", root['edge_idxs'][:])

# Check chunk sizes
print("Positions chunk size:", root['positions'].chunks)
print("Atom Types chunk size:", root['atom_types'].chunks)
print("Bond Orders chunk size:", root['bond_orders'].chunks)
print("Edge Indices chunk size:", root['edge_idxs'].chunks)

Positions: [[-1.08603681 -1.71000467  0.40702887]
 [ 1.51526509  1.26331849  0.31723167]
 [ 0.98958225  0.93055011 -0.25018692]
 [ 0.35657579 -0.56742346  0.38875173]
 [-0.26366849 -0.09798329 -0.60565436]]
Atom Types: [0 0 0 0 0]
Bond Orders: [0 0 0]
Edge Indices: [[0 2]
 [0 3]
 [2 1]]
Positions chunk size: (1, 3)
Atom Types chunk size: (1,)
Bond Orders chunk size: (1,)
Edge Indices chunk size: (1, 2)


# writing batches of molecules to a zarr array

## create a small "dataset" of molecules in tensor format

In [6]:
len(unbatched_molecules['x'])

500

In [8]:
n_molecules = 3000
unbatched_molecules = defaultdict(list)
for _ in range(n_molecules):
    n_atoms = np.random.randint(5, 15)
    n_edges = np.random.randint(1, int(n_atoms*0.4))
    n_pharm_nodes = np.random.randint(4, 8)
    x = np.random.randn(n_atoms, 3) # positions
    a = np.random.randint(0, 5, size=n_atoms) # atom types
    c = np.random.randint(0, 5, size=n_atoms) # atom charges
    edge_idxs = np.random.randint(0, n_atoms, size=(n_edges, 2)) # edge indicies for bonds
    e = np.random.randint(1, 4, size=n_edges) # bond orders

    x_pharm = np.random.randn(n_pharm_nodes, 3) # positions
    a_pharm = np.random.randint(0, 5, size=n_pharm_nodes) # atom types

    unbatched_molecules['x'].append(x)
    unbatched_molecules['a'].append(a)
    unbatched_molecules['c'].append(c)
    unbatched_molecules['edge_index'].append(edge_idxs)
    unbatched_molecules['e'].append(e)
    unbatched_molecules['x_pharm'].append(x_pharm)
    unbatched_molecules['a_pharm'].append(a_pharm)


# now batch the molecules together! there are a few steps here

# first we need to record the number of nodes and edges in each molecule
batch_num_nodes = [x.shape[0] for x in unbatched_molecules['x']]
batch_num_pharm_nodes = [x.shape[0] for x in unbatched_molecules['x_pharm']]
batch_num_edges = [eidxs.shape[0] for eidxs in unbatched_molecules['edge_index']]

# convert batch data to numpy arrays
batch_num_nodes = np.array(batch_num_nodes)
batch_num_edges = np.array(batch_num_edges)
batch_num_pharm_nodes = np.array(batch_num_pharm_nodes)

# concatenate all the data together
x = np.concatenate(unbatched_molecules['x'], axis=0)
a = np.concatenate(unbatched_molecules['a'], axis=0)
c = np.concatenate(unbatched_molecules['c'], axis=0)
x_pharm = np.concatenate(unbatched_molecules['x_pharm'], axis=0)
a_pharm = np.concatenate(unbatched_molecules['a_pharm'], axis=0)

edge_index = np.concatenate(unbatched_molecules['edge_index'], axis=0)
e = np.concatenate(unbatched_molecules['e'], axis=0)


# create an array of indicies to keep track of the start_idx and end_idx of each molecule's node features
node_lookup = build_lookup_table(batch_num_nodes)

# create an array of indicies to keep track of the start_idx and end_idx of each molecule's edge features
edge_lookup = build_lookup_table(batch_num_edges)

# create an array of indicies to keep track of the start_idx and end_idx of each molecule's pharmacophore node features
pharm_node_lookup = build_lookup_table(batch_num_pharm_nodes)

# print("batch_num_nodes:", batch_num_nodes)
# print("batch_num_edges:", batch_num_edges)
print("Shape of x:", x.shape)
print("Shape of a:", a.shape)
print("Shape of e:", e.shape)
print("Shape of c:", c.shape)
print("Shape of x_pharm:", x_pharm.shape)
print("Shape of a_pharm:", a_pharm.shape)
print("Shape of edge_index:", edge_index.shape)
print("Shape of node_lookup:", node_lookup.shape)
print("Shape of edge_lookup:", edge_lookup.shape)
print("Shape of pharm_node_lookup:", pharm_node_lookup.shape)

Shape of x: (28764, 3)
Shape of a: (28764,)
Shape of e: (5117,)
Shape of c: (28764,)
Shape of x_pharm: (16529, 3)
Shape of a_pharm: (16529,)
Shape of edge_index: (5117, 2)
Shape of node_lookup: (3000, 2)
Shape of edge_lookup: (3000, 2)
Shape of pharm_node_lookup: (3000, 2)


## write the molecule dataset to a zarr file

In [9]:
# store = zarr.storage.MemoryStore()
graphs_per_chunk = 500 # very important parameter

# store_path = Path(omtra_root()) / 'data' / 'pharmit_dev' / 'train.zarr'
store_path = Path(omtra_root()) / 'data' / 'pharmit_dev' / 'val.zarr'

store = zarr.storage.LocalStore(str(store_path))

# Create a root group
root = zarr.group(store=store)

ntypes = ['lig', 'pharm']

ntype_groups = {}
for ntype in ntypes:
    ntype_groups[ntype] = root.create_group(ntype)


lig_node = ntype_groups['lig'].create_group('node')
lig_edge_data = ntype_groups['lig'].create_group('edge')
pharm_node_data = ntype_groups['pharm'].create_group('node')

# Store tensors under different keys with specified chunk sizes

# some simple heuristics to decide chunk sizes for node and edge data
mean_lig_nodes_per_graph = int(np.mean(batch_num_nodes))
mean_ll_edges_per_graph = int(np.mean(batch_num_edges))
mean_pharm_nodes_per_graph = int(np.mean([x.shape[0] for x in unbatched_molecules['x_pharm']]))
nodes_per_chunk = graphs_per_chunk * mean_lig_nodes_per_graph
ll_edges_per_chunk = graphs_per_chunk * mean_ll_edges_per_graph
pharm_nodes_per_chunk = graphs_per_chunk * mean_pharm_nodes_per_graph

# create arrays for node data
lig_node.create_array('x', shape=x.shape, chunks=(nodes_per_chunk, 3), dtype=x.dtype)
lig_node.create_array('a', shape=a.shape, chunks=(nodes_per_chunk,), dtype=a.dtype)
lig_node.create_array('c', shape=c.shape, chunks=(nodes_per_chunk,), dtype=c.dtype)

# create arrays for pharmacophore node data
pharm_node_data.create_array('x', shape=x_pharm.shape, chunks=(pharm_nodes_per_chunk, 3), dtype=x_pharm.dtype)
pharm_node_data.create_array('a', shape=a_pharm.shape, chunks=(pharm_nodes_per_chunk,), dtype=a_pharm.dtype)
pharm_node_data.create_array('graph_lookup', shape=pharm_node_lookup.shape, chunks=pharm_node_lookup.shape, dtype=pharm_node_lookup.dtype)

# create arrays for edge data
lig_edge_data.create_array('e', shape=e.shape, chunks=(ll_edges_per_chunk,), dtype=e.dtype)
lig_edge_data.create_array('edge_index', shape=edge_index.shape, chunks=(ll_edges_per_chunk, 2), dtype=edge_index.dtype)

# because node_lookup and edge_lookup are relatively small, we may get away with not chunking them
lig_node.create_array('graph_lookup', shape=node_lookup.shape, chunks=node_lookup.shape, dtype=node_lookup.dtype)
lig_edge_data.create_array('graph_lookup', shape=edge_lookup.shape, chunks=edge_lookup.shape, dtype=edge_lookup.dtype)

# write data to the arrays
lig_node['x'][:] = x
lig_node['a'][:] = a
lig_node['c'][:] = c
lig_edge_data['e'][:] = e
lig_edge_data['edge_index'][:] = edge_index
lig_node['graph_lookup'][:] = node_lookup
lig_edge_data['graph_lookup'][:] = edge_lookup
pharm_node_data['x'][:] = x_pharm
pharm_node_data['a'][:] = a_pharm
pharm_node_data['graph_lookup'][:] = pharm_node_lookup

visualize the structure of the zarr store that we just created

In [6]:
root.tree()

In [5]:
root['node_data/a'].info

Type               : Array
Zarr format        : 3
Data type          : DataType.int64
Shape              : (893,)
Chunk shape        : (80,)
Order              : C
Read-only          : False
Store type         : MemoryStore
Filters            : ()
Serializer         : BytesCodec(endian=<Endian.little: 'little'>)
Compressors        : (ZstdCodec(level=0, checksum=False),)
No. bytes          : 7144 (7.0K)

# for conceptual purposes, a simple torch map-style dataset on top of the zarr store

In [6]:
class ZarrDataset(torch_data.Dataset):
    def __init__(self, zarr_store):
        self.zarr_store = zarr_store

        self.n_graphs = self.zarr_store['node_data/node_lookup'].shape[0]

    def __len__(self):
        return self.n_graphs
    
    def __getitem__(self, idx):

        # get node and edge data groups from zarr store
        node_data = self.zarr_store['node_data']
        edge_data = self.zarr_store['edge_data']

        # lookup start and end indicies for node and edge data to pull just
        # one graph from the full dataset
        node_start_idx, node_end_idx = node_data['node_lookup'][idx]
        edge_start_idx, edge_end_idx = edge_data['edge_lookup'][idx]

        # pull out the data for the graph
        x = node_data['x'][node_start_idx:node_end_idx]
        a = node_data['a'][node_start_idx:node_end_idx]
        e = edge_data['e'][edge_start_idx:edge_end_idx]
        edge_idxs = edge_data['edge_index'][edge_start_idx:edge_end_idx]

        # TODO: convert to DGL graph

        return x, a, e, edge_idxs
    
dataset = ZarrDataset(root)
dataset[0]

(array([[-1.71753563,  0.1028616 , -0.19695899],
        [-0.22229265, -0.21909488,  1.32207747],
        [-0.76388048,  1.21722057,  1.51023126],
        [-0.42169209, -0.6559013 ,  0.39215927],
        [ 0.44029963,  0.87215712,  0.15028246],
        [ 0.51902108, -1.85362516, -1.09245339],
        [-0.58521468,  1.32320007, -0.05868108]]),
 array([3, 1, 4, 1, 2, 3, 1]),
 array([2, 1, 1, 2, 1, 1, 0, 1, 2, 2, 2]),
 array([[5, 2],
        [1, 4],
        [6, 5],
        [2, 1],
        [1, 0],
        [0, 4],
        [5, 0],
        [5, 4],
        [0, 3],
        [2, 4],
        [1, 6]]))

coming up next:
 - [ ] add a dataloader with a custom sampler so that we align our batches with chunks in the zarr store
 - [ ] make an adaptive dataloader (sampler) that will create batches with a max num nodes or edges

# testing zarr dataset with cached chunk reads

In [1]:
from omtra.dataset.pharmit import PharmitDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = PharmitDataset('test_ligand_dataset.zarr')

In [3]:
dataset[("denovo_ligand", 0)]

Graph(num_nodes={'lig': 14, 'pharm': 6, 'prot_atom': 0, 'prot_res': 0},
      num_edges={('lig', 'lig_to_lig', 'lig'): 182, ('lig', 'lig_to_pharm', 'pharm'): 0, ('lig', 'lig_to_prot_atom', 'prot_atom'): 0, ('lig', 'lig_to_prot_res', 'prot_res'): 0, ('pharm', 'pharm_to_lig', 'lig'): 0, ('pharm', 'pharm_to_pharm', 'pharm'): 0, ('pharm', 'pharm_to_prot_atom', 'prot_atom'): 0, ('pharm', 'pharm_to_prot_res', 'prot_res'): 0, ('prot_atom', 'prot_atom_to_lig', 'lig'): 0, ('prot_atom', 'prot_atom_to_pharm', 'pharm'): 0, ('prot_atom', 'prot_atom_to_prot_atom', 'prot_atom'): 0, ('prot_atom', 'prot_atom_to_prot_res', 'prot_res'): 0, ('prot_res', 'prot_res_to_lig', 'lig'): 0, ('prot_res', 'prot_res_to_pharm', 'pharm'): 0, ('prot_res', 'prot_res_to_prot_atom', 'prot_atom'): 0, ('prot_res', 'prot_res_to_prot_res', 'prot_res'): 0},
      metagraph=[('lig', 'lig', 'lig_to_lig'), ('lig', 'pharm', 'lig_to_pharm'), ('lig', 'prot_atom', 'lig_to_prot_atom'), ('lig', 'prot_res', 'lig_to_prot_res'), ('pharm',