Skip to content

Commit

Permalink
move interaction handling in it's own function, add httpserver.EventI…
Browse files Browse the repository at this point in the history
…nteractionCreate as event payload & remove getters or httpserver.Server
  • Loading branch information
topi314 committed Jul 27, 2022
1 parent 5451077 commit ce7454e
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 151 deletions.
6 changes: 3 additions & 3 deletions bot/event_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type EventManager interface {
HandleGatewayEvent(gatewayEventType gateway.EventType, sequenceNumber int, shardID int, event gateway.EventData)

// HandleHTTPEvent calls the HTTPServerEventHandler for the payload
HandleHTTPEvent(respondFunc httpserver.RespondFunc, event gateway.EventInteractionCreate)
HandleHTTPEvent(respondFunc httpserver.RespondFunc, event httpserver.EventInteractionCreate)

// DispatchEvent dispatches a new Event to the Client's EventListener(s)
DispatchEvent(event Event)
Expand Down Expand Up @@ -97,7 +97,7 @@ func (h *genericGatewayEventHandler[T]) HandleGatewayEvent(client Client, sequen

// HTTPServerEventHandler is used to handle HTTP Event(s)
type HTTPServerEventHandler interface {
HandleHTTPEvent(client Client, respondFunc httpserver.RespondFunc, event gateway.EventInteractionCreate)
HandleHTTPEvent(client Client, respondFunc httpserver.RespondFunc, event httpserver.EventInteractionCreate)
}

type eventManagerImpl struct {
Expand All @@ -118,7 +118,7 @@ func (e *eventManagerImpl) HandleGatewayEvent(gatewayEventType gateway.EventType
}
}

func (e *eventManagerImpl) HandleHTTPEvent(respondFunc httpserver.RespondFunc, event gateway.EventInteractionCreate) {
func (e *eventManagerImpl) HandleHTTPEvent(respondFunc httpserver.RespondFunc, event httpserver.EventInteractionCreate) {
e.mu.Lock()
defer e.mu.Unlock()
e.config.HTTPServerHandler.HandleHTTPEvent(e.client, respondFunc, event)
Expand Down
3 changes: 1 addition & 2 deletions handlers/interaction_create_http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ package handlers
import (
"github.com/disgoorg/disgo/bot"
"github.com/disgoorg/disgo/discord"
"github.com/disgoorg/disgo/gateway"
"github.com/disgoorg/disgo/httpserver"
)

var _ bot.HTTPServerEventHandler = (*httpserverHandlerInteractionCreate)(nil)

type httpserverHandlerInteractionCreate struct{}

func (h *httpserverHandlerInteractionCreate) HandleHTTPEvent(client bot.Client, respondFunc httpserver.RespondFunc, event gateway.EventInteractionCreate) {
func (h *httpserverHandlerInteractionCreate) HandleHTTPEvent(client bot.Client, respondFunc httpserver.RespondFunc, event httpserver.EventInteractionCreate) {
// we just want to pong all pings
// no need for any event
if event.Type() == discord.InteractionTypePing {
Expand Down
130 changes: 119 additions & 11 deletions httpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,34 @@ import (
"encoding/hex"
"io"
"net/http"
"sync"
"time"

"github.com/disgoorg/disgo/discord"
"github.com/disgoorg/disgo/gateway"
"github.com/disgoorg/disgo/json"
"github.com/disgoorg/log"
)

type (
// EventHandlerFunc is used to handle events from Discord's Outgoing Webhooks
EventHandlerFunc func(responseFunc RespondFunc, event gateway.EventInteractionCreate)
EventHandlerFunc func(responseFunc RespondFunc, event EventInteractionCreate)

// RespondFunc is used to respond to Discord's Outgoing Webhooks
RespondFunc func(response discord.InteractionResponse) error

// EventInteractionCreate is the event payload when an interaction is created via Discord's Outgoing Webhooks
EventInteractionCreate struct {
discord.Interaction
}
)

// Server is used for receiving Discord's interactions via Outgoing Webhooks
type Server interface {
// Logger returns the logger used by the Server
Logger() log.Logger

// PublicKey returns the public key used by the Server
PublicKey() PublicKey

// Start starts the Server
Start()

// Close closes the Server
Close(ctx context.Context)

// Handle passes a payload to the Server for processing
Handle(respondFunc RespondFunc, event gateway.EventInteractionCreate)
}

// VerifyRequest implements the verification side of the discord interactions api signing algorithm, as documented here: https://discord.com/developers/docs/interactions/slash-commands#security-and-authorization
Expand Down Expand Up @@ -80,3 +78,113 @@ func VerifyRequest(r *http.Request, key PublicKey) bool {

return Verify(key, msg.Bytes(), sig)
}

type replyStatus int

const (
replyStatusWaiting replyStatus = iota
replyStatusReplied
replyStatusTimedOut
)

// HandleInteraction handles an interaction from Discord's Outgoing Webhooks. It verifies and parses the interaction and then calls the passed EventHandlerFunc.
func HandleInteraction(publicKey PublicKey, logger log.Logger, handleFunc EventHandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {

if ok := VerifyRequest(r, publicKey); !ok {
w.WriteHeader(http.StatusUnauthorized)
data, _ := io.ReadAll(r.Body)
logger.Trace("received http interaction with invalid signature. body: ", string(data))
return
}

defer func() {
_ = r.Body.Close()
}()

buff := new(bytes.Buffer)
rqData, _ := io.ReadAll(io.TeeReader(r.Body, buff))
logger.Trace("received http interaction. body: ", string(rqData))

var v EventInteractionCreate
if err := json.NewDecoder(buff).Decode(&v); err != nil {
logger.Error("error while decoding interaction: ", err)
return
}

// these channels are used to communicate between the http handler and where the interaction is responded to
responseChannel := make(chan discord.InteractionResponse)
defer close(responseChannel)
errorChannel := make(chan error)
defer close(errorChannel)

// status of this interaction with a mutex to ensure usage between multiple goroutines
var (
status replyStatus
mu sync.Mutex
)

// send interaction to our handler
go handleFunc(func(response discord.InteractionResponse) error {
mu.Lock()
defer mu.Unlock()

if status == replyStatusTimedOut {
return discord.ErrInteractionExpired
}

if status == replyStatusReplied {
return discord.ErrInteractionAlreadyReplied
}

status = replyStatusReplied
responseChannel <- response
// wait if we get any error while processing the response
return <-errorChannel
}, v)

var (
body any
err error
)

// wait for the interaction to be responded to or to time out after 3s
timer := time.NewTimer(time.Second * 3)
defer timer.Stop()
select {
case response := <-responseChannel:
if body, err = response.ToBody(); err != nil {
http.Error(w, "internal server error", http.StatusInternalServerError)
errorChannel <- err
return
}

case <-timer.C:
mu.Lock()
defer mu.Unlock()
status = replyStatusTimedOut

logger.Debug("interaction timed out")
http.Error(w, "interaction timed out", http.StatusRequestTimeout)
return
}

rsBody := &bytes.Buffer{}
multiWriter := io.MultiWriter(w, rsBody)

if multiPart, ok := body.(*discord.MultipartBuffer); ok {
w.Header().Set("Content-Type", multiPart.ContentType)
_, err = io.Copy(multiWriter, multiPart.Buffer)
} else {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(multiWriter).Encode(body)
}
if err != nil {
errorChannel <- err
return
}

rsData, _ := io.ReadAll(rsBody)
logger.Trace("response to http interaction. body: ", string(rsData))
}
}
137 changes: 2 additions & 135 deletions httpserver/server_impl.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
package httpserver

import (
"bytes"
"context"
"encoding/hex"
"io"
"net/http"
"sync"
"time"

"github.com/disgoorg/disgo/gateway"
"github.com/disgoorg/disgo/json"

"github.com/disgoorg/disgo/discord"
"github.com/disgoorg/log"
)

var _ Server = (*serverImpl)(nil)
Expand Down Expand Up @@ -41,20 +31,8 @@ type serverImpl struct {
eventHandlerFunc EventHandlerFunc
}

func (s *serverImpl) Logger() log.Logger {
return s.config.Logger
}

func (s *serverImpl) PublicKey() PublicKey {
return s.publicKey
}

func (s *serverImpl) Handle(respondFunc RespondFunc, event gateway.EventInteractionCreate) {
s.eventHandlerFunc(respondFunc, event)
}

func (s *serverImpl) Start() {
s.config.ServeMux.Handle(s.config.URL, &WebhookInteractionHandler{server: s})
s.config.ServeMux.Handle(s.config.URL, HandleInteraction(s.publicKey, s.config.Logger, s.eventHandlerFunc))
s.config.HTTPServer.Addr = s.config.Address
s.config.HTTPServer.Handler = s.config.ServeMux

Expand All @@ -66,122 +44,11 @@ func (s *serverImpl) Start() {
err = s.config.HTTPServer.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
s.Logger().Error("error while running http server: ", err)
s.config.Logger.Error("error while running http server: ", err)
}
}()
}

func (s *serverImpl) Close(ctx context.Context) {
_ = s.config.HTTPServer.Shutdown(ctx)
}

// WebhookInteractionHandler implements the http.Handler interface and is used to handle interactions from Discord.
type WebhookInteractionHandler struct {
server Server
}

type replyStatus int

const (
replyStatusWaiting replyStatus = iota
replyStatusReplied
replyStatusTimedOut
)

func (h *WebhookInteractionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if ok := VerifyRequest(r, h.server.PublicKey()); !ok {
w.WriteHeader(http.StatusUnauthorized)
data, _ := io.ReadAll(r.Body)
h.server.Logger().Trace("received http interaction with invalid signature. body: ", string(data))
return
}

defer func() {
_ = r.Body.Close()
}()

buff := new(bytes.Buffer)
rqData, _ := io.ReadAll(io.TeeReader(r.Body, buff))
h.server.Logger().Trace("received http interaction. body: ", string(rqData))

var v gateway.EventInteractionCreate
if err := json.NewDecoder(buff).Decode(&v); err != nil {
h.server.Logger().Error("error while decoding interaction: ", err)
return
}

// these channels are used to communicate between the http handler and where the interaction is responded to
responseChannel := make(chan discord.InteractionResponse)
defer close(responseChannel)
errorChannel := make(chan error)
defer close(errorChannel)

// status of this interaction with a mutex to ensure usage between multiple goroutines
var (
status replyStatus
mu sync.Mutex
)

// send interaction to our handler
go h.server.Handle(func(response discord.InteractionResponse) error {
mu.Lock()
defer mu.Unlock()

if status == replyStatusTimedOut {
return discord.ErrInteractionExpired
}

if status == replyStatusReplied {
return discord.ErrInteractionAlreadyReplied
}

status = replyStatusReplied
responseChannel <- response
// wait if we get any error while processing the response
return <-errorChannel
}, v)

var (
body any
err error
)

// wait for the interaction to be responded to or to time out after 3s
timer := time.NewTimer(time.Second * 3)
defer timer.Stop()
select {
case response := <-responseChannel:
if body, err = response.ToBody(); err != nil {
http.Error(w, "internal server error", http.StatusInternalServerError)
errorChannel <- err
return
}

case <-timer.C:
mu.Lock()
defer mu.Unlock()
status = replyStatusTimedOut

h.server.Logger().Debug("interaction timed out")
http.Error(w, "interaction timed out", http.StatusRequestTimeout)
return
}

rsBody := &bytes.Buffer{}
multiWriter := io.MultiWriter(w, rsBody)

if multiPart, ok := body.(*discord.MultipartBuffer); ok {
w.Header().Set("Content-Type", multiPart.ContentType)
_, err = io.Copy(multiWriter, multiPart.Buffer)
} else {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(multiWriter).Encode(body)
}
if err != nil {
errorChannel <- err
return
}

rsData, _ := io.ReadAll(rsBody)
h.server.Logger().Trace("response to http interaction. body: ", string(rsData))
}

0 comments on commit ce7454e

Please sign in to comment.