diff --git a/api/types_test.go b/api/types_test.go new file mode 100644 index 0000000000..5a093be259 --- /dev/null +++ b/api/types_test.go @@ -0,0 +1,50 @@ +package api + +import ( + "encoding/json" + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestKeepAliveParsingFromJSON(t *testing.T) { + tests := []struct { + name string + req string + exp *Duration + }{ + { + name: "Positive Integer", + req: `{ "keep_alive": 42 }`, + exp: &Duration{42 * time.Second}, + }, + { + name: "Positive Integer String", + req: `{ "keep_alive": "42m" }`, + exp: &Duration{42 * time.Minute}, + }, + { + name: "Negative Integer", + req: `{ "keep_alive": -1 }`, + exp: &Duration{math.MaxInt64}, + }, + { + name: "Negative Integer String", + req: `{ "keep_alive": "-1m" }`, + exp: &Duration{math.MaxInt64}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var dec ChatRequest + err := json.Unmarshal([]byte(test.req), &dec) + require.NoError(t, err) + + assert.Equal(t, test.exp, dec.KeepAlive) + }) + } +} diff --git a/server/routes.go b/server/routes.go index d99c858ca7..f71bbd91b5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -8,6 +8,7 @@ import ( "io" "io/fs" "log/slog" + "math" "net" "net/http" "net/netip" @@ -16,6 +17,7 @@ import ( "path/filepath" "reflect" "runtime" + "strconv" "strings" "sync" "syscall" @@ -207,7 +209,7 @@ func GenerateHandler(c *gin.Context) { var sessionDuration time.Duration if req.KeepAlive == nil { - sessionDuration = defaultSessionDuration + sessionDuration = getDefaultSessionDuration() } else { sessionDuration = req.KeepAlive.Duration } @@ -384,6 +386,32 @@ func GenerateHandler(c *gin.Context) { streamResponse(c, ch) } +func getDefaultSessionDuration() time.Duration { + if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists { + v, err := strconv.Atoi(t) + if err != nil { + d, err := time.ParseDuration(t) + if err != nil { + return defaultSessionDuration + } + + if d < 0 { + return time.Duration(math.MaxInt64) + } + + return d + } + + d := time.Duration(v) * time.Second + if d < 0 { + return time.Duration(math.MaxInt64) + } + return d + } + + return defaultSessionDuration +} + func EmbeddingsHandler(c *gin.Context) { loaded.mu.Lock() defer loaded.mu.Unlock() @@ -427,7 +455,7 @@ func EmbeddingsHandler(c *gin.Context) { var sessionDuration time.Duration if req.KeepAlive == nil { - sessionDuration = defaultSessionDuration + sessionDuration = getDefaultSessionDuration() } else { sessionDuration = req.KeepAlive.Duration } @@ -1228,7 +1256,7 @@ func ChatHandler(c *gin.Context) { var sessionDuration time.Duration if req.KeepAlive == nil { - sessionDuration = defaultSessionDuration + sessionDuration = getDefaultSessionDuration() } else { sessionDuration = req.KeepAlive.Duration }