diff --git a/pkg/ai/ai.go b/pkg/ai/ai.go index 53d7736..98133bf 100644 --- a/pkg/ai/ai.go +++ b/pkg/ai/ai.go @@ -43,6 +43,7 @@ type Image struct { type Error struct { error temporary bool + fatal bool } func NewError(err error, temporary bool) Error { @@ -52,6 +53,13 @@ func NewError(err error, temporary bool) Error { } } +func NewFatal(err error) Error { + return Error{ + error: err, + fatal: true, + } +} + func (e Error) Error() string { return e.error.Error() } @@ -64,6 +72,10 @@ func (e Error) Temporary() bool { return e.temporary } +func (e Error) Fatal() bool { + return e.fatal +} + type entry struct { prompt string index int @@ -294,8 +306,13 @@ func retry(ctx context.Context, fn func(context.Context) error) error { if err == nil { return nil } - // If the error is not temporary, return it var aiErr Error + // If the error is fatal, stop everything + if errors.As(err, &aiErr) && aiErr.Fatal() { + // TODO: handle fatal errors + panic(err) + } + // If the error is not temporary, return it if errors.As(err, &aiErr) && !aiErr.Temporary() { return err } diff --git a/pkg/ai/midjourney/midjourney.go b/pkg/ai/midjourney/midjourney.go index 35299ee..7362fcb 100644 --- a/pkg/ai/midjourney/midjourney.go +++ b/pkg/ai/midjourney/midjourney.go @@ -230,11 +230,16 @@ func parseEmbedFooter(prompt string, msg *discord.Message) (string, error) { return fmt.Sprintf("%s%s", prompt, suffixes), nil } +// Errors parsed from messages var ErrInvalidParameter = errors.New("invalid parameter") var ErrInvalidLink = errors.New("invalid link") var ErrBannedPrompt = errors.New("banned prompt") var ErrJobQueued = errors.New("job queued") var ErrQueueFull = errors.New("queue full") +var ErrPendingMod = errors.New("pending mod message") + +// Other errors +var ErrMessageNotFound = ai.NewError(errors.New("message not found"), false) func parseError(msg *discord.Message) error { if len(msg.Embeds) == 0 { @@ -251,7 +256,7 @@ func parseError(msg *discord.Message) error { case "invalid link": err := fmt.Errorf("midjourney: %w: %s", ErrInvalidLink, desc) return ai.NewError(err, false) - case "banned prompt": + case "banned prompt", "banned prompt detected": err := fmt.Errorf("midjourney: %w: %s", ErrBannedPrompt, desc) return ai.NewError(err, false) case "job queued": @@ -260,6 +265,9 @@ func parseError(msg *discord.Message) error { case "queue full": err := fmt.Errorf("midjourney: %w: %s", ErrQueueFull, desc) return ai.NewError(err, true) + case "pending mod message": + err := fmt.Errorf("midjourney: %w: %s", ErrPendingMod, desc) + return ai.NewFatal(err) default: err := fmt.Errorf("midjourney: %s: %s", title, desc) return ai.NewError(err, true) @@ -517,6 +525,10 @@ func (c *Client) Upscale(ctx context.Context, preview *ai.Preview, index int) (s // response may be received before it finishes, due to rate limit // locking. if _, err := c.c.Do(ctx, "POST", "interactions", upscale); err != nil { + // Check if the message was deleted + if errors.Is(err, discord.ErrMessageNotFound) { + return ErrMessageNotFound + } return fmt.Errorf("midjourney: couldn't send upscale interaction: %w", err) } return nil @@ -553,6 +565,10 @@ func (c *Client) Variation(ctx context.Context, preview *ai.Preview, index int) // response may be received before it finishes, due to rate limit // locking. if _, err := c.c.Do(ctx, "POST", "interactions", variation); err != nil { + // Check if the message was deleted + if errors.Is(err, discord.ErrMessageNotFound) { + return ErrMessageNotFound + } return fmt.Errorf("midjourney: couldn't send variation interaction: %w", err) } return nil diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go index e76d2f5..1e5e9bd 100644 --- a/pkg/discord/discord.go +++ b/pkg/discord/discord.go @@ -194,6 +194,36 @@ func (c *Client) Do(ctx context.Context, method string, path string, body interf var errBadGateway = errors.New("discord: bad gateway") +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + temporary bool +} + +func (e Error) Error() string { + return fmt.Sprintf("discord: %s (%d)", e.Message, e.Code) +} + +func (e Error) Temporary() bool { + return e.temporary +} + +var ErrMessageNotFound = &Error{Message: "Unknown Message", Code: 10008, temporary: false} + +func parseError(raw string) error { + var err Error + if err := json.Unmarshal([]byte(raw), &err); err != nil { + return nil + } + err.temporary = true + switch err.Code { + case 10008: + return ErrMessageNotFound + default: + return err + } +} + func (c *Client) do(method string, path string, body interface{}) ([]byte, error) { // Rate limit c.doLck.Lock() @@ -268,6 +298,9 @@ func (c *Client) do(method string, path string, body interface{}) ([]byte, error return nil, errBadGateway } if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if err := parseError(string(data)); err != nil { + return nil, err + } return nil, fmt.Errorf("discord: request %s returned status code %d (%s)", path, resp.StatusCode, string(data)) } return data, nil @@ -357,6 +390,11 @@ func retry(ctx context.Context, maxAttempts int, fn func() error) error { if attempts >= maxAttempts { return err } + // If the error is not temporary, we stop + var discordErr Error + if errors.As(err, &discordErr) && !discordErr.Temporary() { + return err + } // Bad gateway usually means discord is down, so we wait before retrying if errors.Is(err, errBadGateway) { idx := attempts - 1