Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for float32 in the SparseMax op #439

Merged
merged 2 commits into from
Oct 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 118 additions & 17 deletions op_sparsemax.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func newSparsemaxOp() *sparsemaxOp {
}

// Sparsemax - implements the sparsemax operation described here: http://proceedings.mlr.press/v48/martins16.pdf
// Current implementation only supports float64
func Sparsemax(x *Node) (*Node, error) {
op := newSparsemaxOp()

Expand Down Expand Up @@ -84,7 +83,64 @@ func (op *sparsemaxOp) Do(inputs ...Value) (Value, error) {
return nil, fmt.Errorf("Can't check Sparsemax input: %w", err)
}

var output interface{}

switch inputTensor.Dtype() {
case tensor.Float64:
output = op.float64sparseMax(inputTensor)
case tensor.Float32:
output = op.float32sparseMax(inputTensor)
default:
return nil, fmt.Errorf("invalid input type for Sparsemax, expected float64 or float32, got: %v", inputTensor.Dtype())
}

return tensor.New(tensor.Of(inputTensor.Dtype()), tensor.WithShape(inputTensor.Size()), tensor.WithEngine(inputTensor.Engine()), tensor.WithBacking(output)), nil
}

// FIXME: go2 generics
func (op *sparsemaxOp) float32sparseMax(inputTensor tensor.Tensor) interface{} {
sortedData := make([]float32, inputTensor.Size())

copy(sortedData, inputTensor.Data().([]float32))

sort.Slice(sortedData, func(i, j int) bool {
return sortedData[i] > sortedData[j]
})

kArray := make([]float32, len(sortedData))
cumArray := make([]float32, len(sortedData))
cumSum := float32(0.0)
maxIndex := 0

for i := 0; i < len(sortedData); i++ {
kArray[i] = 1 + float32(i)*sortedData[i]
cumSum += sortedData[i]

cumArray[i] = cumSum - sortedData[i]

if kArray[i] > cumArray[i] {
maxIndex = i + 1
}
}

threshold := float32(cumArray[maxIndex-1]-1) / float32(maxIndex)
output := make([]float32, inputTensor.Size())

for i := 0; i < inputTensor.Size(); i++ {
v, _ := inputTensor.At(i)
vF := v.(float32)

if vF-threshold > 0 {
output[i] = vF - threshold
}
}

return output
}

func (op *sparsemaxOp) float64sparseMax(inputTensor tensor.Tensor) interface{} {
sortedData := make([]float64, inputTensor.Size())

copy(sortedData, inputTensor.Data().([]float64))

sort.Slice(sortedData, func(i, j int) bool {
Expand Down Expand Up @@ -119,7 +175,7 @@ func (op *sparsemaxOp) Do(inputs ...Value) (Value, error) {
}
}

return tensor.New(tensor.Of(inputTensor.Dtype()), tensor.WithShape(inputTensor.Size()), tensor.WithEngine(inputTensor.Engine()), tensor.WithBacking(output)), nil
return output
}

// DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface.
Expand Down Expand Up @@ -229,7 +285,7 @@ func (op *sparsemaxDiffOp) checkInput(inputs ...Value) (tensor.Tensor, tensor.Te
switch t := inputs[0].(type) {
case *dualValue:
if in, ok = t.Value.(tensor.Tensor); !ok {
return nil, nil, errors.Errorf("input should be a tensor, got %T", inputs[0])
return nil, nil, errors.Errorf("input should be a tensor.Tensor, got %T", inputs[0])
}
case tensor.Tensor:
in = t
Expand Down Expand Up @@ -265,16 +321,68 @@ func (op *sparsemaxDiffOp) Do(inputs ...Value) (Value, error) {
return nil, fmt.Errorf("sparsemaxDiffOp.Do inputs sizes should be equal")
}

data, ok := inputTensor.Data().([]float64)
if !ok {
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float64, got %T", inputTensor.Data())
var diff interface{}

switch inputTensor.Dtype() {
case tensor.Float64:
outputData, ok := gradTensor.Data().([]float64)
if !ok {
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float64, got %T", inputTensor.Data())
}

diff = op.float64sparseMaxDiff(inputTensor.Data().([]float64), outputData)
case tensor.Float32:
outputData, ok := gradTensor.Data().([]float32)
if !ok {
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float32, got %T", inputTensor.Data())
}

diff = op.float32sparseMaxDiff(inputTensor.Data().([]float32), outputData)
default:
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float64 or []float32, got %T", inputTensor.Data())
}

outputData, ok := gradTensor.Data().([]float64)
if !ok {
return nil, fmt.Errorf("sparsemaxDiffOp.Do expected input to be []float64, got %T", inputTensor.Data())
val := tensor.New(
tensor.Of(inputTensor.Dtype()),
tensor.WithShape(inputTensor.Size()),
tensor.WithEngine(inputTensor.Engine()),
tensor.WithBacking(diff),
)

return val, nil
}

// FIXME: go2 generics
func (op *sparsemaxDiffOp) float32sparseMaxDiff(data, outputData []float32) interface{} {
nonZeros := float32(0.0)
inputSum := float32(0.0)
diff := make([]float32, len(data))

for i, v := range data {
if v == 0.0 {
continue
}

diff[i] = 1.0

inputSum += outputData[i]
nonZeros++
}

sum := float32(0.0)

if nonZeros > 0 {
sum = inputSum / nonZeros
}

for i := range diff {
diff[i] *= (outputData[i] - sum)
}

return diff
}

func (op *sparsemaxDiffOp) float64sparseMaxDiff(data, outputData []float64) interface{} {
nonZeros := 0.0
inputSum := 0.0
diff := make([]float64, len(data))
Expand All @@ -300,14 +408,7 @@ func (op *sparsemaxDiffOp) Do(inputs ...Value) (Value, error) {
diff[i] *= (outputData[i] - sum)
}

val := tensor.New(
tensor.Of(inputTensor.Dtype()),
tensor.WithShape(inputTensor.Size()),
tensor.WithEngine(inputTensor.Engine()),
tensor.WithBacking(diff),
)

return val, nil
return diff
}

// ensure it complies with the Op interface
Expand Down
97 changes: 77 additions & 20 deletions op_sparsemax_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,103 @@ import (
)

var testCasesSparseMaxDo = []struct {
input []float64
expected []float64
size int
input interface{}
expected interface{}
}{
{
[]float64{0.3, 0.1, 1.2, 2.3}, []float64{1.3, 1.1, 2.2, 3.3},
4, []float64{0.3, 0.1, 1.2, 2.3}, []float64{1.3, 1.1, 2.2, 3.3},
},
{
[]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
10, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
{
[]float64{0.1, 0.1, 0.1}, []float64{0.3666666666666667, 0.3666666666666667, 0.3666666666666667},
3, []float64{0.1, 0.1, 0.1}, []float64{0.3666666666666667, 0.3666666666666667, 0.3666666666666667},
},
{
[]float64{-0.1, 0.3, -1.1, 2.7}, []float64{0.9, 1.3, 0, 3.7},
4, []float64{-0.1, 0.3, -1.1, 2.7}, []float64{0.9, 1.3, 0, 3.7},
},
{
4, []float32{0.3, 0.1, 1.2, 2.3}, []float32{1.3, 1.1, 2.2, 3.3},
},
{
10, []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float32{2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
{
3, []float32{0.1, 0.1, 0.1}, []float32{0.36666664, 0.36666664, 0.36666664},
},
{
4, []float32{-0.1, 0.3, -1.1, 2.7}, []float32{0.9, 1.3, 0, 3.7},
},
}

var testCasesSparseMaxDoDiff = []struct {
input []float64
grad []float64
size int
input interface{}
grad interface{}

expected []float64
expected interface{}
}{
{
5,
[]float64{0.0000, 0.0000, 0.0521, 0.2354, 0.7124},
[]float64{0.2860, -0.0702, 0.8080, 0.9913, 1.4683},
[]float64{0, 0, -0.2811999999999999, -0.09789999999999999, 0.3791},
},
{
5,
[]float64{0.0556, 0.0000, 0.7118, 0.2325, 0.0000},
[]float64{0.1109, -1.4741, 0.7671, 0.2878, 0.0334},
[]float64{-0.2777, -0.0000, 0.3785, -0.1008, -0.0000},
},
{
5,
[]float64{0.2841, 0.0000, 0.7159, 0.0000, 0.0000},
[]float64{0.2094, -1.0000, 0.6411, -0.5032, -0.3909},
[]float64{-0.21585000000000001, 0, 0.21585, 0, 0},
},
{
5,
[]float64{0.2592, 0.0000, 0.6909, 0.0498, 0.0000},
[]float64{0.2094, -1.0000, 0.6411, 0.0000, -0.3909},
[]float64{-0.07410000000000003, 0, 0.3576, -0.28350000000000003, 0},
},
{
5,
[]float32{0.0000, 0.0000, 0.0521, 0.2354, 0.7124},
[]float32{0.2860, -0.0702, 0.8080, 0.9913, 1.4683},
[]float32{-0, -0, -0.2812, -0.09790003, 0.37909997},
},
{
5,
[]float32{0.0556, 0.0000, 0.7118, 0.2325, 0.0000},
[]float32{0.1109, -1.4741, 0.7671, 0.2878, 0.0334},
[]float32{-0.2777, -0, 0.37849998, -0.10079998, -0},
},
{
5,
[]float32{0.2841, 0.0000, 0.7159, 0.0000, 0.0000},
[]float32{0.2094, -1.0000, 0.6411, -0.5032, -0.3909},
[]float32{-0.21585, -0, 0.21585, -0, -0},
},
{
5,
[]float32{0.2592, 0.0000, 0.6909, 0.0498, 0.0000},
[]float32{0.2094, -1.0000, 0.6411, 0.0000, -0.3909},
[]float32{-0.07409999, -0, 0.3576, -0.2835, -0},
},
}

func TestSparsemaxDo(t *testing.T) {
c := require.New(t)

for i, testCase := range testCasesSparseMaxDo {
tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(testCase.size), tensor.WithBacking(testCase.input))
op := newSparsemaxOp()

out, err := op.Do(tt)
c.NoError(err, "failed test case: %d", i)
c.Equal(testCase.expected, out.Data())
c.Equal(testCase.expected, out.Data(), "failed test case: %d", i)
}
}

Expand All @@ -78,9 +120,18 @@ func TestSparsemaxDoDiff(t *testing.T) {
r, err := ApplyOp(op, a)
c.NoError(err)

aT := tensor.New(tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
bT := tensor.New(tensor.WithShape(len(testCase.grad)), tensor.WithBacking(testCase.grad))
rT := tensor.New(tensor.WithShape(len(testCase.input)), tensor.WithBacking(make([]float64, len(testCase.grad))))
var backing interface{}

switch testCase.input.(type) {
case []float64:
backing = make([]float64, testCase.size)
case []float32:
backing = make([]float32, testCase.size)
}

aT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(testCase.input))
bT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(testCase.grad))
rT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(backing))

aVal, _, _, _ := anyToValue(aT)
bVal, _, _, _ := anyToValue(bT)
Expand All @@ -105,8 +156,8 @@ func TestSparsemaxDoSymDiff(t *testing.T) {
a := NewTensor(g, Float64, 1, WithName("a"), WithShape(1))
b := NewTensor(g, Float64, 1, WithName("b"), WithShape(1))

aT := tensor.New(tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
bT := tensor.New(tensor.WithShape(len(testCase.grad)), tensor.WithBacking(testCase.grad))
aT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(testCase.input))
bT := tensor.New(tensor.WithShape(testCase.size), tensor.WithBacking(testCase.grad))

aVal, _, _, _ := anyToValue(aT)
bVal, _, _, _ := anyToValue(bT)
Expand All @@ -125,19 +176,25 @@ func TestSparsemaxDoSymDiff(t *testing.T) {
c.NoError(vm.RunAll())
c.NoError(vm.Close())

c.Equal(testCase.expected, diff[0].boundTo.Data())
c.Equal(testCase.expected, diff[0].boundTo.Data(), "failed test case: %d", i)
}
}

func TestSparsemaxFull(t *testing.T) {
c := require.New(t)

for i, testCase := range testCasesSparseMaxDo {
tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
expected := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.expected)), tensor.WithBacking(testCase.expected))
dtype := tensor.Float64

if _, ok := testCase.input.([]float32); ok {
dtype = tensor.Float32
}

tt := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size), tensor.WithBacking(testCase.input))
expected := tensor.New(tensor.Of(dtype), tensor.WithShape(testCase.size), tensor.WithBacking(testCase.expected))

g := NewGraph()
inp := NewTensor(g, tensor.Float64, 1, WithShape(len(testCase.input)), WithName("inp"))
inp := NewTensor(g, dtype, 1, WithShape(testCase.size), WithName("inp"))
out := Must(Sparsemax(inp))

vm := NewTapeMachine(g)
Expand Down