diff --git a/api/client.go b/api/client.go index e0b9b0aa08..ccbcbf6bca 100644 --- a/api/client.go +++ b/api/client.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "net/url" ) @@ -26,47 +25,18 @@ func NewClient(hosts ...string) *Client { } } -func StatusError(status int, message ...string) error { - if status < 400 { - return nil - } - - if len(message) > 0 && len(message[0]) > 0 { - return fmt.Errorf("%d %s: %s", status, http.StatusText(status), message[0]) - } - - return fmt.Errorf("%d %s", status, http.StatusText(status)) -} - -type options struct { - requestBody io.Reader - responseFunc func(bts []byte) error -} - -func OptionRequestBody(data any) func(*options) { - bts, err := json.Marshal(data) - if err != nil { - panic(err) - } - - return func(opts *options) { - opts.requestBody = bytes.NewReader(bts) - } -} - -func OptionResponseFunc(fn func([]byte) error) func(*options) { - return func(opts *options) { - opts.responseFunc = fn - } -} +func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error { + var buf *bytes.Buffer + if data != nil { + bts, err := json.Marshal(data) + if err != nil { + return err + } -func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*options)) error { - var opts options - for _, fn := range fns { - fn(&opts) + buf = bytes.NewBuffer(bts) } - request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), opts.requestBody) + request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), buf) if err != nil { return err } @@ -80,25 +50,23 @@ func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*o } defer response.Body.Close() - if opts.responseFunc != nil { - scanner := bufio.NewScanner(response.Body) - for scanner.Scan() { - var errorResponse struct { - Error string `json:"error"` - } - - bts := scanner.Bytes() - if err := json.Unmarshal(bts, &errorResponse); err != nil { - return err - } - - if err := StatusError(response.StatusCode, errorResponse.Error); err != nil { - return err - } - - if err := opts.responseFunc(bts); err != nil { - return err - } + scanner := bufio.NewScanner(response.Body) + for scanner.Scan() { + var errorResponse struct { + Error string `json:"error"` + } + + bts := scanner.Bytes() + if err := json.Unmarshal(bts, &errorResponse); err != nil { + return fmt.Errorf("unmarshal: %w", err) + } + + if len(errorResponse.Error) > 0 { + return fmt.Errorf("stream: %s", errorResponse.Error) + } + + if err := callback(bts); err != nil { + return err } } @@ -108,36 +76,25 @@ func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*o type GenerateResponseFunc func(GenerateResponse) error func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error { - return c.stream(ctx, http.MethodPost, "/api/generate", - OptionRequestBody(req), - OptionResponseFunc(func(bts []byte) error { - var resp GenerateResponse - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } - - return fn(resp) - }), - ) + return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error { + var resp GenerateResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) } type PullProgressFunc func(PullProgress) error func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { - return c.stream(ctx, http.MethodPost, "/api/pull", - OptionRequestBody(req), - OptionResponseFunc(func(bts []byte) error { - var resp PullProgress - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } - - if resp.Error.Message != "" { - // couldn't pull the model from the directory, proceed anyway - return nil - } - - return fn(resp) - }), - ) + return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { + var resp PullProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) } diff --git a/api/types.go b/api/types.go index 5dc7488e41..3443ae9be3 100644 --- a/api/types.go +++ b/api/types.go @@ -1,24 +1,5 @@ package api -import ( - "fmt" - "net/http" - "strings" -) - -type Error struct { - Code int32 `json:"code"` - Message string `json:"message"` -} - -func (e Error) Error() string { - if e.Message == "" { - return fmt.Sprintf("%d %v", e.Code, strings.ToLower(http.StatusText(int(e.Code)))) - } - - return e.Message -} - type PullRequest struct { Model string `json:"model"` } @@ -27,7 +8,6 @@ type PullProgress struct { Total int64 `json:"total"` Completed int64 `json:"completed"` Percent float64 `json:"percent"` - Error Error `json:"error"` } type GenerateRequest struct { diff --git a/server/routes.go b/server/routes.go index 7d0fdf72ef..dcd48b07d7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -54,7 +54,7 @@ func generate(c *gin.Context) { } if _, err := os.Stat(req.Model); err != nil { if !errors.Is(err, os.ErrNotExist) { - c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } req.Model = path.Join(cacheDir(), "models", req.Model+".bin") @@ -136,7 +136,7 @@ func Serve(ln net.Listener) error { r.POST("api/pull", func(c *gin.Context) { var req api.PullRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -146,16 +146,10 @@ func Serve(ln net.Listener) error { if err := pull(req.Model, progressCh); err != nil { var opError *net.OpError if errors.As(err, &opError) { - result := api.PullProgress{ - Error: api.Error{ - Code: http.StatusBadGateway, - Message: "failed to get models from directory", - }, - } - c.JSON(http.StatusBadGateway, result) + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) return } - c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } }()