Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/backend/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat

return fn, nil
}

// ImageGenerationFunc is a test-friendly indirection to call image generation logic.
// Tests can override this variable to provide a stub implementation.
var ImageGenerationFunc = ImageGeneration
268 changes: 268 additions & 0 deletions core/http/endpoints/openai/inpainting.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
package openai

import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"time"

"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"

"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
)

// InpaintingEndpoint handles POST /v1/images/inpainting
//
// Swagger / OpenAPI docstring (swaggo):
// @Summary Image inpainting
// @Description Perform image inpainting. Accepts multipart/form-data with `image` and `mask` files.
// @Tags images
// @Accept multipart/form-data
// @Produce application/json
// @Param model formData string true "Model identifier"
// @Param prompt formData string true "Text prompt guiding the generation"
// @Param steps formData int false "Number of inference steps (default 25)"
// @Param image formData file true "Original image file"
// @Param mask formData file true "Mask image file (white = area to inpaint)"
// @Success 200 {object} schema.OpenAIResponse
// @Failure 400 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /v1/images/inpainting [post]
func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
// Parse basic form values
modelName := c.FormValue("model")
prompt := c.FormValue("prompt")
stepsStr := c.FormValue("steps")

if modelName == "" || prompt == "" {
log.Error().Msg("Inpainting Endpoint - missing model or prompt")
return echo.ErrBadRequest
}

// steps default
steps := 25
if stepsStr != "" {
if v, err := strconv.Atoi(stepsStr); err == nil {
steps = v
}
}

// Get uploaded files
imageFile, err := c.FormFile("image")
if err != nil {
log.Error().Err(err).Msg("Inpainting Endpoint - missing image file")
return echo.NewHTTPError(http.StatusBadRequest, "missing image file")
}
maskFile, err := c.FormFile("mask")
if err != nil {
log.Error().Err(err).Msg("Inpainting Endpoint - missing mask file")
return echo.NewHTTPError(http.StatusBadRequest, "missing mask file")
}

// Read files into memory (small files expected)
imgSrc, err := imageFile.Open()
if err != nil {
return err
}
defer imgSrc.Close()
imgBytes, err := io.ReadAll(imgSrc)
if err != nil {
return err
}

maskSrc, err := maskFile.Open()
if err != nil {
return err
}
defer maskSrc.Close()
maskBytes, err := io.ReadAll(maskSrc)
if err != nil {
return err
}

// Create JSON with base64 fields expected by backend
b64Image := base64.StdEncoding.EncodeToString(imgBytes)
b64Mask := base64.StdEncoding.EncodeToString(maskBytes)

// get model config from context (middleware set it)
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
log.Error().Msg("Inpainting Endpoint - model config not found in context")
return echo.ErrBadRequest
}

// Use the GeneratedContentDir so the generated PNG is placed where the
// HTTP static handler serves `/generated-images`.
tmpDir := appConfig.GeneratedContentDir
// Ensure the directory exists
if err := os.MkdirAll(tmpDir, 0750); err != nil {
log.Error().Err(err).Msgf("Inpainting Endpoint - failed to create generated content dir: %s", tmpDir)
return echo.NewHTTPError(http.StatusInternalServerError, "failed to prepare storage")
}
id := uuid.New().String()
jsonPath := filepath.Join(tmpDir, fmt.Sprintf("inpaint_%s.json", id))
jsonFile := map[string]string{
"image": b64Image,
"mask_image": b64Mask,
}
jf, err := os.CreateTemp(tmpDir, "inpaint_")
if err != nil {
return err
}
// setup cleanup on error; if everything succeeds we set success = true
success := false
var dst string
var origRef string
var maskRef string
defer func() {
if !success {
// Best-effort cleanup; log any failures
if jf != nil {
if cerr := jf.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file in cleanup")
}
if name := jf.Name(); name != "" {
if rerr := os.Remove(name); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove temp json file %s in cleanup", name)
}
}
}
if jsonPath != "" {
if rerr := os.Remove(jsonPath); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove json file %s in cleanup", jsonPath)
}
}
if dst != "" {
if rerr := os.Remove(dst); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove dst file %s in cleanup", dst)
}
}
if origRef != "" {
if rerr := os.Remove(origRef); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove orig ref file %s in cleanup", origRef)
}
}
if maskRef != "" {
if rerr := os.Remove(maskRef); rerr != nil && !os.IsNotExist(rerr) {
log.Warn().Err(rerr).Msgf("Inpainting Endpoint - failed to remove mask ref file %s in cleanup", maskRef)
}
}
}
}()

// write original image and mask to disk as ref images so backends that
// accept reference image files can use them (maintainer request).
origTmp, err := os.CreateTemp(tmpDir, "refimg_")
if err != nil {
return err
}
if _, err := origTmp.Write(imgBytes); err != nil {
_ = origTmp.Close()
_ = os.Remove(origTmp.Name())
return err
}
if cerr := origTmp.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close orig temp file")
}
origRef = origTmp.Name()

maskTmp, err := os.CreateTemp(tmpDir, "refmask_")
if err != nil {
// cleanup origTmp on error
_ = os.Remove(origRef)
return err
}
if _, err := maskTmp.Write(maskBytes); err != nil {
_ = maskTmp.Close()
_ = os.Remove(maskTmp.Name())
_ = os.Remove(origRef)
return err
}
if cerr := maskTmp.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close mask temp file")
}
maskRef = maskTmp.Name()
// write JSON
enc := json.NewEncoder(jf)
if err := enc.Encode(jsonFile); err != nil {
if cerr := jf.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file after encode error")
}
return err
}
if cerr := jf.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close temp json file")
}
// rename to desired name
if err := os.Rename(jf.Name(), jsonPath); err != nil {
return err
}
// prepare dst
outTmp, err := os.CreateTemp(tmpDir, "out_")
if err != nil {
return err
}
if cerr := outTmp.Close(); cerr != nil {
log.Warn().Err(cerr).Msg("Inpainting Endpoint - failed to close out temp file")
}
dst = outTmp.Name() + ".png"
if err := os.Rename(outTmp.Name(), dst); err != nil {
return err
}

// Determine width/height default
width := 512
height := 512

// Call backend image generation via indirection so tests can stub it
// Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON
// Also pass ref images (orig + mask) so backends that support ref images can use them.
refImages := []string{origRef, maskRef}
fn, err := backend.ImageGenerationFunc(height, width, 0, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages)
if err != nil {
return err
}

// Execute generation function (blocking)
if err := fn(); err != nil {
return err
}

// On success, build response URL using BaseURL middleware helper and
// the same `generated-images` prefix used by the server static mount.
baseURL := middleware.BaseURL(c)

// Build response using url.JoinPath for correct URL escaping
imgPath, err := url.JoinPath(baseURL, "generated-images", filepath.Base(dst))
if err != nil {
return err
}

created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Data: []schema.Item{{
URL: imgPath,
}},
}

// mark success so defer cleanup will not remove output files
success = true

return c.JSON(http.StatusOK, resp)
}
}
107 changes: 107 additions & 0 deletions core/http/endpoints/openai/inpainting_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package openai

import (
"bytes"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"

"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/stretchr/testify/require"
)

func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) {
b := &bytes.Buffer{}
w := multipart.NewWriter(b)
for k, v := range fields {
_ = w.WriteField(k, v)
}
for fname, content := range files {
fw, err := w.CreateFormFile(fname, fname+".png")
require.NoError(t, err)
_, err = fw.Write(content)
require.NoError(t, err)
}
require.NoError(t, w.Close())
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b)
req.Header.Set("Content-Type", w.FormDataContentType())
return req, w.FormDataContentType()
}

func TestInpainting_MissingFiles(t *testing.T) {
e := echo.New()
// handler requires cl, ml, appConfig but this test verifies missing files early
h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig())

req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

err := h(c)
require.Error(t, err)
}

func TestInpainting_HappyPath(t *testing.T) {
// Setup temp generated content dir
tmpDir, err := os.MkdirTemp("", "gencontent")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)

appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))

// stub the backend.ImageGenerationFunc
orig := backend.ImageGenerationFunc
backend.ImageGenerationFunc = func(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
fn := func() error {
// write a fake png file to dst
return os.WriteFile(dst, []byte("PNGDATA"), 0644)
}
return fn, nil
}
defer func() { backend.ImageGenerationFunc = orig }()

// prepare multipart request with image and mask
fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"}
files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")}
reqBuf, _ := makeMultipartRequest(t, fields, files)

rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(reqBuf, rec)

// set a minimal model config in context as handler expects
c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"})

h := InpaintingEndpoint(nil, nil, appConf)

// call handler
err = h(c)
require.NoError(t, err)
require.Equal(t, http.StatusOK, rec.Code)

// verify response body contains generated-images path
body := rec.Body.String()
require.Contains(t, body, "generated-images")

// confirm the file was created in tmpDir
// parse out filename from response (naive search)
// find "generated-images/" and extract until closing quote or brace
idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/"))
require.True(t, idx >= 0)
rest := rec.Body.Bytes()[idx:]
end := bytes.IndexAny(rest, "\",}\n")
if end == -1 {
end = len(rest)
}
fname := string(rest[len("generated-images/"):end])
// ensure file exists
_, err = os.Stat(filepath.Join(tmpDir, fname))
require.NoError(t, err)
}
4 changes: 2 additions & 2 deletions core/http/endpoints/openai/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
}),
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", config.Name, t.Name, t.Reasoning, t.Arguments)
return true
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, result: %s, tool arguments: %+v", config.Name, t.Name, t.Result, t.ToolArguments)
}),
)

Expand Down
Loading
Loading