Skip to content

Commit

Permalink
DVS Gesture dataset测试通过
Browse files Browse the repository at this point in the history
  • Loading branch information
win10-pc committed Sep 11, 2020
1 parent 7473fc6 commit 8bb07f3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 32 deletions.
17 changes: 17 additions & 0 deletions spikingjelly/datasets/als_dvs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import spikingjelly.datasets
import zipfile
import os
import threading
import tqdm
import numpy as np
import struct
from torchvision.datasets import utils
import time
import multiprocessing

labels_dict = {

} # gesture_mapping.csv
# url md5
resource = ['https://www.dropbox.com/sh/ibq0jsicatn7l6r/AACNrNELV56rs1YInMWUs9CAa?dl=0', None]

89 changes: 57 additions & 32 deletions spikingjelly/datasets/dvs_gesture.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def read_bin(file_name: str):


@staticmethod
def convert_aedat_dir_to_npy_dir(aedat_data_dir: str, npy_data_dir: str):
def cvt_files_fun(aedat_file_list):
def convert_aedat_dir_to_npy_dir(aedat_data_dir: str, events_npy_train_root: str, events_npy_test_root: str):
def cvt_files_fun(aedat_file_list, output_dir):
for aedat_file in aedat_file_list:
base_name = aedat_file[0: -6]
events = DvsGesture.read_bin(os.path.join(aedat_data_dir, aedat_file))
Expand Down Expand Up @@ -142,7 +142,7 @@ def cvt_files_fun(aedat_file_list):
# [index_l, index_r)
j = 0
while True:
file_name = os.path.join(npy_data_dir, f'{base_name}_{label}_{j}.npy')
file_name = os.path.join(output_dir, f'{base_name}_{label}_{j}.npy')
# 由于不同线程执行的base_name一定不相同,因此这里不会出现多线程之间的数据复用造成的错误
if os.path.exists(file_name): # 防止同一个aedat里存在多个相同label的数据段
j += 1
Expand All @@ -155,33 +155,38 @@ def cvt_files_fun(aedat_file_list):
})
break

with open(os.path.join(aedat_data_dir, 'trials_to_train.txt')) as trials_to_train_txt, open(
os.path.join(aedat_data_dir, 'trials_to_test.txt')) as trials_to_test_txt:
train_list = []
for fname in trials_to_train_txt.readlines():
fname = fname.strip()
if fname.__len__() > 0:
train_list.append(fname)
test_list = []
for fname in trials_to_test_txt.readlines():
fname = fname.strip()
if fname.__len__() > 0:
test_list.append(fname)


# 将aedat_data_dir目录下的.aedat文件读取并转换成np保存的字典,保存在npy_data_dir目录
print('convert events data from aedat to numpy format.')
# 速度很慢,并行化

# 统计文件总数量
aedat_files = utils.list_files(aedat_data_dir, '.aedat')

npy_data_num = 0
for aedat_file in aedat_files:
csv_file = os.path.join(aedat_data_dir, aedat_file[0: -6] + '_labels.csv')
npy_data_num += np.loadtxt(csv_file, dtype=np.uint32, delimiter=',', skiprows=1).shape[0]

thread_num = multiprocessing.cpu_count()
block = aedat_files.__len__() // thread_num # 分成thread_num个子任务
npy_data_num = train_list.__len__() + test_list.__len__()
thread_num = max(multiprocessing.cpu_count(), 2)
block = train_list.__len__() // (thread_num - 1) # 训练集分成thread_num - 1个子任务
thread_list = []
for i in range(thread_num - 1):
thread_list.append(spikingjelly.datasets.FunctionThread(cvt_files_fun, aedat_files[i * block: (i + 1) * block]))

thread_list.append(spikingjelly.datasets.FunctionThread(cvt_files_fun, train_list[i * block: (i + 1) * block], events_npy_train_root))
print(f'thread {i} start')
thread_list[-1].start()

thread_list.append(spikingjelly.datasets.FunctionThread(cvt_files_fun, aedat_files[(thread_num - 1) * block:]))
# 测试集再单独作为一个线程
thread_list.append(spikingjelly.datasets.FunctionThread(cvt_files_fun, test_list, events_npy_test_root))
print(f'thread {thread_num - 1} start')
thread_list[-1].start()
# 主线程等待各个子线程
# for i in range(thread_list.__len__()):
# thread_list[i].join()
# print('thread', i, 'finished')

with tqdm.tqdm(total=npy_data_num) as pbar:
while True:
Expand All @@ -192,7 +197,7 @@ def cvt_files_fun(aedat_file_list):
working_thread.append(i)
else:
finished_thread.append(i)
pbar.update(utils.list_files(npy_data_dir, '.npy').__len__())
pbar.update(utils.list_files(events_npy_train_root, '.npy').__len__() + utils.list_files(events_npy_test_root, '.npy').__len__())
print('wroking thread:', working_thread)
print('finished thread:', finished_thread)
if finished_thread.__len__() == thread_list.__len__():
Expand Down Expand Up @@ -221,43 +226,63 @@ def get_events_item(file_name):
def get_frames_item(file_name):
return np.load(file_name), int(os.path.basename(file_name).split('_')[-2]) - 1

def __init__(self, root: str, use_frame=True, frames_num=10, split_by='number', normalization='max'):
def __init__(self, root: str, train: bool, use_frame=True, frames_num=10, split_by='number', normalization='max'):
events_npy_root = os.path.join(root, 'events_npy')
if os.path.exists(events_npy_root):
print(f'npy format events data root {events_npy_root} already exists')
events_npy_train_root = os.path.join(events_npy_root, 'train')
events_npy_test_root = os.path.join(events_npy_root, 'test')
if os.path.exists(events_npy_train_root) and os.path.exists(events_npy_test_root):
print(f'npy format events data root {events_npy_train_root}, {events_npy_test_root} already exists')
else:

extracted_root = os.path.join(root, 'extracted')
if os.path.exists(extracted_root):
print(f'extracted root {extracted_root} already exists.')
else:
self.download_and_extract(root, extracted_root)
os.mkdir(events_npy_root)
print(f'mkdir {events_npy_root}')
if not os.path.exists(events_npy_root):
os.mkdir(events_npy_root)
print(f'mkdir {events_npy_root}')
os.mkdir(events_npy_train_root)
print(f'mkdir {events_npy_train_root}')
os.mkdir(events_npy_test_root)
print(f'mkdir {events_npy_test_root}')
print('read events data from *.aedat and save to *.npy...')
DvsGesture.convert_aedat_dir_to_npy_dir(os.path.join(extracted_root, 'DvsGesture'), events_npy_root)
DvsGesture.convert_aedat_dir_to_npy_dir(os.path.join(extracted_root, 'DvsGesture'), events_npy_train_root, events_npy_test_root)


self.file_name = [] # 保存数据文件的路径
self.use_frame = use_frame
self.data_dir = None
if use_frame:
frames_root = os.path.join(root, f'frames_num_{frames_num}_split_by_{split_by}_normalization_{normalization}')
frames_train_root = os.path.join(frames_root, 'train')
frames_test_root = os.path.join(frames_root, 'test')
if os.path.exists(frames_root):
# 如果root目录下存在frames_root目录,则认为数据集文件存在
print(f'frames data root {frames_root} already exists.')
else:
os.mkdir(frames_root)
print(f'mkdir {frames_root}.')
os.mkdir(frames_train_root)
os.mkdir(frames_test_root)
print(f'mkdir {frames_root}, {frames_train_root}, {frames_test_root}.')
print('creating frames data..')
DvsGesture.create_frames_dataset(events_npy_root, frames_root, frames_num, split_by, normalization)
DvsGesture.create_frames_dataset(events_npy_train_root, frames_train_root, frames_num, split_by, normalization)
DvsGesture.create_frames_dataset(events_npy_test_root, frames_test_root, frames_num, split_by, normalization)
if train:
self.data_dir = frames_train_root
else:
self.data_dir = frames_test_root

self.file_name = utils.list_files(frames_root, '.npy', True)
self.data_dir = frames_root

self.file_name = utils.list_files(self.data_dir, '.npy', True)
self.get_item_fun = DvsGesture.get_frames_item

else:
self.file_name = utils.list_files(events_npy_root, '.npy', True)
self.data_dir = events_npy_root
if train:
self.data_dir = events_npy_train_root
else:
self.data_dir = events_npy_test_root
self.file_name = utils.list_files(self.data_dir, '.npy', True)
self.get_item_fun = DvsGesture.get_events_item


Expand Down

0 comments on commit 8bb07f3

Please sign in to comment.