In [2]:
import multiprocessing
import pyvista as pv
import os, sys
import random
import numpy as np
import json
import time
from functools import partial

Test notebook for generating projections in parallel using multiprocessing package.

In [2]:
def init_plotter(res, bg_color):
    pl = pv.Plotter(off_screen=True, window_size=[res, res])
    pl.background_color = bg_color
    pl.enable_parallel_projection()
    pl.remove_all_lights()
    return pl 
def random_rotate(mesh):
    rotated = mesh.copy()
    deg_x = np.random.randint(1, 360)
    deg_y = np.random.randint(1, 360)
    deg_z = np.random.randint(1, 360)
    rotated.rotate_x(deg_x, inplace=True)
    rotated.rotate_y(deg_y, inplace=True)
    rotated.rotate_z(deg_z, inplace=True)
    return rotated
def get_orthogonal_vector(v):
    '''
    Return an arbitrary orthogonal vector to v
    '''    
    arbitrary_vector = np.array([1, 0, 0])
    v_orth = np.cross(v, arbitrary_vector)
    # If v1 and arbitrary_vector are parallel, you could end up with a zero vector, 
    # in which case you'd need to choose a different arbitrary vector. Let's check:
    if np.all(v_orth == 0):
        arbitrary_vector = np.array([0, 1, 0]) # set to another arbitrary vector (e.g., [0, 1, 0])
        v_orth = np.cross(v1, arbitrary_vector)
    # Normalize the orthogonal vector v2
    v_orth = v_orth / np.linalg.norm(v_orth)
    return v_orth
# Example function to process a single STL file
def process_stl(file_path, save_dir, n_proj):
    # Read the STL file using PyVista
    mesh = pv.read(file_path)
    res = 224
    bg_color='black'
    obj_color='white'
    op=1.0
    theta_2ds, theta_phips = 90, 120
    for i in range(n_proj):
        # default view
        rotated_mesh = random_rotate(mesh)
        id = file_path.rsplit('-',1)[1].rsplit('.',1)[0] # get id as string
        pl = init_plotter(res, bg_color) # initiate plotter
        actor_default = pl.add_mesh(rotated_mesh, show_edges=None, 
                                    color=obj_color, opacity=op)
        camera_position_default = pl.camera.position
        axis_rotation = get_orthogonal_vector(camera_position_default)
        pl.render()
        filename = f'ros-projection-{id}-{i:02d}-default.png'
        savepath = os.path.join(save_dir, filename)
        pl.screenshot(savepath, return_img=False)
        pl.remove_actor(actor_default)
        # 2ds
        mesh_2ds = rotated_mesh.rotate_vector(axis_rotation, theta_2ds, point=rotated_mesh.center)
        actor_2ds = pl.add_mesh(mesh_2ds, show_edges=None, color = obj_color, opacity=op)
        filename = f'ros-projection-{id}-{i:02d}-2ds.png'
        savepath = os.path.join(save_dir, filename)
        pl.render()
        pl.screenshot(savepath, return_img=False)
        pl.remove_actor(actor_2ds) # remove 2ds 
        # phips
        mesh_phips = rotated_mesh.rotate_vector(axis_rotation, theta_phips, point=rotated_mesh.center)
        pl.add_mesh(mesh_phips, show_edges=None, color = obj_color, opacity=op)
        filename = f'ros-projection-{id}-{i:02d}-phips.png'
        savepath = os.path.join(save_dir, filename)
        pl.render()
        pl.screenshot(savepath, return_img=False)
        # close plotter
        pl.close()

In [3]:
# %%time
# # 1 core (serial)
# def main():
#     save_dir = '/glade/derecho/scratch/joko/synth-ros/test-mp'
#     os.makedirs(save_dir, exist_ok=True)
#     # Load the JSON file
#     params_path = '/glade/u/home/joko/ice3d/output/params_200_50.json'
#     with open(params_path, 'rb') as file:
#         params = json.load(file)
#     # load list of STL filepaths
#     stl_paths_txt = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl/stl_relative_paths.txt'
#     with open(stl_paths_txt, 'r') as file:
#         basepath = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl'
#         rel_paths = [line.strip().replace('./','') for line in file]
#         stl_paths = [os.path.join(basepath, i) for i in rel_paths]
#     n_samples = 100
#     n_proj = 100
#     stl_paths = stl_paths[:n_samples] # subset if needed
#     num_cores = 1 # set number of CPU cores e.g., multiprocessing.cpu_count()
#     process_stl_params = partial(process_stl, save_dir=save_dir, n_proj=n_proj)
#     # Use a multiprocessing Pool to process the STL files in parallel
#     with multiprocessing.Pool(processes=num_cores) as pool:
#         pool.map(process_stl_params, stl_paths)
# if __name__ == "__main__":
#     main()

Based on serial execution, this means it takes ~10 seconds per STL sample.  
- this means 10 seconds x 70,000 samples = 700,000 seconds of processing
- this is theoretically equivalent to ~194 core-hours (not including overhead)

In [4]:
%%time
# Load the JSON file
params_path = '/glade/u/home/joko/ice3d/output/params_200_50.json'
with open(params_path, 'rb') as file:
    params = json.load(file)
# load list of STL filepaths
stl_paths_txt = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl/stl_relative_paths.txt'
with open(stl_paths_txt, 'r') as file:
    basepath = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl'
    rel_paths = [line.strip().replace('./','') for line in file]
    stl_paths = [os.path.join(basepath, i) for i in rel_paths]
stl_paths[0]

CPU times: user 1.16 s, sys: 149 ms, total: 1.31 s
Wall time: 1.41 s


'/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl/9/ros-test-054702.stl'

In [6]:
len(params)

70000

In [4]:
%%time
# 8 cores
def main():
    save_dir = '/glade/derecho/scratch/joko/synth-ros/test-mp'
    os.makedirs(save_dir, exist_ok=True)
    # Load the JSON file
    params_path = '/glade/u/home/joko/ice3d/output/params_200_50.json'
    with open(params_path, 'rb') as file:
        params = json.load(file)
    # load list of STL filepaths
    stl_paths_txt = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl/stl_relative_paths.txt'
    with open(stl_paths_txt, 'r') as file:
        basepath = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl'
        rel_paths = [line.strip().replace('./','') for line in file]
        stl_paths = [os.path.join(basepath, i) for i in rel_paths]
    n_samples = 100
    n_proj = 100
    stl_paths = stl_paths[:n_samples] # subset if needed
    num_cores = 8 # set number of CPU cores e.g., multiprocessing.cpu_count()
    process_stl_params = partial(process_stl, save_dir=save_dir, n_proj=n_proj)
    # Use a multiprocessing Pool to process the STL files in parallel
    with multiprocessing.Pool(processes=num_cores) as pool:
        pool.map(process_stl_params, stl_paths)
if __name__ == "__main__":
    main()

CPU times: user 729 ms, sys: 207 ms, total: 936 ms
Wall time: 2min 42s


In [5]:
%%time
# 16 cores
def main():
    save_dir = '/glade/derecho/scratch/joko/synth-ros/test-mp'
    os.makedirs(save_dir, exist_ok=True)
    # Load the JSON file
    params_path = '/glade/u/home/joko/ice3d/output/params_200_50.json'
    with open(params_path, 'rb') as file:
        params = json.load(file)
    # load list of STL filepaths
    stl_paths_txt = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl/stl_relative_paths.txt'
    with open(stl_paths_txt, 'r') as file:
        basepath = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl'
        rel_paths = [line.strip().replace('./','') for line in file]
        stl_paths = [os.path.join(basepath, i) for i in rel_paths]
    n_samples = 100
    n_proj = 100
    stl_paths = stl_paths[:n_samples] # subset if needed
    num_cores = 16 # set number of CPU cores e.g., multiprocessing.cpu_count()
    process_stl_params = partial(process_stl, save_dir=save_dir, n_proj=n_proj)
    # Use a multiprocessing Pool to process the STL files in parallel
    with multiprocessing.Pool(processes=num_cores) as pool:
        pool.map(process_stl_params, stl_paths)
if __name__ == "__main__":
    main()

CPU times: user 724 ms, sys: 174 ms, total: 898 ms
Wall time: 1min 46s


In [6]:
%%time
# 32 cores
def main():
    save_dir = '/glade/derecho/scratch/joko/synth-ros/test-mp'
    os.makedirs(save_dir, exist_ok=True)
    # Load the JSON file
    params_path = '/glade/u/home/joko/ice3d/output/params_200_50.json'
    with open(params_path, 'rb') as file:
        params = json.load(file)
    # load list of STL filepaths
    stl_paths_txt = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl/stl_relative_paths.txt'
    with open(stl_paths_txt, 'r') as file:
        basepath = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl'
        rel_paths = [line.strip().replace('./','') for line in file]
        stl_paths = [os.path.join(basepath, i) for i in rel_paths]
    n_samples = 100
    n_proj = 100
    stl_paths = stl_paths[:n_samples] # subset if needed
    num_cores = 32 # set number of CPU cores e.g., multiprocessing.cpu_count()
    process_stl_params = partial(process_stl, save_dir=save_dir, n_proj=n_proj)
    # Use a multiprocessing Pool to process the STL files in parallel
    with multiprocessing.Pool(processes=num_cores) as pool:
        pool.map(process_stl_params, stl_paths)
if __name__ == "__main__":
    main()

CPU times: user 840 ms, sys: 232 ms, total: 1.07 s
Wall time: 1min 38s


In [7]:
%%time
# 64 cores
def main():
    save_dir = '/glade/derecho/scratch/joko/synth-ros/test-mp'
    os.makedirs(save_dir, exist_ok=True)
    # Load the JSON file
    params_path = '/glade/u/home/joko/ice3d/output/params_200_50.json'
    with open(params_path, 'rb') as file:
        params = json.load(file)
    # load list of STL filepaths
    stl_paths_txt = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl/stl_relative_paths.txt'
    with open(stl_paths_txt, 'r') as file:
        basepath = '/glade/derecho/scratch/joko/synth-ros/params_200_50_20250403/stl'
        rel_paths = [line.strip().replace('./','') for line in file]
        stl_paths = [os.path.join(basepath, i) for i in rel_paths]
    n_samples = 100
    n_proj = 100
    stl_paths = stl_paths[:n_samples] # subset if needed
    num_cores = 64 # set number of CPU cores e.g., multiprocessing.cpu_count()
    process_stl_params = partial(process_stl, save_dir=save_dir, n_proj=n_proj)
    # Use a multiprocessing Pool to process the STL files in parallel
    with multiprocessing.Pool(processes=num_cores) as pool:
        pool.map(process_stl_params, stl_paths)
if __name__ == "__main__":
    main()

CPU times: user 842 ms, sys: 299 ms, total: 1.14 s
Wall time: 1min 25s
