Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@
#
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
# LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m

# allowed access ip config, ie: 192.168.1.0/24,10.0.0.1,127.0.0.1
# export LOCALAI_IP_ALLOWLIST="192.168.1.0/24,10.0.0.1,127.0.0.1"
# LOCALAI_IP_ALLOWLIST=192.168.1.0/24
2 changes: 2 additions & 0 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type RunCMD struct {
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" help:"Default context size for models" group:"performance"`

Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
IpAllowList string `env:"LOCALAI_IP_ALLOWLIST,IP_ALLOWLIST" help:"A list of IP addresses or CIDR ranges to allow access" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
Expand Down Expand Up @@ -192,6 +193,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
os.Setenv("MLX_DISTRIBUTED_HOSTFILE", hostfile)
xlog.Debug("setting MLX_DISTRIBUTED_HOSTFILE", "value", hostfile, "tunnels", tunnels)
}),
config.WithIPAllowList(r.IpAllowList),
}

if r.DisableMetricsEndpoint {
Expand Down
18 changes: 18 additions & 0 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"regexp"
"time"

"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/mudler/xlog"
Expand Down Expand Up @@ -93,6 +94,11 @@ type ApplicationConfig struct {

PathWithoutAuth []string

// ie: 192.168.1.0/24,10.0.0.1,127.0.0.1
IpAllowList string

IPAllowListHelper *utils.IPAllowList

// Agent Pool (LocalAGI integration)
AgentPool AgentPoolConfig
}
Expand Down Expand Up @@ -205,6 +211,18 @@ func WithP2PToken(s string) AppOption {
}
}

func WithIPAllowList(s string) AppOption {
return func(o *ApplicationConfig) {
xlog.Info("Application IpAllowList($LOCALAI_IP_ALLOWLIST)", "value", s)
o.IpAllowList = s
ipAllowListHelper, err := utils.NewIPAllowList(s)
if err != nil {
xlog.Error("Failed to parse IpAllowList", "error", err, "value", s)
}
o.IPAllowListHelper = ipAllowListHelper
}
}

var EnableWatchDog = func(o *ApplicationConfig) {
o.WatchDog = true
}
Expand Down
15 changes: 15 additions & 0 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,21 @@ func API(application *application.Application) (*echo.Echo, error) {
e.Use(middleware.Recover())
}

// IP restriction middleware
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would better be suited as a middleware indeed to avoid spreading the logic in utils packages, instead of being registered here.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if application.ApplicationConfig().IPAllowListHelper != nil {
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
clientIP := c.RealIP()
if !application.ApplicationConfig().IPAllowListHelper.IsAllowed(clientIP) {
return c.JSON(http.StatusForbidden, schema.ErrorResponse{
Error: &schema.APIError{Message: "Forbidden: your IP is not allowed", Code: http.StatusForbidden},
})
}
return next(c)
}
})
}

// Metrics middleware
if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
Expand Down
96 changes: 96 additions & 0 deletions core/http/utils/ipallowlist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package utils

import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
)

type IPAllowList struct {
allowList string
cidrs []*net.IPNet
ips []net.IP
mu sync.RWMutex
enabled bool
}

func NewIPAllowList(allowList string) (*IPAllowList, error) {

w := &IPAllowList{}
err := w.Update(allowList)
return w, err
}

func (w *IPAllowList) GetAllowList() string {
return w.allowList
}

func (w *IPAllowList) Update(allowListStr string) error {
var cidrs []*net.IPNet
var ips []net.IP

allowList := make([]string, 0)
if allowListStr != "" {
allowList = strings.Split(allowListStr, ",")
}

for _, item := range allowList {
_, cidrNet, err := net.ParseCIDR(item)
if err == nil {
cidrs = append(cidrs, cidrNet)
} else {
ip := net.ParseIP(item)
if ip != nil {
ips = append(ips, ip)
} else {
return fmt.Errorf("invalid allowList item: %s", item)
}
}
}

w.mu.Lock()
defer w.mu.Unlock()
w.allowList = allowListStr
w.cidrs = cidrs
w.ips = ips
w.enabled = len(cidrs) > 0 || len(ips) > 0
return nil
}

func (w *IPAllowList) IsAllowed(ip interface{}) bool {
if !w.enabled {
return true
}

var parsedIP net.IP
switch v := ip.(type) {
case string:
parsedIP = net.ParseIP(v)
case net.IP:
parsedIP = v
case netip.Addr:
parsedIP = net.IP(v.AsSlice())
}

if parsedIP == nil {
return false
}

w.mu.RLock()
defer w.mu.RUnlock()

for _, cidr := range w.cidrs {
if cidr.Contains(parsedIP) {
return true
}
}

for _, allowedIP := range w.ips {
if parsedIP.Equal(allowedIP) {
return true
}
}
return false
}
36 changes: 36 additions & 0 deletions core/http/utils/ipallowlist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package utils_test

import (
. "github.com/mudler/LocalAI/core/http/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("IPAllowList", func() {
It("allows all IPs when allowlist is empty", func() {
w, err := NewIPAllowList("")
Expect(err).ToNot(HaveOccurred())
Expect(w.IsAllowed("192.168.1.100")).To(BeTrue())
})

It("respects CIDRs and explicit IPs", func() {
allowList := "192.168.1.0/24,10.0.0.1,127.0.0.1"
w, err := NewIPAllowList(allowList)
Expect(err).ToNot(HaveOccurred())

cases := []struct {
ip string
expected bool
}{
{"192.168.1.100", true},
{"10.0.0.1", true},
{"127.0.0.1", true},
{"10.0.0.2", false},
{"172.16.0.1", false},
}

for _, tc := range cases {
Expect(w.IsAllowed(tc.ip)).To(Equal(tc.expected), "IP: %s", tc.ip)
}
})
})
96 changes: 96 additions & 0 deletions pkg/utils/ip_allowlist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package utils

import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
)

type IPAllowList struct {
allowList string
cidrs []*net.IPNet
ips []net.IP
mu sync.RWMutex
enabled bool
}

func NewIPAllowList(allowList string) (*IPAllowList, error) {

w := &IPAllowList{}
err := w.Update(allowList)
return w, err
}

func (w *IPAllowList) GetAllowList() string {
return w.allowList
}

func (w *IPAllowList) Update(allowListStr string) error {
var cidrs []*net.IPNet
var ips []net.IP

allowList := make([]string, 0)
if allowListStr != "" {
allowList = strings.Split(allowListStr, ",")
}

for _, item := range allowList {
_, cidrNet, err := net.ParseCIDR(item)
if err == nil {
cidrs = append(cidrs, cidrNet)
} else {
ip := net.ParseIP(item)
if ip != nil {
ips = append(ips, ip)
} else {
return fmt.Errorf("invalid allowList item: %s", item)
}
}
}

w.mu.Lock()
defer w.mu.Unlock()
w.allowList = allowListStr
w.cidrs = cidrs
w.ips = ips
w.enabled = len(cidrs) > 0 || len(ips) > 0
return nil
}

func (w *IPAllowList) IsAllowed(ip interface{}) bool {
if !w.enabled {
return true
}

var parsedIP net.IP
switch v := ip.(type) {
case string:
parsedIP = net.ParseIP(v)
case net.IP:
parsedIP = v
case netip.Addr:
parsedIP = net.IP(v.AsSlice())
}

if parsedIP == nil {
return false
}

w.mu.RLock()
defer w.mu.RUnlock()

for _, cidr := range w.cidrs {
if cidr.Contains(parsedIP) {
return true
}
}

for _, allowedIP := range w.ips {
if parsedIP.Equal(allowedIP) {
return true
}
}
return false
}
36 changes: 36 additions & 0 deletions pkg/utils/ip_allowlist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package utils_test

import (
. "github.com/mudler/LocalAI/pkg/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("IPAllowList", func() {
It("allows all IPs when allowlist is empty", func() {
w, err := NewIPAllowList("")
Expect(err).ToNot(HaveOccurred())
Expect(w.IsAllowed("192.168.1.100")).To(BeTrue())
})

It("respects CIDRs and explicit IPs", func() {
allowList := "192.168.1.0/24,10.0.0.1,127.0.0.1"
w, err := NewIPAllowList(allowList)
Expect(err).ToNot(HaveOccurred())

cases := []struct {
ip string
expected bool
}{
{"192.168.1.100", true},
{"10.0.0.1", true},
{"127.0.0.1", true},
{"10.0.0.2", false},
{"172.16.0.1", false},
}

for _, tc := range cases {
Expect(w.IsAllowed(tc.ip)).To(Equal(tc.expected), "IP: %s", tc.ip)
}
})
})
Loading