Skip to content

Commit

Permalink
Slicing improvements (#1363)
Browse files Browse the repository at this point in the history
* better support for numpy-like slicing

* make strided_select in-place if selecting a range of the first n batches

* fix issue

* add slicing test

* fix test name
  • Loading branch information
msperber authored and neubig committed May 27, 2018
1 parent 740a962 commit 736c9ed
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 27 deletions.
8 changes: 7 additions & 1 deletion dynet/expr.cc
Expand Up @@ -189,7 +189,13 @@ Expression pickrange(const Expression& x, unsigned v, unsigned u) {
return Expression(x.pg, x.pg->add_function<PickRange>({x.i}, v, u, 0));
}

Expression strided_select(const Expression& x, const std::vector<int>& strides, const std::vector<int>& range_from, const std::vector<int>& range_to) { return Expression(x.pg, x.pg->add_function<StridedSelect>({x.i}, strides, range_from, range_to)); }
Expression strided_select(const Expression& x, const std::vector<int>& strides, const std::vector<int>& range_from, const std::vector<int>& range_to) {
bool inplaced = true;
for(unsigned d=0;d<strides.size();d++){ if(strides[d]!=1) inplaced = false; }
for(unsigned d=0;d<range_from.size();d++){ if(range_from[d]!=0) inplaced = false; }
for(unsigned d=0;d<range_to.size() && d<x.dim().nd;d++){ if(range_to[d]!=x.dim()[d]) inplaced = false; }
return Expression(x.pg, x.pg->add_function<StridedSelect>({x.i}, strides, range_from, range_to, inplaced));
}

Expression pickneglogsoftmax(const Expression& x, unsigned v) { return Expression(x.pg, x.pg->add_function<PickNegLogSoftmax>({x.i}, v)); }
Expression pickneglogsoftmax(const Expression& x, const vector<unsigned> & v) { return Expression(x.pg, x.pg->add_function<PickNegLogSoftmax>({x.i}, v)); }
Expand Down
8 changes: 7 additions & 1 deletion dynet/nodes-select.h
Expand Up @@ -82,7 +82,13 @@ struct PickBatchElements : public Node {
// y = (x)_{[*pval]}
struct StridedSelect : public Node {
explicit StridedSelect(const std::initializer_list<VariableIndex>& a, const std::vector<int>& strides,
const std::vector<int>& from, const std::vector<int>& to) : Node(a), strides(strides), from(from), to(to) {}
const std::vector<int>& from, const std::vector<int>& to, bool inplaced=false)
: Node(a), strides(strides), from(from), to(to) {
if(inplaced){
forward_inplace_state = INPLACE_TYPE::READ;
backward_inplace_state = INPLACE_TYPE::WRITE;
}
}
DYNET_NODE_DEFINE_DEV_IMPL()
virtual bool supports_multibatch() const override { return true; }
const std::vector<int> strides, from, to;
Expand Down
54 changes: 29 additions & 25 deletions python/_dynet.pyx
Expand Up @@ -654,7 +654,7 @@ cdef class Expression: #{{{
IndexError: If the indices are too large
ValueError: In case of improper slice or if step is used
"""
assert isinstance(index, (int, slice)), "Expression key must be int or slice: %s" % index
assert isinstance(index, (int, slice, tuple)), "Expression key must be int or slice or tuple of slices: %s" % index
cdef int rows = self.c().dim().rows()
cdef int i, j
if isinstance(index, int):
Expand All @@ -666,30 +666,34 @@ cdef class Expression: #{{{
if i < 0:
i += rows
return pick(self, i)
else:
i = 0
j = rows
if index.start is not None:
i = index.start
if i > rows - 1:
raise IndexError("Start index too large: %d > %d" % (i, rows - 1))
if i < -rows:
raise IndexError("Start index too small: %d < %d" % (i, -rows))
if i < 0:
i += rows
if index.stop is not None:
j = index.stop
if j > rows:
raise IndexError("Stop index too large: %d > %d" % (j, rows))
if j < -rows + 1:
raise IndexError("Stop index too small: %d < %d" % (j, -rows + 1))
if j < 0:
j += rows
if i >= j:
raise ValueError("Improper slice: start index must come strictly before stop index")
if index.step is not None:
raise ValueError("Step sizes not yet supported.")
return pick_range(self, i, j)
elif isinstance(index, slice):
return strided_select(self, [index.step] if index.step is not None else [],
[index.start] if index.start is not None else [],
[index.stop] if index.stop is not None else [])
elif isinstance(index, tuple):
steps = []
for slice_i in index:
if slice_i.step is None:
steps.append(1)
else:
if slice_i.step <= 0: raise IndexError("steps must be positive, got:", slice_i.step)
steps.append(slice_i.step)
starts = []
for slice_i in index:
if slice_i.start is None:
starts.append(0)
else:
starts.append(slice_i.start)
stops = []
for i, slice_i in enumerate(index):
if slice_i.stop is None:
if i == len(self.dim()[0]):
stops.append(self.dim()[1])
else:
stops.append(self.dim()[0][i])
else:
stops.append(slice_i.stop)
return strided_select(self, steps, starts, stops)

cpdef scalar_value(self, bool recalculate=False):
"""Returns value of an expression as a scalar
Expand Down
14 changes: 14 additions & 0 deletions tests/python/test.py
Expand Up @@ -459,6 +459,20 @@ def test_layer_norm(self):
self.assertTrue(np.allclose(y.npvalue(), y_np_value))


class TestSlicing(unittest.TestCase):

def test_slicing(self):
dy.renew_cg()
data = np.random.random((10,10,10))
self.assertTrue(np.allclose(dy.inputTensor(data)[:1,:2,:3].npvalue(), data[:1,:2,:3]))
self.assertTrue(np.allclose(dy.inputTensor(data, batched=True)[:1,:2,:3].npvalue(), data[:1,:2,:3]))
self.assertTrue(np.allclose(dy.inputTensor(data)[:,:,:3].npvalue(), data[:,:,:3]))
self.assertTrue(np.allclose(dy.inputTensor(data)[3:,:,:].npvalue(), data[3:,:,:]))
self.assertTrue(np.allclose(dy.inputTensor(data)[:,:,::1].npvalue(), data[:,:,::1]))
self.assertTrue(np.allclose(dy.inputTensor(data)[:,:,::3].npvalue(), data[:,:,::3]))
self.assertTrue(np.allclose(dy.inputTensor(data)[3:5,1:3,1:].npvalue(), data[3:5,1:3,1:]))


class TestSimpleRNN(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 736c9ed

Please sign in to comment.