diff --git a/backend/python/common-env/transformers/transformers-nvidia.yml b/backend/python/common-env/transformers/transformers-nvidia.yml index 553612344d9..e8d8155b508 100644 --- a/backend/python/common-env/transformers/transformers-nvidia.yml +++ b/backend/python/common-env/transformers/transformers-nvidia.yml @@ -89,8 +89,8 @@ dependencies: - six==1.16.0 - sympy==1.12 - tokenizers - - torch==2.2.1 - - torchvision==0.17.1 + - torch==2.1.2 + - torchvision==0.16.2 - torchaudio==2.1.2 - tqdm==4.66.1 - triton==2.1.0 diff --git a/backend/python/common-env/transformers/transformers.yml b/backend/python/common-env/transformers/transformers.yml index 4738bb3880f..3b3b8fe7ed0 100644 --- a/backend/python/common-env/transformers/transformers.yml +++ b/backend/python/common-env/transformers/transformers.yml @@ -81,8 +81,8 @@ dependencies: - six==1.16.0 - sympy==1.12 - tokenizers - - torch==2.2.1 - - torchvision==0.17.1 + - torch==2.1.2 + - torchvision==0.16.2 - torchaudio==2.1.2 - tqdm==4.66.1 - triton==2.1.0 diff --git a/core/config/application_config.go b/core/config/application_config.go index 03242c3c180..c2d4e13a7a8 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -20,6 +20,7 @@ type ApplicationConfig struct { ImageDir string AudioDir string UploadDir string + ConfigsDir string CORS bool PreloadJSONModels string PreloadModelsFromPath string @@ -252,6 +253,12 @@ func WithUploadDir(uploadDir string) AppOption { } } +func WithConfigsDir(configsDir string) AppOption { + return func(o *ApplicationConfig) { + o.ConfigsDir = configsDir + } +} + func WithApiKeys(apiKeys []string) AppOption { return func(o *ApplicationConfig) { o.ApiKeys = apiKeys diff --git a/core/http/api.go b/core/http/api.go index 039e835b7db..de0a4939fbe 100644 --- a/core/http/api.go +++ b/core/http/api.go @@ -3,6 +3,7 @@ package http import ( "encoding/json" "errors" + "github.com/go-skynet/LocalAI/pkg/utils" "os" "strings" @@ -155,8 +156,17 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi }{Version: internal.PrintableVersion()}) }) - // Load upload json - openai.LoadUploadConfig(appConfig.UploadDir) + // Make sure directories exists + os.MkdirAll(appConfig.ImageDir, 0755) + os.MkdirAll(appConfig.AudioDir, 0755) + os.MkdirAll(appConfig.UploadDir, 0755) + os.MkdirAll(appConfig.ConfigsDir, 0755) + os.MkdirAll(appConfig.ModelPath, 0755) + + // Load config jsons + utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) + utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) + utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) @@ -189,6 +199,26 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + // assistant + app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) + app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + // files app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) diff --git a/core/http/endpoints/openai/assistant.go b/core/http/endpoints/openai/assistant.go new file mode 100644 index 00000000000..0e0d8a99dbf --- /dev/null +++ b/core/http/endpoints/openai/assistant.go @@ -0,0 +1,515 @@ +package openai + +import ( + "fmt" + "github.com/go-skynet/LocalAI/core/config" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "net/http" + "sort" + "strconv" + "strings" + "sync/atomic" + "time" +) + +// ToolType defines a type for tool options +type ToolType string + +const ( + CodeInterpreter ToolType = "code_interpreter" + Retrieval ToolType = "retrieval" + Function ToolType = "function" + + MaxCharacterInstructions = 32768 + MaxCharacterDescription = 512 + MaxCharacterName = 256 + MaxToolsSize = 128 + MaxFileIdSize = 20 + MaxCharacterMetadataKey = 64 + MaxCharacterMetadataValue = 512 +) + +type Tool struct { + Type ToolType `json:"type"` +} + +// Assistant represents the structure of an assistant object from the OpenAI API. +type Assistant struct { + ID string `json:"id"` // The unique identifier of the assistant. + Object string `json:"object"` // Object type, which is "assistant". + Created int64 `json:"created"` // The time at which the assistant was created. + Model string `json:"model"` // The model ID used by the assistant. + Name string `json:"name,omitempty"` // The name of the assistant. + Description string `json:"description,omitempty"` // The description of the assistant. + Instructions string `json:"instructions,omitempty"` // The system instructions that the assistant uses. + Tools []Tool `json:"tools,omitempty"` // A list of tools enabled on the assistant. + FileIDs []string `json:"file_ids,omitempty"` // A list of file IDs attached to this assistant. + Metadata map[string]string `json:"metadata,omitempty"` // Set of key-value pairs attached to the assistant. +} + +var ( + Assistants = []Assistant{} // better to return empty array instead of "null" + AssistantsConfigFile = "assistants.json" +) + +type AssistantRequest struct { + Model string `json:"model"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + request := new(AssistantRequest) + if err := c.BodyParser(request); err != nil { + log.Warn().AnErr("Unable to parse AssistantRequest", err) + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) + } + + if !modelExists(ml, request.Model) { + log.Warn().Msgf("Model: %s was not found in list of models.", request.Model) + return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found") + } + + if request.Tools == nil { + request.Tools = []Tool{} + } + + if request.FileIDs == nil { + request.FileIDs = []string{} + } + + if request.Metadata == nil { + request.Metadata = make(map[string]string) + } + + id := "asst_" + strconv.FormatInt(generateRandomID(), 10) + + assistant := Assistant{ + ID: id, + Object: "assistant", + Created: time.Now().Unix(), + Model: request.Model, + Name: request.Name, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + FileIDs: request.FileIDs, + Metadata: request.Metadata, + } + + Assistants = append(Assistants, assistant) + utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants) + return c.Status(fiber.StatusOK).JSON(assistant) + } +} + +var currentId int64 = 0 + +func generateRandomID() int64 { + atomic.AddInt64(¤tId, 1) + return currentId +} + +func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + // Because we're altering the existing assistants list we should just duplicate it for now. + returnAssistants := Assistants + // Parse query parameters + limitQuery := c.Query("limit", "20") + orderQuery := c.Query("order", "desc") + afterQuery := c.Query("after") + beforeQuery := c.Query("before") + + // Convert string limit to integer + limit, err := strconv.Atoi(limitQuery) + if err != nil { + return c.Status(http.StatusBadRequest).SendString(fmt.Sprintf("Invalid limit query value: %s", limitQuery)) + } + + // Sort assistants + sort.SliceStable(returnAssistants, func(i, j int) bool { + if orderQuery == "asc" { + return returnAssistants[i].Created < returnAssistants[j].Created + } + return returnAssistants[i].Created > returnAssistants[j].Created + }) + + // After and before cursors + if afterQuery != "" { + returnAssistants = filterAssistantsAfterID(returnAssistants, afterQuery) + } + if beforeQuery != "" { + returnAssistants = filterAssistantsBeforeID(returnAssistants, beforeQuery) + } + + // Apply limit + if limit < len(returnAssistants) { + returnAssistants = returnAssistants[:limit] + } + + return c.JSON(returnAssistants) + } +} + +// FilterAssistantsBeforeID filters out those assistants whose ID comes before the given ID +// We assume that the assistants are already sorted +func filterAssistantsBeforeID(assistants []Assistant, id string) []Assistant { + idInt, err := strconv.Atoi(id) + if err != nil { + return assistants // Return original slice if invalid id format is provided + } + + var filteredAssistants []Assistant + + for _, assistant := range assistants { + aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_")) + if err != nil { + continue // Skip if invalid id in assistant + } + + if aid < idInt { + filteredAssistants = append(filteredAssistants, assistant) + } + } + + return filteredAssistants +} + +// FilterAssistantsAfterID filters out those assistants whose ID comes after the given ID +// We assume that the assistants are already sorted +func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant { + idInt, err := strconv.Atoi(id) + if err != nil { + return assistants // Return original slice if invalid id format is provided + } + + var filteredAssistants []Assistant + + for _, assistant := range assistants { + aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_")) + if err != nil { + continue // Skip if invalid id in assistant + } + + if aid > idInt { + filteredAssistants = append(filteredAssistants, assistant) + } + } + + return filteredAssistants +} + +func modelExists(ml *model.ModelLoader, modelName string) (found bool) { + found = false + models, err := ml.ListModels() + if err != nil { + return + } + + for _, model := range models { + if model == modelName { + found = true + return + } + } + return +} + +func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + type DeleteAssistantResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + } + + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + for i, assistant := range Assistants { + if assistant.ID == assistantID { + Assistants = append(Assistants[:i], Assistants[i+1:]...) + utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants) + return c.Status(fiber.StatusOK).JSON(DeleteAssistantResponse{ + ID: assistantID, + Object: "assistant.deleted", + Deleted: true, + }) + } + } + + log.Warn().Msgf("Unable to find assistant %s for deletion", assistantID) + return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantResponse{ + ID: assistantID, + Object: "assistant.deleted", + Deleted: false, + }) + } +} + +func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + for _, assistant := range Assistants { + if assistant.ID == assistantID { + return c.Status(fiber.StatusOK).JSON(assistant) + } + } + + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)) + } +} + +type AssistantFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` +} + +var ( + AssistantFiles []AssistantFile + AssistantsFileConfigFile = "assistantsFile.json" +) + +type AssistantFileRequest struct { + FileID string `json:"file_id"` +} + +type DeleteAssistantFileResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + +func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + request := new(AssistantFileRequest) + if err := c.BodyParser(request); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) + } + + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + for _, assistant := range Assistants { + if assistant.ID == assistantID { + if len(assistant.FileIDs) > MaxFileIdSize { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Max files %d for assistant %s reached.", MaxFileIdSize, assistant.Name)) + } + + for _, file := range UploadedFiles { + if file.ID == request.FileID { + assistant.FileIDs = append(assistant.FileIDs, request.FileID) + assistantFile := AssistantFile{ + ID: file.ID, + Object: "assistant.file", + CreatedAt: time.Now().Unix(), + AssistantID: assistant.ID, + } + AssistantFiles = append(AssistantFiles, assistantFile) + utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) + return c.Status(fiber.StatusOK).JSON(assistantFile) + } + } + + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find file_id: %s", request.FileID)) + } + } + + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find ")) + } +} + +func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + type ListAssistantFiles struct { + Data []File + Object string + } + + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + limitQuery := c.Query("limit", "20") + order := c.Query("order", "desc") + limit, err := strconv.Atoi(limitQuery) + if err != nil || limit < 1 || limit > 100 { + limit = 20 // Default to 20 if there's an error or the limit is out of bounds + } + + // Sort files by CreatedAt depending on the order query parameter + if order == "asc" { + sort.Slice(AssistantFiles, func(i, j int) bool { + return AssistantFiles[i].CreatedAt < AssistantFiles[j].CreatedAt + }) + } else { // default to "desc" + sort.Slice(AssistantFiles, func(i, j int) bool { + return AssistantFiles[i].CreatedAt > AssistantFiles[j].CreatedAt + }) + } + + // Limit the number of files returned + var limitedFiles []AssistantFile + hasMore := false + if len(AssistantFiles) > limit { + hasMore = true + limitedFiles = AssistantFiles[:limit] + } else { + limitedFiles = AssistantFiles + } + + response := map[string]interface{}{ + "object": "list", + "data": limitedFiles, + "first_id": func() string { + if len(limitedFiles) > 0 { + return limitedFiles[0].ID + } + return "" + }(), + "last_id": func() string { + if len(limitedFiles) > 0 { + return limitedFiles[len(limitedFiles)-1].ID + } + return "" + }(), + "has_more": hasMore, + } + + return c.Status(fiber.StatusOK).JSON(response) + } +} + +func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + request := new(AssistantRequest) + if err := c.BodyParser(request); err != nil { + log.Warn().AnErr("Unable to parse AssistantRequest", err) + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) + } + + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + for i, assistant := range Assistants { + if assistant.ID == assistantID { + newAssistant := Assistant{ + ID: assistantID, + Object: assistant.Object, + Created: assistant.Created, + Model: request.Model, + Name: request.Name, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + FileIDs: request.FileIDs, // todo: should probably verify fileids exist + Metadata: request.Metadata, + } + + // Remove old one and replace with new one + Assistants = append(Assistants[:i], Assistants[i+1:]...) + Assistants = append(Assistants, newAssistant) + utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants) + return c.Status(fiber.StatusOK).JSON(newAssistant) + } + } + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)) + } +} + +func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + fileId := c.Params("file_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required") + } + // First remove file from assistant + for i, assistant := range Assistants { + if assistant.ID == assistantID { + for j, fileId := range assistant.FileIDs { + if fileId == fileId { + Assistants[i].FileIDs = append(Assistants[i].FileIDs[:j], Assistants[i].FileIDs[j+1:]...) + + // Check if the file exists in the assistantFiles slice + for i, assistantFile := range AssistantFiles { + if assistantFile.ID == fileId { + // Remove the file from the assistantFiles slice + AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...) + utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) + return c.Status(fiber.StatusOK).JSON(DeleteAssistantFileResponse{ + ID: fileId, + Object: "assistant.file.deleted", + Deleted: true, + }) + } + } + } + } + + log.Warn().Msgf("Unable to locate file_id: %s in assistants: %s. Continuing to delete assistant file.", fileId, assistantID) + for i, assistantFile := range AssistantFiles { + if assistantFile.AssistantID == assistantID { + + AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...) + utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) + + return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantFileResponse{ + ID: fileId, + Object: "assistant.file.deleted", + Deleted: true, + }) + } + } + } + } + log.Warn().Msgf("Unable to find assistant: %s", assistantID) + + return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantFileResponse{ + ID: fileId, + Object: "assistant.file.deleted", + Deleted: false, + }) + } +} + +func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + fileId := c.Params("file_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required") + } + + for _, assistantFile := range AssistantFiles { + if assistantFile.AssistantID == assistantID { + if assistantFile.ID == fileId { + return c.Status(fiber.StatusOK).JSON(assistantFile) + } + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId)) + } + } + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID)) + } +} diff --git a/core/http/endpoints/openai/assistant_test.go b/core/http/endpoints/openai/assistant_test.go new file mode 100644 index 00000000000..bdc41ddaf98 --- /dev/null +++ b/core/http/endpoints/openai/assistant_test.go @@ -0,0 +1,456 @@ +package openai + +import ( + "encoding/json" + "fmt" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +var configsDir string = "/tmp/localai/configs" + +type MockLoader struct { + models []string +} + +func tearDown() func() { + return func() { + UploadedFiles = []File{} + Assistants = []Assistant{} + AssistantFiles = []AssistantFile{} + _ = os.Remove(filepath.Join(configsDir, AssistantsConfigFile)) + _ = os.Remove(filepath.Join(configsDir, AssistantsFileConfigFile)) + } +} + +func TestAssistantEndpoints(t *testing.T) { + // Preparing the mocked objects + cl := &config.BackendConfigLoader{} + //configsDir := "/tmp/localai/configs" + modelPath := "/tmp/localai/model" + var ml = model.NewModelLoader(modelPath) + + appConfig := &config.ApplicationConfig{ + ConfigsDir: configsDir, + UploadLimitMB: 10, + UploadDir: "test_dir", + ModelPath: modelPath, + } + + _ = os.RemoveAll(appConfig.ConfigsDir) + _ = os.MkdirAll(appConfig.ConfigsDir, 0755) + _ = os.MkdirAll(modelPath, 0755) + os.Create(filepath.Join(modelPath, "ggml-gpt4all-j")) + + app := fiber.New(fiber.Config{ + BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB. + }) + + // Create a Test Server + app.Get("/assistants", ListAssistantsEndpoint(cl, ml, appConfig)) + app.Post("/assistants", CreateAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id", DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id", GetAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id", ModifyAssistantEndpoint(cl, ml, appConfig)) + + app.Post("/files", UploadFilesEndpoint(cl, appConfig)) + app.Get("/assistants/:assistant_id/files", ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id/files", CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id/files/:file_id", DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files/:file_id", GetAssistantFileEndpoint(cl, ml, appConfig)) + + t.Run("CreateAssistantEndpoint", func(t *testing.T) { + t.Cleanup(tearDown()) + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + resultAssistant, resp, err := createAssistant(app, *ar) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, resp.StatusCode) + + assert.Equal(t, 1, len(Assistants)) + //t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID})) + + assert.Equal(t, ar.Name, resultAssistant.Name) + assert.Equal(t, ar.Model, resultAssistant.Model) + assert.Equal(t, ar.Tools, resultAssistant.Tools) + assert.Equal(t, ar.Description, resultAssistant.Description) + assert.Equal(t, ar.Instructions, resultAssistant.Instructions) + assert.Equal(t, ar.FileIDs, resultAssistant.FileIDs) + assert.Equal(t, ar.Metadata, resultAssistant.Metadata) + }) + + t.Run("ListAssistantsEndpoint", func(t *testing.T) { + var ids []string + var resultAssistant []Assistant + for i := 0; i < 4; i++ { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: fmt.Sprintf("3.5-turbo-%d", i), + Description: fmt.Sprintf("Test Assistant - %d", i), + Instructions: fmt.Sprintf("You are computer science teacher answering student questions - %d", i), + Tools: []Tool{{Type: Function}}, + FileIDs: []string{"fid-1234"}, + Metadata: map[string]string{"meta": "data"}, + } + + //var err error + ra, _, err := createAssistant(app, *ar) + // Because we create the assistants so fast all end up with the same created time. + time.Sleep(time.Second) + resultAssistant = append(resultAssistant, ra) + assert.NoError(t, err) + ids = append(ids, resultAssistant[i].ID) + } + + t.Cleanup(cleanupAllAssistants(t, app, ids)) + + tests := []struct { + name string + reqURL string + expectedStatus int + expectedResult []Assistant + expectedStringResult string + }{ + { + name: "Valid Usage - limit only", + reqURL: "/assistants?limit=2", + expectedStatus: http.StatusOK, + expectedResult: Assistants[:2], // Expecting the first two assistants + }, + { + name: "Valid Usage - order asc", + reqURL: "/assistants?order=asc", + expectedStatus: http.StatusOK, + expectedResult: Assistants, // Expecting all assistants in ascending order + }, + { + name: "Valid Usage - order desc", + reqURL: "/assistants?order=desc", + expectedStatus: http.StatusOK, + expectedResult: []Assistant{Assistants[3], Assistants[2], Assistants[1], Assistants[0]}, // Expecting all assistants in descending order + }, + { + name: "Valid Usage - after specific ID", + reqURL: "/assistants?after=2", + expectedStatus: http.StatusOK, + // Note this is correct because it's put in descending order already + expectedResult: Assistants[:3], // Expecting assistants after (excluding) ID 2 + }, + { + name: "Valid Usage - before specific ID", + reqURL: "/assistants?before=4", + expectedStatus: http.StatusOK, + expectedResult: Assistants[2:], // Expecting assistants before (excluding) ID 3. + }, + { + name: "Invalid Usage - non-integer limit", + reqURL: "/assistants?limit=two", + expectedStatus: http.StatusBadRequest, + expectedStringResult: "Invalid limit query value: two", + }, + { + name: "Invalid Usage - non-existing id in after", + reqURL: "/assistants?after=100", + expectedStatus: http.StatusOK, + expectedResult: []Assistant(nil), // Expecting empty list as there are no IDs above 100 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := httptest.NewRequest(http.MethodGet, tt.reqURL, nil) + response, err := app.Test(request) + assert.NoError(t, err) + assert.Equal(t, tt.expectedStatus, response.StatusCode) + if tt.expectedStatus != fiber.StatusOK { + all, _ := ioutil.ReadAll(response.Body) + assert.Equal(t, tt.expectedStringResult, string(all)) + } else { + var result []Assistant + err = json.NewDecoder(response.Body).Decode(&result) + assert.NoError(t, err) + + assert.Equal(t, tt.expectedResult, result) + } + }) + } + }) + + t.Run("DeleteAssistantEndpoint", func(t *testing.T) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + resultAssistant, _, err := createAssistant(app, *ar) + assert.NoError(t, err) + + target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) + deleteReq := httptest.NewRequest(http.MethodDelete, target, nil) + _, err = app.Test(deleteReq) + assert.NoError(t, err) + assert.Equal(t, 0, len(Assistants)) + }) + + t.Run("GetAssistantEndpoint", func(t *testing.T) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + resultAssistant, _, err := createAssistant(app, *ar) + assert.NoError(t, err) + t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID})) + + target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) + request := httptest.NewRequest(http.MethodGet, target, nil) + response, err := app.Test(request) + assert.NoError(t, err) + + var getAssistant Assistant + err = json.NewDecoder(response.Body).Decode(&getAssistant) + assert.NoError(t, err) + + assert.Equal(t, resultAssistant.ID, getAssistant.ID) + }) + + t.Run("ModifyAssistantEndpoint", func(t *testing.T) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + resultAssistant, _, err := createAssistant(app, *ar) + assert.NoError(t, err) + + modifiedAr := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "4.0-turbo", + Description: "Modified Test Assistant", + Instructions: "You are math teacher answering student questions", + Tools: []Tool{{Type: CodeInterpreter}}, + FileIDs: nil, + Metadata: nil, + } + + modifiedArJson, err := json.Marshal(modifiedAr) + assert.NoError(t, err) + + target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) + request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(modifiedArJson))) + request.Header.Set(fiber.HeaderContentType, "application/json") + + modifyResponse, err := app.Test(request) + assert.NoError(t, err) + var getAssistant Assistant + err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant) + + t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID})) + + assert.Equal(t, resultAssistant.ID, getAssistant.ID) // IDs should match even if contents change + assert.Equal(t, modifiedAr.Tools, getAssistant.Tools) + assert.Equal(t, modifiedAr.Name, getAssistant.Name) + assert.Equal(t, modifiedAr.Instructions, getAssistant.Instructions) + assert.Equal(t, modifiedAr.Description, getAssistant.Description) + }) + + t.Run("CreateAssistantFileEndpoint", func(t *testing.T) { + t.Cleanup(tearDown()) + file, assistant, err := createFileAndAssistant(t, app, appConfig) + assert.NoError(t, err) + + afr := AssistantFileRequest{FileID: file.ID} + af, _, err := createAssistantFile(app, afr, assistant.ID) + + assert.NoError(t, err) + assert.Equal(t, assistant.ID, af.AssistantID) + }) + t.Run("ListAssistantFilesEndpoint", func(t *testing.T) { + t.Cleanup(tearDown()) + file, assistant, err := createFileAndAssistant(t, app, appConfig) + assert.NoError(t, err) + + afr := AssistantFileRequest{FileID: file.ID} + af, _, err := createAssistantFile(app, afr, assistant.ID) + assert.NoError(t, err) + + assert.Equal(t, assistant.ID, af.AssistantID) + }) + t.Run("GetAssistantFileEndpoint", func(t *testing.T) { + t.Cleanup(tearDown()) + file, assistant, err := createFileAndAssistant(t, app, appConfig) + assert.NoError(t, err) + + afr := AssistantFileRequest{FileID: file.ID} + af, _, err := createAssistantFile(app, afr, assistant.ID) + assert.NoError(t, err) + t.Cleanup(cleanupAssistantFile(t, app, af.ID, af.AssistantID)) + + target := fmt.Sprintf("/assistants/%s/files/%s", assistant.ID, file.ID) + request := httptest.NewRequest(http.MethodGet, target, nil) + response, err := app.Test(request) + assert.NoError(t, err) + + var assistantFile AssistantFile + err = json.NewDecoder(response.Body).Decode(&assistantFile) + assert.NoError(t, err) + + assert.Equal(t, af.ID, assistantFile.ID) + assert.Equal(t, af.AssistantID, assistantFile.AssistantID) + }) + t.Run("DeleteAssistantFileEndpoint", func(t *testing.T) { + t.Cleanup(tearDown()) + file, assistant, err := createFileAndAssistant(t, app, appConfig) + assert.NoError(t, err) + + afr := AssistantFileRequest{FileID: file.ID} + af, _, err := createAssistantFile(app, afr, assistant.ID) + assert.NoError(t, err) + + cleanupAssistantFile(t, app, af.ID, af.AssistantID)() + + assert.Empty(t, AssistantFiles) + }) + +} + +func createFileAndAssistant(t *testing.T, app *fiber.App, o *config.ApplicationConfig) (File, Assistant, error) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + assistant, _, err := createAssistant(app, *ar) + if err != nil { + return File{}, Assistant{}, err + } + t.Cleanup(cleanupAllAssistants(t, app, []string{assistant.ID})) + + file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, o) + t.Cleanup(func() { + _, err := CallFilesDeleteEndpoint(t, app, file.ID) + assert.NoError(t, err) + }) + return file, assistant, nil +} + +func createAssistantFile(app *fiber.App, afr AssistantFileRequest, assistantId string) (AssistantFile, *http.Response, error) { + afrJson, err := json.Marshal(afr) + if err != nil { + return AssistantFile{}, nil, err + } + + target := fmt.Sprintf("/assistants/%s/files", assistantId) + request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(afrJson))) + request.Header.Set(fiber.HeaderContentType, "application/json") + request.Header.Set("OpenAi-Beta", "assistants=v1") + + resp, err := app.Test(request) + if err != nil { + return AssistantFile{}, resp, err + } + + var assistantFile AssistantFile + all, err := ioutil.ReadAll(resp.Body) + err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile) + if err != nil { + return AssistantFile{}, resp, err + } + + return assistantFile, resp, nil +} + +func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Response, error) { + assistant, err := json.Marshal(ar) + if err != nil { + return Assistant{}, nil, err + } + + request := httptest.NewRequest(http.MethodPost, "/assistants", strings.NewReader(string(assistant))) + request.Header.Set(fiber.HeaderContentType, "application/json") + request.Header.Set("OpenAi-Beta", "assistants=v1") + + resp, err := app.Test(request) + if err != nil { + return Assistant{}, resp, err + } + + bodyString, err := io.ReadAll(resp.Body) + if err != nil { + return Assistant{}, resp, err + } + + var resultAssistant Assistant + err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant) + + return resultAssistant, resp, nil +} + +func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() { + return func() { + for _, assistant := range ids { + target := fmt.Sprintf("/assistants/%s", assistant) + deleteReq := httptest.NewRequest(http.MethodDelete, target, nil) + _, err := app.Test(deleteReq) + if err != nil { + t.Fatalf("Failed to delete assistant %s: %v", assistant, err) + } + } + } +} + +func cleanupAssistantFile(t *testing.T, app *fiber.App, fileId, assistantId string) func() { + return func() { + target := fmt.Sprintf("/assistants/%s/files/%s", assistantId, fileId) + request := httptest.NewRequest(http.MethodDelete, target, nil) + request.Header.Set(fiber.HeaderContentType, "application/json") + request.Header.Set("OpenAi-Beta", "assistants=v1") + + resp, err := app.Test(request) + assert.NoError(t, err) + + var dafr DeleteAssistantFileResponse + err = json.NewDecoder(resp.Body).Decode(&dafr) + assert.NoError(t, err) + assert.True(t, dafr.Deleted) + } +} diff --git a/core/http/endpoints/openai/files.go b/core/http/endpoints/openai/files.go index 5cb8d7a92ea..add9aaa0101 100644 --- a/core/http/endpoints/openai/files.go +++ b/core/http/endpoints/openai/files.go @@ -1,23 +1,22 @@ package openai import ( - "encoding/json" "errors" "fmt" "os" "path/filepath" + "sync/atomic" "time" "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" ) -var uploadedFiles []File +var UploadedFiles []File -const uploadedFilesFile = "uploadedFiles.json" +const UploadedFilesFile = "uploadedFiles.json" // File represents the structure of a file object from the OpenAI API. type File struct { @@ -29,38 +28,6 @@ type File struct { Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.) } -func saveUploadConfig(uploadDir string) { - file, err := json.MarshalIndent(uploadedFiles, "", " ") - if err != nil { - log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err) - } - - err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644) - if err != nil { - log.Error().Msgf("Failed to save uploadedFiles to file: %s", err) - } -} - -func LoadUploadConfig(uploadPath string) { - uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile) - - _, err := os.Stat(uploadFilePath) - if os.IsNotExist(err) { - log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath) - return - } - - file, err := os.ReadFile(uploadFilePath) - if err != nil { - log.Error().Msgf("Failed to read file: %s", err) - } else { - err = json.Unmarshal(file, &uploadedFiles) - if err != nil { - log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err) - } - } -} - // UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { @@ -95,7 +62,7 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli } f := File{ - ID: fmt.Sprintf("file-%d", time.Now().Unix()), + ID: fmt.Sprintf("file-%d", getNextFileId()), Object: "file", Bytes: int(file.Size), CreatedAt: time.Now(), @@ -103,12 +70,19 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli Purpose: purpose, } - uploadedFiles = append(uploadedFiles, f) - saveUploadConfig(appConfig.UploadDir) + UploadedFiles = append(UploadedFiles, f) + utils.SaveConfig(appConfig.UploadDir, UploadedFilesFile, UploadedFiles) return c.Status(fiber.StatusOK).JSON(f) } } +var currentFileId int64 = 0 + +func getNextFileId() int64 { + atomic.AddInt64(¤tId, 1) + return currentId +} + // ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { type ListFiles struct { @@ -121,9 +95,9 @@ func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Applica purpose := c.Query("purpose") if purpose == "" { - listFiles.Data = uploadedFiles + listFiles.Data = UploadedFiles } else { - for _, f := range uploadedFiles { + for _, f := range UploadedFiles { if purpose == f.Purpose { listFiles.Data = append(listFiles.Data, f) } @@ -140,7 +114,7 @@ func getFileFromRequest(c *fiber.Ctx) (*File, error) { return nil, fmt.Errorf("file_id parameter is required") } - for _, f := range uploadedFiles { + for _, f := range UploadedFiles { if id == f.ID { return &f, nil } @@ -184,14 +158,14 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli } // Remove upload from list - for i, f := range uploadedFiles { + for i, f := range UploadedFiles { if f.ID == file.ID { - uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...) + UploadedFiles = append(UploadedFiles[:i], UploadedFiles[i+1:]...) break } } - saveUploadConfig(appConfig.UploadDir) + utils.SaveConfig(appConfig.UploadDir, UploadedFilesFile, UploadedFiles) return c.JSON(DeleteStatus{ Id: file.ID, Object: "file", diff --git a/core/http/endpoints/openai/files_test.go b/core/http/endpoints/openai/files_test.go index a036bd0dc2a..e1c1011e193 100644 --- a/core/http/endpoints/openai/files_test.go +++ b/core/http/endpoints/openai/files_test.go @@ -3,6 +3,7 @@ package openai import ( "encoding/json" "fmt" + "github.com/rs/zerolog/log" "io" "mime/multipart" "net/http" @@ -73,6 +74,7 @@ func TestUploadFileExceedSizeLimit(t *testing.T) { app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option)) t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) { + t.Cleanup(tearDown()) resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option) assert.NoError(t, err) @@ -80,46 +82,54 @@ func TestUploadFileExceedSizeLimit(t *testing.T) { assert.Contains(t, bodyToString(resp, t), "exceeds upload limit") }) t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) { + t.Cleanup(tearDown()) resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option) assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode) assert.Contains(t, bodyToString(resp, t), "Purpose is not defined") }) t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) { + t.Cleanup(tearDown()) f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option) resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option) fmt.Println(f1) - fmt.Printf("ERror: %v", err) + fmt.Printf("ERror: %v\n", err) + fmt.Printf("resp: %+v\n", resp) assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode) assert.Contains(t, bodyToString(resp, t), "File already exists") }) t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) { + t.Cleanup(tearDown()) file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option) // Check if file exists in the disk - filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt")) + testName := strings.Split(t.Name(), "/")[1] + fileName := testName + "-test.txt" + filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName(fileName)) _, err := os.Stat(filePath) assert.False(t, os.IsNotExist(err)) assert.Equal(t, file.Bytes, 5242880) assert.NotEmpty(t, file.CreatedAt) - assert.Equal(t, file.Filename, "test.txt") + assert.Equal(t, file.Filename, fileName) assert.Equal(t, file.Purpose, "fine-tune") }) t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) { + t.Cleanup(tearDown()) resp, err := CallListFilesEndpoint(t, app, "") assert.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) listFiles := responseToListFile(t, resp) - if len(listFiles.Data) != len(uploadedFiles) { - t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data)) + if len(listFiles.Data) != len(UploadedFiles) { + t.Errorf("Expected %v files, got %v files", len(UploadedFiles), len(listFiles.Data)) } }) t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) { + t.Cleanup(tearDown()) _ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option) resp, err := CallListFilesEndpoint(t, app, "fine-tune") @@ -131,6 +141,7 @@ func TestUploadFileExceedSizeLimit(t *testing.T) { } }) t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) { + t.Cleanup(tearDown()) resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune") assert.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -142,6 +153,7 @@ func TestUploadFileExceedSizeLimit(t *testing.T) { } }) t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) { + t.Cleanup(tearDown()) req := httptest.NewRequest("GET", "/files", nil) resp, _ := app.Test(req) assert.Equal(t, 200, resp.StatusCode) @@ -175,8 +187,10 @@ func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*htt } func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) { + testName := strings.Split(t.Name(), "/")[1] + // Create a file that exceeds the limit - file := createTestFile(t, fileName, fileSize, appConfig) + file := createTestFile(t, testName+"-"+fileName, fileSize, appConfig) // Creating a new HTTP Request body, writer := newMultipartFile(file.Name(), tag, purpose) @@ -188,7 +202,8 @@ func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpos func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File { // Create a file that exceeds the limit - file := createTestFile(t, fileName, fileSize, appConfig) + testName := strings.Split(t.Name(), "/")[1] + file := createTestFile(t, testName+"-"+fileName, fileSize, appConfig) // Creating a new HTTP Request body, writer := newMultipartFile(file.Name(), tag, purpose) @@ -199,11 +214,12 @@ func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, assert.NoError(t, err) f := responseToFile(t, resp) - id := f.ID - t.Cleanup(func() { - _, err := CallFilesDeleteEndpoint(t, app, id) - assert.NoError(t, err) - }) + //id := f.ID + //t.Cleanup(func() { + // _, err := CallFilesDeleteEndpoint(t, app, id) + // assert.NoError(t, err) + // assert.Empty(t, UploadedFiles) + //}) return f @@ -240,7 +256,8 @@ func createTestFile(t *testing.T, name string, sizeMB int, option *config.Applic t.Fatalf("Error MKDIR: %v", err) } - file, _ := os.Create(name) + file, err := os.Create(name) + assert.NoError(t, err) file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File t.Cleanup(func() { @@ -280,7 +297,7 @@ func responseToListFile(t *testing.T, resp *http.Response) ListFiles { err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles) if err != nil { - fmt.Printf("Failed to decode response: %s", err) + log.Error().Msgf("Failed to decode response: %s", err) } return listFiles diff --git a/docs/data/version.json b/docs/data/version.json index 20ca21c5306..b6372479d17 100644 --- a/docs/data/version.json +++ b/docs/data/version.json @@ -1,3 +1,3 @@ { - "version": "v2.10.1" + "version": "v2.11.0" } diff --git a/main.go b/main.go index 400dcb57a80..651dd1c28fe 100644 --- a/main.go +++ b/main.go @@ -149,6 +149,12 @@ func main() { EnvVars: []string{"UPLOAD_PATH"}, Value: "/tmp/localai/upload", }, + &cli.StringFlag{ + Name: "config-path", + Usage: "Path to store uploads from files api", + EnvVars: []string{"CONFIG_PATH"}, + Value: "/tmp/localai/config", + }, &cli.StringFlag{ Name: "backend-assets-path", Usage: "Path used to extract libraries that are required by some of the backends in runtime.", @@ -241,6 +247,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit config.WithImageDir(ctx.String("image-path")), config.WithAudioDir(ctx.String("audio-path")), config.WithUploadDir(ctx.String("upload-path")), + config.WithConfigsDir(ctx.String("config-path")), config.WithF16(ctx.Bool("f16")), config.WithStringGalleries(ctx.String("galleries")), config.WithModelLibraryURL(ctx.String("remote-library")), diff --git a/pkg/utils/config.go b/pkg/utils/config.go new file mode 100644 index 00000000000..a9167ed364d --- /dev/null +++ b/pkg/utils/config.go @@ -0,0 +1,41 @@ +package utils + +import ( + "encoding/json" + "github.com/rs/zerolog/log" + "os" + "path/filepath" +) + +func SaveConfig(filePath, fileName string, obj any) { + file, err := json.MarshalIndent(obj, "", " ") + if err != nil { + log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err) + } + + absolutePath := filepath.Join(filePath, fileName) + err = os.WriteFile(absolutePath, file, 0644) + if err != nil { + log.Error().Msgf("Failed to save configuration file to %s: %s", absolutePath, err) + } +} + +func LoadConfig(filePath, fileName string, obj interface{}) { + uploadFilePath := filepath.Join(filePath, fileName) + + _, err := os.Stat(uploadFilePath) + if os.IsNotExist(err) { + log.Debug().Msgf("No configuration file found at %s", uploadFilePath) + return + } + + file, err := os.ReadFile(uploadFilePath) + if err != nil { + log.Error().Msgf("Failed to read file: %s", err) + } else { + err = json.Unmarshal(file, &obj) + if err != nil { + log.Error().Msgf("Failed to JSON unmarshal the file %s: %v", uploadFilePath, err) + } + } +}