Skip to content

Commit

Permalink
Refactor the max/sum ops to share common code. Have the type/inferSha…
Browse files Browse the repository at this point in the history
…pe/Do methods behave in a consistent manner: (#346)

* Dimensions specified in the "along" parameter are reduced to size 1, but not removed. (Note: this caused TestRepeatOpDoDiff, but this version fixes it.  Perhaps we should make preserving the size-1 dimensions an option of the reduction op?)
* If all dimensions are included, the result will be a scalar.
* If all dimensions but 1 are included, the result is a vector, regardless of which dimension is left intact.

Tests verify that the resulting nodes have the expected shape.

Note: While here, fix a warning on Max's SymDiff where retVal[0] is set when retVal has not been initialized.  I wonder if this is related to #323 where SymDiff for StableSoftMax (which uses Max) was failing with a panic (probably not, as the error message there seems unrelated, but probably a good fix anyway).

Closes #326
  • Loading branch information
bdleitner authored and chewxy committed Nov 17, 2019
1 parent 6fd05db commit 592126c
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 155 deletions.
16 changes: 9 additions & 7 deletions go.mod
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ require (
github.com/gogo/protobuf v1.2.1 // indirect github.com/gogo/protobuf v1.2.1 // indirect
github.com/golang/protobuf v1.3.0 // indirect github.com/golang/protobuf v1.3.0 // indirect
github.com/gonum/blas v0.0.0-20181208220705-f22b278b28ac // indirect github.com/gonum/blas v0.0.0-20181208220705-f22b278b28ac // indirect
github.com/google/flatbuffers v1.10.0 // indirect
github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21 github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21
github.com/mattn/go-runewidth v0.0.4 // indirect github.com/mattn/go-runewidth v0.0.4 // indirect
github.com/pkg/errors v0.8.1 github.com/pkg/errors v0.8.1
github.com/stretchr/testify v1.3.0 github.com/stretchr/testify v1.4.0
github.com/xtgo/set v1.0.0 github.com/xtgo/set v1.0.0
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd // indirect
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
golang.org/x/sys v0.0.0-20190226215855-775f8194d0f9 // indirect golang.org/x/sys v0.0.0-20190226215855-775f8194d0f9 // indirect
gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6 gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee
gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b // indirect
gopkg.in/cheggaaa/pb.v1 v1.0.27 gopkg.in/cheggaaa/pb.v1 v1.0.27
gorgonia.org/cu v0.9.0-beta gorgonia.org/cu v0.9.0-beta
gorgonia.org/dawson v1.1.0 gorgonia.org/dawson v1.1.0
gorgonia.org/tensor v0.9.0-beta gorgonia.org/tensor v0.9.2
gorgonia.org/vecf32 v0.7.0 gorgonia.org/vecf32 v0.9.0
gorgonia.org/vecf64 v0.7.0 gorgonia.org/vecf64 v0.9.0
) )
33 changes: 33 additions & 0 deletions go.sum
Original file line number Original file line Diff line number Diff line change
@@ -1,3 +1,4 @@
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
github.com/awalterschulze/gographviz v0.0.0-20190221210632-1e9ccb565bca h1:xwIXr1FpA2XBoohlpvgb11No/zbsh5Clm/98PWPcHVA= github.com/awalterschulze/gographviz v0.0.0-20190221210632-1e9ccb565bca h1:xwIXr1FpA2XBoohlpvgb11No/zbsh5Clm/98PWPcHVA=
github.com/awalterschulze/gographviz v0.0.0-20190221210632-1e9ccb565bca/go.mod h1:GEV5wmg4YquNw7v1kkyoX9etIk8yVmXj+AkDHuuETHs= github.com/awalterschulze/gographviz v0.0.0-20190221210632-1e9ccb565bca/go.mod h1:GEV5wmg4YquNw7v1kkyoX9etIk8yVmXj+AkDHuuETHs=
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
Expand All @@ -12,13 +13,22 @@ github.com/go-gota/gota v0.10.1 h1:BWci+R5dE28GnXoD1EWoQqe7WCQHAPJ996mK7LZrB4U=
github.com/go-gota/gota v0.10.1/go.mod h1:NZLQccXn0rABmkXjsaugRY6l+UH2dDZSgIgF8E2ipmA= github.com/go-gota/gota v0.10.1/go.mod h1:NZLQccXn0rABmkXjsaugRY6l+UH2dDZSgIgF8E2ipmA=
github.com/gogo/protobuf v1.2.1 h1:/s5zKNz0uPFCZ5hddgPdo2TK2TVrUNMn0OOX8/aZMTE= github.com/gogo/protobuf v1.2.1 h1:/s5zKNz0uPFCZ5hddgPdo2TK2TVrUNMn0OOX8/aZMTE=
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE=
github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk= github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk=
github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gonum/blas v0.0.0-20181208220705-f22b278b28ac h1:Q0Jsdxl5jbxouNs1TQYt0gxesYMU4VXRbsTlgDloZ50= github.com/gonum/blas v0.0.0-20181208220705-f22b278b28ac h1:Q0Jsdxl5jbxouNs1TQYt0gxesYMU4VXRbsTlgDloZ50=
github.com/gonum/blas v0.0.0-20181208220705-f22b278b28ac/go.mod h1:P32wAyui1PQ58Oce/KYkOqQv8cVw1zAapXOl+dRFGbc= github.com/gonum/blas v0.0.0-20181208220705-f22b278b28ac/go.mod h1:P32wAyui1PQ58Oce/KYkOqQv8cVw1zAapXOl+dRFGbc=
github.com/google/flatbuffers v1.10.0 h1:wHCM5N1xsJ3VwePcIpVqnmjAqRXlR44gv4hpGi+/LIw= github.com/google/flatbuffers v1.10.0 h1:wHCM5N1xsJ3VwePcIpVqnmjAqRXlR44gv4hpGi+/LIw=
github.com/google/flatbuffers v1.10.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/flatbuffers v1.10.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A=
github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21 h1:O75p5GUdUfhJqNCMM1ntthjtJCOHVa1lzMSfh5Qsa0Y= github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21 h1:O75p5GUdUfhJqNCMM1ntthjtJCOHVa1lzMSfh5Qsa0Y=
github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21/go.mod h1:N0SVk0uhy+E1PZ3C9ctsPRlvOPAFPkCNlcPBDkt0N3U= github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21/go.mod h1:N0SVk0uhy+E1PZ3C9ctsPRlvOPAFPkCNlcPBDkt0N3U=
Expand All @@ -29,24 +39,40 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190226215855-775f8194d0f9 h1:N26gncmS+iqc/W/SKhX3ElI5pkt72XYoRLgi5Z70LSc= golang.org/x/sys v0.0.0-20190226215855-775f8194d0f9 h1:N26gncmS+iqc/W/SKhX3ElI5pkt72XYoRLgi5Z70LSc=
golang.org/x/sys v0.0.0-20190226215855-775f8194d0f9/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190226215855-775f8194d0f9/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6 h1:GDMcNvYihiCXcWFYiIL8Wx7580Px/E2/pT62tJXe2gY= gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6 h1:GDMcNvYihiCXcWFYiIL8Wx7580Px/E2/pT62tJXe2gY=
gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6/go.mod h1:jevfED4GnIEnJrWW55YmY9DMhajHcnkqVnEXmEtMyNI= gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6/go.mod h1:jevfED4GnIEnJrWW55YmY9DMhajHcnkqVnEXmEtMyNI=
gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee h1:4pVWuAEGpaPZ7dPfd6aA8LyDNzMA2RKCxAS/XNCLZUM=
gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU=
gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d h1:m4zHh49Vwhwq5n7qC7NRl5SqRfTyT/6PP2ASGNGRB1E= gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d h1:m4zHh49Vwhwq5n7qC7NRl5SqRfTyT/6PP2ASGNGRB1E=
gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.27 h1:kJdccidYzt3CaHD1crCFTS1hxyhSi059NhOFUf03YFo= gopkg.in/cheggaaa/pb.v1 v1.0.27 h1:kJdccidYzt3CaHD1crCFTS1hxyhSi059NhOFUf03YFo=
gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gorgonia.org/cu v0.8.0 h1:XpTkl5IpMlTPNJl6pKQPEXVV/9TnEtiRB7j1gGkrzCI= gorgonia.org/cu v0.8.0 h1:XpTkl5IpMlTPNJl6pKQPEXVV/9TnEtiRB7j1gGkrzCI=
gorgonia.org/cu v0.8.0/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8= gorgonia.org/cu v0.8.0/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8=
gorgonia.org/cu v0.9.0-beta h1:s4WQ6fiAGoErwIiXWHRB6Y9ydkx1vTTPwhWzoEZVePc= gorgonia.org/cu v0.9.0-beta h1:s4WQ6fiAGoErwIiXWHRB6Y9ydkx1vTTPwhWzoEZVePc=
Expand All @@ -59,7 +85,14 @@ gorgonia.org/tensor v0.8.1 h1:PTJ81ku5uYs/qsZLMFq02q0DWI4YuJeu0ikieFkkh1o=
gorgonia.org/tensor v0.8.1/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w= gorgonia.org/tensor v0.8.1/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w=
gorgonia.org/tensor v0.9.0-beta h1:16QQufB1vbJxVbIOaB5TwkerdlBWtw+AAnZHUZ531ZE= gorgonia.org/tensor v0.9.0-beta h1:16QQufB1vbJxVbIOaB5TwkerdlBWtw+AAnZHUZ531ZE=
gorgonia.org/tensor v0.9.0-beta/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w= gorgonia.org/tensor v0.9.0-beta/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w=
gorgonia.org/tensor v0.9.2 h1:bVTWB68apbLfdrAlz5Ev3daGhfOhKuPkVFacMSNzpHs=
gorgonia.org/tensor v0.9.2/go.mod h1:603c/8huGtNc1APqh1nWqQu0fYgBvkwt55rvg4CWgZs=
gorgonia.org/vecf32 v0.7.0 h1:mkpVzSyT7/Cput5/ZxaMzzp2xbmOtqOyJlTf7AdSMe0= gorgonia.org/vecf32 v0.7.0 h1:mkpVzSyT7/Cput5/ZxaMzzp2xbmOtqOyJlTf7AdSMe0=
gorgonia.org/vecf32 v0.7.0/go.mod h1:iHG+kvTMqGYA0SgahfO2k62WRnxmHsqAREGbayRDzy8= gorgonia.org/vecf32 v0.7.0/go.mod h1:iHG+kvTMqGYA0SgahfO2k62WRnxmHsqAREGbayRDzy8=
gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg=
gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA=
gorgonia.org/vecf64 v0.7.0 h1:ZphOGJfnWlFfY7x8WAJAfO64IAtYqPPq9TEGem+ItZE= gorgonia.org/vecf64 v0.7.0 h1:ZphOGJfnWlFfY7x8WAJAfO64IAtYqPPq9TEGem+ItZE=
gorgonia.org/vecf64 v0.7.0/go.mod h1:1y4pmcSd+wh3phG+InwWQjYrqwyrtN9h27WLFVQfV1Q= gorgonia.org/vecf64 v0.7.0/go.mod h1:1y4pmcSd+wh3phG+InwWQjYrqwyrtN9h27WLFVQfV1Q=
gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
213 changes: 98 additions & 115 deletions op_reduction.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -11,12 +11,103 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"hash" "hash"
"strings"


"github.com/chewxy/hm" "github.com/chewxy/hm"
"github.com/pkg/errors" "github.com/pkg/errors"
"gorgonia.org/tensor" "gorgonia.org/tensor"
) )


func reductionType(d int, along []int) hm.Type {
a := hm.TypeVariable('a')
t := makeTensorType(d, a)

axes := make(map[int]bool)
for _, axis := range along {
if axis < d {
axes[axis] = true
}
}

if d == 1 || len(axes) == 0 || len(axes) == d {
// then it reduces down
return hm.NewFnType(t, a)
}

var retType hm.Type
if len(axes) == d-1 { // Only 1 non-reduced dim, so we can reduce to a vector as before.
retType = makeTensorType(1, a)
} else {
retType = t
}
return hm.NewFnType(t, retType)
}

func reductionInferShape(along []int, inputs ...DimSizer) (tensor.Shape, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("len(dimsizers)!=1")
}
if len(along) == 0 {
return tensor.ScalarShape(), nil
}
in := inputs[0].(tensor.Shape)
shape := make(tensor.Shape, len(in))
copy(shape, in)
for _, d := range along {
if d >= len(shape) {
return nil, fmt.Errorf("shape error, along %d is not a valid axis for shape %v", d, in)
}
shape[d] = 1
}
// special cases: if all dimensions are 1 -> ScalarShape, if exactly one dimension is != 1 -> vector
vecD := 0
numNot1 := 0
for _, d := range shape {
if d != 1 {
vecD = d
numNot1++
if numNot1 > 1 {
return shape, nil
}
}
}
if numNot1 == 0 {
return tensor.ScalarShape(), nil
}
return tensor.Shape{vecD}, nil
}

func reductionDo(op Op, s string, f func(*tensor.Dense, ...int) (*tensor.Dense, error), along []int, inputs ...Value) (retVal Value, err error) {
if err = checkArity(op, len(inputs)); err != nil {
return
}
at := inputs[0].(tensor.Tensor)
switch t := at.(type) {
case *tensor.Dense:
var ret *tensor.Dense
if ret, err = f(t, along...); err == nil {
if ret.IsScalar() {
retVal, _ = anyToScalar(ret.ScalarValue())
} else {
// the tensor reduction ops remove collapsed dimensions, but here we preserve them except in special cases.
// so we reshape the return to ensure the dimensions match.
var sh tensor.Shape
if sh, err = reductionInferShape(along, t.Shape()); err == nil {
if err = ret.Reshape(sh...); err == nil {
retVal = ret
}
}
}
} else {
return nil, errors.Wrap(err, fmt.Sprintf("failed to apply *tensor.Dense.%s()", strings.Title(s)))
}
default:
return nil, errors.Errorf(nyiFail, fmt.Sprintf("%sOp.Do()", s), at)
}
return

}

type maxOp struct { type maxOp struct {
along axes along axes
d int d int
Expand All @@ -32,43 +123,12 @@ func newMaxOp(along axes, dim int) *maxOp {
func (op maxOp) Arity() int { return 1 } func (op maxOp) Arity() int { return 1 }


func (op maxOp) Type() hm.Type { func (op maxOp) Type() hm.Type {
a := hm.TypeVariable('a') return reductionType(op.d, op.along)
t := makeTensorType(op.d, a)

var retType hm.Type
if op.d == 1 || len(op.along) == 0 || len(op.along) == op.d {
// then it reduces down
return hm.NewFnType(t, a)
}
retType = makeTensorType(op.d-1, a)
return hm.NewFnType(t, retType)
} }


//func (op maxOp) InferShape(...DimSizer) (tensor.Shape, error) { return scalarShape, nil } // TODO, THIS IS INCORRECT //func (op maxOp) InferShape(...DimSizer) (tensor.Shape, error) { return scalarShape, nil } // TODO, THIS IS INCORRECT
func (op maxOp) InferShape(dimsizers ...DimSizer) (tensor.Shape, error) { func (op maxOp) InferShape(dimsizers ...DimSizer) (tensor.Shape, error) {
if len(dimsizers) != 1 { return reductionInferShape(op.along, dimsizers...)
return nil, fmt.Errorf("len(dimsizers)!=1")
}
s := make(tensor.Shape, op.d)
ds := dimsizers[0]
for d := 0; d < op.d; d++ {
dInAlong := false
for _, dim := range op.along {
if d == dim {
dInAlong = true
}
}
if dInAlong {
s[d] = 1
} else {
size, err := ds.DimSize(d)
if err != nil {
return s, err
}
s[d] = size
}
}
return s, nil
} }
func (op maxOp) DiffWRT(i int) []bool { return []bool{true} } func (op maxOp) DiffWRT(i int) []bool { return []bool{true} }


Expand Down Expand Up @@ -102,21 +162,15 @@ func (op maxOp) SymDiff(inputs Nodes, output, gradNode *Node) (retVal Nodes, err
if a2, b2, err = Broadcast(gradNode, eq, bcpat); err != nil { if a2, b2, err = Broadcast(gradNode, eq, bcpat); err != nil {
return nil, errors.Wrap(err, operationError) return nil, errors.Wrap(err, operationError)
} }
retVal = make(Nodes, 1)
if retVal[0], err = Mul(a2, b2); err != nil { if retVal[0], err = Mul(a2, b2); err != nil {
return nil, errors.Wrap(err, operationError) return nil, errors.Wrap(err, operationError)
} }
return return
} }


func (op maxOp) Do(inputs ...Value) (retVal Value, err error) { func (op maxOp) Do(inputs ...Value) (retVal Value, err error) {
if err = checkArity(op, len(inputs)); err != nil { return reductionDo(op, "max", (*tensor.Dense).Max, op.along, inputs...)
return
}
if arg, ok := inputs[0].(*tensor.Dense); ok {
retVal, err = arg.Max(op.along...)
return
}
return nil, errors.Errorf("Max arg is not a tensor")
} }


func (op maxOp) ReturnsPtr() bool { return true } func (op maxOp) ReturnsPtr() bool { return true }
Expand Down Expand Up @@ -167,62 +221,11 @@ func (op sumOp) Arity() int { return 1 }
// sumOp is a function with this type: // sumOp is a function with this type:
// sumOp :: (Summable a) ⇒ Tensor d a → Tensor d-1 a // sumOp :: (Summable a) ⇒ Tensor d a → Tensor d-1 a
func (op sumOp) Type() hm.Type { func (op sumOp) Type() hm.Type {
a := hm.TypeVariable('a') return reductionType(op.d, op.along)
t := makeTensorType(op.d, a)

if op.inputShape.IsVector() {
return hm.NewFnType(t, a)
}

// if it's a monotonic axes, it's basically summing everything.
if monotonic, incr1 := tensor.IsMonotonicInts(op.along); monotonic && incr1 && len(op.along) == len(op.inputShape) {
return hm.NewFnType(t, a)
}

retType := makeTensorType(op.d-1, a)
return hm.NewFnType(t, retType)
} }


func (op sumOp) InferShape(inputs ...DimSizer) (shape tensor.Shape, err error) { func (op sumOp) InferShape(inputs ...DimSizer) (shape tensor.Shape, err error) {
in := inputs[0].(tensor.Shape) return reductionInferShape(op.along, inputs...)
shapeLogf("input shape: %v", in)
switch {
case in.IsScalar():
shape = scalarShape
case in.IsVector() && !in.IsRowVec() && !in.IsColVec():
if len(op.along) > 1 || (len(op.along) == 1 && op.along[0] != 0) {
return nil, errors.Errorf("Shape mismatch: along is %v. Shape is %v", op.along, in)
}
shape = scalarShape
default:
shape = in.Clone()
if len(op.along) > len(shape) {
return nil, errors.Errorf("Shape mismatch: %v and %v", shape, op.along)
}

// special case (sum all)
if monotonic, incr1 := tensor.IsMonotonicInts(op.along); monotonic && incr1 && len(op.along) == len(shape) && op.along[0] == 0 {
shape = scalarShape
return
}

for _, a := range op.along {
if a >= len(shape) {
return nil, errors.Errorf("Axis %d is greater or equal to the length of the shape %v", a, shape)
}
shape[a] = 1
}

switch {

case shape.IsColVec():
shape = shape[:1]
case shape.IsRowVec():
shape = shape[1:]
}

}
return
} }


func (op sumOp) DiffWRT(i int) []bool { return []bool{true} } func (op sumOp) DiffWRT(i int) []bool { return []bool{true} }
Expand Down Expand Up @@ -342,27 +345,7 @@ func (op sumOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err er
} }


func (op sumOp) Do(inputs ...Value) (retVal Value, err error) { func (op sumOp) Do(inputs ...Value) (retVal Value, err error) {
if err = checkArity(op, len(inputs)); err != nil { return reductionDo(op, "sum", (*tensor.Dense).Sum, op.along, inputs...)
return
}

at := inputs[0].(tensor.Tensor)
switch t := at.(type) {
case *tensor.Dense:
var ret *tensor.Dense
if ret, err = t.Sum(op.along...); err == nil {
if ret.IsScalar() {
retVal, _ = anyToScalar(ret.ScalarValue())
} else {
retVal = ret
}
} else {
return nil, errors.Wrap(err, "failed to apply *tensor.Dense.Sum()")
}
default:
return nil, errors.Errorf(nyiFail, "sumOp.Do()", at)
}
return
} }


func (op sumOp) ReturnsPtr() bool { return true } func (op sumOp) ReturnsPtr() bool { return true }
Expand Down
Loading

0 comments on commit 592126c

Please sign in to comment.