Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/sshido-relay/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.25.0
require (
cloud.google.com/go/firestore v1.22.0
github.com/sideshow/apns2 v0.25.0
golang.org/x/time v0.15.0
google.golang.org/api v0.276.0
google.golang.org/grpc v1.80.0
modernc.org/sqlite v1.34.1
Expand Down Expand Up @@ -42,7 +43,6 @@ require (
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
Expand Down
29 changes: 26 additions & 3 deletions server/sshido-relay/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/sideshow/apns2"
"github.com/sideshow/apns2/payload"
"github.com/sideshow/apns2/token"
"golang.org/x/time/rate"
)

type config struct {
Expand Down Expand Up @@ -84,10 +85,16 @@ func main() {
}
defer s.store.Close()

// In-process per-IP token bucket. 5 rps with burst 10 stops the
// abuse vectors in issue #6 (subscriber-spam, notify-URL spam) and
// leaves real users — who hit /subscribe once per device and /n/<id>
// a few times per agent task — far below the ceiling.
limiter := newIPLimiter(rate.Limit(5), 10)

mux := http.NewServeMux()
mux.HandleFunc("/health", s.health)
mux.HandleFunc("/subscribe", s.subscribe)
mux.HandleFunc("/n/", s.notify)
mux.HandleFunc("/subscribe", limiter.middleware(s.subscribe))
mux.HandleFunc("/n/", limiter.middleware(s.notify))
mux.HandleFunc("/privacy", s.privacy)
mux.HandleFunc("/self-host", s.selfHost)
mux.HandleFunc("/", s.landing)
Expand Down Expand Up @@ -175,8 +182,18 @@ func (s *server) subscribe(w http.ResponseWriter, r *http.Request) {
http.Error(w, "method not allowed", 405)
return
}
r.Body = http.MaxBytesReader(w, r.Body, 1<<10) // 1 KiB
var req subscribeReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.DeviceToken == "" {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
http.Error(w, "request too large", http.StatusRequestEntityTooLarge)
return
}
http.Error(w, "bad body", 400)
return
}
if req.DeviceToken == "" {
http.Error(w, "bad body", 400)
return
}
Expand Down Expand Up @@ -222,8 +239,14 @@ func (s *server) notify(w http.ResponseWriter, r *http.Request) {
http.Error(w, "store error", 500)
return
}
r.Body = http.MaxBytesReader(w, r.Body, 4<<10) // 4 KiB — APNs payload max
var req notifyReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
http.Error(w, "request too large", http.StatusRequestEntityTooLarge)
return
}
http.Error(w, "bad body", 400)
return
}
Expand Down
100 changes: 100 additions & 0 deletions server/sshido-relay/rate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package main

import (
"net/http"
"strings"
"sync"
"time"

"golang.org/x/time/rate"
)

// ipLimiter is a per-source-IP token-bucket rate limiter. It keeps a
// rate.Limiter for each IP it has seen and sweeps entries that haven't
// been touched within ttl.
//
// State is in-process only. Cloud Run runs up to max-instances copies,
// so the actual ceiling is rps × instances. The abuse vectors we care
// about (subscriber-spam, leaked-notify-URL spam) are unaffected by
// that ceiling — even 3× the per-instance limit kills the attack
// without affecting real users.
type ipLimiter struct {
mu sync.Mutex
limits map[string]*ipEntry
rps rate.Limit
burst int
ttl time.Duration
}

type ipEntry struct {
limiter *rate.Limiter
lastSeen time.Time
}

func newIPLimiter(rps rate.Limit, burst int) *ipLimiter {
l := &ipLimiter{
limits: map[string]*ipEntry{},
rps: rps,
burst: burst,
ttl: 10 * time.Minute,
}
go l.sweepLoop()
return l
}

func (l *ipLimiter) allow(ip string) bool {
l.mu.Lock()
defer l.mu.Unlock()
e, ok := l.limits[ip]
now := time.Now()
if !ok {
e = &ipEntry{limiter: rate.NewLimiter(l.rps, l.burst), lastSeen: now}
l.limits[ip] = e
return e.limiter.Allow()
}
e.lastSeen = now
return e.limiter.Allow()
}

func (l *ipLimiter) sweepLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
cutoff := time.Now().Add(-l.ttl)
l.mu.Lock()
for ip, e := range l.limits {
if e.lastSeen.Before(cutoff) {
delete(l.limits, ip)
}
}
l.mu.Unlock()
}
}

func (l *ipLimiter) middleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !l.allow(clientIP(r)) {
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next(w, r)
}
}

// clientIP extracts the originating client's IP from a request. Cloud
// Run sets X-Forwarded-For with the chain `client, lb1, lb2`; the
// leftmost entry is the real client. If the header is unset (local
// dev, direct requests), fall back to r.RemoteAddr stripped of port.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if i := strings.Index(xff, ","); i >= 0 {
return strings.TrimSpace(xff[:i])
}
return strings.TrimSpace(xff)
}
addr := r.RemoteAddr
if i := strings.LastIndex(addr, ":"); i >= 0 {
return addr[:i]
}
return addr
}
157 changes: 157 additions & 0 deletions server/sshido-relay/relay_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package main

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"golang.org/x/time/rate"
)

func newTestServer(t *testing.T) *server {
t.Helper()
store, err := newSQLiteStore(":memory:")
if err != nil {
t.Fatalf("newSQLiteStore: %v", err)
}
t.Cleanup(func() { _ = store.Close() })
return &server{
cfg: config{publicURL: "http://test.example", privacyContact: "privacy@sshido.com"},
store: store,
bundle: "com.sshido.app",
// apns nil → notify handler returns 202 "queued (no APNs configured)"
}
}

// MaxBytesReader needs an http.ResponseWriter; httptest.NewRecorder
// satisfies it. We test the handlers directly so we don't pull in
// the rate-limit middleware (covered separately below).

func TestSubscribeBodyCap(t *testing.T) {
s := newTestServer(t)
// Build valid-looking JSON whose string value blows past the 1 KiB
// cap. Junk-byte garbage (2 MiB of 'a') would trip the JSON syntax
// error before MaxBytesReader gets a chance, so use a real field.
huge := strings.Repeat("a", 2<<20) // 2 MiB
body := `{"deviceToken":"` + huge + `"}`
req := httptest.NewRequest(http.MethodPost, "/subscribe", strings.NewReader(body))
w := httptest.NewRecorder()
s.subscribe(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("expected 413 for oversized /subscribe; got %d body=%q", w.Code, w.Body.String())
}
}

func TestNotifyBodyCap(t *testing.T) {
s := newTestServer(t)
// Seed a subscriber so LookupByID succeeds before we hit the body cap.
sub, err := s.store.UpsertByDeviceToken(t.Context(), "devtok123", func() string { return "fixedID" }, time.Now().Unix())
if err != nil {
t.Fatalf("seed subscriber: %v", err)
}

huge := strings.Repeat("a", 2<<20) // 2 MiB
body := `{"title":"` + huge + `"}`
req := httptest.NewRequest(http.MethodPost, "/n/"+sub.ID, strings.NewReader(body))
w := httptest.NewRecorder()
s.notify(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("expected 413 for oversized /n/<id>; got %d body=%q", w.Code, w.Body.String())
}
}

func TestSubscribeHappyPath(t *testing.T) {
s := newTestServer(t)
req := httptest.NewRequest(http.MethodPost, "/subscribe", strings.NewReader(`{"deviceToken":"abc"}`))
w := httptest.NewRecorder()
s.subscribe(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200; got %d body=%q", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), `"id":`) || !strings.Contains(w.Body.String(), `"notifyURL":`) {
t.Fatalf("unexpected response body: %q", w.Body.String())
}
}

func TestRateLimitReturns429(t *testing.T) {
// Tight bucket so the test runs fast: 1 rps with burst 3.
// 10 immediate requests from the same IP should yield ~3 passes
// and ~7 throttled (the bucket refills at 1 rps).
lim := newIPLimiter(rate.Limit(1), 3)
s := newTestServer(t)
h := lim.middleware(s.subscribe)

var throttled, accepted int
for i := 0; i < 10; i++ {
req := httptest.NewRequest(http.MethodPost, "/subscribe", strings.NewReader(`{"deviceToken":"abc"}`))
req.RemoteAddr = "203.0.113.7:1234"
w := httptest.NewRecorder()
h(w, req)
switch w.Code {
case http.StatusTooManyRequests:
throttled++
case http.StatusOK:
accepted++
}
}
if throttled == 0 {
t.Fatalf("expected at least one 429 after 10 rapid requests; got accepted=%d throttled=%d", accepted, throttled)
}
if accepted == 0 {
t.Fatalf("expected at least one 200 (burst capacity); got accepted=%d throttled=%d", accepted, throttled)
}
}

func TestRateLimitIsPerIP(t *testing.T) {
// Different source IPs share neither the bucket nor each other's
// throttling — each gets its own limiter.
lim := newIPLimiter(rate.Limit(1), 2)
s := newTestServer(t)
h := lim.middleware(s.subscribe)

send := func(ip string) int {
req := httptest.NewRequest(http.MethodPost, "/subscribe", strings.NewReader(`{"deviceToken":"abc"}`))
req.RemoteAddr = ip + ":4444"
w := httptest.NewRecorder()
h(w, req)
return w.Code
}

// Exhaust IP A's burst.
_ = send("203.0.113.1")
_ = send("203.0.113.1")
if code := send("203.0.113.1"); code != http.StatusTooManyRequests {
t.Fatalf("expected IP A to be throttled by 3rd request; got %d", code)
}
// IP B should still get a fresh bucket.
if code := send("203.0.113.2"); code != http.StatusOK {
t.Fatalf("expected IP B to be accepted; got %d", code)
}
}

func TestClientIPParsesXForwardedFor(t *testing.T) {
cases := map[string]string{
"203.0.113.5": "203.0.113.5",
"203.0.113.5, 10.0.0.1": "203.0.113.5",
" 203.0.113.5 , 10.0.0.1 ": "203.0.113.5",
"2001:db8::1, 192.168.1.1": "2001:db8::1",
}
for header, want := range cases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", header)
got := clientIP(req)
if got != want {
t.Fatalf("clientIP(%q) = %q; want %q", header, got, want)
}
}
}

func TestClientIPFallsBackToRemoteAddr(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "198.51.100.7:55555"
if got := clientIP(req); got != "198.51.100.7" {
t.Fatalf("clientIP fallback = %q; want 198.51.100.7", got)
}
}