forked from negar-rostamzadeh/LSTM-Attention
-
Notifications
You must be signed in to change notification settings - Fork 2
/
svhn.py
36 lines (30 loc) · 1.27 KB
/
svhn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from fuel.transformers import Mapping
from fuel.datasets.svhn import SVHN
import tasks
class DigitTask(tasks.Classification):
name = "svhn_digit"
def __init__(self, *args, **kwargs):
super(DigitTask, self).__init__(*args, **kwargs)
self.n_classes = 10
self.n_channels = 1
def load_datasets(self):
return dict(
train=SVHN(which_sets=["train"], which_format=2, subset=slice(None, 50000)),
valid=SVHN(which_sets=["train"], which_format=2, subset=slice(50000, None)),
test=SVHN(which_sets=["test"], which_format=2))
def get_stream_num_examples(self, which_set, monitor):
if monitor and which_set == "train":
return 10000
return super(DigitTask, self).get_stream_num_examples(which_set, monitor)
def get_stream(self, *args, **kwargs):
return Mapping(super(DigitTask, self).get_stream(*args, **kwargs),
mapping=fix_target_representation)
def preprocess(self, data):
x, y = data
# remove bogus singleton dimension
y = y.flatten()
y[y == 10] = 0
x_shape = np.tile([x.shape[2:]], (x.shape[0], 1))
return (x.astype(np.float32),
x_shape.astype(np.float32),
y.astype(np.uint8))