In [1]:
import h5py, os
import numpy as np

In [None]:
def combine_h5_files(path1, path2, output_path):
    # Open both input files and create output file
    with h5py.File(path1, 'r') as hf1, h5py.File(path2, 'r') as hf2, h5py.File(output_path, 'w') as hf_out:
        # Assuming both files have same number of graphs
        num_graphs = len(hf2.keys())
        
        # Process each graph
        for graph_idx in range(num_graphs):
            graph_key = f'graph_{graph_idx}'
            graph1 = hf1[graph_key]
            graph2 = hf2[graph_key]
            
            # Create group in output file
            graph_out = hf_out.create_group(graph_key)
            
            # Combine position data (concatenate on dim 1)
            pos1 = graph1['liq_position'][:]
            pos2 = graph2['liq_position'][:]
            combined_pos = np.concatenate([pos1, pos2], axis=1)
            graph_out.create_dataset('liq_position', data=combined_pos)
            
            # Combine mesh_pose data (concatenate on dim 0)
            mesh_pose1 = graph1['mesh_pose'][:]
            mesh_pose2 = graph2['mesh_pose'][:]
            combined_mesh_pose = np.concatenate([mesh_pose1, mesh_pose2], axis=0)
            graph_out.create_dataset('mesh_pose', data=combined_mesh_pose)
            
            # Combine mesh_position data (concatenate on dim 1)
            mesh_pos1 = graph1['mesh_position'][:]
            mesh_pos2 = graph2['mesh_position'][:]
            combined_mesh_pos = np.concatenate([mesh_pos1, mesh_pos2], axis=1)
            graph_out.create_dataset('mesh_position', data=combined_mesh_pos)
            
            # Combine particle_types data (concatenate on dim 0)
            ptypes1 = graph1['particle_types'][:]
            ptypes2 = graph2['particle_types'][:]
            combined_ptypes = np.concatenate([ptypes1, ptypes2], axis=0)
            graph_out.create_dataset('particle_types', data=combined_ptypes)

# Usage

In [45]:
# path1 = '/home/niteesh/Documents/source_codes/learning_to_simulate_pouring/datasets/Pouring_sdf_transRotate_new/test/simulation_0.h5'
# path2 = '/home/niteesh/Documents/source_codes/learning_to_simulate_pouring/datasets/Pouring_sdf_newTest/test/simulation_0.h5'
# path3 = '/home/niteesh/Documents/source_codes/learning_to_simulate_pouring/datasets/Pouring_combined/test/simulation_1.h5'

path1 = '/home/niteesh/Documents/source_codes/PouringSim/data/simulation_output/simout_1312_lessPt/Pouring_sdf_transRotate_new2/test/simulation_2.h5'
path2 = '/home/niteesh/Documents/source_codes/PouringSim/data/simulation_output/simout_MartiniBottle_2701_lessPt/Pouring_sdf_MartiniBottle_2701_lessPt1/test/simulation_0.h5'
path3 = '/home/niteesh/Documents/source_codes/learning_to_simulate_pouring/datasets/Pouring_combined/test/simulation_2.h5'


In [46]:
with h5py.File(path1, 'r') as hf:
    #iterate over the each graph index in order
    num_steps = len(hf.keys())
    print(num_steps)

    #first get init positions from the first graph
    graph_key = f'graph_{0}'
    graph_data = hf[graph_key]
    for i,v in graph_data.items():
        print(i, v.shape)

415
liq_position (7, 1047, 3)
mesh_pose (3, 7, 6)
mesh_position (7, 1271, 3)
particle_types (1271,)


In [47]:
combine_h5_files(path1, path2, path3)

with h5py.File(path3, 'r') as hf:
    #iterate over the each graph index in order
    num_steps = len(hf.keys())
    print(num_steps)

    #first get init positions from the first graph
    graph_key = f'graph_{0}'
    graph_data = hf[graph_key]
    for i,v in graph_data.items():
        print(i, v.shape)

415
liq_position (7, 2094, 3)
mesh_pose (6, 7, 6)
mesh_position (7, 2479, 3)
particle_types (2479,)


In [1]:
import json, os
import numpy as np
import jax.numpy as jnp

from reading_utils_torch import MultiGraphDataset_torch, SingleDatasetBatchSampler
import jax.numpy as jnp
import jraph

In [9]:
data_path = ['datasets/Pouring_sdf_MartiniBottle_2701_lessPt1/train', 'datasets/Pouring_sdf_newTest/train', ]

In [30]:

def _read_metadata(data_path):
  with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:
    return json.loads(fp.read())
 



In [31]:
max_n_liq_node_dset, _max_edges_l_dset, _max_edges_m_dset = 0,0,0
mesh_pt_type_lists = []
batch_size = 5
for path in data_path:
    metadata = _read_metadata(f'{path}/../')
    
    collision_mesh_info_list = metadata["collision_mesh"] 
    mesh_pt_type_list = [ z[1] for z in collision_mesh_info_list] #mesh pt type for handling in v_o
    mesh_pt_type_lists.append(mesh_pt_type_list)

    _max_n_liq_node_per_graph = int(metadata["max_n_liq_node"]) #  can be read from position as well.. ignore for now
    _max_edges_l_per_graph =int( metadata["max_n_edge_l"])
    _max_edges_m_per_graph =int( metadata["max_n_edge_m"])

    max_n_liq_node_dset = max(max_n_liq_node_dset, _max_n_liq_node_per_graph)
    _max_edges_l_dset = max(_max_edges_l_dset, _max_edges_l_per_graph)
    _max_edges_m_dset = max(_max_edges_m_dset, _max_edges_m_per_graph)

    connectivity_radius = metadata["default_connectivity_radius"] # HAS TO BE SAME FOR ALL DATASETS!!

#max for all datasets
max_nodes_edges_info = [batch_size,max_n_liq_node_dset, _max_edges_l_dset, _max_edges_m_dset]

In [32]:
mesh_pt_type_lists

[[0, 1, 0], [0, 1, 0]]

In [33]:
dataset = MultiGraphDataset_torch(data_path, mesh_pt_type_lists)

In [34]:
datas = []
for u, d in enumerate(dataset):
    datas.append(d[0])
    if u==batch_size-1:
        break


In [35]:
for k,v in d[0].nodes.items():
    print(k, v.shape)


liq_position (1047, 6, 3)
mesh_position (1, 1208, 6, 3)
mesh_pose (1, 3, 6, 6)
particle_type (1208,)
particle_type_obj (3,)
target (1047, 3)


In [36]:
batched_graphs = jraph.batch(datas)

In [37]:
for k,v in batched_graphs.nodes.items():
    print(k, v.shape)
print(batched_graphs.n_node)

liq_position (5235, 6, 3)
mesh_pose (5, 3, 6, 6)
mesh_position (5, 1208, 6, 3)
particle_type (6040,)
particle_type_obj (15,)
target (5235, 3)
[1047 1047 1047 1047 1047]


In [38]:
padded_graphs = jraph.pad_with_graphs(
        batched_graphs,
        n_node=1047 * len(datas) + 1,
        n_edge=2e4 * len(datas),
        n_graph=len(datas) + 1
    )

In [39]:
for k,v in padded_graphs.nodes.items():
    print(k, v.shape)

liq_position (5236, 6, 3)
mesh_pose (6, 3, 6, 6)
mesh_position (6, 1208, 6, 3)
particle_type (6041,)
particle_type_obj (16,)
target (5236, 3)


In [40]:
batched_graphs.n_node, padded_graphs.n_node

(Array([1047, 1047, 1047, 1047, 1047], dtype=int32),
 array([1047, 1047, 1047, 1047, 1047,    1], dtype=int32))