/
data_loader.py
executable file
·111 lines (87 loc) · 3.1 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from scipy import io
import numpy as np
import torch
from pdb import set_trace as breakpoint
import torch.utils.data as data
def data_iterator(train_x, train_att):
""" A simple data iterator """
batch_idx = 0
while True:
# shuffle labels and features
idxs = np.arange(0, len(train_x))
np.random.shuffle(idxs)
shuf_visual = train_x[idxs]
shuf_att = train_att[idxs]
batch_size = 100
# breakpoint()
for batch_idx in range(0, len(train_x), batch_size):
visual_batch = shuf_visual[batch_idx:batch_idx + batch_size]
visual_batch = visual_batch.astype("float32")
att_batch = shuf_att[batch_idx:batch_idx + batch_size]
att_batch = Variable(torch.from_numpy(att_batch).float().cuda())
visual_batch = Variable(torch.from_numpy(visual_batch).float().cuda())
yield att_batch, visual_batch
class data_loader(data.Dataset):
def __init__(self, feats, atts, labels, ways=32, shots=4):
self.ways = ways
self.shots = shots
self.feats = torch.tensor(feats).float()
self.atts = torch.tensor(atts).float()
self.labels = labels
self.classes = np.unique(labels)
def __getitem__(self, index):
is_first = True
select_feats = []
select_atts = []
select_labels = []
select_labels = torch.LongTensor(self.ways*self.shots)
selected_classes = np.random.choice(list(self.classes), self.ways, False)
for i in range(self.ways):
idx = (self.labels==selected_classes[i]).nonzero()[0]
select_instances = np.random.choice(idx, self.shots, False)
for j in range(self.shots):
feat = self.feats[select_instances[j], :]
att = self.atts[select_instances[j], :]
feat = feat.unsqueeze(0)
att = att.unsqueeze(0)
# print(feat.size())
# print(att.size())
if is_first:
is_first=False
select_feats = feat
select_atts = att
else:
select_feats = torch.cat((select_feats, feat),0)
select_atts = torch.cat((select_atts, att),0)
select_labels[i*self.shots+j] = i
return select_feats, select_atts, select_labels
def __len__(self):
return self.__size
class data_loader_wt_att(data.Dataset):
def __init__(self, feats, labels, ways=32, shots=4):
self.ways = ways
self.shots = shots
self.feats = torch.tensor(feats).float()
self.labels = labels
self.classes = np.unique(labels)
def __getitem__(self, index):
is_first = True
select_feats = []
select_labels = []
select_labels = torch.LongTensor(self.ways*self.shots)
selected_classes = np.random.choice(list(self.classes), self.ways, False)
for i in range(self.ways):
idx = (self.labels==selected_classes[i]).nonzero()[0]
select_instances = np.random.choice(idx, self.shots, False)
for j in range(self.shots):
feat = self.feats[select_instances[j], :]
feat = feat.unsqueeze(0)
if is_first:
is_first=False
select_feats = feat
else:
select_feats = torch.cat((select_feats, feat),0)
select_labels[i*self.shots+j] = selected_classes[i].item()
return select_feats, select_labels
def __len__(self):
return self.__size