From 94835929ed7c1191cabed412b0e9da60575573c6 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Mon, 27 Apr 2026 21:13:36 +0800 Subject: [PATCH 1/6] added short text field --- README.md | 7 ++ src/handlers/messages.go | 124 ++++++++++++++++++++++++++---- src/handlers/messages_test.go | 140 ++++++++++++++++++++++++++++++++++ src/main.go | 3 +- src/models/models.go | 1 + 5 files changed, 259 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index a3c83a3..7cfa811 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ HTTP API providing user/client message handling for an fmsg host. Exposes CRUD o | `FMSG_API_MAX_DATA_SIZE`| `10` | Maximum message data size in megabytes | | `FMSG_API_MAX_ATTACH_SIZE`| `10` | Maximum attachment file size in megabytes | | `FMSG_API_MAX_MSG_SIZE`| `20` | Maximum total message size (data + attachments) in megabytes | +| `FMSG_API_SHORT_TEXT_SIZE`| `768` | Maximum bytes of message body returned inline as `short_text` for `text/*` UTF-8 messages | | `FMSG_CORS_ORIGINS` | *(optional)* | Comma-separated list of browser origins allowed via CORS, e.g. `https://example.com,https://www.example.com`. Use `*` to allow any origin. When unset, no CORS headers are emitted (server-to-server callers are unaffected). | Standard PostgreSQL environment variables (`PGHOST`, `PGPORT`, `PGUSER`, @@ -279,10 +280,16 @@ Retrieves a single message by ID. The authenticated user must be a participant "topic": "Hello", "type": "text/plain", "size": 12, + "short_text": "hello world", "attachments": [] } ``` +The `short_text` field is included only when the message `type` is `text/*` and +the stored body is valid UTF-8. It contains up to `FMSG_API_SHORT_TEXT_SIZE` +bytes (default 768) of the body, truncated on a UTF-8 rune boundary, so UIs +can render a preview without a separate `GET /fmsg/:id/data` round-trip. + **Errors:** | Status | Condition | diff --git a/src/handlers/messages.go b/src/handlers/messages.go index f55f4ed..ae0f050 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "log" "mime" "net/http" @@ -13,6 +14,7 @@ import ( "strconv" "strings" "time" + "unicode/utf8" "github.com/gin-gonic/gin" "github.com/jackc/pgx/v5" @@ -24,15 +26,16 @@ import ( // MessageHandler holds dependencies for message routes. type MessageHandler struct { - DB *db.DB - DataDir string - MaxDataSize int64 - MaxMsgSize int64 + DB *db.DB + DataDir string + MaxDataSize int64 + MaxMsgSize int64 + ShortTextSize int } // NewMessageHandler creates a MessageHandler. -func NewMessageHandler(database *db.DB, dataDir string, maxDataSize, maxMsgSize int64) *MessageHandler { - return &MessageHandler{DB: database, DataDir: dataDir, MaxDataSize: maxDataSize, MaxMsgSize: maxMsgSize} +func NewMessageHandler(database *db.DB, dataDir string, maxDataSize, maxMsgSize int64, shortTextSize int) *MessageHandler { + return &MessageHandler{DB: database, DataDir: dataDir, MaxDataSize: maxDataSize, MaxMsgSize: maxMsgSize, ShortTextSize: shortTextSize} } // messageListItem is the JSON shape for each message in the list response. @@ -53,6 +56,7 @@ type messageListItem struct { Topic string `json:"topic"` Type string `json:"type"` Size int `json:"size"` + ShortText string `json:"short_text,omitempty"` Attachments []models.Attachment `json:"attachments"` } @@ -185,7 +189,7 @@ func (h *MessageHandler) List(c *gin.Context) { ctx := c.Request.Context() rows, err := h.DB.Pool.Query(ctx, - `SELECT m.id, m.version, m.pid, m.no_reply, m.is_important, m.is_deflate, m.time_sent, m.from_addr, m.topic, m.type, m.size + `SELECT m.id, m.version, m.pid, m.no_reply, m.is_important, m.is_deflate, m.time_sent, m.from_addr, m.topic, m.type, m.size, m.filepath FROM msg m WHERE EXISTS (SELECT 1 FROM msg_to mt WHERE mt.msg_id = m.id AND mt.addr = $1) OR EXISTS (SELECT 1 FROM msg_add_to mat WHERE mat.msg_id = m.id AND mat.addr = $1) @@ -204,12 +208,14 @@ func (h *MessageHandler) List(c *gin.Context) { var msgIDs []int64 for rows.Next() { var m messageListItem - if err := rows.Scan(&m.ID, &m.Version, &m.PID, &m.NoReply, &m.Important, &m.Deflate, &m.Time, &m.From, &m.Topic, &m.Type, &m.Size); err != nil { + var dataPath string + if err := rows.Scan(&m.ID, &m.Version, &m.PID, &m.NoReply, &m.Important, &m.Deflate, &m.Time, &m.From, &m.Topic, &m.Type, &m.Size, &dataPath); err != nil { log.Printf("list messages scan: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list messages"}) return } m.HasPid = m.PID != nil + m.ShortText = h.extractShortText(dataPath, m.Type) messages = append(messages, m) msgIDs = append(msgIDs, m.ID) } @@ -300,7 +306,7 @@ func (h *MessageHandler) Sent(c *gin.Context) { ctx := c.Request.Context() rows, err := h.DB.Pool.Query(ctx, - `SELECT m.id, m.version, m.pid, m.no_reply, m.is_important, m.is_deflate, m.time_sent, m.from_addr, m.topic, m.type, m.size + `SELECT m.id, m.version, m.pid, m.no_reply, m.is_important, m.is_deflate, m.time_sent, m.from_addr, m.topic, m.type, m.size, m.filepath FROM msg m WHERE m.from_addr = $1 ORDER BY m.id DESC @@ -318,12 +324,14 @@ func (h *MessageHandler) Sent(c *gin.Context) { var msgIDs []int64 for rows.Next() { var m messageListItem - if err := rows.Scan(&m.ID, &m.Version, &m.PID, &m.NoReply, &m.Important, &m.Deflate, &m.Time, &m.From, &m.Topic, &m.Type, &m.Size); err != nil { + var dataPath string + if err := rows.Scan(&m.ID, &m.Version, &m.PID, &m.NoReply, &m.Important, &m.Deflate, &m.Time, &m.From, &m.Topic, &m.Type, &m.Size, &dataPath); err != nil { log.Printf("list sent messages scan: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list sent messages"}) return } m.HasPid = m.PID != nil + m.ShortText = h.extractShortText(dataPath, m.Type) messages = append(messages, m) msgIDs = append(msgIDs, m.ID) } @@ -565,9 +573,8 @@ func (h *MessageHandler) DownloadData(c *gin.Context) { } // Path traversal protection: ensure the path is within DataDir. - cleanPath := filepath.Clean(dataPath) - cleanDataDir := filepath.Clean(h.DataDir) - if !strings.HasPrefix(cleanPath, cleanDataDir+string(filepath.Separator)) { + cleanPath, ok := safeDataPath(dataPath, h.DataDir) + if !ok { log.Printf("download data: path traversal attempt: %s", dataPath) c.JSON(http.StatusForbidden, gin.H{"error": "access denied"}) return @@ -856,19 +863,21 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { // fetchMessage loads a message with its recipients and attachments from the DB. func (h *MessageHandler) fetchMessage(ctx context.Context, msgID int64) (*models.Message, error) { row := h.DB.Pool.QueryRow(ctx, - `SELECT version, pid, no_reply, is_important, is_deflate, time_sent, from_addr, topic, type, size FROM msg WHERE id = $1`, + `SELECT version, pid, no_reply, is_important, is_deflate, time_sent, from_addr, topic, type, size, filepath FROM msg WHERE id = $1`, msgID, ) msg := &models.Message{} var pid *int64 var timeSent *float64 - if err := row.Scan(&msg.Version, &pid, &msg.NoReply, &msg.Important, &msg.Deflate, &timeSent, &msg.From, &msg.Topic, &msg.Type, &msg.Size); err != nil { + var dataPath string + if err := row.Scan(&msg.Version, &pid, &msg.NoReply, &msg.Important, &msg.Deflate, &timeSent, &msg.From, &msg.Topic, &msg.Type, &msg.Size, &dataPath); err != nil { return nil, err } msg.PID = pid msg.Time = timeSent msg.HasPid = pid != nil + msg.ShortText = h.extractShortText(dataPath, msg.Type) // Load recipients. rows, err := h.DB.Pool.Query(ctx, "SELECT addr FROM msg_to WHERE msg_id = $1", msgID) @@ -1021,6 +1030,91 @@ func isZip(data []byte) bool { return len(data) >= 4 && data[0] == 0x50 && data[1] == 0x4b && data[2] == 0x03 && data[3] == 0x04 } +// safeDataPath cleans dataPath and verifies it lies inside dataDir. Returns +// the cleaned absolute path and true on success, or "" and false otherwise. +func safeDataPath(dataPath, dataDir string) (string, bool) { + if dataPath == "" || dataDir == "" { + return "", false + } + cleanPath := filepath.Clean(dataPath) + cleanDataDir := filepath.Clean(dataDir) + if !strings.HasPrefix(cleanPath, cleanDataDir+string(filepath.Separator)) { + return "", false + } + return cleanPath, true +} + +// isTextMIME reports whether the given Content-Type's media type begins with +// "text/". Charset and other parameters are ignored. +func isTextMIME(mimeType string) bool { + if mimeType == "" { + return false + } + mediaType, _, err := mime.ParseMediaType(mimeType) + if err != nil { + return false + } + return strings.HasPrefix(mediaType, "text/") +} + +// extractShortText reads up to ShortTextSize bytes from the message body +// referenced by dataPath and returns it as a string when the message type +// is text/* and the bytes form valid UTF-8. Truncation is rounded down to +// the last complete UTF-8 rune so the result is always valid UTF-8. +// Returns "" on any failure (non-text type, invalid UTF-8, missing/unsafe +// path, read error). Errors are logged but not propagated. +func (h *MessageHandler) extractShortText(dataPath, mimeType string) string { + if h.ShortTextSize <= 0 { + return "" + } + if !isTextMIME(mimeType) { + return "" + } + cleanPath, ok := safeDataPath(dataPath, h.DataDir) + if !ok { + return "" + } + f, err := os.Open(cleanPath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + log.Printf("short text: open %s: %v", cleanPath, err) + } + return "" + } + defer f.Close() + + buf := make([]byte, h.ShortTextSize) + n, err := io.ReadFull(f, buf) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + log.Printf("short text: read %s: %v", cleanPath, err) + return "" + } + buf = buf[:n] + + // If the buffer ends in the middle of a multi-byte UTF-8 sequence, drop + // the trailing partial rune so the result is valid UTF-8. + for len(buf) > 0 && !utf8.RuneStart(buf[len(buf)-1]) { + buf = buf[:len(buf)-1] + } + // Drop the last rune if it is itself the start of an incomplete sequence + // (e.g. a 0xC2 lead byte with no continuation). + if len(buf) > 0 { + last := buf[len(buf)-1] + switch { + case last < 0x80: + // single-byte rune, complete + case last >= 0xC0: + // lead byte with no continuation bytes following + buf = buf[:len(buf)-1] + } + } + + if !utf8.Valid(buf) { + return "" + } + return string(buf) +} + // checkDistinctRecipients returns an error if any address in to or addTo // appears more than once (case-insensitive). func checkDistinctRecipients(to, addTo []string) error { diff --git a/src/handlers/messages_test.go b/src/handlers/messages_test.go index e058ede..27713ac 100644 --- a/src/handlers/messages_test.go +++ b/src/handlers/messages_test.go @@ -1,7 +1,11 @@ package handlers import ( + "os" + "path/filepath" + "strings" "testing" + "unicode/utf8" ) func TestParseAddr(t *testing.T) { @@ -142,3 +146,139 @@ func searchSubstring(s, sub string) bool { } return false } + +func TestIsTextMIME(t *testing.T) { + tests := []struct { + mime string + want bool + }{ + {"text/plain", true}, + {"text/html", true}, + {"text/plain; charset=utf-8", true}, + {"TEXT/PLAIN", true}, + {"application/json", false}, + {"application/octet-stream", false}, + {"image/png", false}, + {"application/pdf", false}, + {"", false}, + {"totally invalid;;;", false}, + } + for _, tt := range tests { + t.Run(tt.mime, func(t *testing.T) { + if got := isTextMIME(tt.mime); got != tt.want { + t.Errorf("isTextMIME(%q) = %v, want %v", tt.mime, got, tt.want) + } + }) + } +} + +func TestSafeDataPath(t *testing.T) { + dir := t.TempDir() + inside := filepath.Join(dir, "sub", "file.txt") + if _, ok := safeDataPath(inside, dir); !ok { + t.Errorf("expected inside path to be allowed: %s", inside) + } + outside := filepath.Join(filepath.Dir(dir), "evil.txt") + if _, ok := safeDataPath(outside, dir); ok { + t.Errorf("expected outside path to be rejected: %s", outside) + } + if _, ok := safeDataPath("", dir); ok { + t.Error("expected empty path to be rejected") + } + if _, ok := safeDataPath(inside, ""); ok { + t.Error("expected empty data dir to be rejected") + } +} + +// writeTempFile writes contents to a fresh file inside dir and returns its path. +func writeTempFile(t *testing.T, dir, name string, contents []byte) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, contents, 0o600); err != nil { + t.Fatalf("write %s: %v", path, err) + } + return path +} + +func TestExtractShortText(t *testing.T) { + dir := t.TempDir() + h := &MessageHandler{DataDir: dir, ShortTextSize: 768} + + t.Run("short text/plain returns full content", func(t *testing.T) { + path := writeTempFile(t, dir, "short.txt", []byte("hello world")) + got := h.extractShortText(path, "text/plain") + if got != "hello world" { + t.Errorf("got %q, want %q", got, "hello world") + } + }) + + t.Run("ascii longer than max truncates exactly", func(t *testing.T) { + body := strings.Repeat("a", 2000) + path := writeTempFile(t, dir, "long.txt", []byte(body)) + got := h.extractShortText(path, "text/plain") + if len(got) != 768 { + t.Errorf("got len %d, want 768", len(got)) + } + if got != strings.Repeat("a", 768) { + t.Errorf("unexpected content") + } + }) + + t.Run("utf8 multibyte truncation respects rune boundaries", func(t *testing.T) { + // "€" is 3 bytes in UTF-8 (0xE2 0x82 0xAC). 768 / 3 = 256 runes + // (= 768 bytes) — exact boundary. Use 257 runes (771 bytes) so the + // truncation must drop a partial rune. + body := strings.Repeat("€", 257) + path := writeTempFile(t, dir, "utf8.txt", []byte(body)) + small := &MessageHandler{DataDir: dir, ShortTextSize: 770} + got := small.extractShortText(path, "text/plain; charset=utf-8") + if !utf8.ValidString(got) { + t.Errorf("result is not valid UTF-8: % x", got) + } + // 770 bytes / 3 bytes per rune => 256 complete runes (768 bytes); + // trailing 2 bytes of the 257th rune must be dropped. + if got != strings.Repeat("€", 256) { + t.Errorf("unexpected truncation: len=%d runes=%d", len(got), utf8.RuneCountInString(got)) + } + }) + + t.Run("non-text mime returns empty", func(t *testing.T) { + path := writeTempFile(t, dir, "img.bin", []byte("hello")) + for _, mt := range []string{"application/octet-stream", "image/png", "application/pdf", "application/json"} { + if got := h.extractShortText(path, mt); got != "" { + t.Errorf("mime %q: got %q, want empty", mt, got) + } + } + }) + + t.Run("invalid utf8 returns empty", func(t *testing.T) { + path := writeTempFile(t, dir, "bad.txt", []byte{0xff, 0xfe, 0xfd, 0xfc}) + if got := h.extractShortText(path, "text/plain"); got != "" { + t.Errorf("got %q, want empty", got) + } + }) + + t.Run("missing file returns empty", func(t *testing.T) { + got := h.extractShortText(filepath.Join(dir, "does-not-exist.txt"), "text/plain") + if got != "" { + t.Errorf("got %q, want empty", got) + } + }) + + t.Run("path traversal returns empty", func(t *testing.T) { + outside := filepath.Join(filepath.Dir(dir), "evil.txt") + _ = os.WriteFile(outside, []byte("evil"), 0o600) + defer os.Remove(outside) + if got := h.extractShortText(outside, "text/plain"); got != "" { + t.Errorf("got %q, want empty", got) + } + }) + + t.Run("zero ShortTextSize disables", func(t *testing.T) { + path := writeTempFile(t, dir, "z.txt", []byte("hello")) + zero := &MessageHandler{DataDir: dir, ShortTextSize: 0} + if got := zero.extractShortText(path, "text/plain"); got != "" { + t.Errorf("got %q, want empty", got) + } + }) +} diff --git a/src/main.go b/src/main.go index 5966ea1..2b88d9c 100644 --- a/src/main.go +++ b/src/main.go @@ -50,6 +50,7 @@ func main() { maxDataSize := int64(envOrDefaultInt("FMSG_API_MAX_DATA_SIZE", 10)) * 1024 * 1024 maxAttachSize := int64(envOrDefaultInt("FMSG_API_MAX_ATTACH_SIZE", 10)) * 1024 * 1024 maxMsgSize := int64(envOrDefaultInt("FMSG_API_MAX_MSG_SIZE", 20)) * 1024 * 1024 + shortTextSize := envOrDefaultInt("FMSG_API_SHORT_TEXT_SIZE", 768) // CORS: comma-separated list of allowed browser origins, e.g. // "https://fmsg.io,https://www.fmsg.io". Empty disables CORS. @@ -91,7 +92,7 @@ func main() { router.Use(middleware.NewRateLimiter(ctx, float64(rateLimit), rateBurst)) // Instantiate handlers. - msgHandler := handlers.NewMessageHandler(database, dataDir, maxDataSize, maxMsgSize) + msgHandler := handlers.NewMessageHandler(database, dataDir, maxDataSize, maxMsgSize, shortTextSize) attHandler := handlers.NewAttachmentHandler(database, dataDir, maxAttachSize, maxMsgSize) // Register routes under /fmsg, all protected by JWT. diff --git a/src/models/models.go b/src/models/models.go index c19324e..cb8d8d5 100644 --- a/src/models/models.go +++ b/src/models/models.go @@ -23,5 +23,6 @@ type Message struct { Topic string `json:"topic"` Type string `json:"type"` Size int `json:"size"` + ShortText string `json:"short_text,omitempty"` Attachments []Attachment `json:"attachments"` } From 0dd2f7a674dc18a6706116935a05ae48243e945b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 27 Apr 2026 13:28:51 +0000 Subject: [PATCH 2/6] Apply review feedback: auth-gated ShortText, absolute safeDataPath, UTF-8 fix, README update Agent-Logs-Url: https://github.com/markmnl/fmsg-webapi/sessions/8c2bb90d-2c10-4faa-8752-688a147715d8 Co-authored-by: markmnl <2630321+markmnl@users.noreply.github.com> --- README.md | 8 ++++-- src/handlers/messages.go | 60 +++++++++++++++++++++------------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 7cfa811..fb263dc 100644 --- a/README.md +++ b/README.md @@ -286,9 +286,11 @@ Retrieves a single message by ID. The authenticated user must be a participant ``` The `short_text` field is included only when the message `type` is `text/*` and -the stored body is valid UTF-8. It contains up to `FMSG_API_SHORT_TEXT_SIZE` -bytes (default 768) of the body, truncated on a UTF-8 rune boundary, so UIs -can render a preview without a separate `GET /fmsg/:id/data` round-trip. +the stored body is valid UTF-8. When `FMSG_API_SHORT_TEXT_SIZE` is greater than +`0`, it contains up to that many bytes (default 768) of the body, truncated on +a UTF-8 rune boundary, so UIs can render a preview without a separate +`GET /fmsg/:id/data` round-trip. Set `FMSG_API_SHORT_TEXT_SIZE=0` to disable +`short_text` generation. **Errors:** diff --git a/src/handlers/messages.go b/src/handlers/messages.go index ae0f050..f851463 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -506,7 +506,7 @@ func (h *MessageHandler) Get(c *gin.Context) { } ctx := c.Request.Context() - msg, err := h.fetchMessage(ctx, msgID) + msg, dataPath, err := h.fetchMessage(ctx, msgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -523,6 +523,9 @@ func (h *MessageHandler) Get(c *gin.Context) { return } + // Compute ShortText only after authorization has been confirmed. + msg.ShortText = h.extractShortText(dataPath, msg.Type) + c.JSON(http.StatusOK, msg) } @@ -592,7 +595,7 @@ func (h *MessageHandler) Update(c *gin.Context) { } ctx := c.Request.Context() - existing, err := h.fetchMessage(ctx, msgID) + existing, _, err := h.fetchMessage(ctx, msgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -687,7 +690,7 @@ func (h *MessageHandler) Delete(c *gin.Context) { } ctx := c.Request.Context() - existing, err := h.fetchMessage(ctx, msgID) + existing, _, err := h.fetchMessage(ctx, msgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -759,7 +762,7 @@ func (h *MessageHandler) Send(c *gin.Context) { } ctx := c.Request.Context() - existing, err := h.fetchMessage(ctx, msgID) + existing, _, err := h.fetchMessage(ctx, msgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -861,7 +864,9 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { } // fetchMessage loads a message with its recipients and attachments from the DB. -func (h *MessageHandler) fetchMessage(ctx context.Context, msgID int64) (*models.Message, error) { +// It also returns the raw filepath stored in the database so callers can use it +// after performing their own authorization checks. +func (h *MessageHandler) fetchMessage(ctx context.Context, msgID int64) (*models.Message, string, error) { row := h.DB.Pool.QueryRow(ctx, `SELECT version, pid, no_reply, is_important, is_deflate, time_sent, from_addr, topic, type, size, filepath FROM msg WHERE id = $1`, msgID, @@ -872,12 +877,11 @@ func (h *MessageHandler) fetchMessage(ctx context.Context, msgID int64) (*models var timeSent *float64 var dataPath string if err := row.Scan(&msg.Version, &pid, &msg.NoReply, &msg.Important, &msg.Deflate, &timeSent, &msg.From, &msg.Topic, &msg.Type, &msg.Size, &dataPath); err != nil { - return nil, err + return nil, "", err } msg.PID = pid msg.Time = timeSent msg.HasPid = pid != nil - msg.ShortText = h.extractShortText(dataPath, msg.Type) // Load recipients. rows, err := h.DB.Pool.Query(ctx, "SELECT addr FROM msg_to WHERE msg_id = $1", msgID) @@ -916,7 +920,7 @@ func (h *MessageHandler) fetchMessage(ctx context.Context, msgID int64) (*models attRows.Close() } - return msg, nil + return msg, dataPath, nil } // saveMessageData writes data to the filesystem and returns the absolute path. @@ -1036,9 +1040,21 @@ func safeDataPath(dataPath, dataDir string) (string, bool) { if dataPath == "" || dataDir == "" { return "", false } - cleanPath := filepath.Clean(dataPath) - cleanDataDir := filepath.Clean(dataDir) - if !strings.HasPrefix(cleanPath, cleanDataDir+string(filepath.Separator)) { + absPath, err := filepath.Abs(dataPath) + if err != nil { + return "", false + } + absDataDir, err := filepath.Abs(dataDir) + if err != nil { + return "", false + } + cleanPath := filepath.Clean(absPath) + cleanDataDir := filepath.Clean(absDataDir) + relPath, err := filepath.Rel(cleanDataDir, cleanPath) + if err != nil { + return "", false + } + if strings.HasPrefix(relPath, "..") { return "", false } return cleanPath, true @@ -1091,25 +1107,13 @@ func (h *MessageHandler) extractShortText(dataPath, mimeType string) string { } buf = buf[:n] - // If the buffer ends in the middle of a multi-byte UTF-8 sequence, drop - // the trailing partial rune so the result is valid UTF-8. - for len(buf) > 0 && !utf8.RuneStart(buf[len(buf)-1]) { + // Trim to the largest valid UTF-8 prefix so that we only drop trailing + // incomplete bytes, while preserving complete multi-byte runes that end + // exactly at the buffer boundary. + for len(buf) > 0 && !utf8.Valid(buf) { buf = buf[:len(buf)-1] } - // Drop the last rune if it is itself the start of an incomplete sequence - // (e.g. a 0xC2 lead byte with no continuation). - if len(buf) > 0 { - last := buf[len(buf)-1] - switch { - case last < 0x80: - // single-byte rune, complete - case last >= 0xC0: - // lead byte with no continuation bytes following - buf = buf[:len(buf)-1] - } - } - - if !utf8.Valid(buf) { + if len(buf) == 0 { return "" } return string(buf) From c09338f3c2c9104260946f47692b85b2f3d85b82 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Mon, 27 Apr 2026 21:35:32 +0800 Subject: [PATCH 3/6] fix comments --- src/handlers/attachments.go | 6 +++--- src/handlers/messages.go | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/handlers/attachments.go b/src/handlers/attachments.go index f662215..c92b3da 100644 --- a/src/handlers/attachments.go +++ b/src/handlers/attachments.go @@ -30,7 +30,7 @@ func NewAttachmentHandler(database *db.DB, dataDir string, maxAttachSize, maxMsg return &AttachmentHandler{DB: database, DataDir: dataDir, MaxAttachSize: maxAttachSize, MaxMsgSize: maxMsgSize} } -// Upload handles POST /api/v1/messages/:id/attachments. +// Upload handles POST /fmsg/:id/attachments. func (h *AttachmentHandler) Upload(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -166,7 +166,7 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { c.JSON(http.StatusCreated, gin.H{"filename": intendedFilename, "size": written}) } -// Download handles GET /api/v1/messages/:id/attachments/:filename. +// Download handles GET /fmsg/:id/attachments/:filename. func (h *AttachmentHandler) Download(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -239,7 +239,7 @@ func (h *AttachmentHandler) Download(c *gin.Context) { c.FileAttachment(cleanPath, filename) } -// DeleteAttachment handles DELETE /api/v1/messages/:id/attachments/:filename. +// DeleteAttachment handles DELETE /fmsg/:id/attachments/:filename. func (h *AttachmentHandler) DeleteAttachment(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) diff --git a/src/handlers/messages.go b/src/handlers/messages.go index f851463..9f64580 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -177,7 +177,7 @@ func (h *MessageHandler) latestMessageIDForRecipient(ctx context.Context, identi return latestID, nil } -// List handles GET /api/v1/messages — lists messages where the authenticated user is a recipient. +// List handles GET /fmsg — lists messages where the authenticated user is a recipient. func (h *MessageHandler) List(c *gin.Context) { identity := middleware.GetIdentity(c) @@ -405,7 +405,7 @@ func (h *MessageHandler) Sent(c *gin.Context) { c.JSON(http.StatusOK, messages) } -// Create handles POST /api/v1/messages — creates a draft message. +// Create handles POST /fmsg — creates a draft message. func (h *MessageHandler) Create(c *gin.Context) { identity := middleware.GetIdentity(c) @@ -497,7 +497,7 @@ func (h *MessageHandler) Create(c *gin.Context) { c.JSON(http.StatusCreated, gin.H{"id": msgID}) } -// Get handles GET /api/v1/messages/:id — retrieves a message. +// Get handles GET /fmsg/:id — retrieves a message. func (h *MessageHandler) Get(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -529,7 +529,7 @@ func (h *MessageHandler) Get(c *gin.Context) { c.JSON(http.StatusOK, msg) } -// DownloadData handles GET /api/v1/messages/:id/data — downloads the message body as a file. +// DownloadData handles GET /fmsg/:id/data — downloads the message body as a file. func (h *MessageHandler) DownloadData(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -586,7 +586,7 @@ func (h *MessageHandler) DownloadData(c *gin.Context) { c.FileAttachment(cleanPath, filepath.Base(cleanPath)) } -// Update handles PUT /api/v1/messages/:id — updates a draft message. +// Update handles PUT /fmsg/:id — updates a draft message. func (h *MessageHandler) Update(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -681,7 +681,7 @@ func (h *MessageHandler) Update(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"id": msgID}) } -// Delete handles DELETE /api/v1/messages/:id — deletes a draft message. +// Delete handles DELETE /fmsg/:id — deletes a draft message. func (h *MessageHandler) Delete(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -753,7 +753,7 @@ func (h *MessageHandler) Delete(c *gin.Context) { c.Status(http.StatusNoContent) } -// Send handles POST /api/v1/messages/:id/send — marks a message as sent. +// Send handles POST /fmsg/:id/send — marks a message as sent. func (h *MessageHandler) Send(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) @@ -797,7 +797,7 @@ type addToInput struct { AddTo []string `json:"add_to"` } -// AddRecipients handles POST /api/v1/messages/:id/add-to — adds additional recipients to a message. +// AddRecipients handles POST /fmsg/:id/add-to — adds additional recipients to a message. func (h *MessageHandler) AddRecipients(c *gin.Context) { identity := middleware.GetIdentity(c) msgID, ok := parseID(c) From 6f1e1a78eaf85b416f942f25caa25b2ab7ed9cea Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Mon, 27 Apr 2026 21:41:17 +0800 Subject: [PATCH 4/6] topic <-> pid <-> add to validation --- README.md | 6 +++--- src/handlers/messages.go | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index fb263dc..730ce16 100644 --- a/README.md +++ b/README.md @@ -253,7 +253,7 @@ When `add_to` recipients are provided, `add_to_from` is automatically populated | Status | Condition | | ------ | --------- | -| `400` | Missing/invalid fields or empty `to` | +| `400` | Missing/invalid fields, empty `to`, `topic` set together with `pid`, or `add_to`/`add_to_from` set without `pid` | | `403` | `from` does not match authenticated user | ### GET `/fmsg/:id` @@ -307,7 +307,7 @@ Updates a draft message. Only the owner (`from`) may update, and the message mus | Status | Condition | | ------ | --------- | -| `400` | Invalid fields | +| `400` | Invalid fields, `topic` set together with `pid`, or `add_to`/`add_to_from` set without `pid` | | `403` | Not the owner, or message already sent | | `404` | Message not found | @@ -358,7 +358,7 @@ New addresses must be distinct among themselves (case-insensitive). | Status | Condition | | ------ | --------- | -| `400` | Empty `add_to` or duplicate addresses in request | +| `400` | Empty `add_to`, duplicate addresses, or target message has no `pid` | | `403` | Authenticated user is not an existing participant (sender or `to` recipient) | | `404` | Message not found | diff --git a/src/handlers/messages.go b/src/handlers/messages.go index 9f64580..ccf518a 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -426,6 +426,11 @@ func (h *MessageHandler) Create(c *gin.Context) { return } + if err := validatePidRelations(msg.PID, msg.Topic, msg.AddTo, msg.AddToFrom); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if int64(len(msg.Data)) > h.MaxDataSize { c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) return @@ -625,6 +630,11 @@ func (h *MessageHandler) Update(c *gin.Context) { return } + if err := validatePidRelations(msg.PID, msg.Topic, msg.AddTo, msg.AddToFrom); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if int64(len(msg.Data)) > h.MaxDataSize { c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) return @@ -820,9 +830,10 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { // Load the message to verify it exists. var fromAddr string + var pid *int64 err := h.DB.Pool.QueryRow(ctx, - "SELECT from_addr FROM msg WHERE id = $1", msgID, - ).Scan(&fromAddr) + "SELECT from_addr, pid FROM msg WHERE id = $1", msgID, + ).Scan(&fromAddr, &pid) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -833,6 +844,12 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { return } + // add_to is only valid on replies (messages with a pid). + if pid == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "add_to is only valid when pid is supplied"}) + return + } + // Verify the requester is an existing participant (from or msg_to). if fromAddr != identity { var recipientCount int @@ -1119,6 +1136,25 @@ func (h *MessageHandler) extractShortText(dataPath, mimeType string) string { return string(buf) } +// validatePidRelations enforces: +// - If pid is set, topic must be empty (replies inherit topic from parent). +// - If pid is not set, add_to and add_to_from must be empty (a thread +// must exist before recipients can be added to it). +func validatePidRelations(pid *int64, topic string, addTo []string, addToFrom *string) error { + if pid != nil && topic != "" { + return fmt.Errorf("topic must be empty when pid is supplied") + } + if pid == nil { + if len(addTo) > 0 { + return fmt.Errorf("add_to is only valid when pid is supplied") + } + if addToFrom != nil && *addToFrom != "" { + return fmt.Errorf("add_to_from is only valid when pid is supplied") + } + } + return nil +} + // checkDistinctRecipients returns an error if any address in to or addTo // appears more than once (case-insensitive). func checkDistinctRecipients(to, addTo []string) error { From 30da0606c8f6f309d9dd590c27ca382deda98043 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Mon, 27 Apr 2026 21:46:42 +0800 Subject: [PATCH 5/6] check addr valid --- src/handlers/messages.go | 39 ++++++++++++++++++++++++++++++++++++++ src/middleware/jwt.go | 6 +++--- src/middleware/jwt_test.go | 2 +- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/handlers/messages.go b/src/handlers/messages.go index ccf518a..897b923 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -426,6 +426,11 @@ func (h *MessageHandler) Create(c *gin.Context) { return } + if err := validateAddresses(msg.From, msg.To, msg.AddTo, msg.AddToFrom); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := validatePidRelations(msg.PID, msg.Topic, msg.AddTo, msg.AddToFrom); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -630,6 +635,11 @@ func (h *MessageHandler) Update(c *gin.Context) { return } + if err := validateAddresses(msg.From, msg.To, msg.AddTo, msg.AddToFrom); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := validatePidRelations(msg.PID, msg.Topic, msg.AddTo, msg.AddToFrom); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -826,6 +836,13 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { return } + for _, addr := range input.AddTo { + if !middleware.IsValidAddr(addr) { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid add_to address: %q", addr)}) + return + } + } + ctx := c.Request.Context() // Load the message to verify it exists. @@ -1136,6 +1153,28 @@ func (h *MessageHandler) extractShortText(dataPath, mimeType string) string { return string(buf) } +// validateAddresses returns an error if any of the provided fmsg address +// fields is not a valid "@user@domain" address. addToFrom is optional. +func validateAddresses(from string, to, addTo []string, addToFrom *string) error { + if !middleware.IsValidAddr(from) { + return fmt.Errorf("invalid from address: %q", from) + } + for _, addr := range to { + if !middleware.IsValidAddr(addr) { + return fmt.Errorf("invalid to address: %q", addr) + } + } + for _, addr := range addTo { + if !middleware.IsValidAddr(addr) { + return fmt.Errorf("invalid add_to address: %q", addr) + } + } + if addToFrom != nil && *addToFrom != "" && !middleware.IsValidAddr(*addToFrom) { + return fmt.Errorf("invalid add_to_from address: %q", *addToFrom) + } + return nil +} + // validatePidRelations enforces: // - If pid is set, topic must be empty (replies inherit topic from parent). // - If pid is not set, add_to and add_to_from must be empty (a thread diff --git a/src/middleware/jwt.go b/src/middleware/jwt.go index 2a6e466..e8f530e 100644 --- a/src/middleware/jwt.go +++ b/src/middleware/jwt.go @@ -151,7 +151,7 @@ func New(cfg Config) (gin.HandlerFunc, error) { } addr, _ := claims["sub"].(string) - if !isValidAddr(addr) { + if !IsValidAddr(addr) { log.Printf("auth rejected: ip=%s reason=invalid_addr sub=%q", c.ClientIP(), addr) respondAuth(c, http.StatusUnauthorized, "invalid identity") return @@ -234,8 +234,8 @@ func GetIdentity(c *gin.Context) string { return addr } -// isValidAddr checks that the address has the form "@user@domain". -func isValidAddr(addr string) bool { +// IsValidAddr checks that the address has the form "@user@domain". +func IsValidAddr(addr string) bool { if len(addr) < 3 { return false } diff --git a/src/middleware/jwt_test.go b/src/middleware/jwt_test.go index 549b1a9..1cbd94b 100644 --- a/src/middleware/jwt_test.go +++ b/src/middleware/jwt_test.go @@ -33,7 +33,7 @@ func TestIsValidAddr(t *testing.T) { } for _, tt := range tests { t.Run(tt.addr, func(t *testing.T) { - if got := isValidAddr(tt.addr); got != tt.want { + if got := IsValidAddr(tt.addr); got != tt.want { t.Errorf("isValidAddr(%q) = %v, want %v", tt.addr, got, tt.want) } }) From 5efbe1aa205f8ace8d270b44a8d73979beb1151e Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Tue, 28 Apr 2026 21:20:14 +0800 Subject: [PATCH 6/6] rm jtw replay check --- README.md | 7 +-- src/middleware/jti_cache.go | 93 ------------------------------------- src/middleware/jwt.go | 25 ---------- src/middleware/jwt_test.go | 29 ++---------- 4 files changed, 6 insertions(+), 148 deletions(-) delete mode 100644 src/middleware/jti_cache.go diff --git a/README.md b/README.md index 730ce16..6679825 100644 --- a/README.md +++ b/README.md @@ -53,19 +53,16 @@ Required claims: | `iat` | Issued-at timestamp (Unix seconds). | | `nbf` | Not-before timestamp. | | `exp` | Expiry timestamp (must be in the future, ±10 s leeway). | -| `jti` | Unique token ID. Used for in-process replay prevention until `exp`. | +| `jti` | Optional unique token ID. | | `aud` | Optional; required only when `FMSG_JWT_AUDIENCE` is set. | A 10-second clock-skew leeway is applied to `iat`/`nbf`/`exp` validation. -Replay prevention is in-process and does not coordinate across multiple API -instances; deploy as a single instance or replace the cache before scaling -horizontally. ### HMAC (development) Active when `FMSG_JWT_JWKS_URL` is unset. Tokens must be HS256-signed with the shared secret in `FMSG_API_JWT_SECRET`. Required claims are `sub` and `exp`; -`iat`/`nbf` are honoured when present. No replay prevention is applied. +`iat`/`nbf` are honoured when present. ## Building diff --git a/src/middleware/jti_cache.go b/src/middleware/jti_cache.go deleted file mode 100644 index cbe2480..0000000 --- a/src/middleware/jti_cache.go +++ /dev/null @@ -1,93 +0,0 @@ -package middleware - -import ( - "sync" - "time" -) - -// jtiCacheMaxEntries bounds memory usage of the in-process replay cache. -// When exceeded, expired entries are swept first; if still over the limit, -// new entries are dropped (the request is still rejected only on a true -// duplicate, never on overflow). -const jtiCacheMaxEntries = 100_000 - -// jtiCache tracks JWT IDs that have already been seen, until their -// corresponding token expiry, to prevent replay attacks. -// -// The cache lives in-process; it does not coordinate across multiple API -// instances. For a horizontally-scaled deployment, replace with a shared -// store (e.g. Postgres or Redis). -type jtiCache struct { - mu sync.Mutex - entries map[string]time.Time - stop chan struct{} -} - -// newJTICache returns a cache with a background sweeper running until Close. -func newJTICache() *jtiCache { - c := &jtiCache{ - entries: make(map[string]time.Time), - stop: make(chan struct{}), - } - go c.sweepLoop(time.Minute) - return c -} - -// Seen atomically checks whether jti has been recorded with an unexpired -// entry; if not, records it with the given expiry. Returns true if the -// jti was already present (i.e. this is a replay). -// -// Empty jti strings are never considered seen (caller decides policy). -func (c *jtiCache) Seen(jti string, exp time.Time) bool { - if jti == "" { - return false - } - now := time.Now() - c.mu.Lock() - defer c.mu.Unlock() - if existing, ok := c.entries[jti]; ok && existing.After(now) { - return true - } - if len(c.entries) >= jtiCacheMaxEntries { - c.sweepLocked(now) - if len(c.entries) >= jtiCacheMaxEntries { - // Cache full of unexpired entries; refuse to grow but do not - // falsely flag the token as a replay. - return false - } - } - c.entries[jti] = exp - return false -} - -// Close stops the background sweeper. -func (c *jtiCache) Close() { - select { - case <-c.stop: - default: - close(c.stop) - } -} - -func (c *jtiCache) sweepLoop(interval time.Duration) { - t := time.NewTicker(interval) - defer t.Stop() - for { - select { - case <-c.stop: - return - case now := <-t.C: - c.mu.Lock() - c.sweepLocked(now) - c.mu.Unlock() - } - } -} - -func (c *jtiCache) sweepLocked(now time.Time) { - for k, exp := range c.entries { - if !exp.After(now) { - delete(c.entries, k) - } - } -} diff --git a/src/middleware/jwt.go b/src/middleware/jwt.go index e8f530e..8a3c353 100644 --- a/src/middleware/jwt.go +++ b/src/middleware/jwt.go @@ -67,7 +67,6 @@ type Config struct { // - extracts a Bearer token from the Authorization header, // - parses & verifies the signature according to cfg.Mode, // - validates iss/aud/exp/nbf claims, -// - rejects replays (EdDSA mode only) by tracking jti in-process, // - extracts sub as the user address and validates its shape, // - calls fmsgid to confirm the user is known and accepting messages, // - on success stores the address in the Gin context under IdentityKey. @@ -129,11 +128,6 @@ func New(cfg Config) (gin.HandlerFunc, error) { } parser := jwt.NewParser(parserOpts...) - var replay *jtiCache - if cfg.Mode == ModeEdDSA { - replay = newJTICache() - } - idURL := cfg.IDURL return func(c *gin.Context) { @@ -157,25 +151,6 @@ func New(cfg Config) (gin.HandlerFunc, error) { return } - if replay != nil { - jti, _ := claims["jti"].(string) - if jti == "" { - log.Printf("auth rejected: ip=%s addr=%s reason=missing_jti", c.ClientIP(), addr) - respondAuth(c, http.StatusUnauthorized, "invalid token") - return - } - expTime, err := claims.GetExpirationTime() - if err != nil || expTime == nil { - respondAuth(c, http.StatusUnauthorized, "invalid token") - return - } - if replay.Seen(jti, expTime.Time) { - log.Printf("auth rejected: ip=%s addr=%s reason=jti_replay jti=%s", c.ClientIP(), addr, jti) - respondAuth(c, http.StatusUnauthorized, "token already used") - return - } - } - code, accepting, err := checkFmsgID(idURL, addr) if err != nil { log.Printf("fmsgid check error for %s: %v", addr, err) diff --git a/src/middleware/jwt_test.go b/src/middleware/jwt_test.go index 1cbd94b..b98f653 100644 --- a/src/middleware/jwt_test.go +++ b/src/middleware/jwt_test.go @@ -282,7 +282,7 @@ func TestEdDSAMode_Expired(t *testing.T) { } } -func TestEdDSAMode_Replay(t *testing.T) { +func TestEdDSAMode_Reuse(t *testing.T) { srv := fmsgIDServer(t, http.StatusOK, true) defer srv.Close() priv, jwks := newEdDSAFixture(t) @@ -295,15 +295,15 @@ func TestEdDSAMode_Replay(t *testing.T) { "sub": "@alice@example.com", "iat": time.Now().Unix(), "exp": time.Now().Add(time.Hour).Unix(), - "jti": "replay-me", + "jti": "reuse-me", } tok := signEdDSA(t, priv, "prod-1", claims) if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { t.Fatalf("first call expected 200, got %d", w.Code) } - if w := runMiddleware(t, mw, tok); w.Code != http.StatusUnauthorized { - t.Fatalf("replay expected 401, got %d", w.Code) + if w := runMiddleware(t, mw, tok); w.Code != http.StatusOK { + t.Fatalf("reuse expected 200, got %d", w.Code) } } @@ -325,24 +325,3 @@ func TestEdDSAMode_FmsgIDUnavailable(t *testing.T) { t.Fatalf("expected 503, got %d", w.Code) } } - -func TestJTICache_SeenAndExpiry(t *testing.T) { - c := newJTICache() - defer c.Close() - exp := time.Now().Add(time.Hour) - if c.Seen("a", exp) { - t.Fatal("first Seen should be false") - } - if !c.Seen("a", exp) { - t.Fatal("second Seen should be true") - } - // Expired entry should not count as seen. - if c.Seen("b", time.Now().Add(-time.Second)) { - t.Fatal("expired entry: first Seen should be false") - } - // And subsequently the cached entry, having expired in the past, should - // be replaceable. - if c.Seen("b", time.Now().Add(-time.Second)) { - t.Fatal("expired entry: should not flag as replay") - } -}