Skip to content

Commit

Permalink
remove logger getters on all interfaces except bot.Client
Browse files Browse the repository at this point in the history
  • Loading branch information
topi314 committed Jul 27, 2022
1 parent ce7454e commit d0018c5
Show file tree
Hide file tree
Showing 15 changed files with 86 additions and 126 deletions.
2 changes: 1 addition & 1 deletion bot/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func BuildClient(token string, config Config, gatewayEventHandlerFunc func(clien
client.httpServer = config.HTTPServer

if config.MemberChunkingManager == nil {
config.MemberChunkingManager = NewMemberChunkingManager(client, config.MemberChunkingFilter)
config.MemberChunkingManager = NewMemberChunkingManager(client, config.Logger, config.MemberChunkingFilter)
}
client.memberChunkingManager = config.MemberChunkingManager

Expand Down
6 changes: 3 additions & 3 deletions bot/event_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (e *eventManagerImpl) HandleGatewayEvent(gatewayEventType gateway.EventType
if handler, ok := e.config.GatewayHandlers[gatewayEventType]; ok {
handler.HandleGatewayEvent(e.client, sequenceNumber, shardID, event)
} else {
e.client.Logger().Warnf("no handler for gateway event '%s' found", gatewayEventType)
e.config.Logger.Warnf("no handler for gateway event '%s' found", gatewayEventType)
}
}

Expand All @@ -127,7 +127,7 @@ func (e *eventManagerImpl) HandleHTTPEvent(respondFunc httpserver.RespondFunc, e
func (e *eventManagerImpl) DispatchEvent(event Event) {
defer func() {
if r := recover(); r != nil {
e.client.Logger().Errorf("recovered from panic in event listener: %+v\nstack: %s", r, string(debug.Stack()))
e.config.Logger.Errorf("recovered from panic in event listener: %+v\nstack: %s", r, string(debug.Stack()))
return
}
}()
Expand All @@ -138,7 +138,7 @@ func (e *eventManagerImpl) DispatchEvent(event Event) {
go func() {
defer func() {
if r := recover(); r != nil {
e.client.Logger().Errorf("recovered from panic in event listener: %+v\nstack: %s", r, string(debug.Stack()))
e.config.Logger.Errorf("recovered from panic in event listener: %+v\nstack: %s", r, string(debug.Stack()))
return
}
}()
Expand Down
13 changes: 12 additions & 1 deletion bot/event_manager_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ package bot

import (
"github.com/disgoorg/disgo/gateway"
"github.com/disgoorg/log"
)

// DefaultEventManagerConfig returns a new EventManagerConfig with all default values.
func DefaultEventManagerConfig() *EventManagerConfig {
return &EventManagerConfig{}
return &EventManagerConfig{
Logger: log.Default(),
}
}

// EventManagerConfig can be used to configure the EventManager.
type EventManagerConfig struct {
Logger log.Logger
EventListeners []EventListener
AsyncEventsEnabled bool

Expand All @@ -28,6 +32,13 @@ func (c *EventManagerConfig) Apply(opts []EventManagerConfigOpt) {
}
}

// WithEventManagerLogger overrides the default logger in the EventManagerConfig.
func WithEventManagerLogger(logger log.Logger) EventManagerConfigOpt {
return func(config *EventManagerConfig) {
config.Logger = logger
}
}

// WithListeners adds the given EventListener(s) to the EventManagerConfig.
func WithListeners(listeners ...EventListener) EventManagerConfigOpt {
return func(config *EventManagerConfig) {
Expand Down
10 changes: 8 additions & 2 deletions bot/member_chunking_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@ import (
"github.com/disgoorg/disgo/discord"
"github.com/disgoorg/disgo/gateway"
"github.com/disgoorg/disgo/internal/insecurerandstr"
"github.com/disgoorg/log"
"github.com/disgoorg/snowflake/v2"
)

var _ MemberChunkingManager = (*memberChunkingManagerImpl)(nil)

// NewMemberChunkingManager returns a new MemberChunkingManager with the given MemberChunkingFilter.
func NewMemberChunkingManager(client Client, memberChunkingFilter MemberChunkingFilter) MemberChunkingManager {
func NewMemberChunkingManager(client Client, logger log.Logger, memberChunkingFilter MemberChunkingFilter) MemberChunkingManager {
if memberChunkingFilter == nil {
memberChunkingFilter = MemberChunkingFilterNone
}
if logger == nil {
logger = log.Default()
}
return &memberChunkingManagerImpl{
client: client,
logger: logger,
memberChunkingFilter: memberChunkingFilter,
chunkingRequests: map[string]*chunkingRequest{},
}
Expand Down Expand Up @@ -82,6 +87,7 @@ type chunkingRequest struct {

type memberChunkingManagerImpl struct {
client Client
logger log.Logger
memberChunkingFilter MemberChunkingFilter

chunkingRequestsMu sync.RWMutex
Expand All @@ -97,7 +103,7 @@ func (m *memberChunkingManagerImpl) HandleChunk(payload gateway.EventGuildMember
request, ok := m.chunkingRequests[payload.Nonce]
m.chunkingRequestsMu.RUnlock()
if !ok {
m.client.Logger().Debug("received unknown member chunk event: ", payload)
m.logger.Debug("received unknown member chunk event: ", payload)
return
}

Expand Down
5 changes: 0 additions & 5 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package gateway
import (
"context"
"time"

"github.com/disgoorg/log"
)

// Version defines which discord API version disgo should use to connect to discord.
Expand Down Expand Up @@ -64,9 +62,6 @@ type (

// Gateway is what is used to connect to discord.
type Gateway interface {
// Logger returns the logger that is used by the Gateway.
Logger() log.Logger

// ShardID returns the shard ID that this Gateway is configured to use.
ShardID() int

Expand Down
70 changes: 32 additions & 38 deletions gateway/gateway_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"github.com/disgoorg/disgo/discord"
"github.com/disgoorg/disgo/internal/tokenhelper"
"github.com/disgoorg/disgo/json"
"github.com/disgoorg/log"

"github.com/gorilla/websocket"
)

Expand Down Expand Up @@ -51,10 +49,6 @@ type gatewayImpl struct {
lastHeartbeatReceived time.Time
}

func (g *gatewayImpl) Logger() log.Logger {
return g.config.Logger
}

func (g *gatewayImpl) ShardID() int {
return g.config.ShardID
}
Expand Down Expand Up @@ -90,7 +84,7 @@ func (g *gatewayImpl) formatLogs(a ...any) string {
}

func (g *gatewayImpl) Open(ctx context.Context) error {
g.Logger().Debug(g.formatLogs("opening gateway connection"))
g.config.Logger.Debug(g.formatLogs("opening gateway connection"))

g.connMu.Lock()
defer g.connMu.Unlock()
Expand All @@ -111,12 +105,12 @@ func (g *gatewayImpl) Open(ctx context.Context) error {
}()
rawBody, bErr := io.ReadAll(rs.Body)
if bErr != nil {
g.Logger().Error(g.formatLogs("error while reading response body: ", err))
g.config.Logger.Error(g.formatLogs("error while reading response body: ", err))
}
body = string(rawBody)
}

g.Logger().Error(g.formatLogsf("error connecting to the gateway. url: %s, error: %s, body: %s", gatewayURL, err, body))
g.config.Logger.Error(g.formatLogsf("error connecting to the gateway. url: %s, error: %s, body: %s", gatewayURL, err, body))
return err
}

Expand All @@ -142,7 +136,7 @@ func (g *gatewayImpl) Close(ctx context.Context) {

func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message string) {
if g.heartbeatTicker != nil {
g.Logger().Debug(g.formatLogs("closing heartbeat goroutines..."))
g.config.Logger.Debug(g.formatLogs("closing heartbeat goroutines..."))
g.heartbeatTicker.Stop()
g.heartbeatTicker = nil
}
Expand All @@ -151,9 +145,9 @@ func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message strin
defer g.connMu.Unlock()
if g.conn != nil {
g.config.RateLimiter.Close(ctx)
g.Logger().Debug(g.formatLogsf("closing gateway connection with code: %d, message: %s", code, message))
g.config.Logger.Debug(g.formatLogsf("closing gateway connection with code: %d, message: %s", code, message))
if err := g.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)); err != nil && err != websocket.ErrCloseSent {
g.Logger().Debug(g.formatLogs("error writing close code. error: ", err))
g.config.Logger.Debug(g.formatLogs("error writing close code. error: ", err))
}
_ = g.conn.Close()
g.conn = nil
Expand Down Expand Up @@ -196,7 +190,7 @@ func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) er
}

defer g.config.RateLimiter.Unlock()
g.Logger().Trace(g.formatLogs("sending gateway command: ", string(data)))
g.config.Logger.Trace(g.formatLogs("sending gateway command: ", string(data)))
return g.conn.WriteMessage(messageType, data)
}

Expand All @@ -217,12 +211,12 @@ func (g *gatewayImpl) reconnectTry(ctx context.Context, try int, delay time.Dura
case <-timer.C:
}

g.Logger().Debug(g.formatLogs("reconnecting gateway..."))
g.config.Logger.Debug(g.formatLogs("reconnecting gateway..."))
if err := g.Open(ctx); err != nil {
if err == discord.ErrGatewayAlreadyConnected {
return err
}
g.Logger().Error(g.formatLogs("failed to reconnect gateway. error: ", err))
g.config.Logger.Error(g.formatLogs("failed to reconnect gateway. error: ", err))
g.status = StatusDisconnected
return g.reconnectTry(ctx, try+1, delay)
}
Expand All @@ -232,27 +226,27 @@ func (g *gatewayImpl) reconnectTry(ctx context.Context, try int, delay time.Dura
func (g *gatewayImpl) reconnect(ctx context.Context) {
err := g.reconnectTry(ctx, 0, time.Second)
if err != nil {
g.Logger().Error(g.formatLogs("failed to reopen gateway. error: ", err))
g.config.Logger.Error(g.formatLogs("failed to reopen gateway. error: ", err))
}
}

func (g *gatewayImpl) heartbeat() {
g.heartbeatTicker = time.NewTicker(g.heartbeatInterval)
defer g.heartbeatTicker.Stop()
defer g.Logger().Debug(g.formatLogs("exiting heartbeat goroutine..."))
defer g.config.Logger.Debug(g.formatLogs("exiting heartbeat goroutine..."))

for range g.heartbeatTicker.C {
g.sendHeartbeat()
}
}

func (g *gatewayImpl) sendHeartbeat() {
g.Logger().Debug(g.formatLogs("sending heartbeat..."))
g.config.Logger.Debug(g.formatLogs("sending heartbeat..."))

ctx, cancel := context.WithTimeout(context.Background(), g.heartbeatInterval)
defer cancel()
if err := g.Send(ctx, OpcodeHeartbeat, (*MessageDataHeartbeat)(g.config.LastSequenceReceived)); err != nil && err != discord.ErrShardNotConnected {
g.Logger().Error(g.formatLogs("failed to send heartbeat. error: ", err))
g.config.Logger.Error(g.formatLogs("failed to send heartbeat. error: ", err))
g.CloseWithCode(context.TODO(), websocket.CloseServiceRestart, "heartbeat timeout")
go g.reconnect(context.TODO())
return
Expand All @@ -262,7 +256,7 @@ func (g *gatewayImpl) sendHeartbeat() {

func (g *gatewayImpl) identify() {
g.status = StatusIdentifying
g.Logger().Debug(g.formatLogs("sending Identify command..."))
g.config.Logger.Debug(g.formatLogs("sending Identify command..."))

identify := MessageDataIdentify{
Token: g.token,
Expand All @@ -281,7 +275,7 @@ func (g *gatewayImpl) identify() {
}

if err := g.Send(context.TODO(), OpcodeIdentify, identify); err != nil {
g.Logger().Error(g.formatLogs("error sending Identify command err: ", err))
g.config.Logger.Error(g.formatLogs("error sending Identify command err: ", err))
}
g.status = StatusWaitingForReady
}
Expand All @@ -294,14 +288,14 @@ func (g *gatewayImpl) resume() {
Seq: *g.config.LastSequenceReceived,
}

g.Logger().Debug(g.formatLogs("sending Resume command..."))
g.config.Logger.Debug(g.formatLogs("sending Resume command..."))
if err := g.Send(context.TODO(), OpcodeResume, resume); err != nil {
g.Logger().Error(g.formatLogs("error sending resume command err: ", err))
g.config.Logger.Error(g.formatLogs("error sending resume command err: ", err))
}
}

func (g *gatewayImpl) listen(conn *websocket.Conn) {
defer g.Logger().Debug(g.formatLogs("exiting listen goroutine..."))
defer g.config.Logger.Debug(g.formatLogs("exiting listen goroutine..."))
loop:
for {
mt, reader, err := conn.NextReader()
Expand All @@ -327,19 +321,19 @@ loop:
} else {
intentsURL = "https://discord.com/developers/applications"
}
g.Logger().Error(g.formatLogsf("disallowed gateway intents supplied. go to %s and enable the privileged intent for your application. intents: %d", intentsURL, g.config.Intents))
g.config.Logger.Error(g.formatLogsf("disallowed gateway intents supplied. go to %s and enable the privileged intent for your application. intents: %d", intentsURL, g.config.Intents))
} else if closeCode == CloseEventCodeInvalidSeq {
g.Logger().Error(g.formatLogs("invalid sequence provided. reconnecting..."))
g.config.Logger.Error(g.formatLogs("invalid sequence provided. reconnecting..."))
g.config.LastSequenceReceived = nil
g.config.SessionID = nil
} else {
g.Logger().Error(g.formatLogsf("gateway close received, reconnect: %t, code: %d, error: %s", g.config.AutoReconnect && reconnect, closeError.Code, closeError.Text))
g.config.Logger.Error(g.formatLogsf("gateway close received, reconnect: %t, code: %d, error: %s", g.config.AutoReconnect && reconnect, closeError.Code, closeError.Text))
}
} else if errors.Is(err, net.ErrClosed) {
// we closed the connection ourselves. Don't try to reconnect here
reconnect = false
} else {
g.Logger().Debug(g.formatLogs("failed to read next message from gateway. error: ", err))
g.config.Logger.Debug(g.formatLogs("failed to read next message from gateway. error: ", err))
}

if g.config.AutoReconnect && reconnect {
Expand All @@ -355,7 +349,7 @@ loop:

event, err := g.parseMessage(mt, reader)
if err != nil {
g.Logger().Error(g.formatLogs("error while parsing gateway message. error: ", err))
g.config.Logger.Error(g.formatLogs("error while parsing gateway message. error: ", err))
continue
}

Expand All @@ -373,22 +367,22 @@ loop:
}

case OpcodeDispatch:
g.Logger().Trace(g.formatLogsf("received: OpcodeDispatch %s, data: %s", event.T, string(event.RawD)))
g.config.Logger.Trace(g.formatLogsf("received: OpcodeDispatch %s, data: %s", event.T, string(event.RawD)))

// set last sequence received
g.config.LastSequenceReceived = &event.S

data, ok := event.D.(EventData)
if !ok && event.D != nil {
g.Logger().Error(g.formatLogsf("invalid event data of type %T received", event.D))
g.config.Logger.Error(g.formatLogsf("invalid event data of type %T received", event.D))
continue
}

// get session id here
if readyEvent, ok := data.(EventReady); ok {
g.config.SessionID = &readyEvent.SessionID
g.status = StatusReady
g.Logger().Debug(g.formatLogs("ready event received"))
g.config.Logger.Debug(g.formatLogs("ready event received"))
}

// push event to the command manager
Expand All @@ -401,18 +395,18 @@ loop:
g.eventHandlerFunc(event.T, event.S, g.config.ShardID, data)

case OpcodeHeartbeat:
g.Logger().Debug(g.formatLogs("received: OpcodeHeartbeat"))
g.config.Logger.Debug(g.formatLogs("received: OpcodeHeartbeat"))
g.sendHeartbeat()

case OpcodeReconnect:
g.Logger().Debug(g.formatLogs("received: OpcodeReconnect"))
g.config.Logger.Debug(g.formatLogs("received: OpcodeReconnect"))
g.CloseWithCode(context.TODO(), websocket.CloseServiceRestart, "received reconnect")
go g.reconnect(context.TODO())
break loop

case OpcodeInvalidSession:
canResume := event.D.(MessageDataInvalidSession)
g.Logger().Debug(g.formatLogs("received: OpcodeInvalidSession, canResume: ", canResume))
g.config.Logger.Debug(g.formatLogs("received: OpcodeInvalidSession, canResume: ", canResume))

code := websocket.CloseNormalClosure
if canResume {
Expand All @@ -428,7 +422,7 @@ loop:
break loop

case OpcodeHeartbeatACK:
g.Logger().Debug(g.formatLogs("received: OpcodeHeartbeatACK"))
g.config.Logger.Debug(g.formatLogs("received: OpcodeHeartbeatACK"))
g.lastHeartbeatReceived = time.Now().UTC()
}
}
Expand All @@ -437,7 +431,7 @@ loop:
func (g *gatewayImpl) parseMessage(mt int, reader io.Reader) (Message, error) {
var readCloser io.ReadCloser
if mt == websocket.BinaryMessage {
g.Logger().Trace(g.formatLogs("binary message received. decompressing..."))
g.config.Logger.Trace(g.formatLogs("binary message received. decompressing..."))
var err error
readCloser, err = zlib.NewReader(reader)
if err != nil {
Expand All @@ -452,7 +446,7 @@ func (g *gatewayImpl) parseMessage(mt int, reader io.Reader) (Message, error) {

var message Message
if err := json.NewDecoder(readCloser).Decode(&message); err != nil {
g.Logger().Error(g.formatLogs("error decoding websocket message: ", err))
g.config.Logger.Error(g.formatLogs("error decoding websocket message: ", err))
return Message{}, err
}
return message, nil
Expand Down

0 comments on commit d0018c5

Please sign in to comment.