/
data.py
80 lines (71 loc) · 2.81 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import tensorflow as tf
class Data:
"""
Base class for data.
By default, it assumes the data is a vector and subsamples data
i.i.d. (independent and identically distributed). Use one of the
derived classes for subsampling more complex data structures.
Arguments
----------
data: dict, tf.tensor, np.ndarray, optional
Data whose type depends on the type of model it is fed into:
Stan, TensorFlow, and NumPy/SciPy respectively.
shuffled: bool, optional
Whether the data is shuffled.
"""
def __init__(self, data=None, shuffled=True):
self.data = data
if not shuffled:
# TODO
# shuffle self.data
raise NotImplementedError()
self.counter = 0
if self.data is None:
self.N = None
elif isinstance(self.data, tf.Tensor):
self.N = self.data.get_shape()[0].value
elif isinstance(self.data, np.ndarray):
self.N = self.data.shape[0]
elif isinstance(self.data, dict):
pass
else:
raise
def sample(self, n_data=None):
# TODO scale gradient and printed loss by self.N / self.n_data
if n_data is None:
return self.data
counter_new = self.counter + n_data
if isinstance(self.data, tf.Tensor):
if counter_new <= self.N:
minibatch = tf.gather(self.data,
list(range(self.counter, counter_new)))
self.counter = counter_new
else:
counter_new = counter_new - self.N
minibatch = tf.gather(self.data,
list(range(self.counter, self.N)) + \
list(range(0, counter_new)))
self.counter = counter_new
return minibatch
elif isinstance(self.data, np.ndarray):
if counter_new <= self.N:
minibatch = self.data[self.counter:counter_new]
self.counter = counter_new
else:
counter_new = counter_new - self.N
minibatch = np.concatenate((self.data[self.counter:],
self.data[:counter_new]))
self.counter = counter_new
return minibatch
else:
minibatch = self.data.copy()
if counter_new <= self.N:
minibatch['y'] = minibatch['y'][self.counter:counter_new]
self.counter = counter_new
else:
counter_new = counter_new - self.N
minibatch['y'] = minibatch['y'][self.counter:] + \
minibatch['y'][:counter_new]
self.counter = counter_new
return minibatch