Skip to content

Commit

Permalink
Fixes #233 (#234)
Browse files Browse the repository at this point in the history
* Start to touch #233

* Fixes #233
  • Loading branch information
chewxy committed Sep 8, 2018
1 parent 488c8e7 commit 8490d29
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 66 deletions.
123 changes: 122 additions & 1 deletion known_issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"gorgonia.org/tensor"
)

func TestConstDeriv(t *testing.T) {
func TestIssue182(t *testing.T) {
// This test revolves around repeated calls to run a VM.
// Formerly, upon running the VM once, the derivation of the constant is set.
// This derivation value would get Add()ed to upon subsequqent calls to run the VM.
Expand Down Expand Up @@ -46,3 +46,124 @@ func TestConstDeriv(t *testing.T) {
t.Fatalf("Expected constants to not have derivatives")
}
}

// func TestIssue217(t *testing.T) {
// //it works, cost = 22
// if err := issue217(tensor.Shape{2, 2}, tensor.Shape{2, 2}); err != nil {
// t.Fatal(err)
// }

// //panic: Node Σ[0](%2) :: float32, has 0 dimensions(Shape: ()). Input shape is (1, 1)...
// if err := issue217(tensor.Shape{2, 2}, tensor.Shape{2, 1}); err != nil {
// t.Fatal(err)
// }

// //panic: Node Σ[1](%2) :: float32, has 0 dimensions(Shape: ()). Input shape is (1, 1)...
// if err := issue217(tensor.Shape{1, 2}, tensor.Shape{2, 2}); err != nil {
// t.Fatal(err)
// }
// }

// func issue217(xS, yS tensor.Shape) error {

// g := NewGraph()
// x := NewMatrix(g, Float32, WithName("x"), WithShape(xS...), WithInit(RangedFrom(0)))
// y := NewMatrix(g, Float32, WithName("y"), WithShape(yS...), WithInit(RangedFrom(0)))

// z := Must(Mul(x, y))
// cost := Must(Sum(z))
// //cost := Must(Mean(z))

// _, err := Grad(cost, x, y)
// if err != nil {
// return errors.Wrap(err, "Grad")
// }

// m := NewTapeMachine(g)
// if err = m.RunAll(); err != nil {
// return errors.Wrap(err, "Run")
// }
// return nil
// }

func TestIssue233_F32(t *testing.T) {
g := NewGraph()
xV := tensor.New(tensor.WithShape(1, 1, 5, 5), tensor.WithBacking([]float32{
0, 0, 0, 0, 0,
1, 1, 1, 1, 1,
2, 2, 2, 2, 2,
3, 3, 3, 3, 3,
4, 4, 4, 4, 4,
}))
kernelV := tensor.New(tensor.WithShape(1, 1, 3, 3), tensor.WithBacking([]float32{
1, 1, 1,
1, 1, 1,
1, 1, 1,
}))

x := NewTensor(g, Float32, 4, WithShape(1, 1, 5, 5), WithValue(xV), WithName("x"))
w := NewTensor(g, Float32, 4, WithShape(1, 1, 3, 3), WithValue(kernelV), WithName("w"))

y, err := Conv2d(x, w, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1})
if err != nil {
t.Fatal(err)
}
// logger := log.New(os.Stderr, "", 0)
// vm := NewTapeMachine(g, WithLogger(logger), WithWatchlist(), WithValueFmt("%#v"))
vm := NewTapeMachine(g)
if err := vm.RunAll(); err != nil {
t.Fatal(err)
}

correct := []float32{
2, 3, 3, 3, 2,
6, 9, 9, 9, 6,
12, 18, 18, 18, 12,
18, 27, 27, 27, 18,
14, 21, 21, 21, 14,
}
t.Logf("%v", y.Value())

assert.Equal(t, correct, y.Value().Data())
}

func TestIssue233_F64(t *testing.T) {
g := NewGraph()
xV := tensor.New(tensor.WithShape(1, 1, 5, 5), tensor.WithBacking([]float64{
0, 0, 0, 0, 0,
1, 1, 1, 1, 1,
2, 2, 2, 2, 2,
3, 3, 3, 3, 3,
4, 4, 4, 4, 4,
}))
kernelV := tensor.New(tensor.WithShape(1, 1, 3, 3), tensor.WithBacking([]float64{
1, 1, 1,
1, 1, 1,
1, 1, 1,
}))

x := NewTensor(g, Float64, 4, WithShape(1, 1, 5, 5), WithValue(xV), WithName("x"))
w := NewTensor(g, Float64, 4, WithShape(1, 1, 3, 3), WithValue(kernelV), WithName("w"))

y, err := Conv2d(x, w, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1})
if err != nil {
t.Fatal(err)
}
// logger := log.New(os.Stderr, "", 0)
// vm := NewTapeMachine(g, WithLogger(logger), WithWatchlist(), WithValueFmt("%#v"))
vm := NewTapeMachine(g)
if err := vm.RunAll(); err != nil {
t.Fatal(err)
}

correct := []float64{
2, 3, 3, 3, 2,
6, 9, 9, 9, 6,
12, 18, 18, 18, 12,
18, 27, 27, 27, 18,
14, 21, 21, 21, 14,
}

assert.Equal(t, correct, y.Value().Data())

}
125 changes: 60 additions & 65 deletions op_nn.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,23 +297,28 @@ func (op im2colOp) calcShape(s tensor.Shape) (retVal tensor.Shape) {
}

func (op im2colOp) retHW(h, w int) (retHeight, retWidth int) {
retHeight = (h+2*op.padH-((op.dilationH*op.h-1)+1))/op.strideH + 1
retWidth = (w+2*op.padW-((op.dilationW*op.w-1)+1))/op.strideW + 1
retHeight = (h+2*op.padH-(op.dilationH*(op.h-1)+1))/op.strideH + 1
retWidth = (w+2*op.padW-(op.dilationW*(op.w-1)+1))/op.strideW + 1
return
}

func (op im2colOp) do(prealloc, input Value) (retVal Value, err error) {
inputT := input.(*tensor.Dense)
outputT := prealloc.(*tensor.Dense)

// extract bchw - this bit can be expanded in the future, but for now we only support bchw
s := input.Shape()
s := inputT.Shape()
b := s[0]
c := s[1]
h := s[2]
w := s[3]

inputStrides := inputT.Strides()
retHeight, retWidth := op.retHW(h, w)
batchStrideIm := c * h * w
batchStrideCol := (op.w * op.h * c) * retHeight * retWidth
batchStrideIm := inputStrides[0]
batchStrideCol := outputT.Strides()[0]
chanStride := h * w
inRowStride := inputStrides[2]

var imStart, imEnd, colStart, colEnd int
imEnd = imStart + batchStrideIm
Expand All @@ -324,102 +329,92 @@ func (op im2colOp) do(prealloc, input Value) (retVal Value, err error) {
imData := input.Data().([]float64)
colData := prealloc.Data().([]float64)
for i := 0; i < b; i++ {
op.f64s(c, h, w, chanStride, retHeight, retWidth, imData[imStart:imEnd], colData[colStart:colEnd])

colStart += batchStrideCol
colEnd += batchStrideCol
imStart := i * batchStrideIm
colStart := i * batchStrideCol

imStart += batchStrideIm
imEnd += batchStrideIm

if imEnd > len(imData) {
if imEnd = imStart + batchStrideIm; imEnd >= len(imData) {
imEnd = len(imData)
}
if colEnd > len(colData) {
if colEnd = colStart + batchStrideCol; colEnd >= len(colData) {
colEnd = len(colData)
}

op.f64s(c, h, w, chanStride, inRowStride, retHeight, retWidth, imData[imStart:imEnd], colData[colStart:colEnd])
}
case tensor.Float32:
imData := input.Data().([]float32)
colData := prealloc.Data().([]float32)
for i := 0; i < b; i++ {
op.f32s(c, h, w, chanStride, retHeight, retWidth, imData[imStart:imEnd], colData[colStart:colEnd])
imStart := i * batchStrideIm
colStart := i * batchStrideCol

colStart += batchStrideCol
colEnd += batchStrideCol

imStart += batchStrideIm
imEnd += batchStrideIm

if imEnd > len(imData) {
if imEnd = imStart + batchStrideIm; imEnd >= len(imData) {
imEnd = len(imData)
}
if colEnd > len(colData) {
if colEnd = colStart + batchStrideCol; colEnd >= len(colData) {
colEnd = len(colData)
}

op.f32s(c, h, w, chanStride, inRowStride, retHeight, retWidth, imData[imStart:imEnd], colData[colStart:colEnd])
}
default:
return nil, errors.Errorf(nyiFail, "im2col", input.Dtype())
}
return prealloc, nil
}

func (op im2colOp) f64s(chans, height, width, chanStride, retHeight, retWidth int, im, col []float64) {
func (op im2colOp) f64s(chans, height, width, chanStride, inRowStride, retHeight, retWidth int, im, col []float64) {
var colIdx int
for ch := chans; ch > 0; ch, im = ch-1, im[chanStride:] {
for kernelRow := 0; kernelRow < op.h; kernelRow++ {
for kernelCol := 0; kernelCol < op.w; kernelCol++ {
inRow := -op.padH + kernelRow*op.dilationH
for outRow := retHeight; outRow > 0; outRow-- {
if !(inRow >= 0 && inRow < height) {
for outCol := retWidth; outCol > 0; outCol-- {
col[colIdx] = 0
colIdx++
}
continue
}
inCol := -op.padW + kernelCol*op.dilationW
for outCol := retWidth; outCol > 0; outCol-- {
if inCol >= 0 && inCol < width {
col[colIdx] = im[inRow*width+inCol]
} else {
col[colIdx] = 0
for ch := 0; ch < chans; ch, im = ch+1, im[chanStride:] {
for r := 0; r < retHeight; r++ {
for c := 0; c < retWidth; c++ {
for kr := 0; kr < op.h; kr++ {
inRow := -op.padH + kr*op.dilationH + r*op.strideH
for kc := 0; kc < op.w; kc++ {
inCol := -op.padW + kc*op.dilationW + c*op.strideW
var val float64

switch {
case inRow < 0:
case inCol < 0:
case inRow*inRowStride+inCol >= len(im):
case inCol >= inRowStride:
default:
val = im[inRow*inRowStride+inCol]
}

col[colIdx] = val
colIdx++
inCol += op.strideW
}
inRow += op.strideH
}
}
}
}
}

func (op im2colOp) f32s(chans, height, width, chanStride, retHeight, retWidth int, im, col []float32) {
func (op im2colOp) f32s(chans, height, width, chanStride, inRowStride, retHeight, retWidth int, im, col []float32) {
var colIdx int
for ch := chans; ch > 0; ch, im = ch-1, im[chanStride:] {
for kernelRow := 0; kernelRow < op.h; kernelRow++ {
for kernelCol := 0; kernelCol < op.w; kernelCol++ {
inRow := -op.padH + kernelRow*op.dilationH
for outRow := retHeight; outRow > 0; outRow-- {
if !(inRow >= 0 && inRow < height) {
for outCol := retWidth; outCol > 0; outCol-- {
col[colIdx] = 0
colIdx++
}
continue
}
inCol := -op.padW + kernelCol*op.dilationW
for outCol := retWidth; outCol > 0; outCol-- {
if inCol >= 0 && inCol < width {
col[colIdx] = im[inRow*width+inCol]
} else {
col[colIdx] = 0
for ch := 0; ch < chans; ch, im = ch+1, im[chanStride:] {
for r := 0; r < retHeight; r++ {
for c := 0; c < retWidth; c++ {
for kr := 0; kr < op.h; kr++ {
inRow := -op.padH + kr*op.dilationH + r*op.strideH
for kc := 0; kc < op.w; kc++ {
inCol := -op.padW + kc*op.dilationW + c*op.strideW
var val float32

switch {
case inRow < 0:
case inCol < 0:
case inRow*inRowStride+inCol >= len(im):
case inCol >= inRowStride:
default:
val = im[inRow*inRowStride+inCol]
}

col[colIdx] = val
colIdx++
inCol += op.strideW
}
inRow += op.strideH
}
}
}
Expand Down

0 comments on commit 8490d29

Please sign in to comment.