Skip to content

Commit

Permalink
帧数据的输出修改为tensor而不是np数组
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Sep 14, 2020
1 parent 116afcb commit 332ccf0
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 27 deletions.
1 change: 0 additions & 1 deletion spikingjelly/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from torch.utils.data import Dataset
import os
import numpy as np
import torch
import threading
import zipfile
from torchvision.datasets import utils
Expand Down
15 changes: 4 additions & 11 deletions spikingjelly/datasets/asl_dvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import multiprocessing
import shutil
import scipy.io
from torchvision import transforms
import torch
labels_dict = {
'a': 0,
'b': 1,
Expand Down Expand Up @@ -82,7 +82,7 @@ def get_events_item(file_name):
@staticmethod
def get_frames_item(file_name):
base_name = os.path.basename(file_name)
return np.load(file_name)['arr_0'], labels_dict[base_name[0]]
return torch.from_numpy(np.load(file_name)['arr_0']).float(), labels_dict[base_name[0]]

@staticmethod
def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num: int, split_by: str, normalization: str or None):
Expand Down Expand Up @@ -111,7 +111,7 @@ def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num
thread_list[j].join()
print('thread', j, 'finished')

def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, frames_num=10, split_by='number', normalization='max', transform=transforms.Resize((256, 256))):
def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, frames_num=10, split_by='number', normalization='max'):
'''
:param root: 保存数据集的根目录
:type root: str
Expand All @@ -130,8 +130,6 @@ def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, fram
为 ``'max'`` 则每一帧的数据除以每一帧中数据的最大值;
为 ``norm`` 则每一帧的数据减去每一帧中的均值,然后除以标准差
:type normalization: str or None
:param transform: 对帧数据的每一帧进行的变换,默认是将帧数据缩放到256*256
:type transform: callable
ASL-DVS数据集,出自 `Graph-Based Object Classification for Neuromorphic Vision Sensing <https://arxiv.org/abs/1908.06648>`_,
包含24个英文字母(从A到Y,排除J)的美国手语,American Sign Language (ASL)。更多信息参见 https://github.com/PIX2NVS/NVS2Graph,
Expand All @@ -141,7 +139,6 @@ def __init__(self, root: str, train: bool, split_ratio=0.9, use_frame=True, fram
'''
super().__init__()
self.train = train
self.transform = transform
events_root = os.path.join(root, 'events')
if os.path.exists(events_root):
# 如果root目录下存在events_root目录
Expand Down Expand Up @@ -181,11 +178,7 @@ def __len__(self):

def __getitem__(self, index):
if self.use_frame:
frame, label = self.get_frames_item(self.file_name[index] + '.npz')
if self.transform is not None:
for t in range(frame.shape[0]):
frame[t] = self.transform(frame[t])
return frame, label
return self.get_frames_item(self.file_name[index] + '.npz')
else:
return self.get_events_item(self.file_name[index] + '.mat')

Expand Down
3 changes: 2 additions & 1 deletion spikingjelly/datasets/cifar10_dvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import os
from torchvision.datasets import utils
import torch
labels_dict = {
'airplane': 0,
'automobile': 1,
Expand Down Expand Up @@ -187,7 +188,7 @@ def create_frames_dataset(events_data_dir: str, frames_data_dir: str, frames_num

@staticmethod
def get_frames_item(file_name):
return np.load(file_name)['arr_0'], labels_dict[file_name.split('_')[-2]]
return torch.from_numpy(np.load(file_name)['arr_0']).float(), labels_dict[file_name.split('_')[-2]]

@staticmethod
def get_events_item(file_name):
Expand Down
5 changes: 2 additions & 3 deletions spikingjelly/datasets/dvs128_gesture.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import spikingjelly.datasets
import zipfile
import os
import threading
import tqdm
import numpy as np
import struct
from torchvision.datasets import utils
import time
import multiprocessing
import torch
# https://www.research.ibm.com/dvsgesture/
# https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794

Expand Down Expand Up @@ -224,7 +223,7 @@ def get_events_item(file_name):

@staticmethod
def get_frames_item(file_name):
return np.load(file_name), int(os.path.basename(file_name).split('_')[-2]) - 1
return torch.from_numpy(np.load(file_name)).float(), int(os.path.basename(file_name).split('_')[-2]) - 1

def __init__(self, root: str, train: bool, use_frame=True, frames_num=10, split_by='number', normalization='max'):
'''
Expand Down
6 changes: 1 addition & 5 deletions spikingjelly/datasets/n_mnist.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import spikingjelly.datasets
import zipfile
import os
import threading
import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.datasets import utils

# https://www.garrickorchard.com/datasets/n-mnist
Expand Down Expand Up @@ -102,7 +98,7 @@ def get_events_item(file_name):

@staticmethod
def get_frames_item(file_name):
return np.load(file_name), int(os.path.dirname(file_name)[-1])
return torch.from_numpy(np.load(file_name)).float(), int(os.path.dirname(file_name)[-1])

def __init__(self, root: str, train: bool, use_frame=True, frames_num=10, split_by='number', normalization='max'):
'''
Expand Down
8 changes: 2 additions & 6 deletions spikingjelly/datasets/nav_gesture.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import spikingjelly
import zipfile
import os
import threading
import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.datasets import utils
import shutil
import loris
import torch

# url md5
resource = {
Expand Down Expand Up @@ -68,7 +64,7 @@ def get_events_item(file_name):
@staticmethod
def get_frames_item(file_name):
frames = np.load(file_name)
return frames, NAVGesture.get_label(file_name)
return torch.from_numpy(frames).float(), NAVGesture.get_label(file_name)
@staticmethod
def download_and_extract(download_root: str, extract_root: str):
dataset_name = 'walk'
Expand Down

0 comments on commit 332ccf0

Please sign in to comment.