Skip to content

Commit

Permalink
Add more midjourney error handlers
Browse files Browse the repository at this point in the history
Handle missing message error on upscale and variations.
Handle pending mod message error and abort when it happens.
Improve error when prompt is banned.
  • Loading branch information
igolaizola committed May 6, 2023
1 parent ea12bb4 commit 0093e86
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
19 changes: 18 additions & 1 deletion pkg/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Image struct {
type Error struct {
error
temporary bool
fatal bool
}

func NewError(err error, temporary bool) Error {
Expand All @@ -52,6 +53,13 @@ func NewError(err error, temporary bool) Error {
}
}

func NewFatal(err error) Error {
return Error{
error: err,
fatal: true,
}
}

func (e Error) Error() string {
return e.error.Error()
}
Expand All @@ -64,6 +72,10 @@ func (e Error) Temporary() bool {
return e.temporary
}

func (e Error) Fatal() bool {
return e.fatal
}

type entry struct {
prompt string
index int
Expand Down Expand Up @@ -294,8 +306,13 @@ func retry(ctx context.Context, fn func(context.Context) error) error {
if err == nil {
return nil
}
// If the error is not temporary, return it
var aiErr Error
// If the error is fatal, stop everything
if errors.As(err, &aiErr) && aiErr.Fatal() {
// TODO: handle fatal errors
panic(err)
}
// If the error is not temporary, return it
if errors.As(err, &aiErr) && !aiErr.Temporary() {
return err
}
Expand Down
18 changes: 17 additions & 1 deletion pkg/ai/midjourney/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,16 @@ func parseEmbedFooter(prompt string, msg *discord.Message) (string, error) {
return fmt.Sprintf("%s%s", prompt, suffixes), nil
}

// Errors parsed from messages
var ErrInvalidParameter = errors.New("invalid parameter")
var ErrInvalidLink = errors.New("invalid link")
var ErrBannedPrompt = errors.New("banned prompt")
var ErrJobQueued = errors.New("job queued")
var ErrQueueFull = errors.New("queue full")
var ErrPendingMod = errors.New("pending mod message")

// Other errors
var ErrMessageNotFound = ai.NewError(errors.New("message not found"), false)

func parseError(msg *discord.Message) error {
if len(msg.Embeds) == 0 {
Expand All @@ -251,7 +256,7 @@ func parseError(msg *discord.Message) error {
case "invalid link":
err := fmt.Errorf("midjourney: %w: %s", ErrInvalidLink, desc)
return ai.NewError(err, false)
case "banned prompt":
case "banned prompt", "banned prompt detected":
err := fmt.Errorf("midjourney: %w: %s", ErrBannedPrompt, desc)
return ai.NewError(err, false)
case "job queued":
Expand All @@ -260,6 +265,9 @@ func parseError(msg *discord.Message) error {
case "queue full":
err := fmt.Errorf("midjourney: %w: %s", ErrQueueFull, desc)
return ai.NewError(err, true)
case "pending mod message":
err := fmt.Errorf("midjourney: %w: %s", ErrPendingMod, desc)
return ai.NewFatal(err)
default:
err := fmt.Errorf("midjourney: %s: %s", title, desc)
return ai.NewError(err, true)
Expand Down Expand Up @@ -517,6 +525,10 @@ func (c *Client) Upscale(ctx context.Context, preview *ai.Preview, index int) (s
// response may be received before it finishes, due to rate limit
// locking.
if _, err := c.c.Do(ctx, "POST", "interactions", upscale); err != nil {
// Check if the message was deleted
if errors.Is(err, discord.ErrMessageNotFound) {
return ErrMessageNotFound
}
return fmt.Errorf("midjourney: couldn't send upscale interaction: %w", err)
}
return nil
Expand Down Expand Up @@ -553,6 +565,10 @@ func (c *Client) Variation(ctx context.Context, preview *ai.Preview, index int)
// response may be received before it finishes, due to rate limit
// locking.
if _, err := c.c.Do(ctx, "POST", "interactions", variation); err != nil {
// Check if the message was deleted
if errors.Is(err, discord.ErrMessageNotFound) {
return ErrMessageNotFound
}
return fmt.Errorf("midjourney: couldn't send variation interaction: %w", err)
}
return nil
Expand Down
38 changes: 38 additions & 0 deletions pkg/discord/discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,36 @@ func (c *Client) Do(ctx context.Context, method string, path string, body interf

var errBadGateway = errors.New("discord: bad gateway")

type Error struct {
Code int `json:"code"`
Message string `json:"message"`
temporary bool
}

func (e Error) Error() string {
return fmt.Sprintf("discord: %s (%d)", e.Message, e.Code)
}

func (e Error) Temporary() bool {
return e.temporary
}

var ErrMessageNotFound = &Error{Message: "Unknown Message", Code: 10008, temporary: false}

func parseError(raw string) error {
var err Error
if err := json.Unmarshal([]byte(raw), &err); err != nil {
return nil
}
err.temporary = true
switch err.Code {
case 10008:
return ErrMessageNotFound
default:
return err
}
}

func (c *Client) do(method string, path string, body interface{}) ([]byte, error) {
// Rate limit
c.doLck.Lock()
Expand Down Expand Up @@ -268,6 +298,9 @@ func (c *Client) do(method string, path string, body interface{}) ([]byte, error
return nil, errBadGateway
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if err := parseError(string(data)); err != nil {
return nil, err
}
return nil, fmt.Errorf("discord: request %s returned status code %d (%s)", path, resp.StatusCode, string(data))
}
return data, nil
Expand Down Expand Up @@ -357,6 +390,11 @@ func retry(ctx context.Context, maxAttempts int, fn func() error) error {
if attempts >= maxAttempts {
return err
}
// If the error is not temporary, we stop
var discordErr Error
if errors.As(err, &discordErr) && !discordErr.Temporary() {
return err
}
// Bad gateway usually means discord is down, so we wait before retrying
if errors.Is(err, errBadGateway) {
idx := attempts - 1
Expand Down

0 comments on commit 0093e86

Please sign in to comment.