Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 120 additions & 18 deletions iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"strconv"
"strings"
"time"

"github.com/google/go-github/v68/github"
)
Expand Down Expand Up @@ -123,7 +124,10 @@ func (it *Iterator[T, O]) All() iter.Seq[T] {
vals[k] = v[0]
}

updateOptions(it.opt, vals)
if err := updateOptions(it.opt, vals); err != nil {
it.err = err
return
}
}
}
}
Expand Down Expand Up @@ -200,10 +204,17 @@ func (it *Iterator[T, O]) do() ([]T, *github.Response, error) {
return nil, nil, errors.New("no func provided")
}

var (
stringTypePtr *string
intTypePtr *int
int64TypePtr *int64
boolTypePtr *bool
)

// updateOptions will update the github options based on the provided map and the `url` tag.
// If the field in the struct has a `url` tag it tries to set the value of the field from the one
// found in the map, if any.
func updateOptions(v any, m map[string]string) {
func updateOptions(v any, m map[string]string) error {
valueOf := reflect.ValueOf(v)
typeOf := reflect.TypeOf(v)

Expand All @@ -219,29 +230,120 @@ func updateOptions(v any, m map[string]string) {
// if field is of type struct then iterate over the pointer
if structField.Type.Kind() == reflect.Struct {
if fieldValue.CanAddr() {
updateOptions(fieldValue.Addr().Interface(), m)
if err := updateOptions(fieldValue.Addr().Interface(), m); err != nil {
return err
}
}
}

// otherwise check if it has a 'url' tag
urlTag := structField.Tag.Get("url")
if urlTag != "" {
urlParam := strings.Split(urlTag, ",")[0]

if fieldValue.IsValid() && fieldValue.CanSet() {
if v, found := m[urlParam]; found {
switch fieldValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if i, err := strconv.Atoi(v); err == nil {
fieldValue.SetInt(int64(i))
}
case reflect.Ptr:
fieldValue.Set(reflect.ValueOf(&v))
default:
fieldValue.Set(reflect.ValueOf(v))
}
if urlTag == "" {
continue
}

if !fieldValue.IsValid() || !fieldValue.CanSet() {
continue
}

urlParam := strings.Split(urlTag, ",")[0]
v, found := m[urlParam]
if !found {
continue
}

switch fieldValue.Kind() {

// handle string
case reflect.String:
fieldValue.Set(reflect.ValueOf(v))

// handle numeric types (int, int8, int16, int32, int64)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if i, err := strconv.Atoi(v); err == nil {
fieldValue.SetInt(int64(i))
}

// handle bool
case reflect.Bool:
parsedBool, err := strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("error while parsing string '%s' as bool: %s", v, err)
}
fieldValue.Set(reflect.ValueOf(parsedBool))

// handle pointers (*string, *int, *int64, *bool, *time.Time)
case reflect.Pointer:
switch fieldValue.Type() {

// handle *string
case reflect.TypeOf(stringTypePtr):
fieldValue.Set(reflect.ValueOf(&v))

// handle *int
case reflect.TypeOf(intTypePtr):
parsedInt, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("error while parsing string '%s' as int: %s", v, err)
}
fieldValue.Set(reflect.ValueOf(&parsedInt))

// handle *int64
case reflect.TypeOf(int64TypePtr):
parsedInt64, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return fmt.Errorf("error while parsing string '%s' as int64: %s", v, err)
}
fieldValue.Set(reflect.ValueOf(&parsedInt64))

// handle *bool
case reflect.TypeOf(boolTypePtr):
parsedBool, err := strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("error while parsing string '%s' as bool: %s", v, err)
}
fieldValue.Set(reflect.ValueOf(&parsedBool))

// handle *time.Time
case reflect.TypeOf(&time.Time{}):
layout := time.RFC3339
if len(v) == len(time.DateOnly) {
layout = time.DateOnly
}

result, err := time.Parse(layout, v)
if err != nil {
return fmt.Errorf("error while parsing string '%s' as time.Time: %s", v, err)
}

fieldValue.Set(reflect.ValueOf(&result))

default:
return fmt.Errorf("cannot set '%s' value to unknown pointer of '%s'", v, fieldValue.Type())
}

case reflect.Struct:
// handle time.Time
if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
layout := time.RFC3339
if len(v) == len(time.DateOnly) {
layout = time.DateOnly
}

result, err := time.Parse(layout, v)
if err != nil {
return fmt.Errorf("error while parsing string '%s' as time.Time: %s", v, err)
}

fieldValue.Set(reflect.ValueOf(result))
} else {
return fmt.Errorf("cannot set '%s' value to unknown struct '%s'", v, fieldValue.Type())
}

default:
return fmt.Errorf("cannot set '%s' value to unknown type '%s'", v, fieldValue.Type())
}
}

return nil
}
105 changes: 104 additions & 1 deletion iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ghiter
import (
"reflect"
"testing"
"time"

"github.com/google/go-github/v68/github"
)
Expand All @@ -13,6 +14,7 @@ func Test_updateOptions(t *testing.T) {
opts any
queryParams map[string]string
expectedOpts any
expectedErr bool
}{
{
name: "Simple Opts with ListOptions",
Expand Down Expand Up @@ -78,10 +80,111 @@ func Test_updateOptions(t *testing.T) {
},
},
},
{
name: "date RFC parse",
opts: &github.IssueListCommentsOptions{},
queryParams: map[string]string{
"since": "1989-10-02T00:00:00Z",
},
expectedOpts: &github.IssueListCommentsOptions{
Since: func() *time.Time {
since := time.Date(1989, time.October, 2, 0, 0, 0, 0, time.UTC)
return &since
}(),
},
},
{
name: "date DateTime parse",
opts: &github.IssueListCommentsOptions{},
queryParams: map[string]string{
"since": "1989-10-02",
},
expectedOpts: &github.IssueListCommentsOptions{
Since: func() *time.Time {
since := time.Date(1989, time.October, 2, 0, 0, 0, 0, time.UTC)
return &since
}(),
},
},
{
name: "wrong Date",
opts: &github.IssueListCommentsOptions{},
queryParams: map[string]string{
"since": "1923230Z",
},
expectedErr: true,
},
{
name: "Opts with pointers and multiple ListOptions",
opts: &github.CommitsListOptions{},
queryParams: map[string]string{
"since": "1989-10-02T00:00:00Z",
},
expectedOpts: &github.CommitsListOptions{
Since: time.Date(1989, time.October, 2, 0, 0, 0, 0, time.UTC),
},
},
{
name: "Opts with bool",
opts: &github.ListWorkflowRunsOptions{},
queryParams: map[string]string{
"exclude_pull_requests": "true",
},
expectedOpts: &github.ListWorkflowRunsOptions{
ExcludePullRequests: true,
},
},
{
name: "Opts with bool pointers",
opts: &github.WorkflowRunAttemptOptions{},
queryParams: map[string]string{
"exclude_pull_requests": "true",
},
expectedOpts: &github.WorkflowRunAttemptOptions{
ExcludePullRequests: func() *bool {
exclude := true
return &exclude
}(),
},
},
{
name: "Opts with int pointers",
opts: &github.ListSCIMProvisionedIdentitiesOptions{},
queryParams: map[string]string{
"count": "12",
},
expectedOpts: &github.ListSCIMProvisionedIdentitiesOptions{
Count: func() *int {
count := 12
return &count
}(),
},
},
{
name: "Opts with int64 pointers",
opts: &github.ListCheckRunsOptions{},
queryParams: map[string]string{
"app_id": "12",
},
expectedOpts: &github.ListCheckRunsOptions{
AppID: func() *int64 {
count := int64(12)
return &count
}(),
},
},
}

for _, tc := range tt {
updateOptions(tc.opts, tc.queryParams)
err := updateOptions(tc.opts, tc.queryParams)

if tc.expectedErr {
if err == nil {
t.Fatal("missing expected err\n\n")
}
continue
}

if !reflect.DeepEqual(tc.expectedOpts, tc.opts) {
t.Fatalf("structs are not equal:\nexpected:\t%+v\ngot:\t\t%+v\n", tc.expectedOpts, tc.opts)
}
Expand Down