Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Exempting IP Ranges #429

Merged
merged 7 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ package cmd
import (
"errors"
"fmt"
"heckel.io/ntfy/log"
"io/fs"
"math"
"net"
"net/netip"
"os"
"os/signal"
"strings"
"syscall"
"time"

"heckel.io/ntfy/log"

"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/server"
Expand Down Expand Up @@ -208,16 +210,14 @@ func execServe(c *cli.Context) error {
}

// Resolve hosts
visitorRequestLimitExemptIPs := make([]string, 0)
visitorRequestLimitExemptIPs := make([]netip.Prefix, 0)
for _, host := range visitorRequestLimitExemptHosts {
ips, err := net.LookupIP(host)
ips, err := parseIPHostPrefix(host)
if err != nil {
log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
continue
}
for _, ip := range ips {
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String())
}
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
}

// Run server
Expand Down Expand Up @@ -303,6 +303,33 @@ func sigHandlerConfigReload(config string) {
}
}

func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
//try parsing as prefix
prefix, err := netip.ParsePrefix(host)
if err == nil {
prefixes = append(prefixes, prefix.Masked()) // Masked returns the prefix in its canonical form, the same for every ip in the range. This exists for ease of debugging. For example, 10.1.2.3/16 is 10.1.0.0/16.
return prefixes, nil // success
}

// not a prefix, parse as host or IP
// LookupHost passes through an IP as is
ips, err := net.LookupHost(host)
if err != nil {
return nil, err
}
for _, i := range ips {
ip, err := netip.ParseAddr(i)
if err == nil {
prefix, err := ip.Prefix(ip.BitLen())
if err != nil {
return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error())
}
prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip
}
}
return
}

func reloadLogLevel(inputSource altsrc.InputSourceContext) {
newLevelStr, err := inputSource.String("log-level")
if err != nil {
Expand Down
28 changes: 23 additions & 5 deletions cmd/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ package cmd

import (
"fmt"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/client"
"heckel.io/ntfy/test"
"heckel.io/ntfy/util"
"math/rand"
"os"
"os/exec"
"path/filepath"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/client"
"heckel.io/ntfy/test"
"heckel.io/ntfy/util"
)

func init() {
Expand Down Expand Up @@ -70,6 +72,22 @@ func TestCLI_Serve_WebSocket(t *testing.T) {
require.Equal(t, "mytopic", m.Topic)
}

func TestIP_Host_Parsing(t *testing.T) {
cases := map[string]string{
"1.1.1.1": "1.1.1.1/32",
"fd00::1234": "fd00::1234/128",
"192.168.0.3/24": "192.168.0.0/24",
"10.1.2.3/8": "10.0.0.0/8",
"201:be93::4a6/21": "201:b800::/21",
}
for q, expectedAnswer := range cases {
ips, err := parseIPHostPrefix(q)
require.Nil(t, err)
assert.Equal(t, 1, len(ips))
assert.Equal(t, expectedAnswer, ips[0].String())
}
}

Comment on lines +75 to +90
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I just omit testing hostnames (because I'd have to mock the DNS resolver and all that)? The code is not all that complicated.

func newEmptyFile(t *testing.T) string {
filename := filepath.Join(t.TempDir(), "empty")
require.Nil(t, os.WriteFile(filename, []byte{}, 0600))
Expand Down
5 changes: 3 additions & 2 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"io/fs"
"net/netip"
"time"
)

Expand Down Expand Up @@ -92,7 +93,7 @@ type Config struct {
VisitorAttachmentDailyBandwidthLimit int
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorRequestExemptIPAddrs []string
VisitorRequestExemptIPAddrs []netip.Prefix
VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration
BehindProxy bool
Expand Down Expand Up @@ -135,7 +136,7 @@ func NewConfig() *Config {
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorRequestExemptIPAddrs: make([]string, 0),
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
BehindProxy: false,
Expand Down
15 changes: 11 additions & 4 deletions server/message_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"encoding/json"
"errors"
"fmt"
"net/netip"
"strings"
"time"

_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
"strings"
"time"
)

var (
Expand Down Expand Up @@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error {
attachmentSize,
attachmentExpires,
attachmentURL,
m.Sender,
m.Sender.String(),
m.Encoding,
published,
)
Expand Down Expand Up @@ -454,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
return nil, err
}
}
senderIP, err := netip.ParseAddr(sender)
if err != nil {
senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0
}

var att *attachment
if attachmentName != "" && attachmentURL != "" {
att = &attachment{
Expand All @@ -477,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
Icon: icon,
Actions: actions,
Attachment: att,
Sender: sender,
Sender: senderIP, // Must parse assuming database must be correct
Encoding: encoding,
})
}
Expand Down
20 changes: 13 additions & 7 deletions server/message_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ package server
import (
"database/sql"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net/netip"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
exampleIP1234 = netip.MustParseAddr("1.2.3.4")
)

func TestSqliteCache_Messages(t *testing.T) {
Expand Down Expand Up @@ -281,7 +287,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix()
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.Sender = "1.2.3.4"
m.Sender = exampleIP1234
m.Attachment = &attachment{
Name: "flower.jpg",
Type: "image/jpeg",
Expand All @@ -294,7 +300,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.Sender = "1.2.3.4"
m.Sender = exampleIP1234
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expand All @@ -307,7 +313,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.Sender = "1.2.3.4"
m.Sender = exampleIP1234
m.Attachment = &attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
Expand All @@ -327,15 +333,15 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Sender)
require.Equal(t, "1.2.3.4", messages[0].Sender.String())

require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender)
require.Equal(t, "1.2.3.4", messages[1].Sender.String())

size, err := c.AttachmentBytesUsed("1.2.3.4")
require.Nil(t, err)
Expand Down
30 changes: 22 additions & 8 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"path"
Expand Down Expand Up @@ -42,7 +43,7 @@ type Server struct {
smtpServerBackend *smtpBackend
smtpSender mailer
topics map[string]*topic
visitors map[string]*visitor
visitors map[netip.Addr]*visitor
firebaseClient *firebaseClient
messages int64
auth auth.Auther
Expand Down Expand Up @@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) {
smtpSender: mailer,
topics: topics,
auth: auther,
visitors: make(map[string]*visitor),
visitors: make(map[netip.Addr]*visitor),
}, nil
}

Expand Down Expand Up @@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() {
if s.firebaseClient == nil {
return
}
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
for {
select {
case <-time.After(s.config.FirebaseKeepaliveInterval):
Expand Down Expand Up @@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {

func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) {
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)
} else if err := v.RequestAllowed(); err != nil {
return errHTTPTooManyRequestsLimitRequests
Expand Down Expand Up @@ -1436,21 +1437,34 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor {
remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr)
ipport, err := netip.ParseAddrPort(remoteAddr)
ip := ipport.Addr()
if err != nil {
ip = remoteAddr // This should not happen in real life; only in tests.
// This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified
ip, err = netip.ParseAddr(remoteAddr)
if err != nil {
ip = netip.IPv4Unspecified()
log.Error("Unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err)
}
}
if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
// only the right-most address can be trusted (as this is the one added by our proxy server).
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
ip = strings.TrimSpace(util.LastString(ips, remoteAddr))
myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
if err != nil {
log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error())
// fall back to regular remote address if X-Forwarded-For is damaged
} else {
ip = myip
}

}
return s.visitorFromIP(ip)
}

func (s *Server) visitorFromIP(ip string) *visitor {
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[ip]
Expand Down
10 changes: 6 additions & 4 deletions server/server_firebase_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package server
import (
"encoding/json"
"errors"
"firebase.google.com/go/v4/messaging"
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/auth"
"net/netip"
"strings"
"sync"
"testing"

"firebase.google.com/go/v4/messaging"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/auth"
)

type testAuther struct {
Expand Down Expand Up @@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"))

require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.Messages()))
Expand Down
6 changes: 4 additions & 2 deletions server/server_matrix_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package server

import (
"github.com/stretchr/testify/require"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) {
Expand Down Expand Up @@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) {
func TestMatrix_WriteMatrixError(t *testing.T) {
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
v := newVisitor(newTestConfig(t), nil, "1.2.3.4")
v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"))
require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())
Expand Down
Loading