diff --git a/operations.go b/operations.go index 3cc704de..9ffd3619 100644 --- a/operations.go +++ b/operations.go @@ -632,13 +632,8 @@ func Reshape(n *Node, to tensor.Shape) (retVal *Node, err error) { } // Ravel flattens the given node and returns the new node -func Ravel(n *Node) (retVal *Node) { - retVal, err := Reshape(n, tensor.Shape{n.shape.TotalSize()}) - if err != nil { - panic(err) - } - - return retVal +func Ravel(n *Node) (retVal *Node, err error) { + return Reshape(n, tensor.Shape{n.shape.TotalSize()}) } /* Contraction related operations */ diff --git a/operations_test.go b/operations_test.go index e6dbe730..ad1ca075 100644 --- a/operations_test.go +++ b/operations_test.go @@ -1051,8 +1051,9 @@ func TestRavel(t *testing.T) { for i, rst := range ravelTests { g := NewGraph() t := NewTensor(g, Float64, len(rst.input), WithShape(rst.input...)) - t2 := Ravel(t) + t2, err := Ravel(t) + c.NoError(err) c.Equal(rst.output, t2.Shape(), "expected to be flatten in test case: %d", i) } }