Skip to content

Commit

Permalink
Default Keep Alive environment variable (ollama#3094)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Chris-AS1 <8493773+Chris-AS1@users.noreply.github.com>
  • Loading branch information
2 people authored and byebyebruce committed Mar 26, 2024
1 parent 655a1aa commit 8179c23
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 3 deletions.
50 changes: 50 additions & 0 deletions api/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package api

import (
"encoding/json"
"math"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestKeepAliveParsingFromJSON(t *testing.T) {
tests := []struct {
name string
req string
exp *Duration
}{
{
name: "Positive Integer",
req: `{ "keep_alive": 42 }`,
exp: &Duration{42 * time.Second},
},
{
name: "Positive Integer String",
req: `{ "keep_alive": "42m" }`,
exp: &Duration{42 * time.Minute},
},
{
name: "Negative Integer",
req: `{ "keep_alive": -1 }`,
exp: &Duration{math.MaxInt64},
},
{
name: "Negative Integer String",
req: `{ "keep_alive": "-1m" }`,
exp: &Duration{math.MaxInt64},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var dec ChatRequest
err := json.Unmarshal([]byte(test.req), &dec)
require.NoError(t, err)

assert.Equal(t, test.exp, dec.KeepAlive)
})
}
}
34 changes: 31 additions & 3 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"net"
"net/http"
"net/netip"
Expand All @@ -16,6 +17,7 @@ import (
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
Expand Down Expand Up @@ -207,7 +209,7 @@ func GenerateHandler(c *gin.Context) {

var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
Expand Down Expand Up @@ -384,6 +386,32 @@ func GenerateHandler(c *gin.Context) {
streamResponse(c, ch)
}

func getDefaultSessionDuration() time.Duration {
if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists {
v, err := strconv.Atoi(t)
if err != nil {
d, err := time.ParseDuration(t)
if err != nil {
return defaultSessionDuration
}

if d < 0 {
return time.Duration(math.MaxInt64)
}

return d
}

d := time.Duration(v) * time.Second
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}

return defaultSessionDuration
}

func EmbeddingsHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
Expand Down Expand Up @@ -427,7 +455,7 @@ func EmbeddingsHandler(c *gin.Context) {

var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
Expand Down Expand Up @@ -1228,7 +1256,7 @@ func ChatHandler(c *gin.Context) {

var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
Expand Down

0 comments on commit 8179c23

Please sign in to comment.