diff --git a/iter.go b/iter.go index 3a967e3..b7caf85 100644 --- a/iter.go +++ b/iter.go @@ -9,6 +9,7 @@ import ( "reflect" "strconv" "strings" + "time" "github.com/google/go-github/v68/github" ) @@ -123,7 +124,10 @@ func (it *Iterator[T, O]) All() iter.Seq[T] { vals[k] = v[0] } - updateOptions(it.opt, vals) + if err := updateOptions(it.opt, vals); err != nil { + it.err = err + return + } } } } @@ -200,10 +204,17 @@ func (it *Iterator[T, O]) do() ([]T, *github.Response, error) { return nil, nil, errors.New("no func provided") } +var ( + stringTypePtr *string + intTypePtr *int + int64TypePtr *int64 + boolTypePtr *bool +) + // updateOptions will update the github options based on the provided map and the `url` tag. // If the field in the struct has a `url` tag it tries to set the value of the field from the one // found in the map, if any. -func updateOptions(v any, m map[string]string) { +func updateOptions(v any, m map[string]string) error { valueOf := reflect.ValueOf(v) typeOf := reflect.TypeOf(v) @@ -219,29 +230,120 @@ func updateOptions(v any, m map[string]string) { // if field is of type struct then iterate over the pointer if structField.Type.Kind() == reflect.Struct { if fieldValue.CanAddr() { - updateOptions(fieldValue.Addr().Interface(), m) + if err := updateOptions(fieldValue.Addr().Interface(), m); err != nil { + return err + } } } // otherwise check if it has a 'url' tag urlTag := structField.Tag.Get("url") - if urlTag != "" { - urlParam := strings.Split(urlTag, ",")[0] - - if fieldValue.IsValid() && fieldValue.CanSet() { - if v, found := m[urlParam]; found { - switch fieldValue.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if i, err := strconv.Atoi(v); err == nil { - fieldValue.SetInt(int64(i)) - } - case reflect.Ptr: - fieldValue.Set(reflect.ValueOf(&v)) - default: - fieldValue.Set(reflect.ValueOf(v)) - } + if urlTag == "" { + continue + } + + if !fieldValue.IsValid() || !fieldValue.CanSet() { + continue + } + + urlParam := strings.Split(urlTag, ",")[0] + v, found := m[urlParam] + if !found { + continue + } + + switch fieldValue.Kind() { + + // handle string + case reflect.String: + fieldValue.Set(reflect.ValueOf(v)) + + // handle numeric types (int, int8, int16, int32, int64) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if i, err := strconv.Atoi(v); err == nil { + fieldValue.SetInt(int64(i)) + } + + // handle bool + case reflect.Bool: + parsedBool, err := strconv.ParseBool(v) + if err != nil { + return fmt.Errorf("error while parsing string '%s' as bool: %s", v, err) + } + fieldValue.Set(reflect.ValueOf(parsedBool)) + + // handle pointers (*string, *int, *int64, *bool, *time.Time) + case reflect.Pointer: + switch fieldValue.Type() { + + // handle *string + case reflect.TypeOf(stringTypePtr): + fieldValue.Set(reflect.ValueOf(&v)) + + // handle *int + case reflect.TypeOf(intTypePtr): + parsedInt, err := strconv.Atoi(v) + if err != nil { + return fmt.Errorf("error while parsing string '%s' as int: %s", v, err) } + fieldValue.Set(reflect.ValueOf(&parsedInt)) + + // handle *int64 + case reflect.TypeOf(int64TypePtr): + parsedInt64, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf("error while parsing string '%s' as int64: %s", v, err) + } + fieldValue.Set(reflect.ValueOf(&parsedInt64)) + + // handle *bool + case reflect.TypeOf(boolTypePtr): + parsedBool, err := strconv.ParseBool(v) + if err != nil { + return fmt.Errorf("error while parsing string '%s' as bool: %s", v, err) + } + fieldValue.Set(reflect.ValueOf(&parsedBool)) + + // handle *time.Time + case reflect.TypeOf(&time.Time{}): + layout := time.RFC3339 + if len(v) == len(time.DateOnly) { + layout = time.DateOnly + } + + result, err := time.Parse(layout, v) + if err != nil { + return fmt.Errorf("error while parsing string '%s' as time.Time: %s", v, err) + } + + fieldValue.Set(reflect.ValueOf(&result)) + + default: + return fmt.Errorf("cannot set '%s' value to unknown pointer of '%s'", v, fieldValue.Type()) } + + case reflect.Struct: + // handle time.Time + if fieldValue.Type() == reflect.TypeOf(time.Time{}) { + layout := time.RFC3339 + if len(v) == len(time.DateOnly) { + layout = time.DateOnly + } + + result, err := time.Parse(layout, v) + if err != nil { + return fmt.Errorf("error while parsing string '%s' as time.Time: %s", v, err) + } + + fieldValue.Set(reflect.ValueOf(result)) + } else { + return fmt.Errorf("cannot set '%s' value to unknown struct '%s'", v, fieldValue.Type()) + } + + default: + return fmt.Errorf("cannot set '%s' value to unknown type '%s'", v, fieldValue.Type()) } } + + return nil } diff --git a/iter_test.go b/iter_test.go index db78732..068ca35 100644 --- a/iter_test.go +++ b/iter_test.go @@ -3,6 +3,7 @@ package ghiter import ( "reflect" "testing" + "time" "github.com/google/go-github/v68/github" ) @@ -13,6 +14,7 @@ func Test_updateOptions(t *testing.T) { opts any queryParams map[string]string expectedOpts any + expectedErr bool }{ { name: "Simple Opts with ListOptions", @@ -78,10 +80,111 @@ func Test_updateOptions(t *testing.T) { }, }, }, + { + name: "date RFC parse", + opts: &github.IssueListCommentsOptions{}, + queryParams: map[string]string{ + "since": "1989-10-02T00:00:00Z", + }, + expectedOpts: &github.IssueListCommentsOptions{ + Since: func() *time.Time { + since := time.Date(1989, time.October, 2, 0, 0, 0, 0, time.UTC) + return &since + }(), + }, + }, + { + name: "date DateTime parse", + opts: &github.IssueListCommentsOptions{}, + queryParams: map[string]string{ + "since": "1989-10-02", + }, + expectedOpts: &github.IssueListCommentsOptions{ + Since: func() *time.Time { + since := time.Date(1989, time.October, 2, 0, 0, 0, 0, time.UTC) + return &since + }(), + }, + }, + { + name: "wrong Date", + opts: &github.IssueListCommentsOptions{}, + queryParams: map[string]string{ + "since": "1923230Z", + }, + expectedErr: true, + }, + { + name: "Opts with pointers and multiple ListOptions", + opts: &github.CommitsListOptions{}, + queryParams: map[string]string{ + "since": "1989-10-02T00:00:00Z", + }, + expectedOpts: &github.CommitsListOptions{ + Since: time.Date(1989, time.October, 2, 0, 0, 0, 0, time.UTC), + }, + }, + { + name: "Opts with bool", + opts: &github.ListWorkflowRunsOptions{}, + queryParams: map[string]string{ + "exclude_pull_requests": "true", + }, + expectedOpts: &github.ListWorkflowRunsOptions{ + ExcludePullRequests: true, + }, + }, + { + name: "Opts with bool pointers", + opts: &github.WorkflowRunAttemptOptions{}, + queryParams: map[string]string{ + "exclude_pull_requests": "true", + }, + expectedOpts: &github.WorkflowRunAttemptOptions{ + ExcludePullRequests: func() *bool { + exclude := true + return &exclude + }(), + }, + }, + { + name: "Opts with int pointers", + opts: &github.ListSCIMProvisionedIdentitiesOptions{}, + queryParams: map[string]string{ + "count": "12", + }, + expectedOpts: &github.ListSCIMProvisionedIdentitiesOptions{ + Count: func() *int { + count := 12 + return &count + }(), + }, + }, + { + name: "Opts with int64 pointers", + opts: &github.ListCheckRunsOptions{}, + queryParams: map[string]string{ + "app_id": "12", + }, + expectedOpts: &github.ListCheckRunsOptions{ + AppID: func() *int64 { + count := int64(12) + return &count + }(), + }, + }, } for _, tc := range tt { - updateOptions(tc.opts, tc.queryParams) + err := updateOptions(tc.opts, tc.queryParams) + + if tc.expectedErr { + if err == nil { + t.Fatal("missing expected err\n\n") + } + continue + } + if !reflect.DeepEqual(tc.expectedOpts, tc.opts) { t.Fatalf("structs are not equal:\nexpected:\t%+v\ngot:\t\t%+v\n", tc.expectedOpts, tc.opts) }