Skip to content

Commit

Permalink
fix: support all types in StringSet JSON unmarshal (#1925)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiuker committed Jan 18, 2024
1 parent 76a4146 commit 6ad2b4a
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 7 deletions.
45 changes: 45 additions & 0 deletions pkg/policy/bucket-policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,51 @@ func TestUnmarshalBucketPolicy(t *testing.T) {
}
}
]
}`, shouldSucceed: true},
// Test 10
{policyData: `{
"Version": "2012-10-17",
"Statement": [{
"Effect": "Deny",
"Principal": {
"AWS": [
"*"
]
},
"Action": [
"s3:PutObject"
],
"Resource": [
"arn:aws:s3:::mytest/*"
],
"Condition": {
"Null": {
"s3:x-amz-server-side-encryption": [
true
]
}
}
}]
}`, shouldSucceed: true},
// Test 11
{policyData: `{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Deny",
"Principal": "*",
"Action": "s3:PutObject",
"Resource": [
"arn:aws:s3:::DOC-EXAMPLE-BUCKET1",
"arn:aws:s3:::DOC-EXAMPLE-BUCKET1/*"
],
"Condition": {
"NumericLessThan": {
"s3:TlsVersion": 1.2
}
}
}
]
}`, shouldSucceed: true},
}

Expand Down
11 changes: 4 additions & 7 deletions pkg/set/stringset.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,19 @@ func (set StringSet) MarshalJSON() ([]byte, error) {
}

// UnmarshalJSON - parses JSON data and creates new set with it.
// If 'data' contains JSON string array, the set contains each string.
// If 'data' contains JSON string, the set contains the string as one element.
// If 'data' contains Other JSON types, JSON parse error is returned.
func (set *StringSet) UnmarshalJSON(data []byte) error {
sl := []string{}
sl := []interface{}{}
var err error
if err = json.Unmarshal(data, &sl); err == nil {
*set = make(StringSet)
for _, s := range sl {
set.Add(s)
set.Add(fmt.Sprintf("%v", s))
}
} else {
var s string
var s interface{}
if err = json.Unmarshal(data, &s); err == nil {
*set = make(StringSet)
set.Add(s)
set.Add(fmt.Sprintf("%v", s))
}
}

Expand Down
105 changes: 105 additions & 0 deletions pkg/set/stringset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package set

import (
"fmt"
"reflect"
"sort"
"strings"
"testing"
)
Expand Down Expand Up @@ -346,3 +348,106 @@ func TestStringSetToSlice(t *testing.T) {
}
}
}

func TestStringSet_UnmarshalJSON(t *testing.T) {
type args struct {
data []byte
expectResult []string
}
tests := []struct {
name string
set StringSet
args args
wantErr bool
}{
{
name: "test strings",
set: NewStringSet(),
args: args{
data: []byte(`["foo","bar"]`),
expectResult: []string{"foo", "bar"},
},
wantErr: false,
},
{
name: "test string",
set: NewStringSet(),
args: args{
data: []byte(`"foo"`),
expectResult: []string{"foo"},
},
wantErr: false,
},
{
name: "test bools",
set: NewStringSet(),
args: args{
data: []byte(`[false,true]`),
expectResult: []string{"false", "true"},
},
wantErr: false,
},
{
name: "test bool",
set: NewStringSet(),
args: args{
data: []byte(`false`),
expectResult: []string{"false"},
},
wantErr: false,
},
{
name: "test ints",
set: NewStringSet(),
args: args{
data: []byte(`[1,2]`),
expectResult: []string{"1", "2"},
},
wantErr: false,
},
{
name: "test int",
set: NewStringSet(),
args: args{
data: []byte(`1`),
expectResult: []string{"1"},
},
wantErr: false,
},
{
name: "test floats",
set: NewStringSet(),
args: args{
data: []byte(`[1.1,2.2]`),
expectResult: []string{"1.1", "2.2"},
},
wantErr: false,
},
{
name: "test float",
set: NewStringSet(),
args: args{
data: []byte(`1.1`),
expectResult: []string{"1.1"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.set.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
slice := tt.set.ToSlice()
sort.Slice(slice, func(i, j int) bool {
return slice[i] < slice[j]
})
sort.Slice(tt.args.expectResult, func(i, j int) bool {
return tt.args.expectResult[i] < tt.args.expectResult[j]
})
if !reflect.DeepEqual(slice, tt.args.expectResult) {
t.Errorf("StringSet() get %v, want %v", tt.set.ToSlice(), tt.args.expectResult)
}
})
}
}

0 comments on commit 6ad2b4a

Please sign in to comment.