In [None]:
import numpy as np
from tqdm.notebook import tqdm
from dgl.data.utils import load_graphs
import os
import gc

In [None]:
def graph_to_image(graph, global_sample_num):
    truthmatch = int(graph.nodes['global'].data['TruthMatch'])
    
    image_size = (3, 64, 64) # Change to desired image size
    cell_image = np.zeros(image_size)
    eta_arr = np.array(graph.nodes['points'].data['center'][:,0])
    phi_arr = np.array(graph.nodes['points'].data['center'][:,1])
    E_arr = np.array(graph.nodes['points'].data['E']) / np.array(graph.nodes['global'].data['E'])
    types_arr = np.array(graph.nodes['points'].data['type']).astype(int)
    
    eta_indexed_arr = np.floor(eta_arr * (image_size[1] - 1)/2 + (image_size[1] - 1)/2).astype(int)
    phi_indexed_arr = np.floor(phi_arr * (image_size[2] - 1)/2 + (image_size[2] - 1)/2).astype(int)
    
    for j in range(len(eta_arr)):
        cell_image[types_arr[j]-2, phi_indexed_arr[j], eta_indexed_arr[j]] = E_arr[j]
        
    outfile_data = {}
    outfile_data['truthmatch'] = truthmatch
    outfile_data['cell_image'] = cell_image
    outfile_name = './alldata' + '/sample_' + str(global_sample_num).zfill(6) + '.pkl'
    with open(outfile_name, 'wb') as f:
        pickle.dump(outfile_data, f)
    gc.collect()

In [None]:
def generate_images(data_bin_file_location):    
    #Read the bin file
    graphs_set = load_graphs(data_bin_file_location)
    graphs_list = graphs_set[0]
    print("DONE READ")
    
    #Generate the dataset from graphs list
    for i,graph in enumerate(tqdm(graphs_list)):
        graph_to_image(graph,i)
    gc.collect()

In [None]:
def load_pickle(data_dir, data_file):
    file_path = data_dir + '/' + data_file
    with open(file_path, 'rb') as fi:
        obj = pickle.load(fi)
    return obj

In [None]:
def join_pickles(pickle_dir):
    pickle_files_list = os.listdir(pickle_dir)
    pickle_files_list.sort()
    
    pickle_list = []
    
    for filename in tqdm(pickle_files_list):
        pickle_list.append(load_pickle(pickle_dir,filename))

    gc.collect()

    with open('all_samples.pkl', 'wb') as fo:
        pickle.dump(pickle_list, fo)

    gc.collect()

In [None]:
data_bin_file_name = '../dataset.bin' # change to the dataset file name and location

generate_images(data_bin_file_name) # This generates several pickles, one for each event
join_pickles('./alldata') # This puts all the events in one pickle file