Skip to content

Commit

Permalink
重写了数据集
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Sep 13, 2020
1 parent 0ac9bed commit ddefc8e
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 225 deletions.
64 changes: 6 additions & 58 deletions spikingjelly/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def integrate_events_to_frames(events, height, width, frames_num=10, split_by='t
'''
frames = np.zeros(shape=[frames_num, 2, height, width])

eps = 1e-5 # 涉及到除法的地方,被除数加上eps,防止出现除以0
if split_by == 'time':
# 按照脉冲的发生时刻进行等分
events['t'] -= events['t'][0]
Expand All @@ -57,9 +57,9 @@ def integrate_events_to_frames(events, height, width, frames_num=10, split_by='t
if normalization == 'frequency':
frames[i] /= dt # 表示脉冲发放的频率
elif normalization == 'max':
frames[i] /= frames[i].max()
frames[i] /= max(frames[i].max(), eps)
elif normalization == 'norm':
frames[i] = (frames[i] - frames[i].mean()) / np.sqrt((frames[i].var() + 1e-5))
frames[i] = (frames[i] - frames[i].mean()) / np.sqrt(max(frames[i].var(), eps))
elif normalization is None:
continue
else:
Expand All @@ -80,10 +80,10 @@ def integrate_events_to_frames(events, height, width, frames_num=10, split_by='t
if normalization == 'frequency':
frames[i] /= dt # 表示脉冲发放的频率
elif normalization == 'max':
frames[i] /= max(frames[i].max(), 1e-5)
frames[i] /= max(frames[i].max(), eps)

elif normalization == 'norm':
frames[i] = (frames[i] - frames[i].mean()) / np.sqrt((frames[i].var() + 1e-5))
frames[i] = (frames[i] - frames[i].mean()) / np.sqrt(max(frames[i].var(), eps))
elif normalization is None:
continue
else:
Expand Down Expand Up @@ -220,56 +220,4 @@ def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num
将 ``events_data_dir`` 文件夹下的脉冲数据全部转换成帧数据,并保存在 ``frames_data_dir``。
转换参数的详细含义,参见 ``integrate_events_to_frames`` 函数。
'''
raise NotImplementedError

class SubDirDataset(Dataset):
def __init__(self, root, train=True, split_ratio=0.9):
'''
:param root: 保存数据集的文件夹
:param train: 训练还是测试
:param split_ratio: 训练集占数据集的比例,对于每一类的数据,按文件夹内部数据文件的命名排序,取前 ``split_ratio`` 的数据作为训练集,
其余数据作为测试集
适用于包含多个子文件夹,每个子文件夹名称为类别名,子文件夹内部是npz格式的数据的数据集基类。文件结构类似如下所示:
.. code-block:: bash
dvs_cifar10_npz/
|-- airplane
| |-- 0.npz
| |-- ...
|-- automobile
|-- bird
|-- cat
|-- deer
|-- dog
|-- frog
|-- horse
|-- ship
`-- truck
'''

self.root = root
self.label_name = os.listdir(self.root)
self.file_path = []
self.label = []

for i in range(self.label_name.__len__()):
sub_dir_path = os.path.join(self.root, self.label_name[i])
file_names = os.listdir(sub_dir_path)
split_boundary = int(file_names.__len__() * split_ratio)
if train:
for j in range(0, split_boundary):
self.file_path.append(os.path.join(sub_dir_path, file_names[j]))
self.label.append(i)
else:
for j in range(split_boundary, file_names.__len__()):
self.file_path.append(os.path.join(sub_dir_path, file_names[j]))
self.label.append(i)

def __len__(self):
return self.file_path.__len__()

def __getitem__(self, index):
frame = torch.from_numpy(np.load(self.file_path[index])['arr_0']).float()
return frame, self.label[index]
raise NotImplementedError

0 comments on commit ddefc8e

Please sign in to comment.