Skip to content

Commit

Permalink
增加针对尺寸的transform默认参数
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Sep 14, 2020
1 parent 6941481 commit 116afcb
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions spikingjelly/datasets/asl_dvs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import spikingjelly.datasets
import zipfile
import os
import threading
import tqdm
import numpy as np
import struct
from torchvision.datasets import utils
import time
import multiprocessing
import shutil
import scipy.io
from torchvision import transforms
labels_dict = {
'a': 0,
'b': 1,
Expand Down Expand Up @@ -114,7 +111,7 @@ def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num
thread_list[j].join()
print('thread', j, 'finished')

def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, frames_num=10, split_by='number', normalization='max'):
def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, frames_num=10, split_by='number', normalization='max', transform=transforms.Resize((256, 256))):
'''
:param root: 保存数据集的根目录
:type root: str
Expand All @@ -133,6 +130,8 @@ def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, fram
为 ``'max'`` 则每一帧的数据除以每一帧中数据的最大值;
为 ``norm`` 则每一帧的数据减去每一帧中的均值,然后除以标准差
:type normalization: str or None
:param transform: 对帧数据的每一帧进行的变换,默认是将帧数据缩放到256*256
:type transform: callable
ASL-DVS数据集,出自 `Graph-Based Object Classification for Neuromorphic Vision Sensing <https://arxiv.org/abs/1908.06648>`_,
包含24个英文字母(从A到Y,排除J)的美国手语,American Sign Language (ASL)。更多信息参见 https://github.com/PIX2NVS/NVS2Graph,
Expand All @@ -142,6 +141,7 @@ def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, fram
'''
super().__init__()
self.train = train
self.transform = transform
events_root = os.path.join(root, 'events')
if os.path.exists(events_root):
# 如果root目录下存在events_root目录
Expand Down Expand Up @@ -181,7 +181,11 @@ def __len__(self):

def __getitem__(self, index):
if self.use_frame:
return self.get_frames_item(self.file_name[index] + '.npz')
frame, label = self.get_frames_item(self.file_name[index] + '.npz')
if self.transform is not None:
for t in range(frame.shape[0]):
frame[t] = self.transform(frame[t])
return frame, label
else:
return self.get_events_item(self.file_name[index] + '.mat')

Expand Down

0 comments on commit 116afcb

Please sign in to comment.