-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset.py
122 lines (94 loc) · 3.39 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Callable
import numpy as np
import pipeline as pp
import scipy
from torch.utils.data import Dataset
from utils.constants import ch_cols
@dataclass
class VulpiData:
imu: np.ndarray
pro: np.ndarray
label: str
label_id: int
run_id: int
imu_path: str
pro_path: str
class RawVulpiDataset(Dataset):
def __init__(self, root_dir: Path, transform: Optional[Callable] = None):
self._root_dir = root_dir
self._transform = transform
self._class_paths = sorted([x for x in root_dir.iterdir() if x.is_dir()])
self._class_names = [x.name for x in self._class_paths]
self._class_name_to_id = {x: i for i, x in enumerate(self._class_names)}
self._id_to_class_name = {v: k for k, v in self._class_name_to_id.items()}
imus = self._root_dir.rglob("imu_*.mat")
pros = self._root_dir.rglob("pro_*.mat")
self._imu_paths = sorted(imus)
self._pro_paths = sorted(pros)
def __len__(self):
return len(self._imu_paths)
def __getitem__(self, idx):
imu_path = self._imu_paths[idx]
imu = scipy.io.loadmat(imu_path)
pro_path = self._pro_paths[idx]
pro = scipy.io.loadmat(pro_path)
label = imu_path.parent.name
label_id = self._class_name_to_id[label]
run_id = int(imu_path.name.split("_")[1].split(".")[0])
data = VulpiData(
imu=imu["imu"],
pro=pro["pro"],
label=label,
label_id=label_id,
run_id=run_id,
imu_path=imu_path,
pro_path=pro_path,
)
if self._transform is not None:
data = self._transform(data)
return data
@property
def class_to_id(self):
return self._class_name_to_id
@property
def id_to_class(self):
return self._id_to_class_name
class TemporalDataset(Dataset):
def __init__(self, data, transform: Optional[Callable] = None):
self.data = data
self.transform = transform if transform is not None else pp.Identity()
def __len__(self):
return len(self.data["imu"])
def __getitem__(self, idx):
imu = self.data["imu"][idx]
pro = self.data["pro"][idx]
label = imu[:, ch_cols["terr_idx"]][0]
imu_channels = imu[:, 5:]
pro_channels = pro[:, 5:]
sample = dict(imu=imu_channels, pro=pro_channels), label
return self.transform(sample)
class MCSDataset(Dataset):
def __init__(self, mcs, transform: Optional[Callable] = None):
super().__init__()
self.mcs = mcs
self.transform = transform if transform is not None else pp.Identity()
def __len__(self):
return len(self.mcs["data"])
def __getitem__(self, idx):
sample = self.mcs["data"][idx], self.mcs["label"][idx]
return self.transform(sample)
class MambaDataset(Dataset):
def __init__(self, data, transform: Optional[Callable] = None):
super().__init__()
self.data = data
self.transform = transform if transform is not None else pp.Identity()
def __len__(self):
return len(self.data["imu"])
def __getitem__(self, idx):
sample = (
dict(imu=self.data["imu"][idx], pro=self.data["pro"][idx]),
self.data["labels"][idx],
)
return self.transform(sample)