In [1]:
import numpy as np
import jax
import jax.numpy as jnp

jax.config.update('jax_platform_name', 'cpu')

In [2]:
n = 20

from cdv.dataset import load_file
from cdv.utils import load_pytree
from cdv.databatch import CrystalGraphs
from flax.serialization import from_state_dict

raw_batches: list[CrystalGraphs] = []

for i in range(n):    
    fn = f'/home/nmiklaucic/cdv/precomputed/mptrj/batches/group_0000/{i:05}.mpk'
    raw_batches.append(load_pytree(fn))

jax.block_until_ready(raw_batches);

In [3]:
batches = [from_state_dict(CrystalGraphs.new_empty(1024, 16, 32), state) for state in raw_batches]
jax.block_until_ready(batches);

In [4]:
from collections.abc import Mapping
import zarr

batch = raw_batches[0]

def node_to_zarr(state, n=None, store=None):
    if n is None:
        n = zarr.group(store=store, overwrite=True)
    if isinstance(state, Mapping):        
        for k, v in state.items():
            if isinstance(v, Mapping):
                child = n.create_group(k)
                node_to_zarr(v, child)
            else:
                n.array(k, v)

    return n


zb = node_to_zarr(batch)
zb.tree()

Tree(nodes=(Node(disabled=True, name='/', nodes=(Node(disabled=True, name='edges', nodes=(Node(disabled=True, …

In [11]:
!rm -rf /tmp/tensortest/*

In [17]:
base_store = '/tmp/tensortest'

zarr_version = 2

for i, batch in enumerate(raw_batches):
    store = zarr.ZipStore(f'{base_store}/{i}.zip')
    node_to_zarr(batch, store=store)    
    store.close()    

NotImplementedError: 

In [18]:
opened_batches = []

for i, batch in enumerate(raw_batches):
    store = f'{base_store}/{i}.zip'     
    zb = zarr.open(store, mode='r', zarr_version=zarr_version)
    opened_batches.append(zb)

data = jax.tree_map(lambda x: None, raw_batches[0])

def set_path(data, path, value):
    if not hasattr(value, 'get_basic_selection'):
        return
    if isinstance(path, str):
        path = path.split('/')
    if len(path) == 1:        
        data[path[0]] = value.get_basic_selection()
    else:
        set_path(data[path[0]], path[1:], value)

redone_batches = []
for zb in opened_batches:
    data = jax.tree_map(lambda x: None, raw_batches[0])
    zb.visititems(lambda n, v, data=data: set_path(data, n, v))
    redone_batches.append(data)

jax.tree.map(lambda x, y: x.astype(float) - y.astype(float), redone_batches[0], raw_batches[0])

{'edges': {'receiver': array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]),
  'to_jimage': array([[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
  
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
  
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
  
         ...,
  
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
  
         [[0., 0., 0.],
          [0., 0.