Skip to content

Commit

Permalink
Add support for float32 in the SparseMax op (#439)
Browse files Browse the repository at this point in the history
Co-authored-by: Chewxy <chewxy@gmail.com>
  • Loading branch information
dcu and chewxy committed Oct 11, 2020
1 parent 4e9eb9e commit d26fe26
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 37 deletions.
135 changes: 118 additions & 17 deletions op_sparsemax.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func newSparsemaxOp() *sparsemaxOp {
}

// Sparsemax - implements the sparsemax operation described here: http://proceedings.mlr.press/v48/martins16.pdf
// Current implementation only supports float64
func Sparsemax(x *Node) (*Node, error) {
op := newSparsemaxOp()

Expand Down Expand Up @@ -84,7 +83,64 @@ func (op *sparsemaxOp) Do(inputs ...Value) (Value, error) {
return nil, fmt.Errorf("Can't check Sparsemax input: %w", err)
}

var output interface{}

switch inputTensor.Dtype() {
case tensor.Float64:
output = op.float64sparseMax(inputTensor)
case tensor.Float32:
output = op.float32sparseMax(inputTensor)
default:
return nil, fmt.Errorf("invalid input type for Sparsemax, expected float64 or float32, got: %v", inputTensor.Dtype())
}

return tensor.New(tensor.Of(inputTensor.Dtype()), tensor.WithShape(inputTensor.Size()), tensor.WithEngine(inputTensor.Engine()), tensor.WithBacking(output)), nil
}

// FIXME: go2 generics
func (op *sparsemaxOp) float32sparseMax(inputTensor tensor.Tensor) interface{} {
sortedData := make([]float32, inputTensor.Size())

copy(sortedData, inputTensor.Data().([]float32))

sort.Slice(sortedData, func(i, j int) bool {
return sortedData[i] > sortedData[j]
})

kArray := make([]float32, len(sortedData))
cumArray := make([]float32, len(sortedData))
cumSum := float32(0.0)
maxIndex := 0

for i := 0; i < len(sortedData); i++ {
kArray[i] = 1 + float32(i)*sortedData[i]
cumSum += sortedData[i]

cumArray[i] = cumSum - sortedData[i]

if kArray[i] > cumArray[i] {
maxIndex = i + 1
}
}

threshold := float32(cumArray[maxIndex-1]-1) / float32(maxIndex)
output := make([]float32, inputTensor.Size())

for i := 0; i < inputTensor.Size(); i++ {
v, _ := inputTensor.At(i)
vF := v.(float32)

if vF-threshold > 0 {
output[i] = vF - threshold
}
}

return output
}

func (op *sparsemaxOp) float64sparseMax(inputTensor tensor.Tensor) interface{} {
sortedData := make([]float64, inputTensor.Size())

copy(sortedData, inputTensor.Data().([]float64))

sort.Slice(sortedData, func(i, j int) bool {
Expand Down Expand Up @@ -119,7 +175,7 @@ func (op *sparsemaxOp) Do(inputs ...Value) (Value, error) {
}
}

return tensor.New(tensor.Of(inputTensor.Dtype()), tensor.WithShape(inputTensor.Size()), tensor.WithEngine(inputTensor.Engine()), tensor.WithBacking(output)), nil
return output
}

// DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface.
Expand Down Expand Up @@ -229,7 +285,7 @@ func (op *sparsemaxDiffOp) checkInput(inputs ...Value) (tensor.Tensor, tensor.Te
switch t := inputs[0].(type) {
case *dualValue:
if in, ok = t.Value.(tensor.Tensor); !ok {
return nil, nil, errors.Errorf("input should be a tensor, got %T", inputs[0])
return nil, nil, errors.Errorf("input should be a tensor.Tensor, got %T", inputs[0])
}
case tensor.Tensor:
in = t
Expand Down Expand Up @@ -265,16 +321,68 @@ func (op *sparsemaxDiffOp) Do(inputs ...Value) (Value, error) {
return nil, fmt.Errorf("sparsemaxDiffOp.Do inputs sizes should be equal")
}

data, ok := inputTensor.Data().([]float64)
if !ok {
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float64, got %T", inputTensor.Data())
var diff interface{}

switch inputTensor.Dtype() {
case tensor.Float64:
outputData, ok := gradTensor.Data().([]float64)
if !ok {
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float64, got %T", inputTensor.Data())
}

diff = op.float64sparseMaxDiff(inputTensor.Data().([]float64), outputData)
case tensor.Float32:
outputData, ok := gradTensor.Data().([]float32)
if !ok {
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float32, got %T", inputTensor.Data())
}

diff = op.float32sparseMaxDiff(inputTensor.Data().([]float32), outputData)
default:
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float64 or []float32, got %T", inputTensor.Data())
}

outputData, ok := gradTensor.Data().([]float64)
if !ok {
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float64, got %T", inputTensor.Data())
val := tensor.New(
tensor.Of(inputTensor.Dtype()),
tensor.WithShape(inputTensor.Size()),
tensor.WithEngine(inputTensor.Engine()),
tensor.WithBacking(diff),
)

return val, nil
}

// FIXME: go2 generics
func (op *sparsemaxDiffOp) float32sparseMaxDiff(data, outputData []float32) interface{} {
nonZeros := float32(0.0)
inputSum := float32(0.0)
diff := make([]float32, len(data))

for i, v := range data {
if v == 0.0 {
continue
}

diff[i] = 1.0

inputSum += outputData[i]
nonZeros++
}

sum := float32(0.0)

if nonZeros > 0 {
sum = inputSum / nonZeros
}

for i := range diff {
diff[i] *= (outputData[i] - sum)
}

return diff
}

func (op *sparsemaxDiffOp) float64sparseMaxDiff(data, outputData []float64) interface{} {
nonZeros := 0.0
inputSum := 0.0
diff := make([]float64, len(data))
Expand All @@ -300,14 +408,7 @@ func (op *sparsemaxDiffOp) Do(inputs ...Value) (Value, error) {
diff[i] *= (outputData[i] - sum)
}

val := tensor.New(
tensor.Of(inputTensor.Dtype()),
tensor.WithShape(inputTensor.Size()),
tensor.WithEngine(inputTensor.Engine()),
tensor.WithBacking(diff),
)

return val, nil
return diff
}

// ensure it complies with the Op interface
Expand Down
97 changes: 77 additions & 20 deletions op_sparsemax_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,103 @@ import (
)

var testCasesSparseMaxDo = []struct {
input []float64
expected []float64
size int
input interface{}
expected interface{}
}{
{
[]float64{0.3, 0.1, 1.2, 2.3}, []float64{1.3, 1.1, 2.2, 3.3},
4, []float64{0.3, 0.1, 1.2, 2.3}, []float64{1.3, 1.1, 2.2, 3.3},
},
{
[]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
10, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
{
[]float64{0.1, 0.1, 0.1}, []float64{0.3666666666666667, 0.3666666666666667, 0.3666666666666667},
3, []float64{0.1, 0.1, 0.1}, []float64{0.3666666666666667, 0.3666666666666667, 0.3666666666666667},
},
{
[]float64{-0.1, 0.3, -1.1, 2.7}, []float64{0.9, 1.3, 0, 3.7},
4, []float64{-0.1, 0.3, -1.1, 2.7}, []float64{0.9, 1.3, 0, 3.7},
},
{
4, []float32{0.3, 0.1, 1.2, 2.3}, []float32{1.3, 1.1, 2.2, 3.3},
},
{
10, []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float32{2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
{
3, []float32{0.1, 0.1, 0.1}, []float32{0.36666664, 0.36666664, 0.36666664},
},
{
4, []float32{-0.1, 0.3, -1.1, 2.7}, []float32{0.9, 1.3, 0, 3.7},
},
}

var testCasesSparseMaxDoDiff = []struct {
input []float64
grad []float64
size int
input interface{}
grad interface{}

expected []float64
expected interface{}
}{
{
5,
[]float64{0.0000, 0.0000, 0.0521, 0.2354, 0.7124},
[]float64{0.2860, -0.0702, 0.8080, 0.9913, 1.4683},
[]float64{0, 0, -0.2811999999999999, -0.09789999999999999, 0.3791},
},
{
5,
[]float64{0.0556, 0.0000, 0.7118, 0.2325, 0.0000},
[]float64{0.1109, -1.4741, 0.7671, 0.2878, 0.0334},
[]float64{-0.2777, -0.0000, 0.3785, -0.1008, -0.0000},
},
{
5,
[]float64{0.2841, 0.0000, 0.7159, 0.0000, 0.0000},
[]float64{0.2094, -1.0000, 0.6411, -0.5032, -0.3909},
[]float64{-0.21585000000000001, 0, 0.21585, 0, 0},
},
{
5,
[]float64{0.2592, 0.0000, 0.6909, 0.0498, 0.0000},
[]float64{0.2094, -1.0000, 0.6411, 0.0000, -0.3909},
[]float64{-0.07410000000000003, 0, 0.3576, -0.28350000000000003, 0},
},
{
5,
[]float32{0.0000, 0.0000, 0.0521, 0.2354, 0.7124},
[]float32{0.2860, -0.0702, 0.8080, 0.9913, 1.4683},
[]float32{-0, -0, -0.2812, -0.09790003, 0.37909997},
},
{
5,
[]float32{0.0556, 0.0000, 0.7118, 0.2325, 0.0000},
[]float32{0.1109, -1.4741, 0.7671, 0.2878, 0.0334},
[]float32{-0.2777, -0, 0.37849998, -0.10079998, -0},
},
{
5,
[]float32{0.2841, 0.0000, 0.7159, 0.0000, 0.0000},
[]float32{0.2094, -1.0000, 0.6411, -0.5032, -0.3909},
[]float32{-0.21585, -0, 0.21585, -0, -0},
},
{
5,
[]float32{0.2592, 0.0000, 0.6909, 0.0498, 0.0000},
[]float32{0.2094, -1.0000, 0.6411, 0.0000, -0.3909},
[]float32{-0.07409999, -0, 0.3576, -0.2835, -0},
},
}

func TestSparsemaxDo(t *testing.T) {
c := require.New(t)

for i, testCase := range testCasesSparseMaxDo {
tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(testCase.size), tensor.WithBacking(testCase.input))
op := newSparsemaxOp()

out, err := op.Do(tt)
c.NoError(err, "failed test case: %d", i)
c.Equal(testCase.expected, out.Data())
c.Equal(testCase.expected, out.Data(), "failed test case: %d", i)
}
}

Expand All @@ -78,9 +120,18 @@ func TestSparsemaxDoDiff(t *testing.T) {
r, err := ApplyOp(op, a)
c.NoError(err)

aT := tensor.New(tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
bT := tensor.New(tensor.WithShape(len(testCase.grad)), tensor.WithBacking(testCase.grad))
rT := tensor.New(tensor.WithShape(len(testCase.input)), tensor.WithBacking(make([]float64, len(testCase.grad))))
var backing interface{}

switch testCase.input.(type) {
case []float64:
backing = make([]float64, testCase.size)
case []float32:
backing = make([]float32, testCase.size)
}

aT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(testCase.input))
bT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(testCase.grad))
rT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(backing))

aVal, _, _, _ := anyToValue(aT)
bVal, _, _, _ := anyToValue(bT)
Expand All @@ -105,8 +156,8 @@ func TestSparsemaxDoSymDiff(t *testing.T) {
a := NewTensor(g, Float64, 1, WithName("a"), WithShape(1))
b := NewTensor(g, Float64, 1, WithName("b"), WithShape(1))

aT := tensor.New(tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
bT := tensor.New(tensor.WithShape(len(testCase.grad)), tensor.WithBacking(testCase.grad))
aT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(testCase.input))
bT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(testCase.grad))

aVal, _, _, _ := anyToValue(aT)
bVal, _, _, _ := anyToValue(bT)
Expand All @@ -125,19 +176,25 @@ func TestSparsemaxDoSymDiff(t *testing.T) {
c.NoError(vm.RunAll())
c.NoError(vm.Close())

c.Equal(testCase.expected, diff[0].boundTo.Data())
c.Equal(testCase.expected, diff[0].boundTo.Data(), "failed test case: %d", i)
}
}

func TestSparsemaxFull(t *testing.T) {
c := require.New(t)

for i, testCase := range testCasesSparseMaxDo {
tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
expected := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.expected)), tensor.WithBacking(testCase.expected))
dtype := tensor.Float64

if _, ok := testCase.input.([]float32); ok {
dtype = tensor.Float32
}

tt := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size), tensor.WithBacking(testCase.input))
expected := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size), tensor.WithBacking(testCase.expected))

g := NewGraph()
inp := NewTensor(g, tensor.Float64, 1, WithShape(len(testCase.input)), WithName("inp"))
inp := NewTensor(g, dtype, 1, WithShape(testCase.size), WithName("inp"))
out := Must(Sparsemax(inp))

vm := NewTapeMachine(g)
Expand Down

0 comments on commit d26fe26

Please sign in to comment.