Skip to content

Commit

Permalink
cmdroute: Implement Router.{With,Group}
Browse files Browse the repository at this point in the history
This commit implements cmdroute.Router.Group and cmdroute.Router.With,
similar to go-chi's Mux.

Fixes #418
  • Loading branch information
diamondburned committed Feb 7, 2024
1 parent dbc4ae8 commit 19518a0
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 44 deletions.
73 changes: 62 additions & 11 deletions api/cmdroute/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

// Router is a router for slash commands. A zero-value Router is a valid router.
type Router struct {
nodes map[string]routeNode
mws []Middleware
stack []*Router
nodes map[string]routeNode
mws []Middleware
parent *Router // parent router, if any
groups []Router // next routers to check, if any
}

type routeNode interface {
Expand Down Expand Up @@ -44,9 +45,6 @@ func NewRouter() *Router {
}

func (r *Router) init() {
if r.stack == nil {
r.stack = []*Router{r}
}
if r.nodes == nil {
r.nodes = make(map[string]routeNode, 4)
}
Expand Down Expand Up @@ -75,7 +73,7 @@ func (r *Router) Use(mws ...Middleware) {
// parent command of the given name.
func (r *Router) Sub(name string, f func(r *Router)) {
sub := NewRouter()
sub.stack = append(append([]*Router(nil), r.stack...), sub)
sub.parent = r
f(sub)

r.add(name, routeNodeSub{sub})
Expand All @@ -92,6 +90,39 @@ func (r *Router) AddFunc(name string, f CommandHandlerFunc) {
r.Add(name, f)
}

// Group creates a subrouter that handles certain commands within the parent
// command. This is useful for assigning middlewares to a group of commands that
// belong to the same parent command.
//
// For example, consider the following:
//
// r := cmdroute.NewRouter()
// r.Group(func(r *cmdroute.Router) {
// r.Use(cmdroute.Deferrable(client, cmdroute.DeferOpts{}))
// r.Add("foo", handleFoo)
// })
// r.Add("bar", handleBar)
//
// In this example, the middleware is only applied to the "foo" command, and not
// the "bar" command.
func (r *Router) Group(f func(r *Router)) {
f(r.With())
}

// With is similar to Group, but it returns a new router instead of calling a
// function with a new router. This is useful for chaining middlewares once,
// such as:
//
// r := cmdroute.NewRouter()
// r.With(cmdroute.Deferrable(client, cmdroute.DeferOpts{})).Add("foo", handleFoo)
func (r *Router) With(mws ...Middleware) *Router {
r.groups = append(r.groups, Router{})
sub := &r.groups[len(r.groups)-1]
sub.parent = r
sub.mws = append(sub.mws, mws...)
return sub
}

// HandleInteraction implements webhook.InteractionHandler. It only handles
// events of type CommandInteraction, otherwise nil is returned.
func (r *Router) HandleInteraction(ev *discord.InteractionEvent) *api.InteractionResponse {
Expand All @@ -113,11 +144,11 @@ func (r *Router) callHandler(ev *discord.InteractionEvent, fn InteractionHandler
// Apply middlewares, parent last, first one added last. This ensures that
// when we call the handler, the middlewares are applied in the order they
// were added.
for i := len(r.stack) - 1; i >= 0; i-- {
r := r.stack[i]
for j := len(r.mws) - 1; j >= 0; j-- {
h = r.mws[j](h)
for r != nil {
for i := len(r.mws) - 1; i >= 0; i-- {
h = r.mws[i](h)
}
r = r.parent
}

return h.HandleInteraction(context.Background(), ev)
Expand Down Expand Up @@ -162,7 +193,27 @@ type handlerData struct {
data discord.CommandInteractionOption
}

// findCommandHandler finds the command handler for the given command name.
// It checks the current router and its groups.
func (r *Router) findCommandHandler(ev *discord.InteractionEvent, data discord.CommandInteractionOption) (handlerData, bool) {
found, ok := r.findCommandHandlerOnce(ev, data)
if ok {
return found, true
}

for _, sub := range r.groups {
found, ok = sub.findCommandHandlerOnce(ev, data)
if ok {
return found, true
}
}

return handlerData{}, false
}

// findCommandHandlerOnce finds the command handler for the given command name.
// It only checks the current router and not its groups.
func (r *Router) findCommandHandlerOnce(ev *discord.InteractionEvent, data discord.CommandInteractionOption) (handlerData, bool) {
node, ok := r.nodes[data.Name]
if !ok {
return handlerData{}, false
Expand Down
92 changes: 59 additions & 33 deletions api/cmdroute/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,34 +165,45 @@ func TestRouter(t *testing.T) {
})

t.Run("middlewares", func(t *testing.T) {
var stack []string
pushStack := func(s string) Middleware {
return func(next InteractionHandler) InteractionHandler {
return InteractionHandlerFunc(func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse {
stack = append(stack, s)
return next.HandleInteraction(ctx, ev)
})
}
}
var stack middlewareStacker

r := NewRouter()
r.Use(pushStack("root1"))
r.Use(pushStack("root2"))
r.Use(stack.pusher("root1"))
r.Use(stack.pusher("root2"))
r.Sub("test", func(r *Router) {
r.Use(pushStack("sub1.1"))
r.Use(pushStack("sub1.2"))
r.Sub("sub1", func(r *Router) {
r.Use(pushStack("sub2.1"))
r.Use(pushStack("sub2.2"))
r.Add("sub2", assertHandler(t, mockOptions))
// We put test 1 at the start, but test 2 at the end.
// The order should be preserved.
r.Use(stack.pusher("test.1"))

// unused
r.Group(func(r *Router) {
r.Use(stack.pusher("test.3"))
})

// unused
r.With(stack.pusher("test.4"))

r.Group(func(r *Router) {
r.Use(stack.pusher("test.5"))

r.Sub("sub", func(r *Router) {
r.Use(stack.pusher("test.sub.1"))
r.Use(stack.pusher("test.sub.2"))

r.Add("sub2", assertHandler(t, mockOptions))
})
})

// Test 2 goes here.
r.Use(stack.pusher("test.2"))
})

r.HandleInteraction(newInteractionEvent(&discord.CommandInteraction{
ID: 4,
Name: "test",
Options: []discord.CommandInteractionOption{
{
Name: "sub1",
Name: "sub",
Type: discord.SubcommandGroupOptionType,
Options: []discord.CommandInteractionOption{
{
Expand All @@ -205,23 +216,15 @@ func TestRouter(t *testing.T) {
},
}))

expects := []string{
stack.expect(t, []string{
"root1",
"root2",
"sub1.1",
"sub1.2",
"sub2.1",
"sub2.2",
}
if len(stack) != len(expects) {
t.Fatalf("expected stack to have %d elements, got %d", len(expects), len(stack))
}

for i := range expects {
if stack[i] != expects[i] {
t.Fatalf("expected stack[%d] to be %q, got %q", i, expects[i], stack[i])
}
}
"test.1",
"test.2",
"test.5",
"test.sub.1",
"test.sub.2",
})
})

t.Run("deferred", func(t *testing.T) {
Expand Down Expand Up @@ -335,6 +338,7 @@ var mockOptions = []discord.CommandInteractionOption{
},
}

// assertHandler asserts that the given options are equal to the expected options.
func assertHandler(t *testing.T, opts discord.CommandInteractionOptions) CommandHandler {
return CommandHandlerFunc(func(ctx context.Context, data CommandData) *api.InteractionResponseData {
if len(data.Options) != len(opts) {
Expand Down Expand Up @@ -410,3 +414,25 @@ func strInteractionResp(resp *api.InteractionResponse) string {
}
return fmt.Sprintf("%d:%#v", resp.Type, resp.Data)
}

type middlewareStacker []string

func (m *middlewareStacker) pusher(s string) Middleware {
return func(next InteractionHandler) InteractionHandler {
return InteractionHandlerFunc(func(ctx context.Context, ev *discord.InteractionEvent) *api.InteractionResponse {
*m = append(*m, s)
return next.HandleInteraction(ctx, ev)
})
}
}

func (m middlewareStacker) expect(t *testing.T, expects []string) {
if len(m) != len(expects) {
t.Fatalf("expected stack to have %d elements, got %d: %v", len(expects), len(m), m)
}
for i := range expects {
if m[i] != expects[i] {
t.Fatalf("expected stack[%d] to be %q, got %q", i, expects[i], m[i])
}
}
}

0 comments on commit 19518a0

Please sign in to comment.