From d0018c5a556d07c3442e855dea176300afe782e6 Mon Sep 17 00:00:00 2001 From: TopiSenpai Date: Wed, 27 Jul 2022 02:54:15 +0200 Subject: [PATCH] remove logger getters on all interfaces except bot.Client --- bot/config.go | 2 +- bot/event_manager.go | 6 +-- bot/event_manager_config.go | 13 +++++- bot/member_chunking_manager.go | 10 +++- gateway/gateway.go | 5 -- gateway/gateway_impl.go | 70 +++++++++++++--------------- gateway/gateway_rate_limiter.go | 5 -- gateway/gateway_rate_limiter_impl.go | 9 +--- rest/rest_client.go | 14 ++---- rest/rest_rate_limiter.go | 4 -- rest/rest_rate_limiter_impl.go | 25 ++++------ sharding/shard_manager.go | 4 -- sharding/shard_manager_impl.go | 25 ++++------ sharding/shard_rate_limiter.go | 5 -- sharding/shard_rate_limiter_impl.go | 15 ++---- 15 files changed, 86 insertions(+), 126 deletions(-) diff --git a/bot/config.go b/bot/config.go index 2ffbfcc7..f1c48779 100644 --- a/bot/config.go +++ b/bot/config.go @@ -305,7 +305,7 @@ func BuildClient(token string, config Config, gatewayEventHandlerFunc func(clien client.httpServer = config.HTTPServer if config.MemberChunkingManager == nil { - config.MemberChunkingManager = NewMemberChunkingManager(client, config.MemberChunkingFilter) + config.MemberChunkingManager = NewMemberChunkingManager(client, config.Logger, config.MemberChunkingFilter) } client.memberChunkingManager = config.MemberChunkingManager diff --git a/bot/event_manager.go b/bot/event_manager.go index 3c507e03..050c9a6e 100644 --- a/bot/event_manager.go +++ b/bot/event_manager.go @@ -114,7 +114,7 @@ func (e *eventManagerImpl) HandleGatewayEvent(gatewayEventType gateway.EventType if handler, ok := e.config.GatewayHandlers[gatewayEventType]; ok { handler.HandleGatewayEvent(e.client, sequenceNumber, shardID, event) } else { - e.client.Logger().Warnf("no handler for gateway event '%s' found", gatewayEventType) + e.config.Logger.Warnf("no handler for gateway event '%s' found", gatewayEventType) } } @@ -127,7 +127,7 @@ func (e *eventManagerImpl) HandleHTTPEvent(respondFunc httpserver.RespondFunc, e func (e *eventManagerImpl) DispatchEvent(event Event) { defer func() { if r := recover(); r != nil { - e.client.Logger().Errorf("recovered from panic in event listener: %+v\nstack: %s", r, string(debug.Stack())) + e.config.Logger.Errorf("recovered from panic in event listener: %+v\nstack: %s", r, string(debug.Stack())) return } }() @@ -138,7 +138,7 @@ func (e *eventManagerImpl) DispatchEvent(event Event) { go func() { defer func() { if r := recover(); r != nil { - e.client.Logger().Errorf("recovered from panic in event listener: %+v\nstack: %s", r, string(debug.Stack())) + e.config.Logger.Errorf("recovered from panic in event listener: %+v\nstack: %s", r, string(debug.Stack())) return } }() diff --git a/bot/event_manager_config.go b/bot/event_manager_config.go index 788ace2e..b26321bf 100644 --- a/bot/event_manager_config.go +++ b/bot/event_manager_config.go @@ -2,15 +2,19 @@ package bot import ( "github.com/disgoorg/disgo/gateway" + "github.com/disgoorg/log" ) // DefaultEventManagerConfig returns a new EventManagerConfig with all default values. func DefaultEventManagerConfig() *EventManagerConfig { - return &EventManagerConfig{} + return &EventManagerConfig{ + Logger: log.Default(), + } } // EventManagerConfig can be used to configure the EventManager. type EventManagerConfig struct { + Logger log.Logger EventListeners []EventListener AsyncEventsEnabled bool @@ -28,6 +32,13 @@ func (c *EventManagerConfig) Apply(opts []EventManagerConfigOpt) { } } +// WithEventManagerLogger overrides the default logger in the EventManagerConfig. +func WithEventManagerLogger(logger log.Logger) EventManagerConfigOpt { + return func(config *EventManagerConfig) { + config.Logger = logger + } +} + // WithListeners adds the given EventListener(s) to the EventManagerConfig. func WithListeners(listeners ...EventListener) EventManagerConfigOpt { return func(config *EventManagerConfig) { diff --git a/bot/member_chunking_manager.go b/bot/member_chunking_manager.go index 0ead92c1..0ef6e78a 100644 --- a/bot/member_chunking_manager.go +++ b/bot/member_chunking_manager.go @@ -7,18 +7,23 @@ import ( "github.com/disgoorg/disgo/discord" "github.com/disgoorg/disgo/gateway" "github.com/disgoorg/disgo/internal/insecurerandstr" + "github.com/disgoorg/log" "github.com/disgoorg/snowflake/v2" ) var _ MemberChunkingManager = (*memberChunkingManagerImpl)(nil) // NewMemberChunkingManager returns a new MemberChunkingManager with the given MemberChunkingFilter. -func NewMemberChunkingManager(client Client, memberChunkingFilter MemberChunkingFilter) MemberChunkingManager { +func NewMemberChunkingManager(client Client, logger log.Logger, memberChunkingFilter MemberChunkingFilter) MemberChunkingManager { if memberChunkingFilter == nil { memberChunkingFilter = MemberChunkingFilterNone } + if logger == nil { + logger = log.Default() + } return &memberChunkingManagerImpl{ client: client, + logger: logger, memberChunkingFilter: memberChunkingFilter, chunkingRequests: map[string]*chunkingRequest{}, } @@ -82,6 +87,7 @@ type chunkingRequest struct { type memberChunkingManagerImpl struct { client Client + logger log.Logger memberChunkingFilter MemberChunkingFilter chunkingRequestsMu sync.RWMutex @@ -97,7 +103,7 @@ func (m *memberChunkingManagerImpl) HandleChunk(payload gateway.EventGuildMember request, ok := m.chunkingRequests[payload.Nonce] m.chunkingRequestsMu.RUnlock() if !ok { - m.client.Logger().Debug("received unknown member chunk event: ", payload) + m.logger.Debug("received unknown member chunk event: ", payload) return } diff --git a/gateway/gateway.go b/gateway/gateway.go index cfdc9f79..8d79931c 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -3,8 +3,6 @@ package gateway import ( "context" "time" - - "github.com/disgoorg/log" ) // Version defines which discord API version disgo should use to connect to discord. @@ -64,9 +62,6 @@ type ( // Gateway is what is used to connect to discord. type Gateway interface { - // Logger returns the logger that is used by the Gateway. - Logger() log.Logger - // ShardID returns the shard ID that this Gateway is configured to use. ShardID() int diff --git a/gateway/gateway_impl.go b/gateway/gateway_impl.go index 11b8a750..c06fafe0 100644 --- a/gateway/gateway_impl.go +++ b/gateway/gateway_impl.go @@ -14,8 +14,6 @@ import ( "github.com/disgoorg/disgo/discord" "github.com/disgoorg/disgo/internal/tokenhelper" "github.com/disgoorg/disgo/json" - "github.com/disgoorg/log" - "github.com/gorilla/websocket" ) @@ -51,10 +49,6 @@ type gatewayImpl struct { lastHeartbeatReceived time.Time } -func (g *gatewayImpl) Logger() log.Logger { - return g.config.Logger -} - func (g *gatewayImpl) ShardID() int { return g.config.ShardID } @@ -90,7 +84,7 @@ func (g *gatewayImpl) formatLogs(a ...any) string { } func (g *gatewayImpl) Open(ctx context.Context) error { - g.Logger().Debug(g.formatLogs("opening gateway connection")) + g.config.Logger.Debug(g.formatLogs("opening gateway connection")) g.connMu.Lock() defer g.connMu.Unlock() @@ -111,12 +105,12 @@ func (g *gatewayImpl) Open(ctx context.Context) error { }() rawBody, bErr := io.ReadAll(rs.Body) if bErr != nil { - g.Logger().Error(g.formatLogs("error while reading response body: ", err)) + g.config.Logger.Error(g.formatLogs("error while reading response body: ", err)) } body = string(rawBody) } - g.Logger().Error(g.formatLogsf("error connecting to the gateway. url: %s, error: %s, body: %s", gatewayURL, err, body)) + g.config.Logger.Error(g.formatLogsf("error connecting to the gateway. url: %s, error: %s, body: %s", gatewayURL, err, body)) return err } @@ -142,7 +136,7 @@ func (g *gatewayImpl) Close(ctx context.Context) { func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message string) { if g.heartbeatTicker != nil { - g.Logger().Debug(g.formatLogs("closing heartbeat goroutines...")) + g.config.Logger.Debug(g.formatLogs("closing heartbeat goroutines...")) g.heartbeatTicker.Stop() g.heartbeatTicker = nil } @@ -151,9 +145,9 @@ func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message strin defer g.connMu.Unlock() if g.conn != nil { g.config.RateLimiter.Close(ctx) - g.Logger().Debug(g.formatLogsf("closing gateway connection with code: %d, message: %s", code, message)) + g.config.Logger.Debug(g.formatLogsf("closing gateway connection with code: %d, message: %s", code, message)) if err := g.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)); err != nil && err != websocket.ErrCloseSent { - g.Logger().Debug(g.formatLogs("error writing close code. error: ", err)) + g.config.Logger.Debug(g.formatLogs("error writing close code. error: ", err)) } _ = g.conn.Close() g.conn = nil @@ -196,7 +190,7 @@ func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) er } defer g.config.RateLimiter.Unlock() - g.Logger().Trace(g.formatLogs("sending gateway command: ", string(data))) + g.config.Logger.Trace(g.formatLogs("sending gateway command: ", string(data))) return g.conn.WriteMessage(messageType, data) } @@ -217,12 +211,12 @@ func (g *gatewayImpl) reconnectTry(ctx context.Context, try int, delay time.Dura case <-timer.C: } - g.Logger().Debug(g.formatLogs("reconnecting gateway...")) + g.config.Logger.Debug(g.formatLogs("reconnecting gateway...")) if err := g.Open(ctx); err != nil { if err == discord.ErrGatewayAlreadyConnected { return err } - g.Logger().Error(g.formatLogs("failed to reconnect gateway. error: ", err)) + g.config.Logger.Error(g.formatLogs("failed to reconnect gateway. error: ", err)) g.status = StatusDisconnected return g.reconnectTry(ctx, try+1, delay) } @@ -232,14 +226,14 @@ func (g *gatewayImpl) reconnectTry(ctx context.Context, try int, delay time.Dura func (g *gatewayImpl) reconnect(ctx context.Context) { err := g.reconnectTry(ctx, 0, time.Second) if err != nil { - g.Logger().Error(g.formatLogs("failed to reopen gateway. error: ", err)) + g.config.Logger.Error(g.formatLogs("failed to reopen gateway. error: ", err)) } } func (g *gatewayImpl) heartbeat() { g.heartbeatTicker = time.NewTicker(g.heartbeatInterval) defer g.heartbeatTicker.Stop() - defer g.Logger().Debug(g.formatLogs("exiting heartbeat goroutine...")) + defer g.config.Logger.Debug(g.formatLogs("exiting heartbeat goroutine...")) for range g.heartbeatTicker.C { g.sendHeartbeat() @@ -247,12 +241,12 @@ func (g *gatewayImpl) heartbeat() { } func (g *gatewayImpl) sendHeartbeat() { - g.Logger().Debug(g.formatLogs("sending heartbeat...")) + g.config.Logger.Debug(g.formatLogs("sending heartbeat...")) ctx, cancel := context.WithTimeout(context.Background(), g.heartbeatInterval) defer cancel() if err := g.Send(ctx, OpcodeHeartbeat, (*MessageDataHeartbeat)(g.config.LastSequenceReceived)); err != nil && err != discord.ErrShardNotConnected { - g.Logger().Error(g.formatLogs("failed to send heartbeat. error: ", err)) + g.config.Logger.Error(g.formatLogs("failed to send heartbeat. error: ", err)) g.CloseWithCode(context.TODO(), websocket.CloseServiceRestart, "heartbeat timeout") go g.reconnect(context.TODO()) return @@ -262,7 +256,7 @@ func (g *gatewayImpl) sendHeartbeat() { func (g *gatewayImpl) identify() { g.status = StatusIdentifying - g.Logger().Debug(g.formatLogs("sending Identify command...")) + g.config.Logger.Debug(g.formatLogs("sending Identify command...")) identify := MessageDataIdentify{ Token: g.token, @@ -281,7 +275,7 @@ func (g *gatewayImpl) identify() { } if err := g.Send(context.TODO(), OpcodeIdentify, identify); err != nil { - g.Logger().Error(g.formatLogs("error sending Identify command err: ", err)) + g.config.Logger.Error(g.formatLogs("error sending Identify command err: ", err)) } g.status = StatusWaitingForReady } @@ -294,14 +288,14 @@ func (g *gatewayImpl) resume() { Seq: *g.config.LastSequenceReceived, } - g.Logger().Debug(g.formatLogs("sending Resume command...")) + g.config.Logger.Debug(g.formatLogs("sending Resume command...")) if err := g.Send(context.TODO(), OpcodeResume, resume); err != nil { - g.Logger().Error(g.formatLogs("error sending resume command err: ", err)) + g.config.Logger.Error(g.formatLogs("error sending resume command err: ", err)) } } func (g *gatewayImpl) listen(conn *websocket.Conn) { - defer g.Logger().Debug(g.formatLogs("exiting listen goroutine...")) + defer g.config.Logger.Debug(g.formatLogs("exiting listen goroutine...")) loop: for { mt, reader, err := conn.NextReader() @@ -327,19 +321,19 @@ loop: } else { intentsURL = "https://discord.com/developers/applications" } - g.Logger().Error(g.formatLogsf("disallowed gateway intents supplied. go to %s and enable the privileged intent for your application. intents: %d", intentsURL, g.config.Intents)) + g.config.Logger.Error(g.formatLogsf("disallowed gateway intents supplied. go to %s and enable the privileged intent for your application. intents: %d", intentsURL, g.config.Intents)) } else if closeCode == CloseEventCodeInvalidSeq { - g.Logger().Error(g.formatLogs("invalid sequence provided. reconnecting...")) + g.config.Logger.Error(g.formatLogs("invalid sequence provided. reconnecting...")) g.config.LastSequenceReceived = nil g.config.SessionID = nil } else { - g.Logger().Error(g.formatLogsf("gateway close received, reconnect: %t, code: %d, error: %s", g.config.AutoReconnect && reconnect, closeError.Code, closeError.Text)) + g.config.Logger.Error(g.formatLogsf("gateway close received, reconnect: %t, code: %d, error: %s", g.config.AutoReconnect && reconnect, closeError.Code, closeError.Text)) } } else if errors.Is(err, net.ErrClosed) { // we closed the connection ourselves. Don't try to reconnect here reconnect = false } else { - g.Logger().Debug(g.formatLogs("failed to read next message from gateway. error: ", err)) + g.config.Logger.Debug(g.formatLogs("failed to read next message from gateway. error: ", err)) } if g.config.AutoReconnect && reconnect { @@ -355,7 +349,7 @@ loop: event, err := g.parseMessage(mt, reader) if err != nil { - g.Logger().Error(g.formatLogs("error while parsing gateway message. error: ", err)) + g.config.Logger.Error(g.formatLogs("error while parsing gateway message. error: ", err)) continue } @@ -373,14 +367,14 @@ loop: } case OpcodeDispatch: - g.Logger().Trace(g.formatLogsf("received: OpcodeDispatch %s, data: %s", event.T, string(event.RawD))) + g.config.Logger.Trace(g.formatLogsf("received: OpcodeDispatch %s, data: %s", event.T, string(event.RawD))) // set last sequence received g.config.LastSequenceReceived = &event.S data, ok := event.D.(EventData) if !ok && event.D != nil { - g.Logger().Error(g.formatLogsf("invalid event data of type %T received", event.D)) + g.config.Logger.Error(g.formatLogsf("invalid event data of type %T received", event.D)) continue } @@ -388,7 +382,7 @@ loop: if readyEvent, ok := data.(EventReady); ok { g.config.SessionID = &readyEvent.SessionID g.status = StatusReady - g.Logger().Debug(g.formatLogs("ready event received")) + g.config.Logger.Debug(g.formatLogs("ready event received")) } // push event to the command manager @@ -401,18 +395,18 @@ loop: g.eventHandlerFunc(event.T, event.S, g.config.ShardID, data) case OpcodeHeartbeat: - g.Logger().Debug(g.formatLogs("received: OpcodeHeartbeat")) + g.config.Logger.Debug(g.formatLogs("received: OpcodeHeartbeat")) g.sendHeartbeat() case OpcodeReconnect: - g.Logger().Debug(g.formatLogs("received: OpcodeReconnect")) + g.config.Logger.Debug(g.formatLogs("received: OpcodeReconnect")) g.CloseWithCode(context.TODO(), websocket.CloseServiceRestart, "received reconnect") go g.reconnect(context.TODO()) break loop case OpcodeInvalidSession: canResume := event.D.(MessageDataInvalidSession) - g.Logger().Debug(g.formatLogs("received: OpcodeInvalidSession, canResume: ", canResume)) + g.config.Logger.Debug(g.formatLogs("received: OpcodeInvalidSession, canResume: ", canResume)) code := websocket.CloseNormalClosure if canResume { @@ -428,7 +422,7 @@ loop: break loop case OpcodeHeartbeatACK: - g.Logger().Debug(g.formatLogs("received: OpcodeHeartbeatACK")) + g.config.Logger.Debug(g.formatLogs("received: OpcodeHeartbeatACK")) g.lastHeartbeatReceived = time.Now().UTC() } } @@ -437,7 +431,7 @@ loop: func (g *gatewayImpl) parseMessage(mt int, reader io.Reader) (Message, error) { var readCloser io.ReadCloser if mt == websocket.BinaryMessage { - g.Logger().Trace(g.formatLogs("binary message received. decompressing...")) + g.config.Logger.Trace(g.formatLogs("binary message received. decompressing...")) var err error readCloser, err = zlib.NewReader(reader) if err != nil { @@ -452,7 +446,7 @@ func (g *gatewayImpl) parseMessage(mt int, reader io.Reader) (Message, error) { var message Message if err := json.NewDecoder(readCloser).Decode(&message); err != nil { - g.Logger().Error(g.formatLogs("error decoding websocket message: ", err)) + g.config.Logger.Error(g.formatLogs("error decoding websocket message: ", err)) return Message{}, err } return message, nil diff --git a/gateway/gateway_rate_limiter.go b/gateway/gateway_rate_limiter.go index 4b1c1b0d..027ec7b0 100644 --- a/gateway/gateway_rate_limiter.go +++ b/gateway/gateway_rate_limiter.go @@ -2,15 +2,10 @@ package gateway import ( "context" - - "github.com/disgoorg/log" ) // RateLimiter provides handles the rate limiting logic for connecting to Discord's Gateway. type RateLimiter interface { - // Logger returns the logger used by the RateLimiter. - Logger() log.Logger - // Close gracefully closes the RateLimiter. // If the context deadline is exceeded, the RateLimiter will be closed immediately. Close(ctx context.Context) diff --git a/gateway/gateway_rate_limiter_impl.go b/gateway/gateway_rate_limiter_impl.go index 33041db1..6cdaaa19 100644 --- a/gateway/gateway_rate_limiter_impl.go +++ b/gateway/gateway_rate_limiter_impl.go @@ -4,7 +4,6 @@ import ( "context" "time" - "github.com/disgoorg/log" "github.com/sasha-s/go-csync" ) @@ -27,10 +26,6 @@ type rateLimiterImpl struct { config RateLimiterConfig } -func (l *rateLimiterImpl) Logger() log.Logger { - return l.config.Logger -} - func (l *rateLimiterImpl) Close(ctx context.Context) { _ = l.mu.CLock(ctx) } @@ -42,7 +37,7 @@ func (l *rateLimiterImpl) Reset() { } func (l *rateLimiterImpl) Wait(ctx context.Context) error { - l.Logger().Trace("locking gateway rate limiter") + l.config.Logger.Trace("locking gateway rate limiter") if err := l.mu.CLock(ctx); err != nil { return err } @@ -72,7 +67,7 @@ func (l *rateLimiterImpl) Wait(ctx context.Context) error { } func (l *rateLimiterImpl) Unlock() { - l.Logger().Trace("unlocking gateway rate limiter") + l.config.Logger.Trace("unlocking gateway rate limiter") now := time.Now() if l.reset.Before(now) { l.reset = now.Add(time.Minute) diff --git a/rest/rest_client.go b/rest/rest_client.go index 7ea81654..a5901461 100644 --- a/rest/rest_client.go +++ b/rest/rest_client.go @@ -13,7 +13,6 @@ import ( "github.com/disgoorg/disgo/discord" "github.com/disgoorg/disgo/rest/route" - "github.com/disgoorg/log" ) // NewClient constructs a new Client with the given Config struct @@ -28,9 +27,6 @@ func NewClient(botToken string, opts ...ConfigOpt) Client { // Client allows doing requests to different endpoints type Client interface { - // Logger returns the logger the rest client uses - Logger() log.Logger - // HTTPClient returns the http.Client the rest client uses HTTPClient() *http.Client @@ -54,10 +50,6 @@ func (c *clientImpl) Close(ctx context.Context) { c.config.HTTPClient.CloseIdleConnections() } -func (c *clientImpl) Logger() log.Logger { - return c.config.Logger -} - func (c *clientImpl) HTTPClient() *http.Client { return c.config.HTTPClient } @@ -90,7 +82,7 @@ func (c *clientImpl) retry(cRoute *route.CompiledAPIRoute, rqBody any, rsBody an return fmt.Errorf("failed to marshal request body: %w", err) } } - c.Logger().Tracef("request to %s, body: %s", rqURL, string(rawRqBody)) + c.config.Logger.Tracef("request to %s, body: %s", rqURL, string(rawRqBody)) } rq, err := http.NewRequest(cRoute.APIRoute.Method().String(), rqURL, bytes.NewReader(rawRqBody)) @@ -150,7 +142,7 @@ func (c *clientImpl) retry(cRoute *route.CompiledAPIRoute, rqBody any, rsBody an if rawRsBody, err = io.ReadAll(rs.Body); err != nil { return fmt.Errorf("error reading response body in rest client: %w", err) } - c.Logger().Tracef("response from %s, code %d, body: %s", rqURL, rs.StatusCode, string(rawRsBody)) + c.config.Logger.Tracef("response from %s, code %d, body: %s", rqURL, rs.StatusCode, string(rawRsBody)) } switch rs.StatusCode { @@ -158,7 +150,7 @@ func (c *clientImpl) retry(cRoute *route.CompiledAPIRoute, rqBody any, rsBody an if rsBody != nil && rs.Body != nil { if err = json.Unmarshal(rawRsBody, rsBody); err != nil { wErr := fmt.Errorf("error unmarshalling response body: %w", err) - c.Logger().Error(wErr) + c.config.Logger.Error(wErr) return wErr } } diff --git a/rest/rest_rate_limiter.go b/rest/rest_rate_limiter.go index 976eb318..a8e31dbb 100644 --- a/rest/rest_rate_limiter.go +++ b/rest/rest_rate_limiter.go @@ -5,14 +5,10 @@ import ( "net/http" "github.com/disgoorg/disgo/rest/route" - "github.com/disgoorg/log" ) // RateLimiter can be used to supply your own rate limit implementation type RateLimiter interface { - // Logger returns the logger the RateLimiter uses - Logger() log.Logger - // MaxRetries returns the maximum number of retries the client should do MaxRetries() int diff --git a/rest/rest_rate_limiter_impl.go b/rest/rest_rate_limiter_impl.go index aa454876..a52cf8ab 100644 --- a/rest/rest_rate_limiter_impl.go +++ b/rest/rest_rate_limiter_impl.go @@ -9,7 +9,6 @@ import ( "time" "github.com/disgoorg/disgo/rest/route" - "github.com/disgoorg/log" "github.com/sasha-s/go-csync" ) @@ -48,10 +47,6 @@ type ( } ) -func (l *rateLimiterImpl) Logger() log.Logger { - return l.config.Logger -} - func (l *rateLimiterImpl) MaxRetries() int { return l.config.MaxRetries } @@ -73,13 +68,13 @@ func (l *rateLimiterImpl) doCleanup() { continue } if b.Reset.Before(now) { - l.Logger().Debugf("cleaning up bucket, Hash: %s, ID: %s, Reset: %s", hash, b.ID, b.Reset) + l.config.Logger.Debugf("cleaning up bucket, Hash: %s, ID: %s, Reset: %s", hash, b.ID, b.Reset) delete(l.buckets, hash) } b.mu.Unlock() } if before != len(l.buckets) { - l.Logger().Debugf("cleaned up %d rate limit buckets", before-len(l.buckets)) + l.config.Logger.Debugf("cleaned up %d rate limit buckets", before-len(l.buckets)) } } @@ -122,10 +117,10 @@ func (l *rateLimiterImpl) getRouteHash(route *route.CompiledAPIRoute) hashMajor func (l *rateLimiterImpl) getBucket(route *route.CompiledAPIRoute, create bool) *bucket { hash := l.getRouteHash(route) - l.Logger().Trace("locking buckets") + l.config.Logger.Trace("locking buckets") l.bucketsMu.Lock() defer func() { - l.Logger().Trace("unlocking buckets") + l.config.Logger.Trace("unlocking buckets") l.bucketsMu.Unlock() }() b, ok := l.buckets[hash] @@ -146,7 +141,7 @@ func (l *rateLimiterImpl) getBucket(route *route.CompiledAPIRoute, create bool) func (l *rateLimiterImpl) WaitBucket(ctx context.Context, route *route.CompiledAPIRoute) error { b := l.getBucket(route, true) - l.Logger().Tracef("locking rest bucket, ID: %s, Limit: %d, Remaining: %d, Reset: %s", b.ID, b.Limit, b.Remaining, b.Reset) + l.config.Logger.Tracef("locking rest bucket, ID: %s, Limit: %d, Remaining: %d, Reset: %s", b.ID, b.Limit, b.Remaining, b.Reset) if err := b.mu.CLock(ctx); err != nil { return err } @@ -182,7 +177,7 @@ func (l *rateLimiterImpl) UnlockBucket(route *route.CompiledAPIRoute, rs *http.R return nil } defer func() { - l.Logger().Tracef("unlocking rest bucket, ID: %s, Limit: %d, Remaining: %d, Reset: %s", b.ID, b.Limit, b.Remaining, b.Reset) + l.config.Logger.Tracef("unlocking rest bucket, ID: %s, Limit: %d, Remaining: %d, Reset: %s", b.ID, b.Limit, b.Remaining, b.Reset) b.mu.Unlock() }() @@ -207,7 +202,7 @@ func (l *rateLimiterImpl) UnlockBucket(route *route.CompiledAPIRoute, rs *http.R resetAfterHeader := rs.Header.Get("X-RateLimit-Reset-After") retryAfterHeader := rs.Header.Get("Retry-After") - l.Logger().Tracef("code: %d, headers: global %t, cloudflare: %t, remaining: %s, limit: %s, reset: %s, retryAfter: %s", rs.StatusCode, global, cloudflare, remainingHeader, limitHeader, resetHeader, retryAfterHeader) + l.config.Logger.Tracef("code: %d, headers: global %t, cloudflare: %t, remaining: %s, limit: %s, reset: %s, retryAfter: %s", rs.StatusCode, global, cloudflare, remainingHeader, limitHeader, resetHeader, retryAfterHeader) if rs.StatusCode == http.StatusTooManyRequests { retryAfter, err := strconv.Atoi(retryAfterHeader) @@ -217,14 +212,14 @@ func (l *rateLimiterImpl) UnlockBucket(route *route.CompiledAPIRoute, rs *http.R reset := time.Now().Add(time.Second * time.Duration(retryAfter)) if global { l.global = reset - l.Logger().Warnf("global rate limit exceeded, retry after: %ds", retryAfter) + l.config.Logger.Warnf("global rate limit exceeded, retry after: %ds", retryAfter) } else if cloudflare { l.global = reset - l.Logger().Warnf("cloudflare rate limit exceeded, retry after: %ds", retryAfter) + l.config.Logger.Warnf("cloudflare rate limit exceeded, retry after: %ds", retryAfter) } else { b.Remaining = 0 b.Reset = reset - l.Logger().Warnf("rate limit on route %s exceeded, retry after: %ds", route.URL(), retryAfter) + l.config.Logger.Warnf("rate limit on route %s exceeded, retry after: %ds", route.URL(), retryAfter) } return nil } diff --git a/sharding/shard_manager.go b/sharding/shard_manager.go index 6e9f51d6..5f4e34d8 100644 --- a/sharding/shard_manager.go +++ b/sharding/shard_manager.go @@ -4,16 +4,12 @@ import ( "context" "github.com/disgoorg/disgo/gateway" - "github.com/disgoorg/log" "github.com/disgoorg/snowflake/v2" ) // ShardManager manages multiple gateway.Gateway connections. // For more information on sharding see: https://discord.com/developers/docs/topics/gateway#sharding type ShardManager interface { - // Logger returns the logger used by the ShardManager. - Logger() log.Logger - // Open opens all configured shards. Open(ctx context.Context) // Close closes all shards. diff --git a/sharding/shard_manager_impl.go b/sharding/shard_manager_impl.go index dd5deaac..6a52ecbb 100644 --- a/sharding/shard_manager_impl.go +++ b/sharding/shard_manager_impl.go @@ -5,7 +5,6 @@ import ( "sync" "github.com/disgoorg/disgo/gateway" - "github.com/disgoorg/log" "github.com/disgoorg/snowflake/v2" "github.com/gorilla/websocket" ) @@ -34,15 +33,11 @@ type shardManagerImpl struct { config Config } -func (m *shardManagerImpl) Logger() log.Logger { - return m.config.Logger -} - func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error) { if closeError, ok := err.(*websocket.CloseError); !m.config.AutoScaling || !ok || gateway.CloseEventCode(closeError.Code) != gateway.CloseEventCodeShardingRequired { return } - m.Logger().Debugf("shard %d requires re-sharding", shard.ShardID()) + m.config.Logger.Debugf("shard %d requires re-sharding", shard.ShardID()) // make sure shard is closed shard.Close(context.TODO()) @@ -72,7 +67,7 @@ func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error) { go func() { defer wg.Done() if err := m.config.RateLimiter.WaitBucket(context.TODO(), shardID); err != nil { - m.Logger().Errorf("failed to wait shard bucket %d: %s", shardID, err) + m.config.Logger.Errorf("failed to wait shard bucket %d: %s", shardID, err) return } defer m.config.RateLimiter.UnlockBucket(shardID) @@ -80,16 +75,16 @@ func (m *shardManagerImpl) closeHandler(shard gateway.Gateway, err error) { newShard := m.config.GatewayCreateFunc(m.token, m.eventHandlerFunc, m.closeHandler, append(m.config.GatewayConfigOpts, gateway.WithShardID(shardID), gateway.WithShardCount(newShardCount))...) m.shards[shardID] = newShard if err := newShard.Open(context.TODO()); err != nil { - m.Logger().Errorf("failed to re shard %d, error: %s", shardID, err) + m.config.Logger.Errorf("failed to re shard %d, error: %s", shardID, err) } }() } wg.Wait() - m.Logger().Debugf("re-sharded shard %d into newShards: %d, newShardCount: %d", shard.ShardID(), newShardIDs, newShardCount) + m.config.Logger.Debugf("re-sharded shard %d into newShards: %d, newShardCount: %d", shard.ShardID(), newShardIDs, newShardCount) } func (m *shardManagerImpl) Open(ctx context.Context) { - m.Logger().Debugf("opening %+v shards...", m.config.ShardIDs) + m.config.Logger.Debugf("opening %+v shards...", m.config.ShardIDs) var wg sync.WaitGroup m.shardsMu.Lock() @@ -104,7 +99,7 @@ func (m *shardManagerImpl) Open(ctx context.Context) { go func() { defer wg.Done() if err := m.config.RateLimiter.WaitBucket(ctx, shardID); err != nil { - m.Logger().Errorf("failed to wait shard bucket %d: %s", shardID, err) + m.config.Logger.Errorf("failed to wait shard bucket %d: %s", shardID, err) return } defer m.config.RateLimiter.UnlockBucket(shardID) @@ -112,7 +107,7 @@ func (m *shardManagerImpl) Open(ctx context.Context) { shard := m.config.GatewayCreateFunc(m.token, m.eventHandlerFunc, m.closeHandler, append(m.config.GatewayConfigOpts, gateway.WithShardID(shardID), gateway.WithShardCount(m.config.ShardCount))...) m.shards[shardID] = shard if err := shard.Open(ctx); err != nil { - m.Logger().Errorf("failed to open shard %d: %s", shardID, err) + m.config.Logger.Errorf("failed to open shard %d: %s", shardID, err) } }() } @@ -120,7 +115,7 @@ func (m *shardManagerImpl) Open(ctx context.Context) { } func (m *shardManagerImpl) Close(ctx context.Context) { - m.Logger().Debugf("closing %v shards...", m.config.ShardIDs) + m.config.Logger.Debugf("closing %v shards...", m.config.ShardIDs) var wg sync.WaitGroup m.shardsMu.Lock() @@ -142,7 +137,7 @@ func (m *shardManagerImpl) OpenShard(ctx context.Context, shardID int) error { } func (m *shardManagerImpl) openShard(ctx context.Context, shardID int, shardCount int) error { - m.Logger().Debugf("opening shard %d...", shardID) + m.config.Logger.Debugf("opening shard %d...", shardID) if err := m.config.RateLimiter.WaitBucket(ctx, shardID); err != nil { return err @@ -158,7 +153,7 @@ func (m *shardManagerImpl) openShard(ctx context.Context, shardID int, shardCoun } func (m *shardManagerImpl) CloseShard(ctx context.Context, shardID int) { - m.Logger().Debugf("closing shard %d...", shardID) + m.config.Logger.Debugf("closing shard %d...", shardID) m.shardsMu.Lock() defer m.shardsMu.Unlock() shard, ok := m.shards[shardID] diff --git a/sharding/shard_rate_limiter.go b/sharding/shard_rate_limiter.go index 3fe5f27b..613cc051 100644 --- a/sharding/shard_rate_limiter.go +++ b/sharding/shard_rate_limiter.go @@ -2,15 +2,10 @@ package sharding import ( "context" - - "github.com/disgoorg/log" ) // RateLimiter limits how many shards can log in to Discord at the same time. type RateLimiter interface { - // Logger returns the logger the RateLimiter uses - Logger() log.Logger - // Close gracefully closes the RateLimiter. // If the context deadline is exceeded, the RateLimiter will be closed immediately. Close(ctx context.Context) diff --git a/sharding/shard_rate_limiter_impl.go b/sharding/shard_rate_limiter_impl.go index 7d6b62a7..ed430527 100644 --- a/sharding/shard_rate_limiter_impl.go +++ b/sharding/shard_rate_limiter_impl.go @@ -5,7 +5,6 @@ import ( "sync" "time" - "github.com/disgoorg/log" "github.com/sasha-s/go-csync" ) @@ -29,10 +28,6 @@ type rateLimiterImpl struct { config RateLimiterConfig } -func (r *rateLimiterImpl) Logger() log.Logger { - return r.config.Logger -} - func (r *rateLimiterImpl) Close(ctx context.Context) { var wg sync.WaitGroup r.mu.Lock() @@ -43,7 +38,7 @@ func (r *rateLimiterImpl) Close(ctx context.Context) { go func() { defer wg.Done() if err := b.mu.CLock(ctx); err != nil { - r.Logger().Error("failed to close bucket: ", err) + r.config.Logger.Error("failed to close bucket: ", err) } b.mu.Unlock() }() @@ -51,10 +46,10 @@ func (r *rateLimiterImpl) Close(ctx context.Context) { } func (r *rateLimiterImpl) getBucket(shardID int, create bool) *bucket { - r.Logger().Debug("locking shard srate limiter") + r.config.Logger.Debug("locking shard srate limiter") r.mu.Lock() defer func() { - r.Logger().Debug("unlocking shard srate limiter") + r.config.Logger.Debug("unlocking shard srate limiter") r.mu.Unlock() }() key := ShardMaxConcurrencyKey(shardID, r.config.MaxConcurrency) @@ -74,7 +69,7 @@ func (r *rateLimiterImpl) getBucket(shardID int, create bool) *bucket { func (r *rateLimiterImpl) WaitBucket(ctx context.Context, shardID int) error { b := r.getBucket(shardID, true) - r.Logger().Debugf("locking shard bucket: Key: %d, Reset: %s", b.Key, b.Reset) + r.config.Logger.Debugf("locking shard bucket: Key: %d, Reset: %s", b.Key, b.Reset) if err := b.mu.CLock(ctx); err != nil { return err } @@ -107,7 +102,7 @@ func (r *rateLimiterImpl) UnlockBucket(shardID int) { return } defer func() { - r.Logger().Debugf("unlocking shard bucket: Key: %d, Reset: %s", b.Key, b.Reset) + r.config.Logger.Debugf("unlocking shard bucket: Key: %d, Reset: %s", b.Key, b.Reset) b.mu.Unlock() }()