Skip to content

Commit

Permalink
allow scheduler functions pertaining to a job to be called in any ord…
Browse files Browse the repository at this point in the history
…er (#453)
  • Loading branch information
JohnRoesler committed Apr 14, 2023
1 parent 23c3543 commit c283137
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 104 deletions.
61 changes: 18 additions & 43 deletions scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,17 @@ type Scheduler struct {
updateJob bool // so the scheduler knows to create a new job or update the current
waitForInterval bool // defaults jobs to waiting for first interval to start
singletonMode bool // defaults all jobs to use SingletonMode()
jobCreated bool // so the scheduler knows a job was created prior to calling Every or Cron

startBlockingStopChanMutex sync.Mutex
startBlockingStopChan chan struct{} // stops the scheduler

// tracks whether we're in a chain of scheduling methods for a job
// a chain is started with any of the scheduler methods that operate
// upon a job and are ended with one of [ Do(), Update() ] - note that
// Update() calls Do(), so really they all end with Do().
// This allows the caller to begin with any job related scheduler method
// and only with one of [ Every(), EveryRandom(), Cron(), CronWithSeconds(), MonthFirstWeekday() ]
inScheduleChain bool
}

// days in a week
Expand Down Expand Up @@ -497,22 +504,9 @@ func (s *Scheduler) NextRun() (*Job, time.Time) {
// The default unit is Seconds(). Call a different unit in the chain
// if you would like to change that. For example, Minutes(), Hours(), etc.
func (s *Scheduler) EveryRandom(lower, upper int) *Scheduler {
job := s.newJob(0)
if s.updateJob || s.jobCreated {
job = s.getCurrentJob()
}
job := s.getCurrentJob()

job.setRandomInterval(lower, upper)

if s.updateJob || s.jobCreated {
s.setJobs(append(s.Jobs()[:len(s.Jobs())-1], job))
if s.jobCreated {
s.jobCreated = false
}
} else {
s.setJobs(append(s.Jobs(), job))
}

return s
}

Expand All @@ -521,10 +515,7 @@ func (s *Scheduler) EveryRandom(lower, upper int) *Scheduler {
// parses with time.ParseDuration().
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (s *Scheduler) Every(interval any) *Scheduler {
job := s.newJob(0)
if s.updateJob || s.jobCreated {
job = s.getCurrentJob()
}
job := s.getCurrentJob()

switch interval := interval.(type) {
case int:
Expand All @@ -547,15 +538,6 @@ func (s *Scheduler) Every(interval any) *Scheduler {
job.error = wrapOrError(job.error, ErrInvalidIntervalType)
}

if s.updateJob || s.jobCreated {
s.setJobs(append(s.Jobs()[:len(s.Jobs())-1], job))
if s.jobCreated {
s.jobCreated = false
}
} else {
s.setJobs(append(s.Jobs(), job))
}

return s
}

Expand Down Expand Up @@ -871,6 +853,7 @@ func (s *Scheduler) stopJobs(jobs []*Job) {

func (s *Scheduler) doCommon(jobFun any, params ...any) (*Job, error) {
job := s.getCurrentJob()
s.inScheduleChain = false

jobUnit := job.getUnit()
jobLastRun := job.LastRun()
Expand Down Expand Up @@ -1207,17 +1190,17 @@ func (s *Scheduler) Sunday() *Scheduler {
}

func (s *Scheduler) getCurrentJob() *Job {

if len(s.Jobs()) == 0 {
s.setJobs([]*Job{s.newJob(0)})
s.jobCreated = true
if !s.inScheduleChain {
s.jobsMutex.Lock()
s.jobs = append(s.jobs, s.newJob(0))
s.jobsMutex.Unlock()
s.inScheduleChain = true
}

s.jobsMutex.RLock()
defer s.jobsMutex.RUnlock()

return s.jobs[len(s.jobs)-1]

}

func (s *Scheduler) now() time.Time {
Expand All @@ -1243,6 +1226,7 @@ func (s *Scheduler) Job(j *Job) *Scheduler {
s.Swap(len(jobs)-1, index)
}
}
s.inScheduleChain = true
s.updateJob = true
return s
}
Expand Down Expand Up @@ -1281,10 +1265,7 @@ func (s *Scheduler) CronWithSeconds(cronExpression string) *Scheduler {
}

func (s *Scheduler) cron(cronExpression string, withSeconds bool) *Scheduler {
job := s.newJob(0)
if s.updateJob || s.jobCreated {
job = s.getCurrentJob()
}
job := s.getCurrentJob()

var withLocation string
if strings.HasPrefix(cronExpression, "TZ=") || strings.HasPrefix(cronExpression, "CRON_TZ=") {
Expand Down Expand Up @@ -1313,12 +1294,6 @@ func (s *Scheduler) cron(cronExpression string, withSeconds bool) *Scheduler {
job.setUnit(crontab)
job.startsImmediately = false

if s.updateJob || s.jobCreated {
s.setJobs(append(s.Jobs()[:len(s.Jobs())-1], job))
s.jobCreated = false
} else {
s.setJobs(append(s.Jobs(), job))
}
return s
}

Expand Down
163 changes: 102 additions & 61 deletions scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func TestAt(t *testing.T) {
atTime3 := time.Now().UTC().Add(time.Hour * -1).Round(time.Second)

s := NewScheduler(time.UTC)
job, err := s.Every(1).Week().At(atTime1).At(atTime2).At(atTime3).Do(func() {})
job, err := s.Week().At(atTime1).At(atTime2).At(atTime3).Every(1).Do(func() {})
require.NoError(t, err)
s.StartAsync()

Expand Down Expand Up @@ -512,11 +512,11 @@ func TestScheduler_Remove(t *testing.T) {
t.Run("remove from non-running", func(t *testing.T) {
s := NewScheduler(time.UTC)
s.TagsUnique()
_, err := s.Every(1).Minute().Tag("tag1").Do(task)
_, err := s.Minute().Tag("tag1").Every(1).Do(task)
require.NoError(t, err)
_, err = s.Every(1).Minute().Do(taskWithParams, 1, "hello")
require.NoError(t, err)
_, err = s.Every(1).Minute().Do(task)
_, err = s.Minute().Every(1).Do(task)
require.NoError(t, err)

assert.Equal(t, 3, s.Len(), "Incorrect number of jobs")
Expand Down Expand Up @@ -1732,17 +1732,17 @@ func TestScheduler_Job(t *testing.T) {

j1, err := s.Every("1s").Do(func() {})
require.NoError(t, err)
assert.Equal(t, j1, s.getCurrentJob())
assert.Equal(t, j1, s.jobs[len(s.jobs)-1])

j2, err := s.Every("1s").Do(func() {})
require.NoError(t, err)
assert.Equal(t, j2, s.getCurrentJob())
assert.Equal(t, j2, s.jobs[len(s.jobs)-1])

s.Job(j1)
assert.Equal(t, j1, s.getCurrentJob())
assert.Equal(t, j1, s.jobs[len(s.jobs)-1])

s.Job(j2)
assert.Equal(t, j2, s.getCurrentJob())
assert.Equal(t, j2, s.jobs[len(s.jobs)-1])
}

func TestScheduler_Update(t *testing.T) {
Expand Down Expand Up @@ -2038,10 +2038,18 @@ func TestScheduler_WaitForSchedules(t *testing.T) {
var counterMutex sync.RWMutex
counter := 0

_, err := s.Every("1s").Do(func() { counterMutex.Lock(); defer counterMutex.Unlock(); counter++ })
_, err := s.Every("1s").Do(func() {
counterMutex.Lock()
defer counterMutex.Unlock()
counter++
})
require.NoError(t, err)

_, err = s.CronWithSeconds("*/1 * * * * *").Do(func() { counterMutex.Lock(); defer counterMutex.Unlock(); counter++ })
_, err = s.CronWithSeconds("*/1 * * * * *").Do(func() {
counterMutex.Lock()
defer counterMutex.Unlock()
counter++
})
require.NoError(t, err)
s.StartAsync()

Expand Down Expand Up @@ -2433,65 +2441,98 @@ func TestScheduler_DoWithJobDetails(t *testing.T) {
})
}

func TestScheduler_GetAllTags(t *testing.T) {
t.Run("tags unique", func(t *testing.T) {
testCases := []struct {
description string
tags []string
expected []string
}{
{"no tags", []string{}, nil},
{"one tag", []string{"tag1"}, []string{"tag1"}},
{"two tags", []string{"tag1", "tag2"}, []string{"tag1", "tag2"}},
{"two tags with duplicates", []string{"tag1", "tag2", "tag1"}, []string{"tag1", "tag2"}},
}
func TestScheduler_GetAllTags_Unique(t *testing.T) {
testCases := []struct {
description string
tags []string
expected []string
expectedError error
}{
{"unique: no tags", []string{}, nil, nil},
{"unique: one tag", []string{"tag1"}, []string{"tag1"}, nil},
{"unique: two tags", []string{"tag1", "tag2"}, []string{"tag1", "tag2"}, nil},
{"unique: two tags with duplicates", []string{"tag1", "tag2", "tag1"}, nil, ErrTagsUnique("tag1")},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
s := NewScheduler(time.UTC)
s.TagsUnique()
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
s := NewScheduler(time.UTC)
s.TagsUnique()

for _, tag := range tc.tags {
_, err := s.Tag(tag).Every("100ms").Do(func() {})
require.NoError(t, err)
}
_, err := s.Tag(tc.tags...).Every("100ms").Do(func() {})
if tc.expectedError == nil {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, tc.expectedError.Error())
}

tags := s.GetAllTags()
sort.Strings(tc.expected)
sort.Strings(tags)
tags := s.GetAllTags()
sort.Strings(tc.expected)
sort.Strings(tags)

assert.Equal(t, tc.expected, tags)
})
}
})
assert.Equal(t, tc.expected, tags)
})
}
}

t.Run("tags not unique", func(t *testing.T) {
testCases := []struct {
description string
tags []string
expected []string
}{
{"no tags", []string{}, nil},
{"one tag", []string{"tag1"}, []string{"tag1"}},
{"two tags", []string{"tag1", "tag2"}, []string{"tag1", "tag2"}},
{"two tags with duplicates", []string{"tag1", "tag2", "tag1"}, []string{"tag1", "tag2", "tag1"}},
}
func TestScheduler_GetAllTags_NotUnique(t *testing.T) {
testCases := []struct {
description string
tags []string
expected []string
}{
{"no tags", []string{}, nil},
{"one tag", []string{"tag1"}, []string{"tag1"}},
{"two tags", []string{"tag1", "tag2"}, []string{"tag1", "tag2"}},
{"two tags with duplicates", []string{"tag1", "tag2", "tag1"}, []string{"tag1", "tag2", "tag1"}},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
s := NewScheduler(time.UTC)
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
s := NewScheduler(time.UTC)

for _, tag := range tc.tags {
_, err := s.Tag(tag).Every("100ms").Do(func() {})
require.NoError(t, err)
}
_, err := s.Tag(tc.tags...).Every("100ms").Do(func() {})
require.NoError(t, err)

tags := s.GetAllTags()
sort.Strings(tc.expected)
sort.Strings(tags)
tags := s.GetAllTags()
sort.Strings(tc.expected)
sort.Strings(tags)

assert.Equal(t, tc.expected, tags)
})
}
}

func TestScheduler_ChainOrder(t *testing.T) {
s := NewScheduler(time.UTC)

func1 := func() { panic("func 1 not implemented") }
func2 := func() { panic("func 2 not implemented") }
func3 := func() { panic("func 3 not implemented") }

funcs := []any{func1, func2, func3}

_, err := s.Tag("1").SingletonMode().Milliseconds().EveryRandom(100, 200).Do(func1)
require.NoError(t, err)

_, err = s.Monday().Every(4).Tag("2").Do(func2)
require.NoError(t, err)

_, err = s.Months(1).Tag("3").Every(1).SingletonMode().At("1:00").Do(func3)
require.NoError(t, err)

require.Len(t, s.jobs, 3)
for i, j := range s.jobs {
assert.Equal(t, fmt.Sprint(funcs[i]), fmt.Sprint(j.function))
}

err = s.RemoveByTag("2")
require.NoError(t, err)

require.Len(t, s.jobs, 2)
funcs = append(funcs[:1], funcs[2])
for i, j := range s.jobs {
assert.Equal(t, fmt.Sprint(funcs[i]), fmt.Sprint(j.function))
}

assert.Equal(t, tc.expected, tags)
})
}
})
}

0 comments on commit c283137

Please sign in to comment.