Skip to content

Commit

Permalink
Fixed a bug in Reshape, where a slice of a slice cannot be reshaped. (#…
Browse files Browse the repository at this point in the history
…359)

This is fixed in two ways:
    1. If the input is a `View` then it is materialized. This
    requires more memory, but it is more worth it to allocate than to
    stress over calculating  how much extra overhead to allocate for
    sharing memories
    2. `ShallowClone` is fixed in package `tensor` to make sure that
    views are correctly shallow cloned as well
  • Loading branch information
chewxy committed Jan 4, 2020
1 parent 1e98951 commit dc9e605
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion op_tensor.go
Expand Up @@ -857,6 +857,7 @@ func (op sliceIncrOp) Hashcode() uint32 { return simpleHash(op) }
func (op sliceIncrOp) String() string { func (op sliceIncrOp) String() string {
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString("T[") buf.WriteString("T[")

for i := 0; i < op.along; i++ { for i := 0; i < op.along; i++ {
buf.WriteString(":, ") buf.WriteString(":, ")
} }
Expand Down Expand Up @@ -1155,7 +1156,11 @@ func (op reshapeOp) Do(vals ...Value) (Value, error) {
switch vals[0].(type) { switch vals[0].(type) {
case tensor.Tensor: case tensor.Tensor:
if v, ok := vals[0].(*tensor.Dense); ok { if v, ok := vals[0].(*tensor.Dense); ok {
val = v.ShallowClone() if v.IsView() {
val = v.Materialize()
} else {
val = v.ShallowClone()
}
} else { } else {
if val, err = CloneValue(vals[0]); err != nil { if val, err = CloneValue(vals[0]); err != nil {
return nil, errors.Wrapf(err, cloneFail, vals[0]) return nil, errors.Wrapf(err, cloneFail, vals[0])
Expand All @@ -1164,6 +1169,7 @@ func (op reshapeOp) Do(vals ...Value) (Value, error) {
if !val.Shape().Eq(op.from) { if !val.Shape().Eq(op.from) {
return nil, errors.Errorf("Shape mismatch. Input shape is %v. Expected %v", val.Shape(), op.from) return nil, errors.Errorf("Shape mismatch. Input shape is %v. Expected %v", val.Shape(), op.from)
} }

if err := val.(tensor.Tensor).Reshape(op.to...); err != nil { if err := val.(tensor.Tensor).Reshape(op.to...); err != nil {
return nil, err return nil, err
} }
Expand Down

0 comments on commit dc9e605

Please sign in to comment.