Skip to content

Commit

Permalink
重写数据集基类
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Sep 9, 2020
1 parent 857780b commit 340efc5
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 36 deletions.
111 changes: 99 additions & 12 deletions spikingjelly/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@
import torch
import threading
import zipfile

from torchvision.datasets import utils
class FunctionThread(threading.Thread):
def __init__(self, f, **kwargs):
def __init__(self, f, *args, **kwargs):
super().__init__()
self.f = f
self.args = args
self.kwargs = kwargs
def run(self):
self.f(**self.kwargs)
self.f(*self.args, **self.kwargs)

def integrate_events_to_frames(events, height, weight, frames_num=10, split_by='time', normalization=None):
def integrate_events_to_frames(events, height, width, frames_num=10, split_by='time', normalization=None):
'''
:param events: 键是{'t', 'x', 'y', 'p'},值是np数组的的字典
:param height: 脉冲数据的高度,例如对于DVS CIFAR10是128
:param weight: 脉冲数据的宽度,例如对于DVS CIFAR10是128
:param height: 脉冲数据的高度,例如对于CIFAR10-DVS是128
:param width: 脉冲数据的宽度,例如对于CIFAR10-DVS是128
:param frames_num: 转换后数据的帧数
:param split_by: ``'time'`` 或 ``'number'``。为 ``'time'`` 表示将events数据在时间上分段,例如events记录的 ``t`` 介于
:param split_by: 脉冲数据转换成帧数据的累计方式。``'time'`` 或 ``'number'``。为 ``'time'`` 表示将events数据在时间上分段,例如events记录的 ``t`` 介于
[0, 105]且 ``frames_num=10``,则转化得到的10帧分别为 ``t`` 属于[0, 10), [10,20), ..., [90, 105)的
脉冲的累加;
为 ``'number'`` 表示将events数据在数量上分段,例如events一共有105个且 ``frames_num=10``,则转化得到
Expand All @@ -29,10 +30,10 @@ def integrate_events_to_frames(events, height, weight, frames_num=10, split_by='
为 ``'frequency'`` 则每一帧的数据除以每一帧的累加的原始数据数量;
为 ``'max'`` 则每一帧的数据除以每一帧中数据的最大值;
为 ``norm`` 则每一帧的数据减去每一帧中的均值,然后除以标准差
:return: 转化后的frames数据,是一个 ``shape = [frames_num, 2, height, weight]`` 的np数组
:return: 转化后的frames数据,是一个 ``shape = [frames_num, 2, height, width]`` 的np数组
'''
frames = np.zeros(shape=[frames_num, 2, height, weight])
frames = np.zeros(shape=[frames_num, 2, height, width])

if split_by == 'time':
# 按照脉冲的发生时刻进行等分
Expand Down Expand Up @@ -64,8 +65,6 @@ def integrate_events_to_frames(events, height, weight, frames_num=10, split_by='
else:
raise NotImplementedError
return frames


elif split_by == 'number':
# 按照脉冲数量进行等分
dt = events['t'].shape[0] // frames_num
Expand All @@ -76,7 +75,6 @@ def integrate_events_to_frames(events, height, weight, frames_num=10, split_by='
index_r = events['t'].shape[0]
else:
index_r = index_l + dt

frames[i, events['p'][index_l:index_r], events['y'][index_l:index_r], events['x'][index_l:index_r]] \
+= events['t'][index_l:index_r]
if normalization == 'frequency':
Expand All @@ -93,6 +91,13 @@ def integrate_events_to_frames(events, height, weight, frames_num=10, split_by='
else:
raise NotImplementedError

def convert_events_dir_to_frames_dir(events_data_dir, frames_data_dir, suffix, read_function, height, width,
frames_num=10, split_by='time', normalization=None):
# 遍历events_data_dir目录下的所有脉冲数据文件,在frames_data_dir目录下生成帧数据文件
for events_file in utils.list_files(events_data_dir, suffix, True):
frames = integrate_events_to_frames(read_function(events_file), height, width, frames_num, split_by, normalization)
frames_file = os.path.join(frames_data_dir, os.path.basename(events_file)[0: -suffix.__len__()] + '.npy')
np.save(frames_file, frames)

def extract_zip_in_dir(source_dir, target_dir):
'''
Expand All @@ -108,6 +113,88 @@ def extract_zip_in_dir(source_dir, target_dir):
with zipfile.ZipFile(os.path.join(source_dir, file_name), 'r') as zip_file:
zip_file.extractall(os.path.join(target_dir, file_name[:-4]))

class EventsFramesDatasetBase(Dataset):
@staticmethod
def get_wh():
'''
:return: (width, height)
width: int
events或frames图像的宽度
height: int
events或frames图像的高度
:rtype: tuple
'''
raise NotImplementedError

@staticmethod
def read_bin(file_name: str):
'''
:param file_name: 脉冲数据的文件名
:type file_name: str
:return: events
键是{'t', 'x', 'y', 'p'},值是np数组的的字典
:rtype: dict
'''
raise NotImplementedError

@staticmethod
def get_events_item(file_name):
'''
:param file_name: 脉冲数据的文件名
:type file_name: str
:return: (events, label)
events: dict
键是{'t', 'x', 'y', 'p'},值是np数组的的字典
label: int
数据的标签
:rtype: tuple
'''
raise NotImplementedError

@staticmethod
def get_frames_item(file_name):
'''
:param file_name: 帧数据的文件名
:type file_name: str
:return: (frames, label)
frames: np.ndarray
``shape = [frames_num, 2, height, width]`` 的np数组
label: int
数据的标签
:rtype: tuple
'''
raise NotImplementedError

@staticmethod
def download_and_extract(download_root: str, extract_root: str):
'''
:param download_root: 保存下载文件的文件夹
:type download_root: str
:param extract_root: 保存解压后文件的文件夹
:type extract_root: str
下载数据集到 ``download_root``,然后解压到 ``extract_root``。
'''
raise NotImplementedError

@staticmethod
def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num: int, split_by: str, normalization: str or None):
'''
:param events_data_dir: 保存脉冲数据的文件夹,文件夹的文件全部是脉冲数据
:type events_data_dir: str
:param frames_data_dir: 保存帧数据的文件夹
:type frames_data_dir: str
:param frames_num: 转换后数据的帧数
:type frames_num: int
:param split_by: 脉冲数据转换成帧数据的累计方式
:type split_by: str
:param normalization: 归一化方法
:type normalization: str or None
将 ``events_data_dir`` 文件夹下的脉冲数据全部转换成帧数据,并保存在 ``frames_data_dir``。
转换参数的详细含义,参见 ``integrate_events_to_frames`` 函数。
'''
raise NotImplementedError

class SubDirDataset(Dataset):
def __init__(self, root, train=True, split_ratio=0.9):
Expand Down
2 changes: 1 addition & 1 deletion spikingjelly/datasets/n_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def read_bin(file_name: str):
x = raw_data[0::5]
y = raw_data[1::5]
rd_2__5 = raw_data[2::5]
p = (rd_2__5 & 128) >> 7 # bit 7
p = (rd_2__5 & 128) >> 7
t = ((rd_2__5 & 127) << 16) | (raw_data[3::5] << 8) | (raw_data[4::5])
return {'t': t, 'x': x, 'y': y, 'p': p}

Expand Down
111 changes: 88 additions & 23 deletions spikingjelly/datasets/nav_gesture.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,28 @@
from torch.utils.data import Dataset
from torchvision.datasets import utils
import shutil
import loris

# url md5
resource = {
'walk': ['https://www.neuromorphic-vision.com/public/downloads/navgesture/navgesture-walk.zip',
'5d305266f13005401959e819abe206f0'],
'sit': ['https://www.neuromorphic-vision.com/public/downloads/navgesture/navgesture-sit.zip', None]
'5d305266f13005401959e819abe206f0']
}
labels = {
'do': 0,
'up': 1,
'le': 2,
'ri': 3,
'se': 4,
'ho': 5
}


class NavGesture(spikingjelly.datasets.EventsFramesDatasetBase):
@staticmethod
def get_wh():
return 304, 240

class NavGesture(Dataset):
@staticmethod
def read_bin(file_name: str):
'''
Expand All @@ -35,27 +46,33 @@ def read_bin(file_name: str):
2 bits for polarity
13 bits padding
'''
with open(file_name, 'rb') as bin_f:
# `& 128` 是取一个8位二进制数的最高位
# `& 127` 是取其除了最高位,也就是剩下的7位
raw_data = np.uint64(np.fromfile(bin_f, dtype=np.uint8))
t = (raw_data[0::8] << 24) | (raw_data[1::8] << 16) | (raw_data[2::8] << 8) | raw_data[3::8]
rd_5__8 = raw_data[5::8]
x = (raw_data[4::8] << 8) | (rd_5__8 & 128 >> 7)
rd_6__8 = raw_data[6::8]
y = (rd_5__8 & 127 << 1) | (rd_6__8 & 128)
# 0b01110000 = 112
p = rd_6__8 & 112 >> 4
return {'t': t, 'x': x, 'y': y, 'p': p}
txyp = loris.read_file(file_name)['events']
# txyp.p是bool类型,转换成int
return {'t': txyp.t, 'x': txyp.x, 'y': txyp.y, 'p': txyp.p.astype(int)}

@staticmethod
def get_label(file_name):
# 6 gestures: left, right, up, down, home, select.
# 10 subjects, holding the phone in one hand (selfie mode) while walking indoor and outdoor. It contains 339 clips.
# No train/test split, scores should be reported using average score with one-versus-all cross-validation.
# Files are named userID_classID_userclipID.dat and allow to identify the user and gesture class. For example, "user09_do_11.dat" is a "Down Swipe" gesture from user09. classID can be:
# do: down swipe ; up: up swipe ; le: left swipe ; ri: right swipe ; se: select ; ho: home
base_name = os.path.basename(file_name)
return labels[base_name.split('_')[1]]

@staticmethod
def download_and_extract(dataset_name: str, download_root: str, extract_root=None):
assert dataset_name == 'walk' or dataset_name == 'sit'
def get_events_item(file_name):
events = NavGesture.read_bin(file_name)
return events, NavGesture.get_label(file_name)

@staticmethod
def get_frames_item(file_name):
frames = np.load(file_name)
return frames, NavGesture.get_label(file_name)
@staticmethod
def download_and_extract(download_root: str, extract_root: str):
dataset_name = 'walk'
file_name = os.path.basename(resource[dataset_name][0])
# utils.download_url(url=resource[dataset_name][0], root=download_root,
# filename=file_name, md5=resource[dataset_name][1])
if extract_root is None:
extract_root = os.path.join(download_root, 'extract')
temp_extract_root = os.path.join(extract_root, 'temp_extract')
utils.download_and_extract_archive(url=resource[dataset_name][0], download_root=download_root,
extract_root=temp_extract_root,
Expand All @@ -65,5 +82,53 @@ def download_and_extract(dataset_name: str, download_root: str, extract_root=Non
print(f'extract {zip_file} to {extract_root}')
utils.extract_archive(zip_file, extract_root)
shutil.rmtree(temp_extract_root)
print(f'dataset dir is {extract_root}')
return extract_root

@staticmethod
def create_frames_dataset(events_data_dir, frames_data_dir, frames_num=10, split_by='time', normalization=None):
width, height = NavGesture.get_wh()
thread_list = []
for source_dir in utils.list_dir(events_data_dir):
abs_source_dir = os.path.join(events_data_dir, source_dir)
abs_target_dir = os.path.join(frames_data_dir, source_dir)
if not os.path.exists(abs_target_dir):
os.mkdir(abs_target_dir)
print(f'mkdir {abs_target_dir}')
print(f'thread {thread_list.__len__()} convert events data in {abs_source_dir} to {abs_target_dir}')
thread_list.append(spikingjelly.datasets.FunctionThread(spikingjelly.datasets.convert_events_dir_to_frames_dir,
abs_source_dir, abs_target_dir, '.dat', NavGesture.read_bin, height, width, frames_num, split_by, normalization))
thread_list[-1].start()
for i in range(thread_list.__len__()):
thread_list[i].join()
print('thread', i, 'finished')

def __init__(self, root: str, use_frame=True, frames_num=10, split_by='number', normalization=None):
events_root = os.path.join(root, 'events')
if os.path.exists(events_root) and os.listdir(events_root).__len__() == 10:
# 如果root目录下存在events_root目录,且events_root下有10个子文件夹,则认为数据集文件存在
print(f'events data root {events_root} already exists.')
else:
self.download_and_extract(root, events_root)
self.file_name = [] # 保存数据文件的路径
self.use_frame = use_frame
self.data_dir = None
if use_frame:
frames_root = os.path.join(root, 'frames')
NavGesture.create_frames_dataset(events_root, frames_root, frames_num, split_by, normalization)
for sub_dir in utils.list_dir(frames_root, True):
self.file_name.extend(utils.list_files(sub_dir, '.npy', True))
self.data_dir = frames_root
self.get_item_fun = NavGesture.get_frames_item

else:
for sub_dir in utils.list_dir(events_root, True):
self.file_name.extend(utils.list_files(sub_dir, '.dat', True))
self.data_dir = events_root
self.get_item_fun = NavGesture.get_events_item

def __len__(self):
return self.file_name.__len__()

def __getitem__(self, index):
return self.get_item_fun(self.file_name[index])


0 comments on commit 340efc5

Please sign in to comment.