Skip to content

Commit

Permalink
print exception in threadpool to avoid the errors can not be reported…
Browse files Browse the repository at this point in the history
… to users as #370
  • Loading branch information
fangwei123456 committed May 30, 2023
1 parent 2d2ec2f commit 6dca147
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 51 deletions.
18 changes: 15 additions & 3 deletions spikingjelly/datasets/asl_dvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@ def extract_downloaded_files(download_root: str, extract_root: str):
print(f'Mkdir [{temp_ext_dir}].')
extract_archive(os.path.join(download_root, 'ICCV2019_DVS_dataset.zip'), temp_ext_dir)
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe:
sub_threads = []
for zip_file in os.listdir(temp_ext_dir):
if os.path.splitext(zip_file)[1] == '.zip':
zip_file = os.path.join(temp_ext_dir, zip_file)
print(f'Extract [{zip_file}] to [{extract_root}].')
tpe.submit(extract_archive, zip_file, extract_root)
sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)


shutil.rmtree(temp_ext_dir)
print(f'Rmtree [{temp_ext_dir}].')
Expand Down Expand Up @@ -129,6 +135,7 @@ def create_events_np_files(extract_root: str, events_np_root: str):
'''
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
sub_threads = []
for class_name in os.listdir(extract_root):
mat_dir = os.path.join(extract_root, class_name)
np_dir = os.path.join(events_np_root, class_name)
Expand All @@ -138,8 +145,13 @@ def create_events_np_files(extract_root: str, events_np_root: str):
source_file = os.path.join(mat_dir, bin_file)
target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz')
print(f'Start to convert [{source_file}] to [{target_file}].')
tpe.submit(ASLDVS.read_mat_save_to_np, source_file,
target_file)
sub_threads.append(tpe.submit(ASLDVS.read_mat_save_to_np, source_file,
target_file))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)



print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
19 changes: 16 additions & 3 deletions spikingjelly/datasets/cifar10_dvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,17 @@ def extract_downloaded_files(download_root: str, extract_root: str):
This function defines how to extract download files.
'''
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 10)) as tpe:
sub_threads = []
for zip_file in os.listdir(download_root):
zip_file = os.path.join(download_root, zip_file)
print(f'Extract [{zip_file}] to [{extract_root}].')
tpe.submit(extract_archive, zip_file, extract_root)
sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))

for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)



@staticmethod
Expand Down Expand Up @@ -227,6 +234,7 @@ def create_events_np_files(extract_root: str, events_np_root: str):
'''
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
sub_threads = []
for class_name in os.listdir(extract_root):
aedat_dir = os.path.join(extract_root, class_name)
np_dir = os.path.join(events_np_root, class_name)
Expand All @@ -236,6 +244,11 @@ def create_events_np_files(extract_root: str, events_np_root: str):
source_file = os.path.join(aedat_dir, bin_file)
target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz')
print(f'Start to convert [{source_file}] to [{target_file}].')
tpe.submit(CIFAR10DVS.read_aedat_save_to_np, source_file,
target_file)
sub_threads.append(tpe.submit(CIFAR10DVS.read_aedat_save_to_np, source_file,
target_file))

for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
15 changes: 12 additions & 3 deletions spikingjelly/datasets/dvs128_gesture.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,22 +284,31 @@ def create_events_np_files(extract_root: str, events_np_root: str):
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
sub_threads = []
print(f'Start the ThreadPoolExecutor with max workers = [{tpe._max_workers}].')


for fname in trials_to_train_txt.readlines():
fname = fname.strip()
if fname.__len__() > 0:
aedat_file = os.path.join(aedat_dir, fname)
fname = os.path.splitext(fname)[0]
tpe.submit(DVS128Gesture.split_aedat_files_to_np, fname, aedat_file, os.path.join(aedat_dir, fname + '_labels.csv'), train_dir)
sub_threads.append(tpe.submit(DVS128Gesture.split_aedat_files_to_np, fname, aedat_file, os.path.join(aedat_dir, fname + '_labels.csv'), train_dir))


for fname in trials_to_test_txt.readlines():
fname = fname.strip()
if fname.__len__() > 0:
aedat_file = os.path.join(aedat_dir, fname)
fname = os.path.splitext(fname)[0]
tpe.submit(DVS128Gesture.split_aedat_files_to_np, fname, aedat_file,
os.path.join(aedat_dir, fname + '_labels.csv'), test_dir)
sub_threads.append(tpe.submit(DVS128Gesture.split_aedat_files_to_np, fname, aedat_file,
os.path.join(aedat_dir, fname + '_labels.csv'), test_dir))


for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)

print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
print(f'All aedat files have been split to samples and saved into [{train_dir, test_dir}].')
Expand Down
8 changes: 7 additions & 1 deletion spikingjelly/datasets/hardvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,17 @@ def extract_downloaded_files(download_root: str, extract_root: str):
print(f'Mkdir [{temp_ext_dir}].')
extract_archive(os.path.join(download_root, 'MINI_HARDVS_files.zip'), temp_ext_dir)
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe:
sub_threads = []
for i in range(1, 301):
zip_file = os.path.join(temp_ext_dir, 'MINI_HARDVS_files', 'action_' + str(i).zfill(3) + '.zip')
target_dir = os.path.join(extract_root, 'action_' + str(i).zfill(3))
print(f'Extract [{zip_file}] to [{target_dir}].')
tpe.submit(extract_archive, zip_file, target_dir)
sub_threads.append(tpe.submit(extract_archive, zip_file, target_dir))

for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)

shutil.rmtree(temp_ext_dir)
print(f'Rmtree [{temp_ext_dir}].')
Expand Down
9 changes: 7 additions & 2 deletions spikingjelly/datasets/n_caltech101.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def create_events_np_files(extract_root: str, events_np_root: str):
t_ckp = time.time()
extract_root = os.path.join(extract_root, 'Caltech101')
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
sub_threads = []
# too many threads will make the disk overload
for class_name in os.listdir(extract_root):
bin_dir = os.path.join(extract_root, class_name)
Expand All @@ -123,8 +124,12 @@ def create_events_np_files(extract_root: str, events_np_root: str):
source_file = os.path.join(bin_dir, bin_file)
target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz')
print(f'Start to convert [{source_file}] to [{target_file}].')
tpe.submit(NCaltech101.read_bin_save_to_np, source_file,
target_file)
sub_threads.append(tpe.submit(NCaltech101.read_bin_save_to_np, source_file,
target_file))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)


print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
19 changes: 16 additions & 3 deletions spikingjelly/datasets/n_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,16 @@ def extract_downloaded_files(download_root: str, extract_root: str):
This function defines how to extract download files.
'''
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe:
sub_threads = []
for zip_file in os.listdir(download_root):
zip_file = os.path.join(download_root, zip_file)
print(f'Extract [{zip_file}] to [{extract_root}].')
tpe.submit(extract_archive, zip_file, extract_root)
sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))

for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)


@staticmethod
Expand Down Expand Up @@ -114,6 +120,7 @@ def create_events_np_files(extract_root: str, events_np_root: str):
'''
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
sub_threads = []
# too many threads will make the disk overload
for train_test_dir in ['Train', 'Test']:
source_dir = os.path.join(extract_root, train_test_dir)
Expand All @@ -129,8 +136,14 @@ def create_events_np_files(extract_root: str, events_np_root: str):
source_file = os.path.join(bin_dir, bin_file)
target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz')
print(f'Start to convert [{source_file}] to [{target_file}].')
tpe.submit(NMNIST.read_bin_save_to_np, source_file,
target_file)
sub_threads.append(tpe.submit(NMNIST.read_bin_save_to_np, source_file,
target_file))


for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)


print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
27 changes: 23 additions & 4 deletions spikingjelly/datasets/nav_gesture.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,17 @@ def extract_downloaded_files(download_root: str, extract_root: str):
print(f'Mkdir [{temp_ext_dir}].')
extract_archive(os.path.join(download_root, 'navgesture-walk.zip'), temp_ext_dir)
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 4)) as tpe:
sub_threads = []
for zip_file in os.listdir(temp_ext_dir):
if os.path.splitext(zip_file)[1] == '.zip':
zip_file = os.path.join(temp_ext_dir, zip_file)
print(f'Extract [{zip_file}] to [{extract_root}].')
tpe.submit(extract_archive, zip_file, extract_root)
sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))

for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)

shutil.rmtree(temp_ext_dir)
print(f'Rmtree [{temp_ext_dir}].')
Expand Down Expand Up @@ -292,15 +298,22 @@ def create_events_np_files(extract_root: str, events_np_root: str):
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(),
configure.max_threads_number_for_datasets_preprocess)) as tpe:
for user_name in os.listdir(extract_root):
sub_threads = []
aedat_dir = os.path.join(extract_root, user_name)
for bin_file in os.listdir(aedat_dir):
base_name = os.path.splitext(bin_file)[0]
label = base_name.split('_')[1]
source_file = os.path.join(aedat_dir, bin_file)
target_file = os.path.join(np_dir_dict[label], base_name + '.npz')
print(f'Start to convert [{source_file}] to [{target_file}].')
tpe.submit(NAVGestureWalk.read_aedat_save_to_np, source_file,
target_file)

sub_threads.append(tpe.submit(NAVGestureWalk.read_aedat_save_to_np, source_file,
target_file))

for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')


Expand Down Expand Up @@ -329,11 +342,17 @@ def extract_downloaded_files(download_root: str, extract_root: str):
print(f'Mkdir [{temp_ext_dir}].')
extract_archive(os.path.join(download_root, 'navgesture-sit.zip'), temp_ext_dir)
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 4)) as tpe:
sub_threads = []
for zip_file in os.listdir(temp_ext_dir):
if os.path.splitext(zip_file)[1] == '.zip':
zip_file = os.path.join(temp_ext_dir, zip_file)
print(f'Extract [{zip_file}] to [{extract_root}].')
tpe.submit(extract_archive, zip_file, extract_root)
sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))

for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)

shutil.rmtree(temp_ext_dir)
print(f'Rmtree [{temp_ext_dir}].')

0 comments on commit 6dca147

Please sign in to comment.