From 4177574ec7876df29eef98f037fb03c1b6e55121 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Sat, 16 May 2026 13:18:00 +0800 Subject: [PATCH 1/4] refactor to cmd/ pkg/ idomatic go --- AGENTS.md | 8 +- src/go.mod => go.mod | 0 src/go.sum => go.sum | 0 pkg/protocol/README.md | 40 + pkg/protocol/protocol.go | 332 ++++++++ src/.env.example | 23 - src/deflate.go | 157 ---- src/deflate_test.go | 549 ------------ src/defs.go | 248 ------ src/defs_test.go | 854 ------------------- src/dns.go | 148 ---- src/host.go | 1735 -------------------------------------- src/host_test.go | 622 -------------- src/id.go | 92 -- src/outgoing.go | 66 -- src/outgoing_test.go | 142 ---- src/sender.go | 628 -------------- src/store.go | 592 ------------- src/store_test.go | 139 --- 19 files changed, 377 insertions(+), 5998 deletions(-) rename src/go.mod => go.mod (100%) rename src/go.sum => go.sum (100%) create mode 100644 pkg/protocol/README.md create mode 100644 pkg/protocol/protocol.go delete mode 100644 src/.env.example delete mode 100644 src/deflate.go delete mode 100644 src/deflate_test.go delete mode 100644 src/defs.go delete mode 100644 src/defs_test.go delete mode 100644 src/dns.go delete mode 100644 src/host.go delete mode 100644 src/host_test.go delete mode 100644 src/id.go delete mode 100644 src/outgoing.go delete mode 100644 src/outgoing_test.go delete mode 100644 src/sender.go delete mode 100644 src/store.go delete mode 100644 src/store_test.go diff --git a/AGENTS.md b/AGENTS.md index 300059f..1accc1e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,6 +19,8 @@ All code MUST conform to the specification. When in doubt, re-read SPEC.md and f ## Build & Test - Language: Go -- Module path: `src/` -- Build: `cd src && go build ./...` -- Test: `cd src && go test ./...` +- Module path: repo root (`go.mod` at `fmsgd/go.mod`) +- Binary source: `cmd/fmsgd/` +- Shared protocol package: `pkg/protocol/` +- Build: `go build ./...` (from repo root) +- Test: `go test ./...` (from repo root) diff --git a/src/go.mod b/go.mod similarity index 100% rename from src/go.mod rename to go.mod diff --git a/src/go.sum b/go.sum similarity index 100% rename from src/go.sum rename to go.sum diff --git a/pkg/protocol/README.md b/pkg/protocol/README.md new file mode 100644 index 0000000..2219762 --- /dev/null +++ b/pkg/protocol/README.md @@ -0,0 +1,40 @@ +# pkg/protocol + +This package implements the fmsg wire protocol encoding and hashing as defined in [SPEC.md](../../SPEC.md). + +## Import + +```go +import "github.com/markmnl/fmsgd/pkg/protocol" +``` + +## What it provides + +- **Types**: `FMsgHeader`, `FMsgAddress`, `FMsgAttachmentHeader` +- **Flag constants**: `FlagHasPid`, `FlagHasAddTo`, `FlagCommonType`, `FlagImportant`, `FlagNoReply`, `FlagDeflate` +- **Wire encoding**: `FMsgHeader.Encode()` — serialises a header to the exact byte sequence defined in SPEC.md +- **Hashing**: `FMsgHeader.GetHeaderHash()` — SHA-256 of the encoded header; `FMsgHeader.GetMessageHash()` — SHA-256 of header + decompressed body + decompressed attachments +- **Common media type lookup**: `GetCommonMediaType(id)`, `GetCommonMediaTypeID(mimeType)` + +## Example: compute a message hash + +```go +h := &protocol.FMsgHeader{ + Version: 1, + Flags: 0, + From: protocol.FMsgAddress{User: "alice", Domain: "example.com"}, + To: []protocol.FMsgAddress{{User: "bob", Domain: "other.com"}}, + Timestamp: 1700000000.0, + Topic: "hello", + Type: "text/plain", + Size: 5, + Filepath: "/path/to/body/file", +} +hash, err := h.GetMessageHash() +``` + +## Notes + +- `Encode()` produces fields 1–12 of the fmsg wire format (header through attachment headers). Message data and attachment data follow separately on the wire. +- `GetMessageHash()` hashes over **decompressed** data even when the stored file is zlib-compressed (`FlagDeflate` set). Set `ExpandedSize` accordingly. +- `Filepath` and `FMsgAttachmentHeader.Filepath` must point to readable files on disk for `GetMessageHash()` to succeed. diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go new file mode 100644 index 0000000..3ce3152 --- /dev/null +++ b/pkg/protocol/protocol.go @@ -0,0 +1,332 @@ +// Package protocol implements the fmsg wire protocol encoding and hashing. +// +// It provides the core types and methods needed to build, encode, and hash +// fmsg messages as defined in the fmsg protocol specification (SPEC.md). +// +// To compute a message hash from database fields, populate an [FMsgHeader] +// (including Filepath and per-attachment Filepath values), then call +// [FMsgHeader.GetMessageHash]. +package protocol + +import ( + "bytes" + "compress/zlib" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "os" + "strings" +) + +// Flag bit assignments per SPEC.md. +// Bits 6–7 are reserved and must be zero on the wire. +const ( + FlagHasPid uint8 = 1 // bit 0: pid field present; message is a reply + FlagHasAddTo uint8 = 1 << 1 // bit 1: add-to addresses present + FlagCommonType uint8 = 1 << 2 // bit 2: type encoded as common type ID, not string + FlagImportant uint8 = 1 << 3 // bit 3: sender marks message as important + FlagNoReply uint8 = 1 << 4 // bit 4: sender will discard replies + FlagDeflate uint8 = 1 << 5 // bit 5: message body is zlib-deflate compressed +) + +// FMsgAddress is an fmsg address of the form @user@domain. +type FMsgAddress struct { + User string + Domain string +} + +// ToString returns the address in @user@domain form. +func (addr *FMsgAddress) ToString() string { + return fmt.Sprintf("@%s@%s", addr.User, addr.Domain) +} + +// FMsgAttachmentHeader holds the wire-level metadata for a single attachment. +type FMsgAttachmentHeader struct { + Flags uint8 + TypeID uint8 + Type string + Filename string + Size uint32 // wire (possibly compressed) size in bytes + ExpandedSize uint32 // decompressed size; present on wire iff bit 1 of Flags is set + + Filepath string // path to attachment data on disk +} + +// FMsgHeader holds all fields of an fmsg message header. +// +// Fields ChallengeHash, ChallengeCompleted, and InitialResponseCode are +// fmsgd server-runtime state; they are not part of the wire format and can +// be ignored by other consumers of this package. +type FMsgHeader struct { + Version uint8 + Flags uint8 + Pid []byte + From FMsgAddress + To []FMsgAddress + AddToFrom *FMsgAddress // present when FlagHasAddTo is set + AddTo []FMsgAddress + Timestamp float64 + TypeID uint8 + Topic string + Type string + + Size uint32 // wire (possibly compressed) size of the message body + ExpandedSize uint32 // decompressed size; present on wire iff FlagDeflate set + Attachments []FMsgAttachmentHeader + + HeaderHash []byte + ChallengeHash [32]byte // fmsgd server field: challenge response hash + ChallengeCompleted bool // fmsgd server field: true if challenge was completed + InitialResponseCode uint8 // fmsgd server field: protocol response code (11/64/65) + Filepath string // path to message body data on disk + messageHash []byte +} + +// Encode serialises the message header to wire format as defined in SPEC.md. +// The returned bytes cover all fields up to and including the attachment +// headers (fields 1–12 per spec). This method panics on internal buffer errors +// rather than returning an error. +func (h *FMsgHeader) Encode() []byte { + var b bytes.Buffer + b.WriteByte(h.Version) + b.WriteByte(h.Flags) + if h.Flags&FlagHasPid == 1 { + b.Write(h.Pid[:]) + } + str := h.From.ToString() + b.WriteByte(byte(len(str))) + b.WriteString(str) + b.WriteByte(byte(len(h.To))) + for _, addr := range h.To { + str = addr.ToString() + b.WriteByte(byte(len(str))) + b.WriteString(str) + } + if h.Flags&FlagHasAddTo != 0 { + // add-to-from address (field 6) + addToFrom := h.AddToFrom + if addToFrom == nil { + addToFrom = &h.From + } + str := addToFrom.ToString() + b.WriteByte(byte(len(str))) + b.WriteString(str) + // add-to addresses (field 7) + b.WriteByte(byte(len(h.AddTo))) + for _, addr := range h.AddTo { + str = addr.ToString() + b.WriteByte(byte(len(str))) + b.WriteString(str) + } + } + if err := binary.Write(&b, binary.LittleEndian, h.Timestamp); err != nil { + panic(err) + } + // topic is only present when pid is NOT set + if h.Flags&FlagHasPid == 0 { + b.WriteByte(byte(len(h.Topic))) + b.WriteString(h.Topic) + } + if h.Flags&FlagCommonType != 0 { + typeID := h.TypeID + if typeID == 0 { + if id, ok := getCommonMediaTypeID(h.Type); ok { + typeID = id + } + } + b.WriteByte(typeID) + } else { + b.WriteByte(byte(len(h.Type))) + b.WriteString(h.Type) + } + // size (uint32 LE) + if err := binary.Write(&b, binary.LittleEndian, h.Size); err != nil { + panic(err) + } + // expanded size (uint32 LE) — present iff zlib-deflate flag set + if h.Flags&FlagDeflate != 0 { + if err := binary.Write(&b, binary.LittleEndian, h.ExpandedSize); err != nil { + panic(err) + } + } + // attachment headers + b.WriteByte(byte(len(h.Attachments))) + for _, att := range h.Attachments { + b.WriteByte(att.Flags) + if att.Flags&1 != 0 { + typeID := att.TypeID + if typeID == 0 { + if id, ok := getCommonMediaTypeID(att.Type); ok { + typeID = id + } + } + b.WriteByte(typeID) + } else { + b.WriteByte(byte(len(att.Type))) + b.WriteString(att.Type) + } + b.WriteByte(byte(len(att.Filename))) + b.WriteString(att.Filename) + if err := binary.Write(&b, binary.LittleEndian, att.Size); err != nil { + panic(err) + } + // attachment expanded size — present iff attachment zlib-deflate flag set + if att.Flags&(1<<1) != 0 { + if err := binary.Write(&b, binary.LittleEndian, att.ExpandedSize); err != nil { + panic(err) + } + } + } + return b.Bytes() +} + +// String returns a human-readable summary of the header fields. +func (h *FMsgHeader) String() string { + var b strings.Builder + fmt.Fprintf(&b, "v%d flags=%d", h.Version, h.Flags) + if len(h.Pid) > 0 { + fmt.Fprintf(&b, " pid=%s", hex.EncodeToString(h.Pid)) + } + fmt.Fprintf(&b, "\nfrom:\t%s", h.From.ToString()) + for i, addr := range h.To { + if i == 0 { + fmt.Fprintf(&b, "\nto:\t%s", addr.ToString()) + } else { + fmt.Fprintf(&b, "\n\t%s", addr.ToString()) + } + } + for i, addr := range h.AddTo { + if i == 0 { + fmt.Fprintf(&b, "\nadd to:\t%s", addr.ToString()) + } else { + fmt.Fprintf(&b, "\n\t%s", addr.ToString()) + } + } + fmt.Fprintf(&b, "\ntopic:\t%s", h.Topic) + fmt.Fprintf(&b, "\ntype:\t%s", h.Type) + fmt.Fprintf(&b, "\nsize:\t%d", h.Size) + return b.String() +} + +// GetHeaderHash returns the SHA-256 hash of the encoded header (fields 1–12). +// The result is cached after the first call. +func (h *FMsgHeader) GetHeaderHash() []byte { + if h.HeaderHash == nil { + b := sha256.Sum256(h.Encode()) + h.HeaderHash = b[:] + } + return h.HeaderHash +} + +// GetMessageHash returns the SHA-256 hash of the full message: +// encoded header + decompressed message body + decompressed attachment data. +// The result is cached after the first call. +func (h *FMsgHeader) GetMessageHash() ([]byte, error) { + if h.messageHash == nil { + hash := sha256.New() + + headerBytes := h.Encode() + if _, err := io.Copy(hash, bytes.NewBuffer(headerBytes)); err != nil { + return nil, err + } + + if err := hashPayload(hash, h.Filepath, int64(h.Size), h.Flags&FlagDeflate != 0, h.ExpandedSize); err != nil { + return nil, err + } + + // include attachment data (sequential byte sequences following + // the message body, bounded by attachment header sizes) + for _, att := range h.Attachments { + compressed := att.Flags&(1<<1) != 0 + if err := hashPayload(hash, att.Filepath, int64(att.Size), compressed, att.ExpandedSize); err != nil { + return nil, fmt.Errorf("hash attachment %s: %w", att.Filename, err) + } + } + + h.messageHash = hash.Sum(nil) + } + return h.messageHash, nil +} + +// HashPayload reads wireSize bytes from the file at filepath, decompressing +// via zlib if deflated, and writes the (decompressed) bytes to dst. +// The message hash is always computed over decompressed data. +func HashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool, expandedSize uint32) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + defer f.Close() + + if deflated { + lr := io.LimitReader(f, wireSize) + zr, err := zlib.NewReader(lr) + if err != nil { + return err + } + written, err := io.Copy(dst, zr) + _ = zr.Close() + if err != nil { + return err + } + if uint32(written) != expandedSize { + return fmt.Errorf("decompressed size %d does not match declared expanded size %d", written, expandedSize) + } + return nil + } + + _, err = io.CopyN(dst, f, wireSize) + return err +} + +// commonMediaTypes maps common type IDs (1–64) to MIME strings per SPEC.md §4. +// Unmapped IDs must be rejected with response code 1 (invalid). +var commonMediaTypes = map[uint8]string{ + 1: "application/epub+zip", 2: "application/gzip", 3: "application/json", 4: "application/msword", + 5: "application/octet-stream", 6: "application/pdf", 7: "application/rtf", 8: "application/vnd.amazon.ebook", + 9: "application/vnd.ms-excel", 10: "application/vnd.ms-powerpoint", + 11: "application/vnd.oasis.opendocument.presentation", 12: "application/vnd.oasis.opendocument.spreadsheet", + 13: "application/vnd.oasis.opendocument.text", + 14: "application/vnd.openxmlformats-officedocument.presentationml.presentation", + 15: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + 16: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + 17: "application/x-tar", 18: "application/xhtml+xml", 19: "application/xml", 20: "application/zip", + 21: "audio/aac", 22: "audio/midi", 23: "audio/mpeg", 24: "audio/ogg", 25: "audio/opus", 26: "audio/vnd.wave", 27: "audio/webm", + 28: "font/otf", 29: "font/ttf", 30: "font/woff", 31: "font/woff2", + 32: "image/apng", 33: "image/avif", 34: "image/bmp", 35: "image/gif", 36: "image/heic", 37: "image/jpeg", 38: "image/png", + 39: "image/svg+xml", 40: "image/tiff", 41: "image/webp", + 42: "model/3mf", 43: "model/gltf-binary", 44: "model/obj", 45: "model/step", 46: "model/stl", 47: "model/vnd.usdz+zip", + 48: "text/calendar", 49: "text/css", 50: "text/csv", 51: "text/html", 52: "text/javascript", 53: "text/markdown", + 54: "text/plain;charset=US-ASCII", 55: "text/plain;charset=UTF-16", 56: "text/plain;charset=UTF-8", 57: "text/vcard", + 58: "video/H264", 59: "video/H265", 60: "video/H266", 61: "video/ogg", 62: "video/VP8", 63: "video/VP9", 64: "video/webm", +} + +// GetCommonMediaType returns the MIME type string for a common type ID, or +// ("", false) if the ID is not in the mapping (should be rejected per spec). +func GetCommonMediaType(id uint8) (string, bool) { + s, ok := commonMediaTypes[id] + return s, ok +} + +// GetCommonMediaTypeID returns the common type ID for a MIME string, or +// (0, false) if the string has no assigned ID. +func GetCommonMediaTypeID(mediaType string) (uint8, bool) { + for id, mime := range commonMediaTypes { + if mime == mediaType { + return id, true + } + } + return 0, false +} + +// getCommonMediaTypeID is the unexported alias used internally by Encode. +func getCommonMediaTypeID(mediaType string) (uint8, bool) { + return GetCommonMediaTypeID(mediaType) +} + +// hashPayload is the unexported alias used internally by GetMessageHash. +func hashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool, expandedSize uint32) error { + return HashPayload(dst, filepath, wireSize, deflated, expandedSize) +} diff --git a/src/.env.example b/src/.env.example deleted file mode 100644 index dbedf7e..0000000 --- a/src/.env.example +++ /dev/null @@ -1,23 +0,0 @@ -# fmsgd Environment Variables - -# Required -FMSG_DATA_DIR=/var/lib/fmsgd/ -FMSG_DOMAIN=example.com -FMSG_ID_URL=http://127.0.0.1:8080 - - -FMSG_MAX_MSG_SIZE=10240 -FMSG_MAX_EXPANDED_SIZE=10240 -FMSG_MAX_PAST_TIME_DELTA=604800 -FMSG_MAX_FUTURE_TIME_DELTA=300 -FMSG_MIN_DOWNLOAD_RATE=5000 -FMSG_MIN_UPLOAD_RATE=5000 -FMSG_READ_BUFFER_SIZE=1600 - -# PostgreSQL connection variables (see https://www.postgresql.org/docs/current/libpq-envars.html) -PGHOST=127.0.0.1 -PGPORT=5432 -PGUSER= -PGPASSWORD= -PGDATABASE=fmsgd -PGSSLMODE=disable diff --git a/src/deflate.go b/src/deflate.go deleted file mode 100644 index 68b85a7..0000000 --- a/src/deflate.go +++ /dev/null @@ -1,157 +0,0 @@ -package main - -import ( - "bytes" - "compress/zlib" - "io" - "os" - "strings" -) - -// minDeflateSize is the minimum payload size in bytes before compression is -// attempted. -const minDeflateSize uint32 = 512 - -// incompressibleTypes lists media types (lowercased, without parameters) that -// are already compressed or otherwise unlikely to benefit from zlib-deflate. -var incompressibleTypes = map[string]bool{ - // images - "image/jpeg": true, "image/png": true, "image/gif": true, - "image/webp": true, "image/heic": true, "image/avif": true, - "image/apng": true, - // audio - "audio/aac": true, "audio/mpeg": true, "audio/ogg": true, - "audio/opus": true, "audio/webm": true, - // video - "video/h264": true, "video/h265": true, "video/h266": true, - "video/ogg": true, "video/vp8": true, "video/vp9": true, - "video/webm": true, - // archives / compressed containers - "application/gzip": true, "application/zip": true, - "application/epub+zip": true, - "application/octet-stream": true, - // zip-based office formats - "application/vnd.oasis.opendocument.presentation": true, - "application/vnd.oasis.opendocument.spreadsheet": true, - "application/vnd.oasis.opendocument.text": true, - "application/vnd.openxmlformats-officedocument.presentationml.presentation": true, - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true, - "application/vnd.openxmlformats-officedocument.wordprocessingml.document": true, - "application/vnd.amazon.ebook": true, - // fonts (compressed) - "font/woff": true, "font/woff2": true, - // pdf (internally compressed) - "application/pdf": true, - // 3d models (compressed containers) - "model/3mf": true, "model/gltf-binary": true, - "model/vnd.usdz+zip": true, -} - -// shouldCompress reports whether compression should be attempted for a payload -// with the given media type and size. It returns false for payloads that are -// too small or use a media type known to be already compressed. -func shouldCompress(mediaType string, dataSize uint32) bool { - if dataSize < minDeflateSize { - return false - } - t := strings.ToLower(mediaType) - if i := strings.IndexByte(t, ';'); i >= 0 { - t = strings.TrimRight(t[:i], " ") - } - return !incompressibleTypes[t] -} - -// deflateSampleSize is the number of bytes sampled from the start of a file -// to estimate compressibility before committing to a full-file compression -// pass. Chosen large enough for zlib to find patterns but small enough to be -// fast even on very large files. -const deflateSampleSize = 8192 - -// probeSample compresses up to deflateSampleSize bytes from the start of src -// and reports whether the ratio looks promising (compressed < 80% of input). -// src is seeked back to the start on return. -func probeSample(src *os.File, srcSize uint32) (bool, error) { - sampleLen := int64(deflateSampleSize) - if int64(srcSize) < sampleLen { - sampleLen = int64(srcSize) - } - - var buf bytes.Buffer - zw := zlib.NewWriter(&buf) - if _, err := io.CopyN(zw, src, sampleLen); err != nil { - _ = zw.Close() - return false, err - } - if err := zw.Close(); err != nil { - return false, err - } - - if _, err := src.Seek(0, io.SeekStart); err != nil { - return false, err - } - - return int64(buf.Len()) < sampleLen*8/10, nil -} - -// tryCompress compresses the file at srcPath using zlib-deflate and writes the -// result to a temporary file. For files larger than deflateSampleSize it first -// compresses a prefix sample to estimate compressibility, avoiding a full pass -// over files that won't compress well. It returns worthwhile=true only when -// the compressed output is less than 80% of the original size (at least a 20% -// reduction). When not worthwhile the temporary file is removed. When -// worthwhile the caller is responsible for removing the file at dstPath. -func tryCompress(srcPath string, srcSize uint32) (dstPath string, compressedSize uint32, worthwhile bool, err error) { - src, err := os.Open(srcPath) - if err != nil { - return "", 0, false, err - } - defer src.Close() - - // For files larger than the sample size, probe a prefix first. - if srcSize > deflateSampleSize { - promising, err := probeSample(src, srcSize) - if err != nil { - return "", 0, false, err - } - if !promising { - return "", 0, false, nil - } - } - - dst, err := os.CreateTemp("", "fmsg-deflate-*") - if err != nil { - return "", 0, false, err - } - dstName := dst.Name() - - zw := zlib.NewWriter(dst) - if _, err := io.Copy(zw, src); err != nil { - _ = zw.Close() - _ = dst.Close() - _ = os.Remove(dstName) - return "", 0, false, err - } - if err := zw.Close(); err != nil { - _ = dst.Close() - _ = os.Remove(dstName) - return "", 0, false, err - } - if err := dst.Close(); err != nil { - _ = os.Remove(dstName) - return "", 0, false, err - } - - fi, err := os.Stat(dstName) - if err != nil { - _ = os.Remove(dstName) - return "", 0, false, err - } - - cSize := uint32(fi.Size()) - if cSize >= srcSize*8/10 { - _ = os.Remove(dstName) - return "", 0, false, nil - } - - return dstName, cSize, true, nil -} diff --git a/src/deflate_test.go b/src/deflate_test.go deleted file mode 100644 index 7c10c34..0000000 --- a/src/deflate_test.go +++ /dev/null @@ -1,549 +0,0 @@ -package main - -import ( - "bytes" - "compress/zlib" - "crypto/rand" - "crypto/sha256" - "io" - "os" - "strings" - "testing" -) - -// --- shouldDeflate tests --- - -func TestShouldDeflate_TextTypes(t *testing.T) { - compressible := []string{ - "text/plain;charset=UTF-8", - "text/html", - "text/markdown", - "text/csv", - "text/css", - "text/javascript", - "text/calendar", - "text/vcard", - "text/plain;charset=US-ASCII", - "text/plain;charset=UTF-16", - "application/json", - "application/xml", - "application/xhtml+xml", - "application/rtf", - "application/x-tar", - "application/msword", - "application/vnd.ms-excel", - "application/vnd.ms-powerpoint", - "image/svg+xml", - "audio/midi", - "model/obj", - "model/step", - "model/stl", - } - for _, mt := range compressible { - if !shouldCompress(mt, 1024) { - t.Errorf("shouldCompress(%q, 1024) = false, want true", mt) - } - } -} - -func TestShouldDeflate_IncompressibleTypes(t *testing.T) { - skip := []string{ - "image/jpeg", - "image/png", - "image/gif", - "image/webp", - "image/heic", - "image/avif", - "image/apng", - "audio/aac", - "audio/mpeg", - "audio/ogg", - "audio/opus", - "audio/webm", - "video/H264", - "video/H265", - "video/H266", - "video/ogg", - "video/VP8", - "video/VP9", - "video/webm", - "application/gzip", - "application/zip", - "application/epub+zip", - "application/octet-stream", - "application/pdf", - "application/vnd.oasis.opendocument.presentation", - "application/vnd.oasis.opendocument.spreadsheet", - "application/vnd.oasis.opendocument.text", - "application/vnd.openxmlformats-officedocument.presentationml.presentation", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.amazon.ebook", - "font/woff", - "font/woff2", - "model/3mf", - "model/gltf-binary", - "model/vnd.usdz+zip", - } - for _, mt := range skip { - if shouldCompress(mt, 1024) { - t.Errorf("shouldCompress(%q, 1024) = true, want false", mt) - } - } -} - -func TestShouldDeflate_SmallPayload(t *testing.T) { - sizes := []uint32{0, 1, 100, 511} - for _, sz := range sizes { - if shouldCompress("text/plain;charset=UTF-8", sz) { - t.Errorf("shouldCompress(text/plain, %d) = true, want false", sz) - } - } -} - -func TestShouldDeflate_EdgeCases(t *testing.T) { - // Exactly at threshold: should attempt - if !shouldCompress("text/plain;charset=UTF-8", 512) { - t.Error("shouldDeflate at threshold 512 should return true") - } - // Unknown type: default to try compression - if !shouldCompress("application/x-custom", 1024) { - t.Error("shouldDeflate for unknown type should return true") - } - // Type with parameters should match base type - if shouldCompress("application/pdf; charset=utf-8", 1024) { - t.Error("shouldDeflate should strip parameters and match application/pdf") - } - // Case insensitive - if shouldCompress("VIDEO/H264", 1024) { - t.Error("shouldDeflate should be case-insensitive") - } -} - -// --- tryDeflate tests --- - -func writeTempFile(t *testing.T, data []byte) string { - t.Helper() - f, err := os.CreateTemp("", "deflate-test-*") - if err != nil { - t.Fatal(err) - } - if _, err := f.Write(data); err != nil { - f.Close() - os.Remove(f.Name()) - t.Fatal(err) - } - f.Close() - return f.Name() -} - -func TestTryDeflate_CompressibleData(t *testing.T) { - original := []byte(strings.Repeat("hello world, this is compressible text data! ", 100)) - srcPath := writeTempFile(t, original) - defer os.Remove(srcPath) - - dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(original))) - if err != nil { - t.Fatal(err) - } - if !worthwhile { - t.Fatal("expected compression to be worthwhile for repetitive text") - } - defer os.Remove(dstPath) - - if cSize >= uint32(len(original))*8/10 { - t.Errorf("compressed size %d not < 80%% of original %d", cSize, len(original)) - } - - // Verify the compressed file decompresses to the original data - f, err := os.Open(dstPath) - if err != nil { - t.Fatal(err) - } - defer f.Close() - - zr, err := zlib.NewReader(f) - if err != nil { - t.Fatal(err) - } - decompressed, err := io.ReadAll(zr) - zr.Close() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decompressed, original) { - t.Error("decompressed data does not match original") - } -} - -func TestTryDeflate_IncompressibleData(t *testing.T) { - // Random bytes are effectively incompressible - data := make([]byte, 2048) - if _, err := rand.Read(data); err != nil { - t.Fatal(err) - } - srcPath := writeTempFile(t, data) - defer os.Remove(srcPath) - - _, _, worthwhile, err := tryCompress(srcPath, uint32(len(data))) - if err != nil { - t.Fatal(err) - } - if worthwhile { - t.Error("expected compression of random data to not be worthwhile") - } -} - -func TestTryDeflate_RoundTrip(t *testing.T) { - original := []byte(strings.Repeat("Round-trip test data with enough repetition to compress well. ", 50)) - srcPath := writeTempFile(t, original) - defer os.Remove(srcPath) - - dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(original))) - if err != nil { - t.Fatal(err) - } - if !worthwhile { - t.Fatal("expected compression to be worthwhile") - } - defer os.Remove(dstPath) - - // Read compressed file - compressed, err := os.ReadFile(dstPath) - if err != nil { - t.Fatal(err) - } - if uint32(len(compressed)) != cSize { - t.Errorf("compressed file size %d != reported size %d", len(compressed), cSize) - } - - // Decompress and verify - zr, err := zlib.NewReader(bytes.NewReader(compressed)) - if err != nil { - t.Fatal(err) - } - decompressed, err := io.ReadAll(zr) - zr.Close() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decompressed, original) { - t.Errorf("round-trip mismatch: got %d bytes, want %d bytes", len(decompressed), len(original)) - } -} - -func TestTryDeflate_CleanupOnNotWorthwhile(t *testing.T) { - // Random data won't compress well — the temp file should be removed - data := make([]byte, 2048) - if _, err := rand.Read(data); err != nil { - t.Fatal(err) - } - srcPath := writeTempFile(t, data) - defer os.Remove(srcPath) - - dstPath, _, worthwhile, err := tryCompress(srcPath, uint32(len(data))) - if err != nil { - t.Fatal(err) - } - if worthwhile { - defer os.Remove(dstPath) - t.Fatal("expected not worthwhile for random data") - } - // dstPath should be empty and no leaked temp file - if dstPath != "" { - t.Errorf("expected empty dstPath when not worthwhile, got %q", dstPath) - } -} - -func TestTryDeflate_ProbeRejectsLargeIncompressible(t *testing.T) { - // A file larger than deflateSampleSize filled with random bytes should be - // rejected by the sample probe without writing a full compressed file. - data := make([]byte, deflateSampleSize+4096) - if _, err := rand.Read(data); err != nil { - t.Fatal(err) - } - srcPath := writeTempFile(t, data) - defer os.Remove(srcPath) - - _, _, worthwhile, err := tryCompress(srcPath, uint32(len(data))) - if err != nil { - t.Fatal(err) - } - if worthwhile { - t.Error("expected probe to reject large random data") - } -} - -func TestTryDeflate_ProbeAcceptsLargeCompressible(t *testing.T) { - // A file larger than deflateSampleSize filled with repetitive text should - // pass the probe and compress the full file successfully. - data := []byte(strings.Repeat("probe compressible test data! ", 1000)) - if len(data) <= deflateSampleSize { - t.Fatalf("test data %d bytes not larger than sample size %d", len(data), deflateSampleSize) - } - srcPath := writeTempFile(t, data) - defer os.Remove(srcPath) - - dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(data))) - if err != nil { - t.Fatal(err) - } - if !worthwhile { - t.Fatal("expected large compressible data to be worthwhile") - } - defer os.Remove(dstPath) - - if cSize >= uint32(len(data))*8/10 { - t.Errorf("compressed size %d not < 80%% of original %d", cSize, len(data)) - } - - // Verify round-trip - f, err := os.Open(dstPath) - if err != nil { - t.Fatal(err) - } - defer f.Close() - zr, err := zlib.NewReader(f) - if err != nil { - t.Fatal(err) - } - decompressed, err := io.ReadAll(zr) - zr.Close() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decompressed, data) { - t.Error("decompressed data does not match original") - } -} - -// --- Hash determinism tests --- - -func TestGetMessageHash_WithDeflate(t *testing.T) { - // Create repetitive data that compresses well - original := []byte(strings.Repeat("deflate hash test data ", 100)) - srcPath := writeTempFile(t, original) - defer os.Remove(srcPath) - - // Compress it - dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(original))) - if err != nil { - t.Fatal(err) - } - if !worthwhile { - t.Fatal("expected compression to be worthwhile") - } - defer os.Remove(dstPath) - - // Build header with deflate flag pointing at compressed file - h := &FMsgHeader{ - Version: 1, - Flags: FlagDeflate, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, - Topic: "test", - Type: "text/plain;charset=UTF-8", - Size: cSize, - ExpandedSize: uint32(len(original)), - Filepath: dstPath, - } - - msgHash, err := h.GetMessageHash() - if err != nil { - t.Fatal(err) - } - - // Manually compute expected: SHA-256(encoded header + decompressed data) - expected := sha256.New() - expected.Write(h.Encode()) - expected.Write(original) - expectedHash := expected.Sum(nil) - - if !bytes.Equal(msgHash, expectedHash) { - t.Errorf("hash mismatch:\n got %x\n want %x", msgHash, expectedHash) - } -} - -func TestGetMessageHash_WithoutDeflate(t *testing.T) { - original := []byte(strings.Repeat("no deflate hash test ", 100)) - srcPath := writeTempFile(t, original) - defer os.Remove(srcPath) - - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, - Topic: "test", - Type: "text/plain;charset=UTF-8", - Size: uint32(len(original)), - Filepath: srcPath, - } - - msgHash, err := h.GetMessageHash() - if err != nil { - t.Fatal(err) - } - - expected := sha256.New() - expected.Write(h.Encode()) - expected.Write(original) - expectedHash := expected.Sum(nil) - - if !bytes.Equal(msgHash, expectedHash) { - t.Errorf("hash mismatch:\n got %x\n want %x", msgHash, expectedHash) - } -} - -func TestGetMessageHash_DeflateChangesHash(t *testing.T) { - // The same data produces different message hashes depending on whether - // it is deflated, because the header bytes differ (flags and size fields). - original := []byte(strings.Repeat("deflate vs plain ", 100)) - srcPath := writeTempFile(t, original) - defer os.Remove(srcPath) - - dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(original))) - if err != nil { - t.Fatal(err) - } - if !worthwhile { - t.Fatal("expected compression to be worthwhile") - } - defer os.Remove(dstPath) - - base := FMsgHeader{ - Version: 1, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, - Topic: "test", - Type: "text/plain;charset=UTF-8", - } - - // Hash without deflate - plain := base - plain.Flags = 0 - plain.Size = uint32(len(original)) - plain.Filepath = srcPath - hashPlain, err := plain.GetMessageHash() - if err != nil { - t.Fatal(err) - } - - // Hash with deflate - deflated := base - deflated.Flags = FlagDeflate - deflated.Size = cSize - deflated.ExpandedSize = uint32(len(original)) - deflated.Filepath = dstPath - hashDeflated, err := deflated.GetMessageHash() - if err != nil { - t.Fatal(err) - } - - if bytes.Equal(hashPlain, hashDeflated) { - t.Error("expected different hashes for deflated vs non-deflated wire representations") - } -} - -func TestGetMessageHash_AttachmentDeflate(t *testing.T) { - msgData := []byte("short message body that fits in a file") - msgPath := writeTempFile(t, msgData) - defer os.Remove(msgPath) - - attOriginal := []byte(strings.Repeat("attachment data for compression test ", 100)) - attSrcPath := writeTempFile(t, attOriginal) - defer os.Remove(attSrcPath) - - attDstPath, attCSize, worthwhile, err := tryCompress(attSrcPath, uint32(len(attOriginal))) - if err != nil { - t.Fatal(err) - } - if !worthwhile { - t.Fatal("expected attachment compression to be worthwhile") - } - defer os.Remove(attDstPath) - - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, - Topic: "test", - Type: "text/plain;charset=UTF-8", - Size: uint32(len(msgData)), - Filepath: msgPath, - Attachments: []FMsgAttachmentHeader{ - { - Flags: 1 << 1, // attachment deflate bit - Type: "text/csv", - Filename: "data.csv", - Size: attCSize, - ExpandedSize: uint32(len(attOriginal)), - Filepath: attDstPath, - }, - }, - } - - msgHash, err := h.GetMessageHash() - if err != nil { - t.Fatal(err) - } - - // Manually compute: SHA-256(header + msg data + decompressed attachment) - expected := sha256.New() - expected.Write(h.Encode()) - expected.Write(msgData) - expected.Write(attOriginal) - expectedHash := expected.Sum(nil) - - if !bytes.Equal(msgHash, expectedHash) { - t.Errorf("attachment hash mismatch:\n got %x\n want %x", msgHash, expectedHash) - } -} - -// --- Encode flag tests --- - -func TestEncode_DeflateFlag(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: FlagDeflate, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, - Topic: "test", - Type: "text/plain;charset=UTF-8", - } - b := h.Encode() - if b[1]&FlagDeflate == 0 { - t.Error("deflate flag bit (5) not set in encoded header flags byte") - } -} - -func TestEncode_AttachmentDeflateFlag(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, - Topic: "test", - Type: "text/plain;charset=UTF-8", - Attachments: []FMsgAttachmentHeader{ - {Flags: 1 << 1, Type: "text/plain", Filename: "test.txt", Size: 100}, - }, - } - b := h.Encode() - // The encoded header ends with attachment headers. Find the attachment - // flags byte: it's the first byte after the attachment count byte. - // The attachment count is at len(b) - (1 + 1 + len("text/plain") + 1 + len("test.txt") + 4) - 1 - // Simpler: just verify the flags byte value appears in the output. - // The attachment count byte (1) followed by attachment flags byte (0x02). - found := false - for i := 0; i < len(b)-1; i++ { - if b[i] == 1 && b[i+1] == (1<<1) { // count=1, flags=0x02 - found = true - break - } - } - if !found { - t.Error("attachment deflate flag bit (1) not found in encoded header") - } -} diff --git a/src/defs.go b/src/defs.go deleted file mode 100644 index a0ee4fa..0000000 --- a/src/defs.go +++ /dev/null @@ -1,248 +0,0 @@ -package main - -import ( - "bytes" - "compress/zlib" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "fmt" - "io" - "os" - "strings" -) - -type FMsgAddress struct { - User string - Domain string -} - -type FMsgAttachmentHeader struct { - Flags uint8 - TypeID uint8 - Type string - Filename string - Size uint32 - ExpandedSize uint32 - - Filepath string -} - -type FMsgHeader struct { - Version uint8 - Flags uint8 - Pid []byte - From FMsgAddress - To []FMsgAddress - AddToFrom *FMsgAddress // Present when has-add-to flag is set - AddTo []FMsgAddress - Timestamp float64 - TypeID uint8 - Topic string - Type string - - // Size in bytes of entire message - Size uint32 - ExpandedSize uint32 // Decompressed size; present on wire iff FlagDeflate set - Attachments []FMsgAttachmentHeader - - HeaderHash []byte - ChallengeHash [32]byte - ChallengeCompleted bool // True if challenge was initiated and completed - InitialResponseCode uint8 // Protocol response chosen after header validation (11/64/65) - Filepath string - messageHash []byte -} - -// Returns a string representation of an address in the form @user@example.com -func (addr *FMsgAddress) ToString() string { - return fmt.Sprintf("@%s@%s", addr.User, addr.Domain) -} - -// Encode the message header to wire format as a []byte. This includes all -// fields up to and including the attachment headers per spec. This function -// will panic on error instead of returning one. -func (h *FMsgHeader) Encode() []byte { - var b bytes.Buffer - b.WriteByte(h.Version) - b.WriteByte(h.Flags) - if h.Flags&FlagHasPid == 1 { - b.Write(h.Pid[:]) - } - str := h.From.ToString() - b.WriteByte(byte(len(str))) - b.WriteString(str) - b.WriteByte(byte(len(h.To))) - for _, addr := range h.To { - str = addr.ToString() - b.WriteByte(byte(len(str))) - b.WriteString(str) - } - if h.Flags&FlagHasAddTo != 0 { - // add-to-from address (field 6) - addToFrom := h.AddToFrom - if addToFrom == nil { - addToFrom = &h.From - } - str := addToFrom.ToString() - b.WriteByte(byte(len(str))) - b.WriteString(str) - // add-to addresses (field 7) - b.WriteByte(byte(len(h.AddTo))) - for _, addr := range h.AddTo { - str = addr.ToString() - b.WriteByte(byte(len(str))) - b.WriteString(str) - } - } - if err := binary.Write(&b, binary.LittleEndian, h.Timestamp); err != nil { - panic(err) - } - // topic is only present when pid is NOT set - if h.Flags&FlagHasPid == 0 { - b.WriteByte(byte(len(h.Topic))) - b.WriteString(h.Topic) - } - if h.Flags&FlagCommonType != 0 { - typeID := h.TypeID - if typeID == 0 { - if id, ok := getCommonMediaTypeID(h.Type); ok { - typeID = id - } - } - b.WriteByte(typeID) - } else { - b.WriteByte(byte(len(h.Type))) - b.WriteString(h.Type) - } - // size (uint32 LE) - if err := binary.Write(&b, binary.LittleEndian, h.Size); err != nil { - panic(err) - } - // expanded size (uint32 LE) — present iff zlib-deflate flag set - if h.Flags&FlagDeflate != 0 { - if err := binary.Write(&b, binary.LittleEndian, h.ExpandedSize); err != nil { - panic(err) - } - } - // attachment headers - b.WriteByte(byte(len(h.Attachments))) - for _, att := range h.Attachments { - b.WriteByte(att.Flags) - if att.Flags&1 != 0 { - typeID := att.TypeID - if typeID == 0 { - if id, ok := getCommonMediaTypeID(att.Type); ok { - typeID = id - } - } - b.WriteByte(typeID) - } else { - b.WriteByte(byte(len(att.Type))) - b.WriteString(att.Type) - } - b.WriteByte(byte(len(att.Filename))) - b.WriteString(att.Filename) - if err := binary.Write(&b, binary.LittleEndian, att.Size); err != nil { - panic(err) - } - // attachment expanded size — present iff attachment zlib-deflate flag set - if att.Flags&(1<<1) != 0 { - if err := binary.Write(&b, binary.LittleEndian, att.ExpandedSize); err != nil { - panic(err) - } - } - } - return b.Bytes() -} - -// String returns a human-readable summary of the header fields. -func (h *FMsgHeader) String() string { - var b strings.Builder - fmt.Fprintf(&b, "v%d flags=%d", h.Version, h.Flags) - if len(h.Pid) > 0 { - fmt.Fprintf(&b, " pid=%s", hex.EncodeToString(h.Pid)) - } - fmt.Fprintf(&b, "\nfrom:\t%s", h.From.ToString()) - for i, addr := range h.To { - if i == 0 { - fmt.Fprintf(&b, "\nto:\t%s", addr.ToString()) - } else { - fmt.Fprintf(&b, "\n\t%s", addr.ToString()) - } - } - for i, addr := range h.AddTo { - if i == 0 { - fmt.Fprintf(&b, "\nadd to:\t%s", addr.ToString()) - } else { - fmt.Fprintf(&b, "\n\t%s", addr.ToString()) - } - } - fmt.Fprintf(&b, "\ntopic:\t%s", h.Topic) - fmt.Fprintf(&b, "\ntype:\t%s", h.Type) - fmt.Fprintf(&b, "\nsize:\t%d", h.Size) - return b.String() -} - -func (h *FMsgHeader) GetHeaderHash() []byte { - if h.HeaderHash == nil { - b := sha256.Sum256(h.Encode()) - h.HeaderHash = b[:] - } - return h.HeaderHash -} - -func (h *FMsgHeader) GetMessageHash() ([]byte, error) { - if h.messageHash == nil { - hash := sha256.New() - - headerBytes := h.Encode() - if _, err := io.Copy(hash, bytes.NewBuffer(headerBytes)); err != nil { - return nil, err - } - - if err := hashPayload(hash, h.Filepath, int64(h.Size), h.Flags&FlagDeflate != 0, h.ExpandedSize); err != nil { - return nil, err - } - - // include attachment data (sequential byte sequences following - // the message body, bounded by attachment header sizes) - for _, att := range h.Attachments { - compressed := att.Flags&(1<<1) != 0 - if err := hashPayload(hash, att.Filepath, int64(att.Size), compressed, att.ExpandedSize); err != nil { - return nil, fmt.Errorf("hash attachment %s: %w", att.Filename, err) - } - } - - h.messageHash = hash.Sum(nil) - } - return h.messageHash, nil -} - -func hashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool, expandedSize uint32) error { - f, err := os.Open(filepath) - if err != nil { - return err - } - defer f.Close() - - if deflated { - lr := io.LimitReader(f, wireSize) - zr, err := zlib.NewReader(lr) - if err != nil { - return err - } - written, err := io.Copy(dst, zr) - _ = zr.Close() - if err != nil { - return err - } - if uint32(written) != expandedSize { - return fmt.Errorf("decompressed size %d does not match declared expanded size %d", written, expandedSize) - } - return nil - } - - _, err = io.CopyN(dst, f, wireSize) - return err -} diff --git a/src/defs_test.go b/src/defs_test.go deleted file mode 100644 index 591b555..0000000 --- a/src/defs_test.go +++ /dev/null @@ -1,854 +0,0 @@ -package main - -import ( - "bytes" - "compress/zlib" - "crypto/sha256" - "encoding/binary" - "io" - "math" - "os" - "path/filepath" - "testing" -) - -func TestAddressToString(t *testing.T) { - tests := []struct { - addr FMsgAddress - want string - }{ - {FMsgAddress{User: "alice", Domain: "example.com"}, "@alice@example.com"}, - {FMsgAddress{User: "Bob", Domain: "EXAMPLE.COM"}, "@Bob@EXAMPLE.COM"}, - {FMsgAddress{User: "a-b.c", Domain: "x.y.z"}, "@a-b.c@x.y.z"}, - } - for _, tt := range tests { - got := tt.addr.ToString() - if got != tt.want { - t.Errorf("FMsgAddress{%q, %q}.ToString() = %q, want %q", tt.addr.User, tt.addr.Domain, got, tt.want) - } - } -} - -func TestEncodeMinimalHeader(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "alice", Domain: "a.com"}, - To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, - Timestamp: 1700000000.0, - Topic: "hello", - Type: "text/plain", - } - b := h.Encode() - - r := bytes.NewReader(b) - - // version - ver, _ := r.ReadByte() - if ver != 1 { - t.Fatalf("version = %d, want 1", ver) - } - - // flags - flags, _ := r.ReadByte() - if flags != 0 { - t.Fatalf("flags = %d, want 0", flags) - } - - // from address - fromLen, _ := r.ReadByte() - fromBytes := make([]byte, fromLen) - r.Read(fromBytes) - if string(fromBytes) != "@alice@a.com" { - t.Fatalf("from = %q, want %q", string(fromBytes), "@alice@a.com") - } - - // to count - toCount, _ := r.ReadByte() - if toCount != 1 { - t.Fatalf("to count = %d, want 1", toCount) - } - - // to[0] - toLen, _ := r.ReadByte() - toBytes := make([]byte, toLen) - r.Read(toBytes) - if string(toBytes) != "@bob@b.com" { - t.Fatalf("to[0] = %q, want %q", string(toBytes), "@bob@b.com") - } - - // timestamp - var ts float64 - binary.Read(r, binary.LittleEndian, &ts) - if ts != 1700000000.0 { - t.Fatalf("timestamp = %f, want 1700000000.0", ts) - } - - // topic - topicLen, _ := r.ReadByte() - topicBytes := make([]byte, topicLen) - r.Read(topicBytes) - if string(topicBytes) != "hello" { - t.Fatalf("topic = %q, want %q", string(topicBytes), "hello") - } - - // type - typeLen, _ := r.ReadByte() - typeBytes := make([]byte, typeLen) - r.Read(typeBytes) - if string(typeBytes) != "text/plain" { - t.Fatalf("type = %q, want %q", string(typeBytes), "text/plain") - } - - // size (uint32 LE) - var size uint32 - binary.Read(r, binary.LittleEndian, &size) - if size != 0 { - t.Fatalf("size = %d, want 0", size) - } - - // attachment count - attachCount, _ := r.ReadByte() - if attachCount != 0 { - t.Fatalf("attach count = %d, want 0", attachCount) - } - - // should have consumed entire buffer - if r.Len() != 0 { - t.Fatalf("unexpected %d trailing bytes", r.Len()) - } -} - -func TestEncodeWithPid(t *testing.T) { - pid := make([]byte, 32) - for i := range pid { - pid[i] = byte(i) - } - h := &FMsgHeader{ - Version: 1, - Flags: FlagHasPid, - Pid: pid, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 0, - Topic: "should be omitted", - Type: "text/plain", - } - b := h.Encode() - r := bytes.NewReader(b) - - r.ReadByte() // version - r.ReadByte() // flags - - // pid should be next 32 bytes - pidOut := make([]byte, 32) - n, _ := r.Read(pidOut) - if n != 32 { - t.Fatalf("pid bytes read = %d, want 32", n) - } - if !bytes.Equal(pidOut, pid) { - t.Fatalf("pid mismatch") - } - - // skip from - fLen, _ := r.ReadByte() - fBuf := make([]byte, fLen) - r.Read(fBuf) - - // skip to count + to[0] - toCount, _ := r.ReadByte() - if toCount != 1 { - t.Fatalf("to count = %d, want 1", toCount) - } - tLen, _ := r.ReadByte() - tBuf := make([]byte, tLen) - r.Read(tBuf) - - // skip timestamp - var ts float64 - binary.Read(r, binary.LittleEndian, &ts) - - // topic must NOT be present when pid is set — next byte should be type length - typeLen, _ := r.ReadByte() - typeBytes := make([]byte, typeLen) - r.Read(typeBytes) - if string(typeBytes) != "text/plain" { - t.Fatalf("expected type field directly after timestamp, got %q", string(typeBytes)) - } - - // size + attachment count - var size uint32 - binary.Read(r, binary.LittleEndian, &size) - attachCount, _ := r.ReadByte() - if attachCount != 0 { - t.Fatalf("attach count = %d, want 0", attachCount) - } - - if r.Len() != 0 { - t.Fatalf("unexpected %d trailing bytes", r.Len()) - } -} - -func TestEncodeWithAddTo(t *testing.T) { - pid := make([]byte, 32) - h := &FMsgHeader{ - Version: 1, - Flags: FlagHasPid | FlagHasAddTo, - Pid: pid, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - AddToFrom: &FMsgAddress{User: "a", Domain: "b.com"}, - AddTo: []FMsgAddress{{User: "e", Domain: "f.com"}}, - Timestamp: 0, - Topic: "", - Type: "text/plain", - } - b := h.Encode() - r := bytes.NewReader(b) - - r.ReadByte() // version - r.ReadByte() // flags - - // skip pid (32 bytes) - pidBuf := make([]byte, 32) - r.Read(pidBuf) - - // skip from - fLen, _ := r.ReadByte() - fBuf := make([]byte, fLen) - r.Read(fBuf) - - // skip to count + to[0] - toCount, _ := r.ReadByte() - if toCount != 1 { - t.Fatalf("to count = %d, want 1", toCount) - } - tLen, _ := r.ReadByte() - tBuf := make([]byte, tLen) - r.Read(tBuf) - - // add-to-from - addToFromLen, _ := r.ReadByte() - addToFrom := make([]byte, addToFromLen) - r.Read(addToFrom) - if string(addToFrom) != "@a@b.com" { - t.Fatalf("add-to-from = %q, want %q", string(addToFrom), "@a@b.com") - } - - // add to count - addToCount, _ := r.ReadByte() - if addToCount != 1 { - t.Fatalf("add to count = %d, want 1", addToCount) - } - - // add to[0] - atLen, _ := r.ReadByte() - atBuf := make([]byte, atLen) - r.Read(atBuf) - if string(atBuf) != "@e@f.com" { - t.Fatalf("add to[0] = %q, want %q", string(atBuf), "@e@f.com") - } -} - -func TestEncodeWithAddToDefaultsAddToFromToFromAddress(t *testing.T) { - pid := make([]byte, 32) - h := &FMsgHeader{ - Version: 1, - Flags: FlagHasPid | FlagHasAddTo, - Pid: pid, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - AddTo: []FMsgAddress{{User: "e", Domain: "f.com"}}, - Timestamp: 0, - Type: "text/plain", - } - b := h.Encode() - r := bytes.NewReader(b) - - r.ReadByte() // version - r.ReadByte() // flags - r.Read(make([]byte, 32)) - fLen, _ := r.ReadByte() - r.Read(make([]byte, fLen)) - r.ReadByte() // to count - tLen, _ := r.ReadByte() - r.Read(make([]byte, tLen)) - - addToFromLen, _ := r.ReadByte() - addToFrom := make([]byte, addToFromLen) - r.Read(addToFrom) - if string(addToFrom) != "@a@b.com" { - t.Fatalf("default add-to-from = %q, want %q", string(addToFrom), "@a@b.com") - } -} - -func TestEncodeNoAddToWhenFlagUnset(t *testing.T) { - // When FlagHasAddTo is NOT set, add-to addresses should not appear on the wire - // even if the AddTo slice is populated. - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - AddTo: []FMsgAddress{{User: "e", Domain: "f.com"}}, // should be ignored - Timestamp: 0, - Topic: "", - Type: "text/plain", - } - withAddTo := h.Encode() - - h2 := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 0, - Topic: "", - Type: "text/plain", - } - withoutAddTo := h2.Encode() - - if !bytes.Equal(withAddTo, withoutAddTo) { - t.Fatalf("encoded bytes differ when AddTo populated but flag unset") - } -} - -func TestGetHeaderHash(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "alice", Domain: "a.com"}, - To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, - Timestamp: 1700000000.0, - Topic: "test", - Type: "text/plain", - } - hash := h.GetHeaderHash() - if len(hash) != 32 { - t.Fatalf("hash length = %d, want 32", len(hash)) - } - - // Must be deterministic - hash2 := h.GetHeaderHash() - if !bytes.Equal(hash, hash2) { - t.Fatal("GetHeaderHash not deterministic") - } - - // Must match manual SHA-256 of Encode() - expected := sha256.Sum256(h.Encode()) - if !bytes.Equal(hash, expected[:]) { - t.Fatal("GetHeaderHash does not match sha256(Encode())") - } -} - -func TestGetHeaderHashCommonTypeMatchesWireIDEncoding(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: FlagCommonType, - From: FMsgAddress{User: "alice", Domain: "a.com"}, - To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, - Timestamp: 1700000000, - Topic: "x", - TypeID: 3, - Type: "application/json", - } - expected := sha256.Sum256(h.Encode()) - got := h.GetHeaderHash() - if !bytes.Equal(got, expected[:]) { - t.Fatalf("GetHeaderHash mismatch for common type ID") - } -} - -func TestStringOutput(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "alice", Domain: "a.com"}, - To: []FMsgAddress{{User: "bob", Domain: "b.com"}, {User: "carol", Domain: "c.com"}}, - Timestamp: 0, - Topic: "greetings", - Type: "text/plain", - Size: 42, - } - s := h.String() - - // Check key substrings are present - for _, want := range []string{ - "v1", - "@alice@a.com", - "@bob@b.com", - "@carol@c.com", - "greetings", - "text/plain", - "42", - } { - if !bytes.Contains([]byte(s), []byte(want)) { - t.Errorf("String() missing %q", want) - } - } -} - -func TestStringWithAddTo(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: FlagHasAddTo, - From: FMsgAddress{User: "alice", Domain: "a.com"}, - To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, - AddTo: []FMsgAddress{{User: "dave", Domain: "d.com"}}, - Topic: "t", - Type: "text/plain", - } - s := h.String() - if !bytes.Contains([]byte(s), []byte("add to:")) { - t.Error("String() missing 'add to:' label") - } - if !bytes.Contains([]byte(s), []byte("@dave@d.com")) { - t.Error("String() missing add-to address") - } -} - -func TestEncodeTimestampEncoding(t *testing.T) { - // Verify the timestamp is encoded as little-endian float64 - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 1700000000.5, - Topic: "", - Type: "", - } - b := h.Encode() - - // Find timestamp position: version(1) + flags(1) + from(1+len) + to_count(1) + to[0](1+len) - fromStr := "@a@b.com" - toStr := "@c@d.com" - offset := 1 + 1 + 1 + len(fromStr) + 1 + 1 + len(toStr) // = 2 + 9 + 10 = 21 - tsBytes := b[offset : offset+8] - - bits := binary.LittleEndian.Uint64(tsBytes) - ts := math.Float64frombits(bits) - if ts != 1700000000.5 { - t.Fatalf("timestamp = %f, want 1700000000.5", ts) - } - - // After timestamp: topic(1+0) + type(1+0) + size(4) + attach_count(1) = 7 bytes - if r := bytes.NewReader(b[offset+8:]); r.Len() != 7 { - t.Fatalf("trailing bytes after timestamp = %d, want 7", r.Len()) - } -} - -func TestEncodeWithAttachments(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 0, - Topic: "", - Type: "text/plain", - Size: 100, - Attachments: []FMsgAttachmentHeader{ - {Flags: 0, Type: "image/png", Filename: "pic.png", Size: 2048}, - {Flags: 1, TypeID: 38, Type: "image/png", Filename: "doc.txt", Size: 512}, - }, - } - b := h.Encode() - r := bytes.NewReader(b) - - r.ReadByte() // version - r.ReadByte() // flags - - // skip from - fLen, _ := r.ReadByte() - r.Read(make([]byte, fLen)) - // skip to count + to[0] - r.ReadByte() - tLen, _ := r.ReadByte() - r.Read(make([]byte, tLen)) - // skip timestamp - var ts float64 - binary.Read(r, binary.LittleEndian, &ts) - // skip topic - topicLen, _ := r.ReadByte() - r.Read(make([]byte, topicLen)) - // skip type - typeLen, _ := r.ReadByte() - r.Read(make([]byte, typeLen)) - - // size - var size uint32 - binary.Read(r, binary.LittleEndian, &size) - if size != 100 { - t.Fatalf("size = %d, want 100", size) - } - - // attachment count - attachCount, _ := r.ReadByte() - if attachCount != 2 { - t.Fatalf("attach count = %d, want 2", attachCount) - } - - // attachment 0 - att0Flags, _ := r.ReadByte() - if att0Flags != 0 { - t.Fatalf("att[0] flags = %d, want 0", att0Flags) - } - att0TypeLen, _ := r.ReadByte() - att0Type := make([]byte, att0TypeLen) - r.Read(att0Type) - if string(att0Type) != "image/png" { - t.Fatalf("att[0] type = %q, want %q", string(att0Type), "image/png") - } - att0FnLen, _ := r.ReadByte() - att0Fn := make([]byte, att0FnLen) - r.Read(att0Fn) - if string(att0Fn) != "pic.png" { - t.Fatalf("att[0] filename = %q, want %q", string(att0Fn), "pic.png") - } - var att0Size uint32 - binary.Read(r, binary.LittleEndian, &att0Size) - if att0Size != 2048 { - t.Fatalf("att[0] size = %d, want 2048", att0Size) - } - - // attachment 1 - att1Flags, _ := r.ReadByte() - if att1Flags != 1 { - t.Fatalf("att[1] flags = %d, want 1", att1Flags) - } - att1TypeID, _ := r.ReadByte() - if att1TypeID != 38 { - t.Fatalf("att[1] type ID = %d, want 38", att1TypeID) - } - att1FnLen, _ := r.ReadByte() - att1Fn := make([]byte, att1FnLen) - r.Read(att1Fn) - if string(att1Fn) != "doc.txt" { - t.Fatalf("att[1] filename = %q, want %q", string(att1Fn), "doc.txt") - } - var att1Size uint32 - binary.Read(r, binary.LittleEndian, &att1Size) - if att1Size != 512 { - t.Fatalf("att[1] size = %d, want 512", att1Size) - } - - if r.Len() != 0 { - t.Fatalf("unexpected %d trailing bytes", r.Len()) - } -} - -func TestEncodeWithCommonMessageType(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: FlagCommonType, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 0, - Topic: "", - TypeID: 3, - Type: "application/json", - } - b := h.Encode() - r := bytes.NewReader(b) - - r.ReadByte() // version - r.ReadByte() // flags - - // skip from - fLen, _ := r.ReadByte() - r.Read(make([]byte, fLen)) - // skip to count + to[0] - r.ReadByte() - tLen, _ := r.ReadByte() - r.Read(make([]byte, tLen)) - // skip timestamp - var ts float64 - binary.Read(r, binary.LittleEndian, &ts) - // skip topic - topicLen, _ := r.ReadByte() - r.Read(make([]byte, topicLen)) - - typeID, _ := r.ReadByte() - if typeID != 3 { - t.Fatalf("type ID = %d, want 3", typeID) - } -} - -func TestGetMessageHashUsesDecompressedPayloads(t *testing.T) { - compress := func(data []byte) []byte { - var b bytes.Buffer - w := zlib.NewWriter(&b) - if _, err := w.Write(data); err != nil { - t.Fatalf("zlib write: %v", err) - } - if err := w.Close(); err != nil { - t.Fatalf("zlib close: %v", err) - } - return b.Bytes() - } - - msgPlain := []byte("hello compressed body") - attPlain := []byte("hello compressed attachment") - msgWire := compress(msgPlain) - attWire := compress(attPlain) - - tmpDir := t.TempDir() - msgPath := filepath.Join(tmpDir, "msg.bin") - if err := os.WriteFile(msgPath, msgWire, 0600); err != nil { - t.Fatalf("write msg file: %v", err) - } - attPath := filepath.Join(tmpDir, "att.bin") - if err := os.WriteFile(attPath, attWire, 0600); err != nil { - t.Fatalf("write attachment file: %v", err) - } - - h := &FMsgHeader{ - Version: 1, - Flags: FlagDeflate, - From: FMsgAddress{User: "alice", Domain: "a.com"}, - To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, - Timestamp: 1700000000, - Topic: "t", - Type: "text/plain", - Size: uint32(len(msgWire)), - ExpandedSize: uint32(len(msgPlain)), - Attachments: []FMsgAttachmentHeader{ - {Flags: 1 << 1, Type: "application/octet-stream", Filename: "a.bin", Size: uint32(len(attWire)), ExpandedSize: uint32(len(attPlain)), Filepath: attPath}, - }, - Filepath: msgPath, - } - - got, err := h.GetMessageHash() - if err != nil { - t.Fatalf("GetMessageHash() error: %v", err) - } - - manual := sha256.New() - if _, err := io.Copy(manual, bytes.NewReader(h.Encode())); err != nil { - t.Fatalf("manual header copy: %v", err) - } - if _, err := manual.Write(msgPlain); err != nil { - t.Fatalf("manual msg write: %v", err) - } - if _, err := manual.Write(attPlain); err != nil { - t.Fatalf("manual att write: %v", err) - } - want := manual.Sum(nil) - - if !bytes.Equal(got, want) { - t.Fatalf("message hash mismatch: got %x want %x", got, want) - } -} - -func TestEncodeExpandedSizePresentWhenDeflateSet(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: FlagDeflate, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 0, - Topic: "", - Type: "text/plain", - Size: 50, - ExpandedSize: 200, - } - b := h.Encode() - r := bytes.NewReader(b) - - r.ReadByte() // version - r.ReadByte() // flags - - // skip from - fLen, _ := r.ReadByte() - r.Read(make([]byte, fLen)) - // skip to count + to[0] - r.ReadByte() - tLen, _ := r.ReadByte() - r.Read(make([]byte, tLen)) - // skip timestamp - var ts float64 - binary.Read(r, binary.LittleEndian, &ts) - // skip topic - topicLen, _ := r.ReadByte() - r.Read(make([]byte, topicLen)) - // skip type - typeLen, _ := r.ReadByte() - r.Read(make([]byte, typeLen)) - - // size - var size uint32 - binary.Read(r, binary.LittleEndian, &size) - if size != 50 { - t.Fatalf("size = %d, want 50", size) - } - - // expanded size must be present because FlagDeflate is set - var expandedSize uint32 - if err := binary.Read(r, binary.LittleEndian, &expandedSize); err != nil { - t.Fatalf("reading expanded size: %v", err) - } - if expandedSize != 200 { - t.Fatalf("expanded size = %d, want 200", expandedSize) - } - - // attachment count - attachCount, _ := r.ReadByte() - if attachCount != 0 { - t.Fatalf("attach count = %d, want 0", attachCount) - } - - if r.Len() != 0 { - t.Fatalf("unexpected %d trailing bytes", r.Len()) - } -} - -func TestEncodeNoExpandedSizeWhenDeflateUnset(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 0, - Topic: "", - Type: "text/plain", - Size: 100, - ExpandedSize: 999, // must NOT appear on wire - } - b := h.Encode() - r := bytes.NewReader(b) - - r.ReadByte() // version - r.ReadByte() // flags - - fLen, _ := r.ReadByte() - r.Read(make([]byte, fLen)) - r.ReadByte() - tLen, _ := r.ReadByte() - r.Read(make([]byte, tLen)) - var ts float64 - binary.Read(r, binary.LittleEndian, &ts) - topicLen, _ := r.ReadByte() - r.Read(make([]byte, topicLen)) - typeLen, _ := r.ReadByte() - r.Read(make([]byte, typeLen)) - - var size uint32 - binary.Read(r, binary.LittleEndian, &size) - if size != 100 { - t.Fatalf("size = %d, want 100", size) - } - - // No expanded size field; next byte should be attachment count = 0 - attachCount, _ := r.ReadByte() - if attachCount != 0 { - t.Fatalf("attach count = %d, want 0", attachCount) - } - - if r.Len() != 0 { - t.Fatalf("unexpected %d trailing bytes (expanded size should not be present when deflate unset)", r.Len()) - } -} - -func TestEncodeAttachmentExpandedSizePresentWhenDeflateSet(t *testing.T) { - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 0, - Topic: "", - Type: "text/plain", - Size: 0, - Attachments: []FMsgAttachmentHeader{ - // attachment with zlib-deflate flag (bit 1 = 0b00000010) - {Flags: 1 << 1, Type: "text/plain", Filename: "doc.txt", Size: 60, ExpandedSize: 300}, - }, - } - b := h.Encode() - r := bytes.NewReader(b) - - r.ReadByte() // version - r.ReadByte() // flags - fLen, _ := r.ReadByte() - r.Read(make([]byte, fLen)) - r.ReadByte() - tLen, _ := r.ReadByte() - r.Read(make([]byte, tLen)) - var ts float64 - binary.Read(r, binary.LittleEndian, &ts) - topicLen, _ := r.ReadByte() - r.Read(make([]byte, topicLen)) - typeLen, _ := r.ReadByte() - r.Read(make([]byte, typeLen)) - var msgSize uint32 - binary.Read(r, binary.LittleEndian, &msgSize) - - // attachment count - attachCount, _ := r.ReadByte() - if attachCount != 1 { - t.Fatalf("attach count = %d, want 1", attachCount) - } - - // attachment flags - attFlags, _ := r.ReadByte() - if attFlags != 1<<1 { - t.Fatalf("att flags = %d, want %d", attFlags, 1<<1) - } - // type (length-prefixed) - attTypeLen, _ := r.ReadByte() - r.Read(make([]byte, attTypeLen)) - // filename - attFnLen, _ := r.ReadByte() - r.Read(make([]byte, attFnLen)) - // wire size - var attSize uint32 - binary.Read(r, binary.LittleEndian, &attSize) - if attSize != 60 { - t.Fatalf("att size = %d, want 60", attSize) - } - // expanded size must be present - var attExpandedSize uint32 - if err := binary.Read(r, binary.LittleEndian, &attExpandedSize); err != nil { - t.Fatalf("reading att expanded size: %v", err) - } - if attExpandedSize != 300 { - t.Fatalf("att expanded size = %d, want 300", attExpandedSize) - } - - if r.Len() != 0 { - t.Fatalf("unexpected %d trailing bytes", r.Len()) - } -} - -func TestHashPayloadRejectsExpandedSizeMismatch(t *testing.T) { - compress := func(data []byte) []byte { - var b bytes.Buffer - w := zlib.NewWriter(&b) - w.Write(data) - w.Close() - return b.Bytes() - } - - plain := []byte("hello world this is test data") - wire := compress(plain) - - tmpDir := t.TempDir() - p := filepath.Join(tmpDir, "data.bin") - if err := os.WriteFile(p, wire, 0600); err != nil { - t.Fatalf("write file: %v", err) - } - - // Correct expanded size should succeed - var dst bytes.Buffer - if err := hashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))); err != nil { - t.Fatalf("hashPayload with correct expanded size: %v", err) - } - - // Wrong expanded size should fail - dst.Reset() - err := hashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))+1) - if err == nil { - t.Fatal("hashPayload with wrong expanded size: expected error, got nil") - } -} diff --git a/src/dns.go b/src/dns.go deleted file mode 100644 index a5e6daa..0000000 --- a/src/dns.go +++ /dev/null @@ -1,148 +0,0 @@ -package main - -import ( - "fmt" - "io" - "log" - "net" - "net/http" - "os" - "strings" - "time" - - "github.com/miekg/dns" -) - -func dnssecRequired() bool { - return os.Getenv("FMSG_REQUIRE_DNSSEC") == "true" -} - -func resolverAuthenticatedData(name string, qtype uint16) (bool, error) { - cfg, err := dns.ClientConfigFromFile("/etc/resolv.conf") - if err != nil { - return false, err - } - - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(name), qtype) - msg.SetEdns0(4096, true) - - client := &dns.Client{Timeout: 5 * time.Second} - var lastErr error - for _, server := range cfg.Servers { - addr := net.JoinHostPort(server, cfg.Port) - resp, _, err := client.Exchange(msg, addr) - if err != nil { - lastErr = err - continue - } - if resp == nil { - lastErr = fmt.Errorf("nil DNS response from %s", addr) - continue - } - if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError { - lastErr = fmt.Errorf("dns rcode %d from %s", resp.Rcode, addr) - continue - } - return resp.AuthenticatedData, nil - } - - if lastErr == nil { - lastErr = fmt.Errorf("no DNS resolvers configured") - } - return false, lastErr -} - -// lookupAuthorisedIPs resolves fmsg. for A and AAAA records -func lookupAuthorisedIPs(domain string) ([]net.IP, error) { - fmsgDomain := "fmsg." + domain - ips, err := net.LookupIP(fmsgDomain) - if err != nil { - return nil, fmt.Errorf("DNS lookup for %s failed: %w", fmsgDomain, err) - } - if len(ips) == 0 { - return nil, fmt.Errorf("no A/AAAA records found for %s", fmsgDomain) - } - - if dnssecRequired() { - adA, errA := resolverAuthenticatedData(fmsgDomain, dns.TypeA) - adAAAA, errAAAA := resolverAuthenticatedData(fmsgDomain, dns.TypeAAAA) - if !adA && !adAAAA { - if errA != nil && errAAAA != nil { - return nil, fmt.Errorf("dnssec validation failed for %s: A=%v AAAA=%v", fmsgDomain, errA, errAAAA) - } - return nil, fmt.Errorf("dnssec validation failed for %s: resolver did not set AD bit", fmsgDomain) - } - } - - return ips, nil -} - -// getExternalIP discovers this host's external IP address -func getExternalIP() (net.IP, error) { - services := []string{ - "https://api.ipify.org", - "https://checkip.amazonaws.com", - "https://icanhazip.com", - } - client := &http.Client{Timeout: 10 * time.Second} - var lastErr error - for _, svc := range services { - resp, err := client.Get(svc) - if err != nil { - lastErr = fmt.Errorf("%s: %w", svc, err) - continue - } - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - lastErr = fmt.Errorf("%s: failed to read response: %w", svc, err) - continue - } - if resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("%s: unexpected status %d", svc, resp.StatusCode) - continue - } - ip := net.ParseIP(strings.TrimSpace(string(body))) - if ip == nil { - lastErr = fmt.Errorf("%s: failed to parse IP from response: %s", svc, string(body)) - continue - } - return ip, nil - } - return nil, fmt.Errorf("all external IP services failed, last error: %w", lastErr) -} - -// verifyDomainIP checks that this host's external IP is present in the -// fmsg. authorised IP set. Panics if not found. -func verifyDomainIP(domain string) { - externalIP, err := getExternalIP() - if err != nil { - log.Panicf("ERROR: failed to get external IP: %s", err) - } - log.Printf("INFO: external IP: %s", externalIP) - - authorisedIPs, err := lookupAuthorisedIPs(domain) - if err != nil { - log.Panicf("ERROR: failed to lookup fmsg.%s: %s", domain, err) - } - - for _, ip := range authorisedIPs { - if externalIP.Equal(ip) { - log.Printf("INFO: external IP %s found in fmsg.%s authorised IPs", externalIP, domain) - return - } - } - - log.Panicf("ERROR: external IP %s not found in fmsg.%s authorised IPs %v", externalIP, domain, authorisedIPs) -} - -// checkDomainIP verifies the external IP is authorised unless -// FMSG_SKIP_DOMAIN_IP_CHECK is set to "true". -func checkDomainIP(domain string) { - if os.Getenv("FMSG_SKIP_DOMAIN_IP_CHECK") == "true" { - log.Println("INFO: skipping domain IP verification (FMSG_SKIP_DOMAIN_IP_CHECK=true)") - return - } - verifyDomainIP(domain) -} diff --git a/src/host.go b/src/host.go deleted file mode 100644 index 5bc6e6e..0000000 --- a/src/host.go +++ /dev/null @@ -1,1735 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "compress/zlib" - "crypto/tls" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "io" - "log" - "math" - "mime" - "net" - "net/url" - "os" - "path/filepath" - "strings" - "time" - "unicode" - "unicode/utf8" - - env "github.com/caitlinelfring/go-env-default" - "github.com/joho/godotenv" - "github.com/levenlabs/golib/timeutil" -) - -const ( - InboxDirName = "in" - OutboxDirName = "out" - - // Flag bit assignments per SPEC.md: - // bit 0 = has pid, bit 1 = has add to, bit 2 = common type, bit 3 = important, - // bit 4 = no reply, bit 5 = deflate, bits 6-7 = reserved. - FlagHasPid uint8 = 1 - FlagHasAddTo uint8 = 1 << 1 - FlagCommonType uint8 = 1 << 2 - FlagImportant uint8 = 1 << 3 - FlagNoReply uint8 = 1 << 4 - FlagDeflate uint8 = 1 << 5 - - RejectCodeInvalid uint8 = 1 - RejectCodeUnsupportedVersion uint8 = 2 - RejectCodeUndisclosed uint8 = 3 - RejectCodeTooBig uint8 = 4 - RejectCodeInsufficentResources uint8 = 5 - RejectCodeParentNotFound uint8 = 6 - RejectCodePastTime uint8 = 7 - RejectCodeFutureTime uint8 = 8 - RejectCodeTimeTravel uint8 = 9 - RejectCodeDuplicate uint8 = 10 - AcceptCodeAddTo uint8 = 11 - AcceptCodeContinue uint8 = 64 - AcceptCodeSkipData uint8 = 65 - - RejectCodeUserUnknown uint8 = 100 - RejectCodeUserFull uint8 = 101 - RejectCodeUserNotAccepting uint8 = 102 - RejectCodeUserDuplicate uint8 = 103 - RejectCodeUserUndisclosed uint8 = 105 - - RejectCodeAccept uint8 = 200 - - messageReservedBitsMask uint8 = 0b11000000 - attachmentReservedBitsMask uint8 = 0b11111100 -) - -// responseCodeName returns the human-friendly name for a response code. -func responseCodeName(code uint8) string { - switch code { - case RejectCodeInvalid: - return "invalid" - case RejectCodeUnsupportedVersion: - return "unsupported version" - case RejectCodeUndisclosed: - return "undisclosed" - case RejectCodeTooBig: - return "too big" - case RejectCodeInsufficentResources: - return "insufficient resources" - case RejectCodeParentNotFound: - return "parent not found" - case RejectCodePastTime: - return "past time" - case RejectCodeFutureTime: - return "future time" - case RejectCodeTimeTravel: - return "time travel" - case RejectCodeDuplicate: - return "duplicate" - case AcceptCodeAddTo: - return "accept add to" - case RejectCodeUserUnknown: - return "user unknown" - case RejectCodeUserFull: - return "user full" - case RejectCodeUserNotAccepting: - return "user not accepting" - case RejectCodeUserDuplicate: - return "user duplicate" - case RejectCodeUserUndisclosed: - return "user undisclosed" - case RejectCodeAccept: - return "accept" - default: - return fmt.Sprintf("unknown(%d)", code) - } -} - -var ErrProtocolViolation = errors.New("protocol violation") - -// commonMediaTypes maps common type IDs to their MIME strings per SPEC.md §4. -// IDs 1–64; unmapped IDs must be rejected with code 1 (invalid). -var commonMediaTypes = map[uint8]string{ - 1: "application/epub+zip", 2: "application/gzip", 3: "application/json", 4: "application/msword", - 5: "application/octet-stream", 6: "application/pdf", 7: "application/rtf", 8: "application/vnd.amazon.ebook", - 9: "application/vnd.ms-excel", 10: "application/vnd.ms-powerpoint", - 11: "application/vnd.oasis.opendocument.presentation", 12: "application/vnd.oasis.opendocument.spreadsheet", - 13: "application/vnd.oasis.opendocument.text", - 14: "application/vnd.openxmlformats-officedocument.presentationml.presentation", - 15: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - 16: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - 17: "application/x-tar", 18: "application/xhtml+xml", 19: "application/xml", 20: "application/zip", - 21: "audio/aac", 22: "audio/midi", 23: "audio/mpeg", 24: "audio/ogg", 25: "audio/opus", 26: "audio/vnd.wave", 27: "audio/webm", - 28: "font/otf", 29: "font/ttf", 30: "font/woff", 31: "font/woff2", - 32: "image/apng", 33: "image/avif", 34: "image/bmp", 35: "image/gif", 36: "image/heic", 37: "image/jpeg", 38: "image/png", - 39: "image/svg+xml", 40: "image/tiff", 41: "image/webp", - 42: "model/3mf", 43: "model/gltf-binary", 44: "model/obj", 45: "model/step", 46: "model/stl", 47: "model/vnd.usdz+zip", - 48: "text/calendar", 49: "text/css", 50: "text/csv", 51: "text/html", 52: "text/javascript", 53: "text/markdown", - 54: "text/plain;charset=US-ASCII", 55: "text/plain;charset=UTF-16", 56: "text/plain;charset=UTF-8", 57: "text/vcard", - 58: "video/H264", 59: "video/H265", 60: "video/H266", 61: "video/ogg", 62: "video/VP8", 63: "video/VP9", 64: "video/webm", -} - -// getCommonMediaType returns the MIME type string for a common type ID, or -// empty string + false if the ID is not mapped (should be rejected per spec). -func getCommonMediaType(id uint8) (string, bool) { - s, ok := commonMediaTypes[id] - return s, ok -} - -// getCommonMediaTypeID returns the common type ID for a MIME string. -func getCommonMediaTypeID(mediaType string) (uint8, bool) { - for id, mime := range commonMediaTypes { - if mime == mediaType { - return id, true - } - } - return 0, false -} - -var Port = 4930 - -// The only reason RemotePort would ever be different from Port is when running two fmsg hosts on the same machine so the same port is unavaliable. -var RemotePort = 4930 -var PastTimeDelta float64 = 7 * 24 * 60 * 60 -var FutureTimeDelta float64 = 300 -var MinDownloadRate float64 = 5000 -var MinUploadRate float64 = 5000 -var ReadBufferSize = 1600 -var MaxMessageSize = uint32(1024 * 10) -var MaxExpandedSize = uint32(1024 * 10) -var SkipAuthorisedIPs = false -var TLSInsecureSkipVerify = false -var DataDir = "got on startup" -var Domain = "got on startup" -var IDURI = "got on startup" -var AtRune, _ = utf8.DecodeRuneInString("@") -var MinNetIODeadline = 6 * time.Second - -var serverTLSConfig *tls.Config - -func buildServerTLSConfig() *tls.Config { - certFile := os.Getenv("FMSG_TLS_CERT") - keyFile := os.Getenv("FMSG_TLS_KEY") - if certFile == "" || keyFile == "" { - log.Fatalf("ERROR: FMSG_TLS_CERT and FMSG_TLS_KEY must be set") - } - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - log.Fatalf("ERROR: loading TLS certificate: %s", err) - } - return &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - }, - NextProtos: []string{"fmsg/1"}, - } -} - -func buildClientTLSConfig(serverName string) *tls.Config { - return &tls.Config{ - ServerName: serverName, - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: TLSInsecureSkipVerify, - NextProtos: []string{"fmsg/1"}, - } -} - -// loadEnvConfig reads env vars (after godotenv.Load so .env is picked up). -func loadEnvConfig() { - Port = env.GetIntDefault("FMSG_PORT", 4930) - RemotePort = env.GetIntDefault("FMSG_REMOTE_PORT", 4930) - PastTimeDelta = env.GetFloatDefault("FMSG_MAX_PAST_TIME_DELTA", 7*24*60*60) - FutureTimeDelta = env.GetFloatDefault("FMSG_MAX_FUTURE_TIME_DELTA", 300) - MinDownloadRate = env.GetFloatDefault("FMSG_MIN_DOWNLOAD_RATE", 5000) - MinUploadRate = env.GetFloatDefault("FMSG_MIN_UPLOAD_RATE", 5000) - ReadBufferSize = env.GetIntDefault("FMSG_READ_BUFFER_SIZE", 1600) - MaxMessageSize = uint32(env.GetIntDefault("FMSG_MAX_MSG_SIZE", 1024*10)) - MaxExpandedSize = uint32(env.GetIntDefault("FMSG_MAX_EXPANDED_SIZE", int(MaxMessageSize))) - SkipAuthorisedIPs = os.Getenv("FMSG_SKIP_AUTHORISED_IPS") == "true" - TLSInsecureSkipVerify = os.Getenv("FMSG_TLS_INSECURE_SKIP_VERIFY") == "true" -} - -// Updates DataDir from environment, panics if not a valid directory. -func setDataDir() { - value, hasValue := os.LookupEnv("FMSG_DATA_DIR") - if !hasValue { - log.Panic("ERROR: FMSG_DATA_DIR not set") - } - _, err := os.Stat(value) - if err != nil { - log.Panicf("ERROR: FMSG_DATA_DIR, %s: %s", value, err) - } - DataDir = value -} - -// Updates Domain from environment, panics if not a valid domain. -func setDomain() { - domain, hasValue := os.LookupEnv("FMSG_DOMAIN") - if !hasValue { - log.Panicln("ERROR: FMSG_DOMAIN not set") - } - _, err := net.LookupHost("fmsg." + domain) - if err != nil { - log.Panicf("ERROR: FMSG_DOMAIN, %s: %s\n", domain, err) - } - Domain = domain - - // verify our external IP is in the fmsg authorised IP set - checkDomainIP(domain) -} - -// Updates IDURL from environment, panics if not valid. -func setIDURL() { - rawUrl, hasValue := os.LookupEnv("FMSG_ID_URL") - if !hasValue { - log.Panicln("ERROR: FMSG_ID_URL not set") - } - url, err := url.Parse(rawUrl) - if err != nil { - log.Panicf("ERROR: FMSG_ID_URL not valid, %s: %s", rawUrl, err) - } - _, err = net.LookupHost(url.Hostname()) - if err != nil { - log.Panicf("ERROR: FMSG_ID_URL lookup failed, %s: %s", url, err) - } - // TODO ping URL to verify its up and responding in a timely manner - IDURI = rawUrl - log.Printf("INFO: ID URL: %s", IDURI) -} - -func calcNetIODuration(sizeInBytes int, bytesPerSecond float64) time.Duration { - rate := float64(sizeInBytes) / bytesPerSecond - d := time.Duration(rate * float64(time.Second)) - if d < MinNetIODeadline { - return MinNetIODeadline - } - return d -} - -func isValidUser(s string) bool { - if !utf8.ValidString(s) || len(s) == 0 || len(s) > 64 { - return false - } - - isSpecial := func(r rune) bool { - return r == '-' || r == '_' || r == '.' - } - - runes := []rune(s) - if isSpecial(runes[0]) || isSpecial(runes[len(runes)-1]) { - return false - } - - lastWasSpecial := false - for _, c := range runes { - if unicode.IsLetter(c) || unicode.IsNumber(c) { - lastWasSpecial = false - continue - } - if !isSpecial(c) { - return false - } - if lastWasSpecial { - return false - } - lastWasSpecial = true - } - return true -} - -func isASCIIBytes(b []byte) bool { - for _, c := range b { - if c > 127 { - return false - } - } - return true -} - -func isValidAttachmentFilename(name string) bool { - if !utf8.ValidString(name) || len(name) == 0 || len(name) >= 256 { - return false - } - - isSpecial := func(r rune) bool { - return r == '-' || r == '_' || r == ' ' || r == '.' - } - - runes := []rune(name) - if isSpecial(runes[0]) || isSpecial(runes[len(runes)-1]) { - return false - } - - lastWasSpecial := false - for _, r := range runes { - if unicode.IsLetter(r) || unicode.IsNumber(r) { - lastWasSpecial = false - continue - } - if !isSpecial(r) { - return false - } - if lastWasSpecial { - return false - } - lastWasSpecial = true - } - - return true -} - -func isMessageRetrievable(msg *FMsgHeader) bool { - if msg == nil { - return false - } - if msg.Filepath != "" { - st, err := os.Stat(msg.Filepath) - if err == nil && !st.IsDir() { - return true - } - } - if len(msg.Pid) == 0 { - return false - } - parentID, err := lookupMsgIdByHash(msg.Pid) - if err != nil || parentID == 0 { - return false - } - parentMsg, err := getMsgByID(parentID) - if err != nil { - return false - } - if parentMsg == nil { - return false - } - return isMessageRetrievable(parentMsg) -} - -func isParentParticipant(parent *FMsgHeader, addr *FMsgAddress) bool { - if parent == nil || addr == nil { - return false - } - target := strings.ToLower(addr.ToString()) - if strings.ToLower(parent.From.ToString()) == target { - return true - } - for i := range parent.To { - if strings.ToLower(parent.To[i].ToString()) == target { - return true - } - } - if parent.AddToFrom != nil && strings.ToLower(parent.AddToFrom.ToString()) == target { - return true - } - for i := range parent.AddTo { - if strings.ToLower(parent.AddTo[i].ToString()) == target { - return true - } - } - return false -} - -func isValidDomain(s string) bool { - if len(s) == 0 || len(s) > 253 { - return false - } - if s == "localhost" { - return true - } - labels := strings.Split(s, ".") - if len(labels) < 2 { - return false - } - for _, label := range labels { - if len(label) == 0 || len(label) > 63 { - return false - } - if label[0] == '-' || label[len(label)-1] == '-' { - return false - } - for _, c := range label { - if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-') { - return false - } - } - } - return true -} - -func parseAddress(b []byte) (*FMsgAddress, error) { - if len(b) < 4 { - return nil, fmt.Errorf("invalid address: too short (%d bytes)", len(b)) - } - var addr = &FMsgAddress{} - addrStr := string(b) - firstAt := strings.IndexRune(addrStr, AtRune) - if firstAt == -1 || firstAt != 0 { - return addr, fmt.Errorf("invalid address, must start with @ %s", addr) - } - lastAt := strings.LastIndex(addrStr, "@") - if lastAt == firstAt { - return addr, fmt.Errorf("invalid address, must have second @ %s", addr) - } - addr.User = addrStr[1:lastAt] - if !isValidUser(addr.User) { - return addr, fmt.Errorf("invalid user in address: %s", addr.User) - } - addr.Domain = addrStr[lastAt+1:] - if !isValidDomain(addr.Domain) { - return addr, fmt.Errorf("invalid domain in address: %s", addr.Domain) - } - return addr, nil -} - -// Reads byte slice prefixed with uint8 size from reader supplied -func ReadUInt8Slice(r io.Reader) ([]byte, error) { - var size byte - err := binary.Read(r, binary.LittleEndian, &size) - if err != nil { - return nil, err - } - return io.ReadAll(io.LimitReader(r, int64(size))) -} - -func readAddress(r io.Reader) (*FMsgAddress, error) { - slice, err := ReadUInt8Slice(r) - if err != nil { - return nil, err - } - return parseAddress(slice) -} - -func handleChallenge(c net.Conn, r *bufio.Reader) error { - hashSlice, err := io.ReadAll(io.LimitReader(r, 32)) - if err != nil { - return err - } - hash := *(*[32]byte)(hashSlice) - log.Printf("INFO: CHALLENGE <-- %s", hex.EncodeToString(hashSlice)) - - // Verify the challenger's IP is the Host-B IP registered for this message - // (§10.5 step 2). An unrecognised hash OR a mismatched IP both → TERMINATE. - remoteIP, _, _ := net.SplitHostPort(c.RemoteAddr().String()) - header, exists := lookupOutgoing(hash, remoteIP) - if !exists { - return fmt.Errorf("challenge for unknown message: %s, from: %s", hex.EncodeToString(hashSlice), c.RemoteAddr().String()) - } - msgHash, err := header.GetMessageHash() - if err != nil { - return err - } - if _, err := c.Write(msgHash); err != nil { - return err - } - return nil -} - -func rejectAccept(c net.Conn, codes []byte) error { - _, err := c.Write(codes) - return err -} - -func sendCode(c net.Conn, code uint8) error { - return rejectAccept(c, []byte{code}) -} - -func validateMessageFlags(c net.Conn, flags uint8) error { - if flags&messageReservedBitsMask != 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("reserved message flag bits set: %#08b", flags) - } - return nil -} - -func validateAttachmentFlags(c net.Conn, flags uint8) error { - if flags&attachmentReservedBitsMask != 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("reserved attachment flag bits set: %#08b", flags) - } - return nil -} - -func hasDomainRecipient(addrs []FMsgAddress, domain string) bool { - for _, addr := range addrs { - if strings.EqualFold(addr.Domain, domain) { - return true - } - } - return false -} - -func determineSenderDomain(h *FMsgHeader) string { - if len(h.AddTo) > 0 && h.AddToFrom != nil { - return h.AddToFrom.Domain - } - return h.From.Domain -} - -func verifySenderIP(c net.Conn, senderDomain string) error { - if SkipAuthorisedIPs { - return nil - } - - remoteHost, _, err := net.SplitHostPort(c.RemoteAddr().String()) - if err != nil { - log.Printf("WARN: failed to parse remote address for DNS check: %s", err) - return fmt.Errorf("DNS verification failed") - } - - remoteIP := net.ParseIP(remoteHost) - if remoteIP == nil { - log.Printf("WARN: failed to parse remote IP: %s", remoteHost) - return fmt.Errorf("DNS verification failed") - } - - authorisedIPs, err := lookupAuthorisedIPs(senderDomain) - if err != nil { - log.Printf("WARN: DNS lookup failed for fmsg.%s: %s", senderDomain, err) - return fmt.Errorf("DNS verification failed") - } - - for _, ip := range authorisedIPs { - if remoteIP.Equal(ip) { - return nil - } - } - - log.Printf("WARN: remote IP %s not in authorised IPs for fmsg.%s", remoteIP.String(), senderDomain) - return fmt.Errorf("DNS verification failed") -} - -func handleAddToPath(c net.Conn, h *FMsgHeader) (*FMsgHeader, error) { - if len(h.AddTo) == 0 { - return h, nil - } - - addToHasOurDomain := hasDomainRecipient(h.AddTo, Domain) - - parentID, err := lookupMsgIdByHash(h.Pid) - if err != nil { - return h, err - } - - if parentID == 0 { - h.InitialResponseCode = AcceptCodeContinue - return h, nil - } - - parentMsg, err := getMsgByID(parentID) - if err != nil { - return h, err - } - if parentMsg == nil || !isMessageRetrievable(parentMsg) { - h.InitialResponseCode = AcceptCodeContinue - return h, nil - } - - if parentMsg.Timestamp-FutureTimeDelta > h.Timestamp { - if err := sendCode(c, RejectCodeTimeTravel); err != nil { - return h, err - } - return h, fmt.Errorf("add-to: time travel detected (parent time %f, current %f)", parentMsg.Timestamp, h.Timestamp) - } - - if addToHasOurDomain { - h.InitialResponseCode = AcceptCodeSkipData - return h, nil - } - - h.Filepath = parentMsg.Filepath - for i := range h.Attachments { - if i < len(parentMsg.Attachments) { - h.Attachments[i].Filepath = parentMsg.Attachments[i].Filepath - } - } - h.InitialResponseCode = AcceptCodeAddTo - return h, nil -} - -func validatePidReplyPath(c net.Conn, h *FMsgHeader) error { - if len(h.AddTo) != 0 || h.Flags&FlagHasPid == 0 { - return nil - } - - parentID, err := lookupMsgIdByHash(h.Pid) - if err != nil { - return err - } - if parentID == 0 { - if err := sendCode(c, RejectCodeParentNotFound); err != nil { - return err - } - return fmt.Errorf("pid reply: parent not found for pid %s", hex.EncodeToString(h.Pid)) - } - - parentMsg, err := getMsgByID(parentID) - if err != nil { - return err - } - if parentMsg == nil { - if err := sendCode(c, RejectCodeParentNotFound); err != nil { - return err - } - return fmt.Errorf("pid reply: parent message not found by ID %d", parentID) - } - if !isMessageRetrievable(parentMsg) { - if err := sendCode(c, RejectCodeParentNotFound); err != nil { - return err - } - return fmt.Errorf("pid reply: parent is not retrievable for msg %d", parentID) - } - - if parentMsg.Timestamp-FutureTimeDelta > h.Timestamp { - if err := sendCode(c, RejectCodeTimeTravel); err != nil { - return err - } - return fmt.Errorf("pid reply: time travel detected (parent time %f, current %f)", parentMsg.Timestamp, h.Timestamp) - } - if !isParentParticipant(parentMsg, &h.From) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("pid reply: sender %s was not a participant of parent", h.From.ToString()) - } - - return nil -} - -func readVersionOrChallenge(c net.Conn, r *bufio.Reader, h *FMsgHeader) (bool, error) { - v, err := r.ReadByte() - if err != nil { - return false, err - } - if v >= 129 { - challengeVersion := 256 - int(v) - if challengeVersion == 1 { - return true, handleChallenge(c, r) - } - if err := sendCode(c, RejectCodeUnsupportedVersion); err != nil { - log.Printf("WARN: failed to send unsupported version response: %s", err) - } - return false, fmt.Errorf("unsupported challenge version: %d", challengeVersion) - } - if v != 1 { - if err := sendCode(c, RejectCodeUnsupportedVersion); err != nil { - log.Printf("WARN: failed to send unsupported version response: %s", err) - } - return false, fmt.Errorf("unsupported message version: %d", v) - } - h.Version = v - return false, nil -} - -func readToRecipients(c net.Conn, r *bufio.Reader, h *FMsgHeader) (map[string]bool, error) { - num, err := r.ReadByte() - if err != nil { - return nil, err - } - if num == 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return nil, err - } - return nil, fmt.Errorf("to count must be >= 1") - } - seen := make(map[string]bool) - for num > 0 { - addr, err := readAddress(r) - if err != nil { - return nil, err - } - key := strings.ToLower(addr.ToString()) - if seen[key] { - return nil, fmt.Errorf("duplicate recipient address: %s", addr.ToString()) - } - seen[key] = true - h.To = append(h.To, *addr) - num-- - } - return seen, nil -} - -func readAddToRecipients(c net.Conn, r *bufio.Reader, h *FMsgHeader, seen map[string]bool) error { - if h.Flags&FlagHasAddTo == 0 { - return nil - } - if h.Flags&FlagHasPid == 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("add to exists but pid does not") - } - - addToFrom, err := readAddress(r) - if err != nil { - if err2 := sendCode(c, RejectCodeInvalid); err2 != nil { - return err2 - } - return fmt.Errorf("reading add-to-from address: %w", err) - } - - addToFromKey := strings.ToLower(addToFrom.ToString()) - fromKey := strings.ToLower(h.From.ToString()) - inFromOrTo := fromKey == addToFromKey - if !inFromOrTo { - for _, toAddr := range h.To { - if strings.ToLower(toAddr.ToString()) == addToFromKey { - inFromOrTo = true - break - } - } - } - if !inFromOrTo { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("add-to-from (%s) not in from or to", addToFrom.ToString()) - } - h.AddToFrom = addToFrom - - addToCount, err := r.ReadByte() - if err != nil { - return err - } - if addToCount == 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("add to flag set but count is 0") - } - - addToSeen := make(map[string]bool) - for addToCount > 0 { - addr, err := readAddress(r) - if err != nil { - return err - } - key := strings.ToLower(addr.ToString()) - if addToSeen[key] { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("duplicate recipient address in add to: %s", addr.ToString()) - } - addToSeen[key] = true - if seen[key] { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("add-to address already in to: %s", addr.ToString()) - } - h.AddTo = append(h.AddTo, *addr) - addToCount-- - } - - return nil -} - -func readAndValidateTimestamp(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { - if err := binary.Read(r, binary.LittleEndian, &h.Timestamp); err != nil { - return err - } - now := timeutil.TimestampNow().Float64() - delta := now - h.Timestamp - if PastTimeDelta > 0 && delta > PastTimeDelta { - if err := sendCode(c, RejectCodePastTime); err != nil { - return err - } - return fmt.Errorf("message timestamp: %f too far in past, delta: %fs", h.Timestamp, delta) - } - if FutureTimeDelta > 0 && delta < 0 && math.Abs(delta) > FutureTimeDelta { - if err := sendCode(c, RejectCodeFutureTime); err != nil { - return err - } - return fmt.Errorf("message timestamp: %f too far in future, delta: %fs", h.Timestamp, delta) - } - return nil -} - -func readMessageType(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { - if h.Flags&FlagCommonType != 0 { - typeID, err := r.ReadByte() - if err != nil { - return err - } - mtype, ok := getCommonMediaType(typeID) - if !ok { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("unmapped common type ID: %d", typeID) - } - h.TypeID = typeID - h.Type = mtype - return nil - } - - mime, err := ReadUInt8Slice(r) - if err != nil { - return err - } - if !isASCIIBytes(mime) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("message media type must be US-ASCII") - } - h.Type = string(mime) - return nil -} - -func readAttachmentType(c net.Conn, r *bufio.Reader, flags uint8) (string, uint8, error) { - if flags&(1<<0) != 0 { - typeID, err := r.ReadByte() - if err != nil { - return "", 0, err - } - mtype, ok := getCommonMediaType(typeID) - if !ok { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return "", 0, err - } - return "", 0, fmt.Errorf("unmapped attachment common type ID: %d", typeID) - } - return mtype, typeID, nil - } - - typeBytes, err := ReadUInt8Slice(r) - if err != nil { - return "", 0, err - } - if !isASCIIBytes(typeBytes) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return "", 0, err - } - return "", 0, fmt.Errorf("attachment media type must be US-ASCII") - } - return string(typeBytes), 0, nil -} - -func readAttachmentHeaders(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { - var attachCount uint8 - if err := binary.Read(r, binary.LittleEndian, &attachCount); err != nil { - return err - } - - totalSize := h.Size - // When message is compressed, expanded size comes from the header field. - // When uncompressed, the wire size IS the expanded size. - var totalExpandedSize uint32 - if h.Flags&FlagDeflate != 0 { - totalExpandedSize = h.ExpandedSize - } else { - totalExpandedSize = h.Size - } - filenameSeen := make(map[string]bool) - for i := uint8(0); i < attachCount; i++ { - attFlags, err := r.ReadByte() - if err != nil { - return err - } - if err := validateAttachmentFlags(c, attFlags); err != nil { - return err - } - - attType, attTypeID, err := readAttachmentType(c, r, attFlags) - if err != nil { - return err - } - - filenameBytes, err := ReadUInt8Slice(r) - if err != nil { - return err - } - filename := string(filenameBytes) - if !isValidAttachmentFilename(filename) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("invalid attachment filename: %s", filename) - } - filenameKey := strings.ToLower(filename) - if filenameSeen[filenameKey] { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("duplicate attachment filename: %s", filename) - } - filenameSeen[filenameKey] = true - - var attSize uint32 - if err := binary.Read(r, binary.LittleEndian, &attSize); err != nil { - return err - } - - // read attachment expanded size — present iff attachment zlib-deflate flag set (§5) - var attExpandedSize uint32 - if attFlags&(1<<1) != 0 { - if err := binary.Read(r, binary.LittleEndian, &attExpandedSize); err != nil { - return err - } - totalExpandedSize += attExpandedSize - } else { - // uncompressed: expanded size equals wire size - totalExpandedSize += attSize - } - - h.Attachments = append(h.Attachments, FMsgAttachmentHeader{ - Flags: attFlags, - TypeID: attTypeID, - Type: attType, - Filename: filename, - Size: attSize, - ExpandedSize: attExpandedSize, - }) - totalSize += attSize - } - - if totalSize > MaxMessageSize { - if err := sendCode(c, RejectCodeTooBig); err != nil { - return err - } - return fmt.Errorf("total message size %d exceeds max %d", totalSize, MaxMessageSize) - } - - if totalExpandedSize > MaxExpandedSize { - if err := sendCode(c, RejectCodeTooBig); err != nil { - return err - } - return fmt.Errorf("total expanded size %d exceeds MAX_EXPANDED_SIZE %d", totalExpandedSize, MaxExpandedSize) - } - - return nil -} - -func readHeader(c net.Conn) (*FMsgHeader, *bufio.Reader, error) { - r := bufio.NewReaderSize(c, ReadBufferSize) - var h = &FMsgHeader{InitialResponseCode: AcceptCodeContinue} - - d := calcNetIODuration(66000, MinDownloadRate) // max possible header size - c.SetReadDeadline(time.Now().Add(d)) - - handled, err := readVersionOrChallenge(c, r, h) - if err != nil { - if handled { - return nil, r, err - } - return h, r, err - } - if handled { - return nil, r, nil - } - - // read flags - flags, err := r.ReadByte() - if err != nil { - return h, r, err - } - h.Flags = flags - if err := validateMessageFlags(c, flags); err != nil { - return h, r, err - } - - // read pid if any - if flags&FlagHasPid == 1 { - pid, err := io.ReadAll(io.LimitReader(r, 32)) - if err != nil { - return h, r, err - } - h.Pid = make([]byte, 32) - copy(h.Pid, pid) - } - - // read from address - from, err := readAddress(r) - if err != nil { - return h, r, err - } - - h.From = *from - - seen, err := readToRecipients(c, r, h) - if err != nil { - return h, r, err - } - - if err := readAddToRecipients(c, r, h, seen); err != nil { - return h, r, err - } - - if err := readAndValidateTimestamp(c, r, h); err != nil { - return h, r, err - } - - // read topic — only present when pid is NOT set (first message in a thread) - if flags&FlagHasPid == 0 { - topic, err := ReadUInt8Slice(r) - if err != nil { - return h, r, err - } - h.Topic = string(topic) - } - - if err := readMessageType(c, r, h); err != nil { - return h, r, err - } - - // read message size - if err := binary.Read(r, binary.LittleEndian, &h.Size); err != nil { - return h, r, err - } - // read expanded size — present iff zlib-deflate flag is set (§2 field 12) - if h.Flags&FlagDeflate != 0 { - if err := binary.Read(r, binary.LittleEndian, &h.ExpandedSize); err != nil { - return h, r, err - } - if h.ExpandedSize > MaxExpandedSize { - if err := sendCode(c, RejectCodeTooBig); err != nil { - return h, r, err - } - return h, r, fmt.Errorf("expanded size %d exceeds MAX_EXPANDED_SIZE %d", h.ExpandedSize, MaxExpandedSize) - } - } - // Size check is deferred until attachment headers are parsed (see below) - - if err := readAttachmentHeaders(c, r, h); err != nil { - return h, r, err - } - - log.Printf("INFO: <-- MSG\n%s", h) - - if !hasDomainRecipient(h.To, Domain) && !hasDomainRecipient(h.AddTo, Domain) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return h, r, err - } - return h, r, fmt.Errorf("no recipients for domain %s", Domain) - } - - if err := verifySenderIP(c, determineSenderDomain(h)); err != nil { - return nil, r, err - } - - h, err = handleAddToPath(c, h) - if err != nil { - return h, r, err - } - if h == nil { - return nil, r, nil - } - - if err := validatePidReplyPath(c, h); err != nil { - return h, r, err - } - - return h, r, nil -} - -// Sends CHALLENGE request to sender, receiving and storing the challenge hash. -// DNS verification of the remote IP is performed during header exchange (readHeader). -// TODO [Spec step 2]: The spec defines challenge modes (NEVER, ALWAYS, -// HAS_NOT_PARTICIPATED, DIFFERENT_DOMAIN) as implementation choices. -// Currently defaults to ALWAYS. Implement configurable challenge mode. -func challenge(conn net.Conn, h *FMsgHeader, senderDomain string) error { - - // Connection 2 MUST target the same IP as Connection 1 (spec 2.1). - remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String()) - if err != nil { - return fmt.Errorf("failed to parse remote address for challenge: %w", err) - } - conn2, err := tls.Dial("tcp", net.JoinHostPort(remoteHost, fmt.Sprintf("%d", RemotePort)), buildClientTLSConfig("fmsg."+senderDomain)) - if err != nil { - return err - } - version := uint8(255) - if err := binary.Write(conn2, binary.LittleEndian, version); err != nil { - return err - } - hash := h.GetHeaderHash() - log.Printf("INFO: --> CHALLENGE\t%s\n", hex.EncodeToString(hash)) - if _, err := conn2.Write(hash); err != nil { - return err - } - - // read challenge response - resp, err := io.ReadAll(io.LimitReader(conn2, 32)) - if err != nil { - return err - } - if len(resp) != 32 { - return fmt.Errorf("challenge response size %d, expected 32", len(resp)) - } - copy(h.ChallengeHash[:], resp) - h.ChallengeCompleted = true - log.Printf("INFO: <-- CHALLENGE RESP\t%s\n", hex.EncodeToString(resp)) - - // gracefully close 2nd connection - if err := conn2.Close(); err != nil { - return err - } - - return nil -} - -func validateMsgRecvForAddr(h *FMsgHeader, addr *FMsgAddress, msgHash []byte) (code uint8, err error) { - duplicate, err := hasAddrReceivedMsgHash(msgHash, addr) - if err != nil { - return RejectCodeUserUndisclosed, err - } - if duplicate { - return RejectCodeUserDuplicate, nil - } - - detail, err := getAddressDetail(addr) - if err != nil { - return RejectCodeUserUndisclosed, err - } - if detail == nil { - return RejectCodeUserUnknown, nil - } - - // check user accepting new - if !detail.AcceptingNew { - return RejectCodeUserNotAccepting, nil - } - - // check user limits - if detail.LimitRecvCountPer1d > -1 && detail.RecvCountPer1d+1 > detail.LimitRecvCountPer1d { - log.Printf("WARN: Message rejected: RecvCountPer1d would exceed LimitRecvCountPer1d %d", detail.LimitRecvCountPer1d) - return RejectCodeUserFull, nil - } - if detail.LimitRecvSizePer1d > -1 && detail.RecvSizePer1d+int64(h.Size) > detail.LimitRecvSizePer1d { - log.Printf("WARN: Message rejected: RecvSizePer1d would exceed LimitRecvSizePer1d %d", detail.LimitRecvSizePer1d) - return RejectCodeUserFull, nil - } - if detail.LimitRecvSizeTotal > -1 && detail.RecvSizeTotal+int64(h.Size) > detail.LimitRecvSizeTotal { - log.Printf("WARN: Message rejected: RecvSizeTotal would exceed LimitRecvSizeTotal %d", detail.LimitRecvSizeTotal) - return RejectCodeUserFull, nil - } - - return RejectCodeAccept, nil -} - -// uniqueFilepath generates a unique file path in the given directory, -// appending a counter suffix if the base name already exists. -func uniqueFilepath(dir string, timestamp uint32, ext string) string { - base := fmt.Sprintf("%d", timestamp) - fp := filepath.Join(dir, base+ext) - if _, err := os.Stat(fp); os.IsNotExist(err) { - return fp - } - for i := 1; ; i++ { - fp = filepath.Join(dir, fmt.Sprintf("%s_%d%s", base, i, ext)) - if _, err := os.Stat(fp); os.IsNotExist(err) { - return fp - } - } -} - -func localRecipients(h *FMsgHeader) []FMsgAddress { - addrs := make([]FMsgAddress, 0, len(h.To)+len(h.AddTo)) - for _, addr := range h.To { - if strings.EqualFold(addr.Domain, Domain) { - addrs = append(addrs, addr) - } - } - for _, addr := range h.AddTo { - if strings.EqualFold(addr.Domain, Domain) { - addrs = append(addrs, addr) - } - } - return addrs -} - -func allLocalRecipientsHaveMessageHash(msgHash []byte, addrs []FMsgAddress) (bool, error) { - if len(addrs) == 0 { - return false, nil - } - for i := range addrs { - duplicate, err := hasAddrReceivedMsgHash(msgHash, &addrs[i]) - if err != nil { - return false, err - } - if !duplicate { - return false, nil - } - } - return true, nil -} - -func markAllCodes(codes []byte, code uint8) { - for i := range codes { - codes[i] = code - } -} - -func prepareMessageData(r io.Reader, h *FMsgHeader, skipData bool) ([]string, error) { - if skipData { - parentID, err := lookupMsgIdByHash(h.Pid) - if err != nil { - return nil, err - } - if parentID == 0 { - return nil, fmt.Errorf("%w code 65 requires stored parent for pid %s", ErrProtocolViolation, hex.EncodeToString(h.Pid)) - } - parentMsg, err := getMsgByID(parentID) - if err != nil { - return nil, err - } - if parentMsg == nil || parentMsg.Filepath == "" { - return nil, fmt.Errorf("%w code 65 parent data unavailable for msg %d", ErrProtocolViolation, parentID) - } - h.Filepath = parentMsg.Filepath - return nil, nil - } - - createdPaths := make([]string, 0, 1+len(h.Attachments)) - - fd, err := os.CreateTemp("", "fmsg-download-*") - if err != nil { - return nil, err - } - - if _, err := io.CopyN(fd, r, int64(h.Size)); err != nil { - fd.Close() - _ = os.Remove(fd.Name()) - return nil, err - } - if err := fd.Close(); err != nil { - _ = os.Remove(fd.Name()) - return nil, err - } - - h.Filepath = fd.Name() - createdPaths = append(createdPaths, fd.Name()) - - for i := range h.Attachments { - afd, err := os.CreateTemp("", "fmsg-attachment-*") - if err != nil { - for _, path := range createdPaths { - _ = os.Remove(path) - } - return nil, err - } - - if _, err := io.CopyN(afd, r, int64(h.Attachments[i].Size)); err != nil { - afd.Close() - _ = os.Remove(afd.Name()) - for _, path := range createdPaths { - _ = os.Remove(path) - } - return nil, err - } - if err := afd.Close(); err != nil { - _ = os.Remove(afd.Name()) - for _, path := range createdPaths { - _ = os.Remove(path) - } - return nil, err - } - h.Attachments[i].Filepath = afd.Name() - createdPaths = append(createdPaths, afd.Name()) - } - - return createdPaths, nil -} - -func cleanupFiles(paths []string) { - for _, path := range paths { - if path == "" { - continue - } - _ = os.Remove(path) - } -} - -func copyMessagePayload(src *os.File, dstPath string, compressed bool, wireSize uint32) error { - if _, err := src.Seek(0, io.SeekStart); err != nil { - return err - } - - fd2, err := os.Create(dstPath) - if err != nil { - return err - } - - var copyErr error - if compressed { - lr := io.LimitReader(src, int64(wireSize)) - zr, err := zlib.NewReader(lr) - if err != nil { - fd2.Close() - _ = os.Remove(dstPath) - return err - } - _, copyErr = io.Copy(fd2, zr) - _ = zr.Close() - } else { - _, copyErr = io.CopyN(fd2, src, int64(wireSize)) - } - if err := fd2.Close(); err != nil { - return err - } - - if copyErr != nil { - _ = os.Remove(dstPath) - return copyErr - } - return nil -} - -func uniqueAttachmentPath(dir string, timestamp uint32, idx int, filename string) string { - ext := filepath.Ext(filename) - base := fmt.Sprintf("%d_att_%d", timestamp, idx) - p := filepath.Join(dir, base+ext) - if _, err := os.Stat(p); os.IsNotExist(err) { - return p - } - for n := 1; ; n++ { - p = filepath.Join(dir, fmt.Sprintf("%s_%d%s", base, n, ext)) - if _, err := os.Stat(p); os.IsNotExist(err) { - return p - } - } -} - -func persistAttachmentPayloads(h *FMsgHeader, dirpath string) error { - for i := range h.Attachments { - a := &h.Attachments[i] - src, err := os.Open(a.Filepath) - if err != nil { - return err - } - dstPath := uniqueAttachmentPath(dirpath, uint32(h.Timestamp), i, a.Filename) - compressed := a.Flags&(1<<1) != 0 - err = copyMessagePayload(src, dstPath, compressed, a.Size) - src.Close() - if err != nil { - return err - } - a.Filepath = dstPath - } - return nil -} - -func storeAcceptedMessage(h *FMsgHeader, codes []byte, acceptedTo []FMsgAddress, acceptedAddTo []FMsgAddress, primaryFilepath string) bool { - if len(acceptedTo) == 0 && len(acceptedAddTo) == 0 { - return false - } - - origTo := h.To - origAddTo := h.AddTo - h.To = acceptedTo - h.AddTo = acceptedAddTo - h.Filepath = primaryFilepath - if err := storeMsgDetail(h); err != nil { - log.Printf("ERROR: storing message: %s", err) - h.To = origTo - h.AddTo = origAddTo - for i := range codes { - if codes[i] == RejectCodeAccept { - codes[i] = RejectCodeUndisclosed - } - } - return false - } - - h.To = origTo - h.AddTo = origAddTo - allAccepted := append(acceptedTo, acceptedAddTo...) - for i := range allAccepted { - if err := postMsgStatRecv(&allAccepted[i], h.Timestamp, int(h.Size)); err != nil { - log.Printf("WARN: Failed to post msg recv stat: %s", err) - } - } - return true -} - -func downloadMessage(c net.Conn, r io.Reader, h *FMsgHeader, skipData bool) error { - addrs := localRecipients(h) - if len(addrs) == 0 { - return fmt.Errorf("%w our domain: %s, not in recipient list", ErrProtocolViolation, Domain) - } - codes := make([]byte, len(addrs)) - - createdPaths, err := prepareMessageData(r, h, skipData) - if err != nil { - return err - } - cleanupOnReturn := !skipData - defer func() { - if cleanupOnReturn { - cleanupFiles(createdPaths) - } - }() - - // verify hash matches challenge response when challenge was completed - msgHash, err := h.GetMessageHash() - if err != nil { - return err - } - if h.ChallengeCompleted && !bytes.Equal(h.ChallengeHash[:], msgHash) { - challengeHashStr := hex.EncodeToString(h.ChallengeHash[:]) - actualHashStr := hex.EncodeToString(msgHash) - return fmt.Errorf("%w actual hash: %s mismatch challenge response: %s", ErrProtocolViolation, actualHashStr, challengeHashStr) - } - - // pid/add-to validation is handled during header exchange in readHeader(). - - // determine file extension from MIME type - exts, _ := mime.ExtensionsByType(h.Type) - var ext string - if exts == nil { - ext = ".unknown" - } else { - ext = exts[0] - } - - src, err := os.Open(h.Filepath) - if err != nil { - return err - } - defer src.Close() - - // validate each recipient and copy message for accepted ones - // Build a set of add-to addresses for later classification - addToSet := make(map[string]bool) - for _, addr := range h.AddTo { - addToSet[strings.ToLower(addr.ToString())] = true - } - acceptedTo := []FMsgAddress{} - acceptedAddTo := []FMsgAddress{} - var primaryFilepath string - for i, addr := range addrs { - code, err := validateMsgRecvForAddr(h, &addr, msgHash) - if err != nil { - return err - } - if code != RejectCodeAccept { - log.Printf("WARN: Rejected message to: %s: %s (%d)", addr.ToString(), responseCodeName(code), code) - codes[i] = code - continue - } - - // copy to recipient's directory - dirpath := filepath.Join(DataDir, addr.Domain, addr.User, InboxDirName) - if err := os.MkdirAll(dirpath, 0750); err != nil { - return err - } - - fp := uniqueFilepath(dirpath, uint32(h.Timestamp), ext) - if err := copyMessagePayload(src, fp, h.Flags&FlagDeflate != 0, h.Size); err != nil { - log.Printf("ERROR: copying downloaded message from: %s, to: %s", h.Filepath, fp) - codes[i] = RejectCodeUndisclosed - continue - } - - codes[i] = RejectCodeAccept - if addToSet[strings.ToLower(addr.ToString())] { - acceptedAddTo = append(acceptedAddTo, addr) - } else { - acceptedTo = append(acceptedTo, addr) - } - if primaryFilepath == "" { - primaryFilepath = fp - if err := persistAttachmentPayloads(h, filepath.Dir(primaryFilepath)); err != nil { - log.Printf("ERROR: copying attachment payloads for message storage: %s", err) - codes[i] = RejectCodeUndisclosed - primaryFilepath = "" - acceptedTo = acceptedTo[:0] - acceptedAddTo = acceptedAddTo[:0] - continue - } - } - } - - stored := storeAcceptedMessage(h, codes, acceptedTo, acceptedAddTo, primaryFilepath) - if stored { - cleanupOnReturn = false - } - - return rejectAccept(c, codes) -} - -// resolvePostChallengeCode determines the initial response code to send after -// the optional challenge (§10.4). Code 11 (accept add-to) is returned as-is -// since it has no local recipients to duplicate-check. For the skip-data (65) -// and continue (64) paths, a completed challenge with all-local-duplicate -// produces code 10 (duplicate) instead. -func resolvePostChallengeCode(initialCode uint8, challengeCompleted bool, allLocalDup bool) uint8 { - if initialCode == AcceptCodeAddTo { - return AcceptCodeAddTo - } - if challengeCompleted && allLocalDup { - return RejectCodeDuplicate - } - if initialCode == AcceptCodeSkipData { - return AcceptCodeSkipData - } - return AcceptCodeContinue -} - -func abortConn(c net.Conn) { - if tcp, ok := c.(*net.TCPConn); ok { - tcp.SetLinger(0) - } - _ = c.Close() -} - -type responseTrackingConn struct { - net.Conn - wroteResponse bool -} - -func (c *responseTrackingConn) Write(b []byte) (int, error) { - n, err := c.Conn.Write(b) - if n > 0 { - c.wroteResponse = true - } - return n, err -} - -func handleConn(c net.Conn) { - defer func() { - if r := recover(); r != nil { - log.Printf("ERROR: Recovered in handleConn: %v", r) - } - }() - - log.Printf("INFO: Connection from: %s\n", c.RemoteAddr().String()) - tc := &responseTrackingConn{Conn: c} - - // read header - header, r, err := readHeader(tc) - if err != nil { - log.Printf("WARN: reading header from, %s: %s", c.RemoteAddr().String(), err) - if tc.wroteResponse { - _ = c.Close() - return - } - abortConn(c) - return - } - - // if no header AND no error this was a challenge thats been handeled - if header == nil { - c.Close() - return - } - - if err := challenge(tc, header, determineSenderDomain(header)); err != nil { - log.Printf("ERROR: Challenge failed to, %s: %s", c.RemoteAddr().String(), err) - abortConn(c) - return - } - - // §10.4: Determine initial response code after optional challenge. - // Code 11 (add-to, no local recipients) does not need a dup check. - // Codes 65 and 64 both require a dup check when challenge was completed. - allLocalDup := false - if header.ChallengeCompleted && header.InitialResponseCode != AcceptCodeAddTo { - addrs := localRecipients(header) - var err error - allLocalDup, err = allLocalRecipientsHaveMessageHash(header.ChallengeHash[:], addrs) - if err != nil { - log.Printf("ERROR: duplicate check failed for %s: %s", c.RemoteAddr().String(), err) - if err := sendCode(c, RejectCodeUndisclosed); err != nil { - abortConn(c) - return - } - _ = c.Close() - return - } - } - - code := resolvePostChallengeCode(header.InitialResponseCode, header.ChallengeCompleted, allLocalDup) - skipData := false - - switch code { - case AcceptCodeAddTo: - // No local add-to recipients; store header and respond code 11, close. - if err := storeMsgHeaderOnly(header); err != nil { - log.Printf("ERROR: storing add-to header: %s", err) - if err := sendCode(c, RejectCodeUndisclosed); err != nil { - abortConn(c) - return - } - _ = c.Close() - return - } - if err := sendCode(c, AcceptCodeAddTo); err != nil { - log.Printf("ERROR: failed sending code 11 to %s: %s", c.RemoteAddr().String(), err) - abortConn(c) - return - } - log.Printf("INFO: additional recipients received (code 11) for pid %s", hex.EncodeToString(header.Pid)) - c.Close() - return - case RejectCodeDuplicate: - if err := sendCode(c, RejectCodeDuplicate); err != nil { - log.Printf("ERROR: failed sending code 10 to %s: %s", c.RemoteAddr().String(), err) - } - c.Close() - return - case AcceptCodeSkipData: - if err := sendCode(c, AcceptCodeSkipData); err != nil { - log.Printf("ERROR: failed sending code 65 to %s: %s", c.RemoteAddr().String(), err) - abortConn(c) - return - } - skipData = true - log.Printf("INFO: sent code 65 (skip data) to %s", c.RemoteAddr().String()) - default: - if err := sendCode(c, AcceptCodeContinue); err != nil { - log.Printf("ERROR: failed sending code 64 to %s: %s", c.RemoteAddr().String(), err) - abortConn(c) - return - } - log.Printf("INFO: sent code 64 (continue) to %s", c.RemoteAddr().String()) - } - - // store message - deadlineBytes := int(header.Size) - if skipData { - deadlineBytes = 1 - } - c.SetReadDeadline(time.Now().Add(calcNetIODuration(deadlineBytes, MinDownloadRate))) - if err := downloadMessage(c, r, header, skipData); err != nil { - // if error was a protocal violation, abort; otherise let sender know there was an internal error - log.Printf("ERROR: Download failed from, %s: %s", c.RemoteAddr().String(), err) - if errors.Is(err, ErrProtocolViolation) { - abortConn(c) - return - } else { - _ = sendCode(c, RejectCodeUndisclosed) - } - } - - // gracefully close 1st connection - c.Close() -} - -func main() { - - initOutgoing() - - // load environment variables from .env file if present - if err := godotenv.Load(); err != nil { - log.Printf("INFO: Could not load .env file: %v", err) - } - - // read env config (must be after godotenv.Load) - loadEnvConfig() - - // determine listen address from args - listenAddress := "127.0.0.1" - for _, arg := range os.Args[1:] { - listenAddress = arg - } - - // initalize database - err := testDb() - if err != nil { - log.Fatalf("ERROR: connecting to database: %s\n", err) - } - - // set DataDir, Domain and IDURL from env - setDataDir() - setDomain() - setIDURL() - - // load TLS configuration (must be after loadEnvConfig for FMSG_TLS_INSECURE_SKIP_VERIFY) - serverTLSConfig = buildServerTLSConfig() - - // start sender in background (small delay so listener is ready first) - go func() { - time.Sleep(1 * time.Second) - startSender() - }() - - // start listening - addr := fmt.Sprintf("%s:%d", listenAddress, Port) - ln, err := tls.Listen("tcp", addr, serverTLSConfig) - if err != nil { - log.Fatal(err) - } - log.Printf("INFO: Ready to receive on %s\n", addr) - for { - conn, err := ln.Accept() - if err != nil { - log.Printf("ERROR: Accept connection from %s returned: %s\n", ln.Addr().String(), err) - } else { - go handleConn(conn) - } - } - -} diff --git a/src/host_test.go b/src/host_test.go deleted file mode 100644 index 48c54be..0000000 --- a/src/host_test.go +++ /dev/null @@ -1,622 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "encoding/binary" - "net" - "testing" - "time" -) - -type testAddr string - -func (a testAddr) Network() string { return "tcp" } -func (a testAddr) String() string { return string(a) } - -type testConn struct { - bytes.Buffer -} - -func (c *testConn) Read(b []byte) (int, error) { return 0, nil } -func (c *testConn) Write(b []byte) (int, error) { return c.Buffer.Write(b) } -func (c *testConn) Close() error { return nil } -func (c *testConn) LocalAddr() net.Addr { return testAddr("127.0.0.1:1000") } -func (c *testConn) RemoteAddr() net.Addr { return testAddr("127.0.0.1:2000") } -func (c *testConn) SetDeadline(t time.Time) error { return nil } -func (c *testConn) SetReadDeadline(t time.Time) error { return nil } -func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } - -func TestIsValidUser(t *testing.T) { - valid := []string{"alice", "Bob", "a-b", "a_b", "a.b", "user123", "A", "u\u00f1icode", "\u7528\u62371"} - for _, u := range valid { - if !isValidUser(u) { - t.Errorf("isValidUser(%q) = false, want true", u) - } - } - - invalid := []string{"", " ", "a b", "a@b", "a/b", string(make([]byte, 65)), "-alice", "alice-", "a..b", "a-_b"} - for _, u := range invalid { - if isValidUser(u) { - t.Errorf("isValidUser(%q) = true, want false", u) - } - } -} - -func TestIsValidDomain(t *testing.T) { - valid := []string{"example.com", "a.b.c", "foo-bar.com", "localhost"} - for _, d := range valid { - if !isValidDomain(d) { - t.Errorf("isValidDomain(%q) = false, want true", d) - } - } - - invalid := []string{ - "", - "nodot", // no dots and not localhost - ".leading.dot", // empty label - "trailing.", // empty label - "-start.com", // label starts with hyphen - "end-.com", // label ends with hyphen - "has space.com", - } - for _, d := range invalid { - if isValidDomain(d) { - t.Errorf("isValidDomain(%q) = true, want false", d) - } - } -} - -func TestParseAddress(t *testing.T) { - tests := []struct { - input string - wantErr bool - user string - domain string - }{ - {"@alice@example.com", false, "alice", "example.com"}, - {"@Bob@EXAMPLE.COM", false, "Bob", "EXAMPLE.COM"}, - {"@a-b.c@x.y.z", false, "a-b.c", "x.y.z"}, - // errors - {"alice@example.com", true, "", ""}, // missing leading @ - {"@alice", true, "", ""}, // missing second @ - {"@", true, "", ""}, // too short - {"ab", true, "", ""}, // too short - {"@@example.com", true, "", ""}, // empty user - {"@alice@", true, "", ""}, // empty domain (not valid) - {"@alice@nodot", true, "", ""}, // domain with no dot (not localhost) - } - for _, tt := range tests { - addr, err := parseAddress([]byte(tt.input)) - if tt.wantErr { - if err == nil { - t.Errorf("parseAddress(%q) = nil error, want error", tt.input) - } - continue - } - if err != nil { - t.Errorf("parseAddress(%q) error: %v", tt.input, err) - continue - } - if addr.User != tt.user || addr.Domain != tt.domain { - t.Errorf("parseAddress(%q) = {%q, %q}, want {%q, %q}", tt.input, addr.User, addr.Domain, tt.user, tt.domain) - } - } -} - -func TestReadUInt8Slice(t *testing.T) { - // Build a buffer: uint8 length = 5, then "hello" - var buf bytes.Buffer - buf.WriteByte(5) - buf.WriteString("hello") - // Extra trailing bytes should not be consumed - buf.WriteString("extra") - - slice, err := ReadUInt8Slice(&buf) - if err != nil { - t.Fatalf("ReadUInt8Slice error: %v", err) - } - if string(slice) != "hello" { - t.Fatalf("ReadUInt8Slice = %q, want %q", string(slice), "hello") - } - // "extra" should remain - rest := make([]byte, 5) - n, _ := buf.Read(rest) - if string(rest[:n]) != "extra" { - t.Fatalf("remaining bytes = %q, want %q", string(rest[:n]), "extra") - } -} - -func TestReadUInt8SliceEmpty(t *testing.T) { - var buf bytes.Buffer - buf.WriteByte(0) // zero-length slice - - slice, err := ReadUInt8Slice(&buf) - if err != nil { - t.Fatalf("ReadUInt8Slice error: %v", err) - } - if len(slice) != 0 { - t.Fatalf("expected empty slice, got len %d", len(slice)) - } -} - -func TestCalcNetIODuration(t *testing.T) { - // Small sizes should return MinNetIODeadline - d := calcNetIODuration(100, 5000) - if d < MinNetIODeadline { - t.Fatalf("calcNetIODuration(100, 5000) = %v, want >= %v", d, MinNetIODeadline) - } - - // Large sizes should exceed MinNetIODeadline - d = calcNetIODuration(1_000_000, 5000) - expected := time.Duration(float64(1_000_000) / 5000 * float64(time.Second)) // 200s - if d != expected { - t.Fatalf("calcNetIODuration(1000000, 5000) = %v, want %v", d, expected) - } -} - -func TestResponseCodeName(t *testing.T) { - tests := []struct { - code uint8 - want string - }{ - {RejectCodeInvalid, "invalid"}, - {RejectCodeUnsupportedVersion, "unsupported version"}, - {RejectCodeUndisclosed, "undisclosed"}, - {RejectCodeTooBig, "too big"}, - {RejectCodeInsufficentResources, "insufficient resources"}, - {RejectCodeParentNotFound, "parent not found"}, - {RejectCodePastTime, "past time"}, - {RejectCodeFutureTime, "future time"}, - {RejectCodeTimeTravel, "time travel"}, - {RejectCodeDuplicate, "duplicate"}, - {AcceptCodeAddTo, "accept add to"}, - {RejectCodeUserUnknown, "user unknown"}, - {RejectCodeUserFull, "user full"}, - {RejectCodeUserNotAccepting, "user not accepting"}, - {RejectCodeUserDuplicate, "user duplicate"}, - {RejectCodeUserUndisclosed, "user undisclosed"}, - {RejectCodeAccept, "accept"}, - {99, "unknown(99)"}, - } - for _, tt := range tests { - got := responseCodeName(tt.code) - if got != tt.want { - t.Errorf("responseCodeName(%d) = %q, want %q", tt.code, got, tt.want) - } - } -} - -func TestPerRecipientDuplicateAndUndisclosedCodeValues(t *testing.T) { - if RejectCodeUserDuplicate != 103 { - t.Fatalf("RejectCodeUserDuplicate = %d, want 103", RejectCodeUserDuplicate) - } - if RejectCodeUserUndisclosed != 105 { - t.Fatalf("RejectCodeUserUndisclosed = %d, want 105", RejectCodeUserUndisclosed) - } -} - -func TestFlagConstants(t *testing.T) { - // Verify flag bit assignments match SPEC.md - if FlagHasPid != 1 { - t.Errorf("FlagHasPid = %d, want 1 (bit 0)", FlagHasPid) - } - if FlagHasAddTo != 2 { - t.Errorf("FlagHasAddTo = %d, want 2 (bit 1)", FlagHasAddTo) - } - if FlagCommonType != 4 { - t.Errorf("FlagCommonType = %d, want 4 (bit 2)", FlagCommonType) - } - if FlagImportant != 8 { - t.Errorf("FlagImportant = %d, want 8 (bit 3)", FlagImportant) - } - if FlagNoReply != 16 { - t.Errorf("FlagNoReply = %d, want 16 (bit 4)", FlagNoReply) - } - if FlagDeflate != 32 { - t.Errorf("FlagDeflate = %d, want 32 (bit 5)", FlagDeflate) - } -} - -func encodeUInt8String(t *testing.T, s string) []byte { - t.Helper() - if len(s) > 255 { - t.Fatalf("string too long for uint8 prefix: %d", len(s)) - } - b := []byte{byte(len(s))} - b = append(b, []byte(s)...) - return b -} - -func TestHasDomainRecipient(t *testing.T) { - addrs := []FMsgAddress{ - {User: "alice", Domain: "example.com"}, - {User: "bob", Domain: "other.org"}, - } - if !hasDomainRecipient(addrs, "EXAMPLE.COM") { - t.Fatalf("expected domain match") - } - if hasDomainRecipient(addrs, "missing.test") { - t.Fatalf("did not expect domain match") - } -} - -func TestDetermineSenderDomain(t *testing.T) { - h := &FMsgHeader{ - From: FMsgAddress{User: "alice", Domain: "from.example"}, - } - if got := determineSenderDomain(h); got != "from.example" { - t.Fatalf("determineSenderDomain() = %q, want %q", got, "from.example") - } - - h.AddTo = []FMsgAddress{{User: "new", Domain: "to.example"}} - h.AddToFrom = &FMsgAddress{User: "bob", Domain: "sender.example"} - if got := determineSenderDomain(h); got != "sender.example" { - t.Fatalf("determineSenderDomain() = %q, want %q", got, "sender.example") - } -} - -func TestReadToRecipients(t *testing.T) { - b := []byte{2} - b = append(b, encodeUInt8String(t, "@alice@example.com")...) - b = append(b, encodeUInt8String(t, "@bob@example.com")...) - - h := &FMsgHeader{} - seen, err := readToRecipients(nil, bufio.NewReader(bytes.NewReader(b)), h) - if err != nil { - t.Fatalf("readToRecipients returned error: %v", err) - } - if len(h.To) != 2 { - t.Fatalf("len(h.To) = %d, want 2", len(h.To)) - } - if !seen["@alice@example.com"] || !seen["@bob@example.com"] { - t.Fatalf("seen map missing expected recipients: %#v", seen) - } -} - -func TestReadAddToRecipients(t *testing.T) { - h := &FMsgHeader{ - Flags: FlagHasPid | FlagHasAddTo, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "example.com"}}, - } - seen := map[string]bool{"@bob@example.com": true} - - b := []byte{} - b = append(b, encodeUInt8String(t, "@alice@example.com")...) // add-to-from - b = append(b, 1) // add-to count - b = append(b, encodeUInt8String(t, "@carol@example.com")...) - - err := readAddToRecipients(nil, bufio.NewReader(bytes.NewReader(b)), h, seen) - if err != nil { - t.Fatalf("readAddToRecipients returned error: %v", err) - } - if h.AddToFrom == nil || h.AddToFrom.ToString() != "@alice@example.com" { - t.Fatalf("unexpected AddToFrom: %+v", h.AddToFrom) - } - if len(h.AddTo) != 1 || h.AddTo[0].ToString() != "@carol@example.com" { - t.Fatalf("unexpected AddTo: %+v", h.AddTo) - } -} - -func TestReadMessageType(t *testing.T) { - hCommon := &FMsgHeader{Flags: FlagCommonType} - if err := readMessageType(nil, bufio.NewReader(bytes.NewReader([]byte{3})), hCommon); err != nil { - t.Fatalf("readMessageType(common) error: %v", err) - } - if hCommon.TypeID != 3 { - t.Fatalf("common type ID = %d, want 3", hCommon.TypeID) - } - if hCommon.Type != "application/json" { - t.Fatalf("common type = %q, want %q", hCommon.Type, "application/json") - } - - hText := &FMsgHeader{Flags: 0} - b := encodeUInt8String(t, "text/plain") - if err := readMessageType(nil, bufio.NewReader(bytes.NewReader(b)), hText); err != nil { - t.Fatalf("readMessageType(string) error: %v", err) - } - if hText.Type != "text/plain" { - t.Fatalf("string type = %q, want %q", hText.Type, "text/plain") - } -} - -func TestReadAttachmentHeaders(t *testing.T) { - origMax := MaxMessageSize - MaxMessageSize = 1024 - t.Cleanup(func() { - MaxMessageSize = origMax - }) - - h := &FMsgHeader{Size: 10} - b := []byte{1} // attachment count - b = append(b, 0) // attachment flags (no common type) - b = append(b, encodeUInt8String(t, "text/plain")...) - b = append(b, encodeUInt8String(t, "file.txt")...) - - var sz [4]byte - binary.LittleEndian.PutUint32(sz[:], 12) - b = append(b, sz[:]...) - - err := readAttachmentHeaders(nil, bufio.NewReader(bytes.NewReader(b)), h) - if err != nil { - t.Fatalf("readAttachmentHeaders returned error: %v", err) - } - if len(h.Attachments) != 1 { - t.Fatalf("len(h.Attachments) = %d, want 1", len(h.Attachments)) - } - att := h.Attachments[0] - if att.TypeID != 0 { - t.Fatalf("attachment type ID = %d, want 0 for non-common", att.TypeID) - } - if att.Type != "text/plain" || att.Filename != "file.txt" || att.Size != 12 { - t.Fatalf("unexpected attachment parsed: %+v", att) - } -} - -func TestReadAddToRecipientsRejectsWhenPidMissing(t *testing.T) { - h := &FMsgHeader{Flags: FlagHasAddTo} - c := &testConn{} - - err := readAddToRecipients(c, bufio.NewReader(bytes.NewReader(nil)), h, map[string]bool{}) - if err == nil { - t.Fatalf("expected error when add-to flag is set without pid") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { - t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) - } -} - -func TestReadAddToRecipientsRejectsDuplicateAddTo(t *testing.T) { - h := &FMsgHeader{ - Flags: FlagHasPid | FlagHasAddTo, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "example.com"}}, - } - c := &testConn{} - seen := map[string]bool{"@bob@example.com": true} - - b := []byte{} - b = append(b, encodeUInt8String(t, "@alice@example.com")...) // add-to-from - b = append(b, 2) // add-to count - b = append(b, encodeUInt8String(t, "@carol@example.com")...) - b = append(b, encodeUInt8String(t, "@carol@example.com")...) - - err := readAddToRecipients(c, bufio.NewReader(bytes.NewReader(b)), h, seen) - if err == nil { - t.Fatalf("expected duplicate add-to error") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { - t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) - } -} - -func TestReadMessageTypeRejectsUnknownCommonType(t *testing.T) { - h := &FMsgHeader{Flags: FlagCommonType} - c := &testConn{} - - err := readMessageType(c, bufio.NewReader(bytes.NewReader([]byte{200})), h) - if err == nil { - t.Fatalf("expected error for unknown common type") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { - t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) - } -} - -func TestReadMessageTypeRejectsNonASCIIStringType(t *testing.T) { - h := &FMsgHeader{Flags: 0} - c := &testConn{} - - b := encodeUInt8String(t, "text/\u03c0lain") - err := readMessageType(c, bufio.NewReader(bytes.NewReader(b)), h) - if err == nil { - t.Fatalf("expected error for non-ASCII message type") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { - t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) - } -} - -func TestReadAttachmentTypeRejectsNonASCIIStringType(t *testing.T) { - c := &testConn{} - b := encodeUInt8String(t, "text/\u03c0lain") - - _, _, err := readAttachmentType(c, bufio.NewReader(bytes.NewReader(b)), 0) - if err == nil { - t.Fatalf("expected error for non-ASCII attachment type") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { - t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) - } -} - -func TestReadAttachmentHeadersRejectsInvalidFilename(t *testing.T) { - origMax := MaxMessageSize - MaxMessageSize = 1024 - t.Cleanup(func() { - MaxMessageSize = origMax - }) - - h := &FMsgHeader{Size: 10} - c := &testConn{} - b := []byte{1} - b = append(b, 0) - b = append(b, encodeUInt8String(t, "text/plain")...) - b = append(b, encodeUInt8String(t, "bad..name")...) // invalid: consecutive special chars - - var sz [4]byte - binary.LittleEndian.PutUint32(sz[:], 12) - b = append(b, sz[:]...) - - err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) - if err == nil { - t.Fatalf("expected error for invalid attachment filename") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { - t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) - } -} - -func TestReadAttachmentHeadersRejectsTooBig(t *testing.T) { - origMax := MaxMessageSize - MaxMessageSize = 20 - t.Cleanup(func() { - MaxMessageSize = origMax - }) - - h := &FMsgHeader{Size: 15} - c := &testConn{} - b := []byte{1} - b = append(b, 0) - b = append(b, encodeUInt8String(t, "text/plain")...) - b = append(b, encodeUInt8String(t, "file.txt")...) - - var sz [4]byte - binary.LittleEndian.PutUint32(sz[:], 10) - b = append(b, sz[:]...) - - err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) - if err == nil { - t.Fatalf("expected size overflow error") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeTooBig { - t.Fatalf("expected reject code %d, got %v", RejectCodeTooBig, got) - } -} - -func TestValidateMessageFlagsRejectsReservedBits(t *testing.T) { - c := &testConn{} - err := validateMessageFlags(c, 1<<6) - if err == nil { - t.Fatalf("expected error for reserved message flag bit") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { - t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) - } -} - -func TestReadAttachmentHeadersRejectsReservedAttachmentBits(t *testing.T) { - h := &FMsgHeader{Size: 0} - c := &testConn{} - - // attachment count=1, then attachment flags with reserved bit 2 set - b := []byte{1, 1 << 2} - err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) - if err == nil { - t.Fatalf("expected error for reserved attachment flag bits") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { - t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) - } -} - -func TestResolvePostChallengeCode(t *testing.T) { - tests := []struct { - name string - initialCode uint8 - challengeCompleted bool - allLocalDup bool - want uint8 - }{ - // Add-to (code 11) path — never overridden by dup check. - {"add-to no challenge", AcceptCodeAddTo, false, false, AcceptCodeAddTo}, - {"add-to challenge no dup", AcceptCodeAddTo, true, false, AcceptCodeAddTo}, - {"add-to challenge all dup", AcceptCodeAddTo, true, true, AcceptCodeAddTo}, - - // Continue (code 64) path — dup check yields code 10 when all dup. - {"continue no challenge", AcceptCodeContinue, false, false, AcceptCodeContinue}, - {"continue challenge no dup", AcceptCodeContinue, true, false, AcceptCodeContinue}, - {"continue challenge all dup", AcceptCodeContinue, true, true, RejectCodeDuplicate}, - - // Skip-data (code 65) path — dup check yields code 10 when all dup. - {"skip-data no challenge", AcceptCodeSkipData, false, false, AcceptCodeSkipData}, - {"skip-data challenge no dup", AcceptCodeSkipData, true, false, AcceptCodeSkipData}, - {"skip-data challenge all dup", AcceptCodeSkipData, true, true, RejectCodeDuplicate}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := resolvePostChallengeCode(tt.initialCode, tt.challengeCompleted, tt.allLocalDup) - if got != tt.want { - t.Errorf("resolvePostChallengeCode(%d, %v, %v) = %d (%s), want %d (%s)", - tt.initialCode, tt.challengeCompleted, tt.allLocalDup, - got, responseCodeName(got), tt.want, responseCodeName(tt.want)) - } - }) - } -} - -func TestReadAttachmentHeadersReadsExpandedSizeForCompressedAttachment(t *testing.T) { - origMax := MaxMessageSize - origExpanded := MaxExpandedSize - MaxMessageSize = 1024 - MaxExpandedSize = 1024 - t.Cleanup(func() { - MaxMessageSize = origMax - MaxExpandedSize = origExpanded - }) - - h := &FMsgHeader{Size: 0} - b := []byte{1} // 1 attachment - b = append(b, 1<<1) // attachment flags: zlib-deflate (bit 1) - b = append(b, encodeUInt8String(t, "text/plain")...) - b = append(b, encodeUInt8String(t, "file.txt")...) - - var wireSize [4]byte - binary.LittleEndian.PutUint32(wireSize[:], 50) - b = append(b, wireSize[:]...) - - var expandedSize [4]byte - binary.LittleEndian.PutUint32(expandedSize[:], 200) - b = append(b, expandedSize[:]...) - - err := readAttachmentHeaders(nil, bufio.NewReader(bytes.NewReader(b)), h) - if err != nil { - t.Fatalf("readAttachmentHeaders returned error: %v", err) - } - if len(h.Attachments) != 1 { - t.Fatalf("len(h.Attachments) = %d, want 1", len(h.Attachments)) - } - att := h.Attachments[0] - if att.Size != 50 { - t.Fatalf("att.Size = %d, want 50", att.Size) - } - if att.ExpandedSize != 200 { - t.Fatalf("att.ExpandedSize = %d, want 200", att.ExpandedSize) - } -} - -func TestReadAttachmentHeadersRejectsExpandedSizeExceedsMax(t *testing.T) { - origMax := MaxMessageSize - origExpanded := MaxExpandedSize - MaxMessageSize = 1024 - MaxExpandedSize = 100 - t.Cleanup(func() { - MaxMessageSize = origMax - MaxExpandedSize = origExpanded - }) - - h := &FMsgHeader{Size: 0} - c := &testConn{} - b := []byte{1} // 1 attachment - b = append(b, 1<<1) // attachment flags: zlib-deflate (bit 1) - b = append(b, encodeUInt8String(t, "text/plain")...) - b = append(b, encodeUInt8String(t, "file.txt")...) - - var wireSize [4]byte - binary.LittleEndian.PutUint32(wireSize[:], 50) - b = append(b, wireSize[:]...) - - // expanded size exceeds MaxExpandedSize=100 - var expandedSize [4]byte - binary.LittleEndian.PutUint32(expandedSize[:], 200) - b = append(b, expandedSize[:]...) - - err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) - if err == nil { - t.Fatalf("expected error when expanded size exceeds max") - } - if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeTooBig { - t.Fatalf("expected reject code %d, got %v", RejectCodeTooBig, got) - } -} diff --git a/src/id.go b/src/id.go deleted file mode 100644 index 32b0bab..0000000 --- a/src/id.go +++ /dev/null @@ -1,92 +0,0 @@ -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "net/url" -) - -type AddressDetail struct { - Address string `json:"address"` - DisplayName string `json:"displayName"` - AcceptingNew bool `json:"acceptingNew"` - LimitRecvSizeTotal int64 `json:"limitRecvSizeTotal"` - LimitRecvSizePerMsg int64 `json:"limitRecvSizePerMsg"` - LimitRecvSizePer1d int64 `json:"limitRecvSizePer1d"` - LimitRecvCountPer1d int64 `json:"limitRecvCountPer1d"` - LimitSendSizeTotal int64 `json:"limitSendSizeTotal"` - LimitSendSizePerMsg int64 `json:"limitSendSizePerMsg"` - LimitSendSizePer1d int64 `json:"limitSendSizePer1d"` - LimitSendCountPer1d int64 `json:"limitSendCountPer1d"` - RecvSizeTotal int64 `json:"recvSizeTotal"` - RecvSizePer1d int64 `json:"recvSizePer1d"` - RecvCountPer1d int64 `json:"recvCountPer1d"` - SendSizeTotal int64 `json:"sendSizeTotal"` - SendSizePer1d int64 `json:"sendSizePer1d"` - SendCountPer1d int64 `json:"sendCountPer1d"` - Tags []string `json:"tags"` -} - -// Returns pointer to an AddressDetail populated by querying fmsg Id standard at FMSG_ID_URL for -// address supplied. If the address is not found returns nil, nil. -func getAddressDetail(addr *FMsgAddress) (*AddressDetail, error) { - uri := IDURI + "/fmsgid/" + url.PathEscape(addr.ToString()) - resp, err := http.Get(uri) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - return nil, nil - } - - var detail AddressDetail - err = json.NewDecoder(resp.Body).Decode(&detail) - if err != nil { - return nil, err - } - - return &detail, nil -} - -func postMsgStatSend(addr *FMsgAddress, timestamp float64, size int) error { - return postMsgStat(addr, timestamp, size, true) -} - -func postMsgStatRecv(addr *FMsgAddress, timestamp float64, size int) error { - return postMsgStat(addr, timestamp, size, false) -} - -func postMsgStat(addr *FMsgAddress, timestamp float64, size int, isSending bool) error { - var part string - if isSending { - part = "send" - } else { - part = "recv" - } - uri := fmt.Sprintf("%s/fmsgid/%s", IDURI, part) - - payload := map[string]interface{}{ - "address": addr.ToString(), - "ts": timestamp, - "size": size} - jsonPayload, err := json.Marshal(payload) - if err != nil { - return err - } - - resp, err := http.Post(uri, "application/json", bytes.NewBuffer(jsonPayload)) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("POST %s returned %d", uri, resp.StatusCode) - } - - return nil -} diff --git a/src/outgoing.go b/src/outgoing.go deleted file mode 100644 index 242811f..0000000 --- a/src/outgoing.go +++ /dev/null @@ -1,66 +0,0 @@ -package main - -import "sync" - -// outgoingEntry tracks an in-flight outgoing message header together with -// the set of Host-B IPs currently being serviced for that message. -// The IP set is used to validate incoming challenges (§10.5 step 2). -type outgoingEntry struct { - header *FMsgHeader - ips map[string]struct{} -} - -// outgoingMap indexes in-flight message headers by their header hash. -// All access is synchronised via outgoingMu. -var outgoingMap map[[32]byte]*outgoingEntry -var outgoingMu sync.RWMutex - -func initOutgoing() { - outgoingMap = make(map[[32]byte]*outgoingEntry) -} - -// registerOutgoing records hash → (header, ip) so challenge handlers can look -// it up. Multiple IPs may be registered for the same hash when the same message -// is being concurrently delivered to different domains (§10.2 step 2). -func registerOutgoing(hash [32]byte, h *FMsgHeader, ip string) { - outgoingMu.Lock() - e, ok := outgoingMap[hash] - if !ok { - e = &outgoingEntry{header: h, ips: make(map[string]struct{})} - outgoingMap[hash] = e - } - e.ips[ip] = struct{}{} - outgoingMu.Unlock() -} - -// lookupOutgoing returns the header for hash iff ip is a registered Host-B IP -// for that entry. Returns (nil, false) if the hash is unknown or ip is not in -// the registered set (§10.5 step 2). -func lookupOutgoing(hash [32]byte, ip string) (*FMsgHeader, bool) { - outgoingMu.RLock() - e, ok := outgoingMap[hash] - if !ok { - outgoingMu.RUnlock() - return nil, false - } - _, ipOK := e.ips[ip] - h := e.header - outgoingMu.RUnlock() - if !ipOK { - return nil, false - } - return h, true -} - -// removeOutgoingIP removes ip from the entry's IP set. When the set becomes -// empty the map entry is deleted entirely (§10.2 step 7). -func removeOutgoingIP(hash [32]byte, ip string) { - outgoingMu.Lock() - if e, ok := outgoingMap[hash]; ok { - delete(e.ips, ip) - if len(e.ips) == 0 { - delete(outgoingMap, hash) - } - } - outgoingMu.Unlock() -} diff --git a/src/outgoing_test.go b/src/outgoing_test.go deleted file mode 100644 index af75412..0000000 --- a/src/outgoing_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package main - -import ( - "testing" -) - -func TestOutgoingMapOperations(t *testing.T) { - initOutgoing() - - h := &FMsgHeader{ - Version: 1, - Flags: 0, - From: FMsgAddress{User: "alice", Domain: "a.com"}, - To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, - Topic: "test", - Type: "text/plain", - } - - var hash [32]byte - copy(hash[:], h.GetHeaderHash()) - - const ip = "1.2.3.4" - - // Lookup before register should fail - _, ok := lookupOutgoing(hash, ip) - if ok { - t.Fatal("lookupOutgoing found entry before register") - } - - // Register - registerOutgoing(hash, h, ip) - - // Lookup with correct IP should succeed - got, ok := lookupOutgoing(hash, ip) - if !ok { - t.Fatal("lookupOutgoing failed after register") - } - if got != h { - t.Fatal("lookupOutgoing returned different pointer") - } - - // Lookup with wrong IP should fail - _, ok = lookupOutgoing(hash, "9.9.9.9") - if ok { - t.Fatal("lookupOutgoing succeeded with wrong IP") - } - - // Remove IP — entry should be gone - removeOutgoingIP(hash, ip) - - _, ok = lookupOutgoing(hash, ip) - if ok { - t.Fatal("lookupOutgoing found entry after removeOutgoingIP") - } -} - -func TestOutgoingMapMultipleIPs(t *testing.T) { - initOutgoing() - - h := &FMsgHeader{ - Version: 1, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 1.0, - Type: "text/plain", - } - - var hash [32]byte - copy(hash[:], h.GetHeaderHash()) - - registerOutgoing(hash, h, "1.1.1.1") - registerOutgoing(hash, h, "2.2.2.2") - - // Both IPs should resolve - for _, ip := range []string{"1.1.1.1", "2.2.2.2"} { - if _, ok := lookupOutgoing(hash, ip); !ok { - t.Errorf("expected lookup to succeed for IP %s", ip) - } - } - - // Removing first IP still leaves entry for second - removeOutgoingIP(hash, "1.1.1.1") - if _, ok := lookupOutgoing(hash, "1.1.1.1"); ok { - t.Error("1.1.1.1 still present after remove") - } - if _, ok := lookupOutgoing(hash, "2.2.2.2"); !ok { - t.Error("2.2.2.2 missing after removing 1.1.1.1") - } - - // Removing last IP deletes the entry - removeOutgoingIP(hash, "2.2.2.2") - if _, ok := lookupOutgoing(hash, "2.2.2.2"); ok { - t.Error("entry still present after removing last IP") - } -} - -func TestOutgoingMapMultipleEntries(t *testing.T) { - initOutgoing() - - h1 := &FMsgHeader{ - Version: 1, - From: FMsgAddress{User: "a", Domain: "b.com"}, - To: []FMsgAddress{{User: "c", Domain: "d.com"}}, - Timestamp: 1.0, - Type: "text/plain", - } - h2 := &FMsgHeader{ - Version: 1, - From: FMsgAddress{User: "x", Domain: "y.com"}, - To: []FMsgAddress{{User: "z", Domain: "w.com"}}, - Timestamp: 2.0, - Type: "text/plain", - } - - var hash1, hash2 [32]byte - copy(hash1[:], h1.GetHeaderHash()) - copy(hash2[:], h2.GetHeaderHash()) - - registerOutgoing(hash1, h1, "1.1.1.1") - registerOutgoing(hash2, h2, "2.2.2.2") - - got1, ok1 := lookupOutgoing(hash1, "1.1.1.1") - got2, ok2 := lookupOutgoing(hash2, "2.2.2.2") - - if !ok1 || got1 != h1 { - t.Error("failed to look up h1") - } - if !ok2 || got2 != h2 { - t.Error("failed to look up h2") - } - - // Remove one, other should remain - removeOutgoingIP(hash1, "1.1.1.1") - _, ok1 = lookupOutgoing(hash1, "1.1.1.1") - _, ok2 = lookupOutgoing(hash2, "2.2.2.2") - if ok1 { - t.Error("h1 still present after remove") - } - if !ok2 { - t.Error("h2 missing after removing h1") - } -} diff --git a/src/sender.go b/src/sender.go deleted file mode 100644 index a33a1e8..0000000 --- a/src/sender.go +++ /dev/null @@ -1,628 +0,0 @@ -package main - -import ( - "crypto/tls" - "database/sql" - "encoding/hex" - "fmt" - "io" - "log" - "net" - "os" - "strings" - "time" - - env "github.com/caitlinelfring/go-env-default" - "github.com/levenlabs/golib/timeutil" - "github.com/lib/pq" -) - -var RetryInterval float64 = 20 -var RetryMaxAge float64 = 86400 -var PollInterval = 10 -var MaxConcurrentSend = 1024 - -// localResponseCodeNoResponse is stored only in the database; it is not an fmsg protocol response code. -const localResponseCodeNoResponse = -1 - -var retryableResponseCodes = []int16{ - int16(localResponseCodeNoResponse), - int16(RejectCodeUndisclosed), - int16(RejectCodeInsufficentResources), - int16(RejectCodeUserFull), -} - -func loadSenderEnvConfig() { - RetryInterval = env.GetFloatDefault("FMSG_RETRY_INTERVAL", 20) - RetryMaxAge = env.GetFloatDefault("FMSG_RETRY_MAX_AGE", 86400) - PollInterval = env.GetIntDefault("FMSG_POLL_INTERVAL", 10) - MaxConcurrentSend = env.GetIntDefault("FMSG_MAX_CONCURRENT_SEND", 1024) -} - -// pendingTarget identifies a (message, domain) pair that needs delivery. -type pendingTarget struct { - MsgID int64 - Domain string -} - -// findPendingTargets discovers (msg_id, domain) pairs with undelivered, -// retryable recipients. This is a lightweight read-only query — row-level -// locks are acquired per-delivery in deliverMessage. -func findPendingTargets() ([]pendingTarget, error) { - db, err := sql.Open("postgres", "") - if err != nil { - return nil, err - } - defer db.Close() - - now := timeutil.TimestampNow().Float64() - - // query both msg_to and msg_add_to for pending targets - rows, err := db.Query(` - SELECT mt.msg_id, mt.addr - FROM msg_to mt - INNER JOIN msg m ON m.id = mt.msg_id - WHERE mt.time_delivered IS NULL - AND m.time_sent IS NOT NULL - AND (mt.response_code IS NULL OR mt.response_code = ANY($4)) - AND (mt.time_last_attempt IS NULL OR ($1 - mt.time_last_attempt) > LEAST($2 * POWER(2.0, GREATEST(mt.attempt_count - 1, 0)::float), $3)) - AND ($1 - m.time_sent) < $3 - UNION ALL - SELECT mat.msg_id, mat.addr - FROM msg_add_to mat - INNER JOIN msg m ON m.id = mat.msg_id - WHERE mat.time_delivered IS NULL - AND m.time_sent IS NOT NULL - AND (mat.response_code IS NULL OR mat.response_code = ANY($4)) - AND (mat.time_last_attempt IS NULL OR ($1 - mat.time_last_attempt) > LEAST($2 * POWER(2.0, GREATEST(mat.attempt_count - 1, 0)::float), $3)) - AND ($1 - m.time_sent) < $3 - `, now, RetryInterval, RetryMaxAge, pq.Array(retryableResponseCodes)) - if err != nil { - return nil, err - } - defer rows.Close() - - type key struct { - msgID int64 - domain string - } - seen := make(map[key]bool) - var targets []pendingTarget - - for rows.Next() { - var msgID int64 - var addr string - if err := rows.Scan(&msgID, &addr); err != nil { - return nil, err - } - lastAt := strings.LastIndex(addr, "@") - if lastAt == -1 { - continue - } - domain := addr[lastAt+1:] - if strings.EqualFold(domain, Domain) { - continue // local domain — no remote delivery needed - } - k := key{msgID, domain} - if !seen[k] { - seen[k] = true - targets = append(targets, pendingTarget{MsgID: msgID, Domain: domain}) - } - } - return targets, rows.Err() -} - -// sendMsgData transmits the message body then all attachment payloads on conn. -func sendMsgData(conn net.Conn, h *FMsgHeader) error { - fd, err := os.Open(h.Filepath) - if err != nil { - return fmt.Errorf("opening data file %s: %w", h.Filepath, err) - } - defer fd.Close() - - conn.SetWriteDeadline(time.Now().Add(calcNetIODuration(int(h.Size), MinUploadRate))) - if _, err := io.CopyN(conn, fd, int64(h.Size)); err != nil { - return fmt.Errorf("sending data: %w", err) - } - for _, att := range h.Attachments { - af, err := os.Open(att.Filepath) - if err != nil { - return fmt.Errorf("opening attachment %s: %w", att.Filename, err) - } - _, copyErr := io.CopyN(conn, af, int64(att.Size)) - af.Close() - if copyErr != nil { - return fmt.Errorf("sending attachment %s: %w", att.Filename, copyErr) - } - } - return nil -} - -// updateRecipient records a delivery outcome for one address in table. -// Deliveries set time_delivered; failures set time_last_attempt and increment -// attempt_count to drive exponential back-off on subsequent retries. -func updateRecipient(tx *sql.Tx, table, addr string, msgID int64, now float64, code int, delivered bool) { - var err error - if delivered { - _, err = tx.Exec(fmt.Sprintf(` - UPDATE %s SET time_delivered = $1, response_code = $2 - WHERE msg_id = $3 AND addr = $4 - `, table), now, code, msgID, addr) - } else { - _, err = tx.Exec(fmt.Sprintf(` - UPDATE %s SET time_last_attempt = $1, response_code = $2, - attempt_count = attempt_count + 1 - WHERE msg_id = $3 AND addr = $4 - `, table), now, code, msgID, addr) - } - if err != nil { - log.Printf("ERROR: sender: update recipient %s: %s", addr, err) - } -} - -// updateAllLocked applies the same outcome to every locked to and add-to address. -func updateAllLocked(tx *sql.Tx, lockedAddrs, lockedAddToAddrs []string, msgID int64, now float64, code int, delivered bool) { - for _, a := range lockedAddrs { - updateRecipient(tx, "msg_to", a, msgID, now, code, delivered) - } - for _, a := range lockedAddToAddrs { - updateRecipient(tx, "msg_add_to", a, msgID, now, code, delivered) - } -} - -// commitOrLog commits the transaction and marks it as committed. -func commitOrLog(tx *sql.Tx, committed *bool, msgID int64) { - if err := tx.Commit(); err != nil { - log.Printf("ERROR: sender: commit tx for msg %d: %s", msgID, err) - } else { - *committed = true - } -} - -func recordRetryableFailure(tx *sql.Tx, committed *bool, lockedAddrs, lockedAddToAddrs []string, msgID int64) { - now := timeutil.TimestampNow().Float64() - updateAllLocked(tx, lockedAddrs, lockedAddToAddrs, msgID, now, localResponseCodeNoResponse, false) - commitOrLog(tx, committed, msgID) -} - -// deliverMessage handles delivery of a single message to a single remote domain. -// -// It manages its own database transaction with the following lifecycle: -// - Locks the pending msg_to rows for this (message, domain) via FOR UPDATE SKIP LOCKED. -// - Loads the full message including ALL recipients (for the original wire header). -// - Sends the complete original message to the remote host. -// - On success: updates time_delivered + response_code, commits. -// - On rejection (got response code): updates response_code + time_last_attempt, commits. -// - On early delivery error: records a retryable failure, commits, and backs off. -func deliverMessage(target pendingTarget) { - if strings.EqualFold(target.Domain, Domain) { - // local domain — mark as delivered rather than sending remotely - db, err := sql.Open("postgres", "") - if err != nil { - log.Printf("ERROR: sender: db open for local delivery: %s", err) - return - } - defer db.Close() - now := timeutil.TimestampNow().Float64() - if _, err := db.Exec(` - UPDATE msg_to SET time_delivered = $1, response_code = 200 - WHERE msg_id = $2 AND time_delivered IS NULL - AND lower(split_part(addr, '@', 3)) = lower($3) - `, now, target.MsgID, target.Domain); err != nil { - log.Printf("ERROR: sender: marking local recipients delivered for msg %d: %s", target.MsgID, err) - } - return - } - - db, err := sql.Open("postgres", "") - if err != nil { - log.Printf("ERROR: sender: db open: %s", err) - return - } - defer db.Close() - - tx, err := db.Begin() - if err != nil { - log.Printf("ERROR: sender: begin tx: %s", err) - return - } - committed := false - defer func() { - if !committed { - tx.Rollback() - } - }() - - now := timeutil.TimestampNow().Float64() - - // Lock pending (undelivered, retryable) msg_to rows for this message - // on the target domain. SKIP LOCKED avoids blocking concurrent senders. - lockRows, err := tx.Query(` - SELECT mt.addr - FROM msg_to mt - INNER JOIN msg m ON m.id = mt.msg_id - WHERE mt.msg_id = $1 - AND mt.time_delivered IS NULL - AND m.time_sent IS NOT NULL - AND (mt.response_code IS NULL OR mt.response_code = ANY($5)) - AND (mt.time_last_attempt IS NULL OR ($2 - mt.time_last_attempt) > LEAST($3 * POWER(2.0, GREATEST(mt.attempt_count - 1, 0)::float), $4)) - AND ($2 - m.time_sent) < $4 - FOR UPDATE OF mt SKIP LOCKED - `, target.MsgID, now, RetryInterval, RetryMaxAge, pq.Array(retryableResponseCodes)) - if err != nil { - log.Printf("ERROR: sender: lock rows for msg %d: %s", target.MsgID, err) - return - } - - var lockedAddrs []string - for lockRows.Next() { - var addr string - if err := lockRows.Scan(&addr); err != nil { - lockRows.Close() - log.Printf("ERROR: sender: scan locked addr: %s", err) - return - } - lastAt := strings.LastIndex(addr, "@") - if lastAt != -1 && strings.EqualFold(addr[lastAt+1:], target.Domain) { - lockedAddrs = append(lockedAddrs, addr) - } - } - lockRows.Close() - if err := lockRows.Err(); err != nil { - log.Printf("ERROR: sender: lock rows err for msg %d: %s", target.MsgID, err) - return - } - - // Also lock pending msg_add_to rows for this message on the target domain. - lockAddToRows, err := tx.Query(` - SELECT mat.addr - FROM msg_add_to mat - INNER JOIN msg m ON m.id = mat.msg_id - WHERE mat.msg_id = $1 - AND mat.time_delivered IS NULL - AND m.time_sent IS NOT NULL - AND (mat.response_code IS NULL OR mat.response_code = ANY($5)) - AND (mat.time_last_attempt IS NULL OR ($2 - mat.time_last_attempt) > LEAST($3 * POWER(2.0, GREATEST(mat.attempt_count - 1, 0)::float), $4)) - AND ($2 - m.time_sent) < $4 - FOR UPDATE OF mat SKIP LOCKED - `, target.MsgID, now, RetryInterval, RetryMaxAge, pq.Array(retryableResponseCodes)) - if err != nil { - log.Printf("ERROR: sender: lock add-to rows for msg %d: %s", target.MsgID, err) - return - } - - var lockedAddToAddrs []string - for lockAddToRows.Next() { - var addr string - if err := lockAddToRows.Scan(&addr); err != nil { - lockAddToRows.Close() - log.Printf("ERROR: sender: scan locked add-to addr: %s", err) - return - } - lastAt := strings.LastIndex(addr, "@") - if lastAt != -1 && strings.EqualFold(addr[lastAt+1:], target.Domain) { - lockedAddToAddrs = append(lockedAddToAddrs, addr) - } - } - lockAddToRows.Close() - if err := lockAddToRows.Err(); err != nil { - log.Printf("ERROR: sender: lock add-to rows err for msg %d: %s", target.MsgID, err) - return - } - - if len(lockedAddrs) == 0 && len(lockedAddToAddrs) == 0 { - return // already locked by another sender or no longer eligible - } - - deferRetry := true - defer func() { - if deferRetry && !committed { - recordRetryableFailure(tx, &committed, lockedAddrs, lockedAddToAddrs, target.MsgID) - } - }() - - // Load the full message from msg table - h, err := loadMsg(tx, target.MsgID) - if err != nil { - log.Printf("ERROR: sender: %s", err) - return - } - - // Try zlib-deflate compression for message data and attachment data. - // Compressed temp files are cleaned up after delivery completes. - var deflateCleanup []string - defer func() { - for _, p := range deflateCleanup { - _ = os.Remove(p) - } - }() - if shouldCompress(h.Type, h.Size) { - dp, cs, ok, derr := tryCompress(h.Filepath, h.Size) - if derr != nil { - log.Printf("WARN: sender: compress msg data for msg %d: %s", target.MsgID, derr) - } else if ok { - log.Printf("INFO: sender: compressed msg %d data: %d -> %d bytes", target.MsgID, h.Size, cs) - deflateCleanup = append(deflateCleanup, dp) - h.Filepath = dp - h.ExpandedSize = h.Size - h.Size = cs - h.Flags |= FlagDeflate - } - } - for i := range h.Attachments { - att := &h.Attachments[i] - if shouldCompress(att.Type, att.Size) { - dp, cs, ok, derr := tryCompress(att.Filepath, att.Size) - if derr != nil { - log.Printf("WARN: sender: compress attachment %s for msg %d: %s", att.Filename, target.MsgID, derr) - } else if ok { - log.Printf("INFO: sender: compressed msg %d attachment %s: %d -> %d bytes", target.MsgID, att.Filename, att.Size, cs) - deflateCleanup = append(deflateCleanup, dp) - att.Filepath = dp - att.ExpandedSize = att.Size - att.Size = cs - att.Flags |= 1 << 1 - } - } - } - - // Ensure sha256 is populated for outgoing messages so future pid lookups - // (e.g. add-to notifications referencing this message) can find it. - msgHash, err := h.GetMessageHash() - if err != nil { - log.Printf("ERROR: sender: computing message hash for msg %d: %s", target.MsgID, err) - return - } - if _, err := tx.Exec(`UPDATE msg SET sha256 = $1 WHERE id = $2 AND sha256 IS NULL`, - msgHash, target.MsgID); err != nil { - log.Printf("ERROR: sender: storing sha256 for msg %d: %s", target.MsgID, err) - return - } - if err := resolvePendingChildLinks(txParentLinkStore{tx: tx}, target.MsgID, msgHash); err != nil { - log.Printf("ERROR: sender: resolving child pids for msg %d: %s", target.MsgID, err) - return - } - - // Compute header hash now; registerOutgoing with Host B's IP happens after - // the connection is established (IP needed for challenge validation §10.5). - hash := h.GetHeaderHash() - hashArr := *(*[32]byte)(hash) - - // Build the list of recipients on the target domain in to then add-to order. - // Per spec, per-recipient response codes follow the same order. - lockedSet := make(map[string]bool) - for _, a := range lockedAddrs { - lockedSet[strings.ToLower(a)] = true - } - for _, a := range lockedAddToAddrs { - lockedSet[strings.ToLower(a)] = true - } - type domainRecip struct { - addr string - isLocked bool - isAddTo bool - } - var domainRecips []domainRecip - for _, addr := range h.To { - if strings.EqualFold(addr.Domain, target.Domain) { - s := addr.ToString() - domainRecips = append(domainRecips, domainRecip{ - addr: s, - isLocked: lockedSet[strings.ToLower(s)], - isAddTo: false, - }) - } - } - for _, addr := range h.AddTo { - if strings.EqualFold(addr.Domain, target.Domain) { - s := addr.ToString() - domainRecips = append(domainRecips, domainRecip{ - addr: s, - isLocked: lockedSet[strings.ToLower(s)], - isAddTo: true, - }) - } - } - - // --- network delivery --- - - targetIPs, err := lookupAuthorisedIPs(target.Domain) - if err != nil { - log.Printf("ERROR: sender: DNS lookup for fmsg.%s failed: %s", target.Domain, err) - return - } - - var conn net.Conn - dialer := &net.Dialer{Timeout: 10 * time.Second} - tlsConf := buildClientTLSConfig("fmsg." + target.Domain) - for _, ip := range targetIPs { - addr := net.JoinHostPort(ip.String(), fmt.Sprintf("%d", RemotePort)) - conn, err = tls.DialWithDialer(dialer, "tcp", addr, tlsConf) - if err == nil { - break - } - log.Printf("WARN: sender: connect to %s failed: %s", addr, err) - } - if conn == nil { - log.Printf("ERROR: sender: could not connect to any IP for fmsg.%s", target.Domain) - return - } - defer conn.Close() - - // Register in outgoing map with Host B's IP before sending the header so - // any incoming challenge can be matched by hash AND IP (§10.2 step 2). - connectedIP := conn.RemoteAddr().(*net.TCPAddr).IP.String() - log.Printf("INFO: sender: registering outgoing message %s (%s)", hex.EncodeToString(hashArr[:]), connectedIP) - registerOutgoing(hashArr, h, connectedIP) - defer removeOutgoingIP(hashArr, connectedIP) - - // Step 3: Transmit message header. - if _, err := conn.Write(h.Encode()); err != nil { - log.Printf("ERROR: sender: writing header for msg %d: %s", target.MsgID, err) - return - } - - // Step 5: Read the initial response byte before sending any data (§10.2 step 5). - // The challenge handler may fire on a separate goroutine during this wait. - conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - initCode := make([]byte, 1) - if _, err := io.ReadFull(conn, initCode); err != nil { - log.Printf("ERROR: sender: reading initial response for msg %d: %s", target.MsgID, err) - return - } - now = timeutil.TimestampNow().Float64() - isAddToMsg := h.Flags&FlagHasAddTo != 0 - - switch initCode[0] { - case AcceptCodeContinue: // 64 — send data + attachments, then per-recipient codes - if err := sendMsgData(conn, h); err != nil { - log.Printf("ERROR: sender: %s (msg %d)", err, target.MsgID) - return - } - case AcceptCodeSkipData: // 65 — add-to, parent stored, recipients on this host; skip data - if !isAddToMsg { - log.Printf("WARN: sender: msg %d received protocol-invalid code 65 from %s for non-add-to message, terminating", - target.MsgID, target.Domain) - return - } - // do not transmit data; per-recipient codes follow below - case AcceptCodeAddTo: // 11 — add-to accepted, no recipients on this host - if !isAddToMsg { - log.Printf("WARN: sender: msg %d received protocol-invalid code 11 from %s for non-add-to message, terminating", - target.MsgID, target.Domain) - return - } - log.Printf("INFO: sender: msg %d add-to accepted by %s (code 11)", target.MsgID, target.Domain) - updateAllLocked(tx, lockedAddrs, lockedAddToAddrs, target.MsgID, now, int(initCode[0]), true) - deferRetry = false - commitOrLog(tx, &committed, target.MsgID) - return - default: - if initCode[0] >= 1 && initCode[0] <= 10 { - // global rejection - log.Printf("WARN: sender: msg %d rejected by %s: %s (%d)", - target.MsgID, target.Domain, responseCodeName(initCode[0]), initCode[0]) - updateAllLocked(tx, lockedAddrs, lockedAddToAddrs, target.MsgID, now, int(initCode[0]), false) - deferRetry = false - commitOrLog(tx, &committed, target.MsgID) - } else { - // unexpected code — TERMINATE - log.Printf("WARN: sender: msg %d unexpected response code %d from %s, terminating", - target.MsgID, initCode[0], target.Domain) - } - return - } - - // Step 6: Read one per-recipient code per recipient on this host, in - // to-field order then add-to order (§10.2 step 6). - conn.SetReadDeadline(time.Now().Add(30 * time.Second)) - codes := make([]byte, len(domainRecips)) - if _, err := io.ReadFull(conn, codes); err != nil { - log.Printf("ERROR: sender: reading per-recipient codes for msg %d: %s", target.MsgID, err) - return - } - now = timeutil.TimestampNow().Float64() - - for i, dr := range domainRecips { - if !dr.isLocked { - continue - } - c := codes[i] - table := "msg_to" - if dr.isAddTo { - table = "msg_add_to" - } - delivered := c == RejectCodeAccept - if delivered { - log.Printf("INFO: sender: delivered msg %d to %s", target.MsgID, dr.addr) - } else { - log.Printf("WARN: sender: msg %d to %s: %s (%d)", target.MsgID, dr.addr, responseCodeName(c), c) - } - updateRecipient(tx, table, dr.addr, target.MsgID, now, int(c), delivered) - } - - deferRetry = false - commitOrLog(tx, &committed, target.MsgID) -} - -// processPendingMessages finds messages needing delivery and dispatches a -// goroutine per (message, domain) pair, bounded by the semaphore. -func processPendingMessages(sem chan struct{}) { - targets, err := findPendingTargets() - if err != nil { - log.Printf("ERROR: sender: finding pending targets: %s", err) - return - } - if len(targets) == 0 { - return - } - log.Printf("INFO: sender: found %d pending target(s)", len(targets)) - - for _, t := range targets { - sem <- struct{}{} // acquire - go func(t pendingTarget) { - defer func() { <-sem }() - deliverMessage(t) - }(t) - } -} - -// startSender runs the sender loop: polls the database periodically and also -// listens for PostgreSQL notifications for immediate pickup of new messages. -func startSender() { - loadSenderEnvConfig() - log.Printf("INFO: sender: started (poll=%ds, retry=%.0fs, max_concurrent=%d)", - PollInterval, RetryInterval, MaxConcurrentSend) - - sem := make(chan struct{}, MaxConcurrentSend) - - // set up PostgreSQL LISTEN for immediate notification - notifyCh := make(chan struct{}, 1) - go func() { - listener := pq.NewListener("", 10*time.Second, time.Minute, func(ev pq.ListenerEventType, err error) { - if err != nil { - log.Printf("ERROR: sender: pg listener: %s", err) - } - }) - if err := listener.Listen("new_msg_to"); err != nil { - log.Printf("ERROR: sender: could not LISTEN on new_msg_to: %s", err) - return - } - defer listener.Close() - log.Println("INFO: sender: listening for new_msg_to notifications") - for { - select { - case n := <-listener.Notify: - if n != nil { - log.Printf("INFO: sender: notification received: %s", n.Extra) - select { - case notifyCh <- struct{}{}: - default: - } - } - case <-time.After(32 * time.Second): - // ping to keep connection alive - if err := listener.Ping(); err != nil { - log.Printf("ERROR: sender: pg listener ping: %s", err) - } - } - } - }() - - // initial poll on startup - processPendingMessages(sem) - - ticker := time.NewTicker(time.Duration(PollInterval) * time.Second) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - processPendingMessages(sem) - case <-notifyCh: - // small delay to batch rapid inserts - time.Sleep(256 * time.Millisecond) - processPendingMessages(sem) - } - } -} diff --git a/src/store.go b/src/store.go deleted file mode 100644 index b2689f2..0000000 --- a/src/store.go +++ /dev/null @@ -1,592 +0,0 @@ -package main - -import ( - "database/sql" - "fmt" - "log" - "strings" - - "github.com/levenlabs/golib/timeutil" - _ "github.com/lib/pq" -) - -func testDb() error { - db, err := sql.Open("postgres", "") - if err != nil { - return err - } - defer db.Close() - err = db.Ping() - if err != nil { - return err - } - - var dbName, user, host, port string - _ = db.QueryRow("SELECT current_database()").Scan(&dbName) - _ = db.QueryRow("SELECT current_user").Scan(&user) - _ = db.QueryRow("SELECT inet_server_addr()::text").Scan(&host) - _ = db.QueryRow("SELECT inet_server_port()::text").Scan(&port) - log.Printf("INFO: Database connected: %s@%s:%s/%s", user, host, port, dbName) - - // verify required tables exist - for _, table := range []string{"msg", "msg_to", "msg_add_to", "msg_attachment"} { - var exists bool - err = db.QueryRow(`SELECT EXISTS ( - SELECT FROM information_schema.tables - WHERE table_name = $1 - )`, table).Scan(&exists) - if err != nil { - return fmt.Errorf("checking table %s: %w", table, err) - } - if !exists { - return fmt.Errorf("required table %s does not exist", table) - } - } - return nil -} - -// lookupMsgIdByHash returns the msg id for a message with the given SHA256 hash, -// or 0 if no such message exists. -func lookupMsgIdByHash(hash []byte) (int64, error) { - db, err := sql.Open("postgres", "") - if err != nil { - return 0, err - } - defer db.Close() - - var id int64 - err = db.QueryRow("SELECT id FROM msg WHERE sha256 = $1", hash).Scan(&id) - if err == sql.ErrNoRows { - return 0, nil - } - return id, err -} - -// hasAddrReceivedMsgHash reports whether addr has already received a stored -// message identified by hash. -func hasAddrReceivedMsgHash(hash []byte, addr *FMsgAddress) (bool, error) { - if addr == nil || len(hash) == 0 { - return false, nil - } - - db, err := sql.Open("postgres", "") - if err != nil { - return false, err - } - defer db.Close() - - addrStr := strings.ToLower(addr.ToString()) - - var exists bool - err = db.QueryRow(` - SELECT EXISTS ( - SELECT 1 - FROM msg m - JOIN msg_to mt ON mt.msg_id = m.id - WHERE m.sha256 = $1 - AND lower(mt.addr) = $2 - AND mt.time_delivered IS NOT NULL - UNION ALL - SELECT 1 - FROM msg m - JOIN msg_add_to mat ON mat.msg_id = m.id - WHERE m.sha256 = $1 - AND lower(mat.addr) = $2 - AND mat.time_delivered IS NOT NULL - ) - `, hash, addrStr).Scan(&exists) - if err != nil { - return false, err - } - - return exists, nil -} - -type parentLinkStore interface { - lookupParentID(parentHash []byte) (int64, error) - setParentID(msgID int64, parentID int64) error - setPendingChildrenParentID(parentID int64, parentHash []byte) error -} - -type txParentLinkStore struct { - tx *sql.Tx -} - -func (s txParentLinkStore) lookupParentID(parentHash []byte) (int64, error) { - var id int64 - err := s.tx.QueryRow("SELECT id FROM msg WHERE sha256 = $1", parentHash).Scan(&id) - if err == sql.ErrNoRows { - return 0, nil - } - return id, err -} - -func (s txParentLinkStore) setParentID(msgID int64, parentID int64) error { - _, err := s.tx.Exec("UPDATE msg SET pid = $1 WHERE id = $2", parentID, msgID) - return err -} - -func (s txParentLinkStore) setPendingChildrenParentID(parentID int64, parentHash []byte) error { - _, err := s.tx.Exec("UPDATE msg SET pid = $1 WHERE psha256 = $2 AND pid IS NULL", parentID, parentHash) - return err -} - -func resolveStoredParent(store parentLinkStore, msgID int64, parentHash []byte, requireParent bool) error { - if len(parentHash) == 0 { - return nil - } - - parentID, err := store.lookupParentID(parentHash) - if err != nil { - return err - } - if parentID == 0 { - if requireParent { - return fmt.Errorf("parent message not found for psha256 %x", parentHash) - } - return nil - } - - return store.setParentID(msgID, parentID) -} - -func resolvePendingChildLinks(store parentLinkStore, parentID int64, parentHash []byte) error { - if len(parentHash) == 0 { - return nil - } - return store.setPendingChildrenParentID(parentID, parentHash) -} - -func resolveMsgParentLinks(tx *sql.Tx, msgID int64, msgHash []byte, parentHash []byte, requireParent bool) error { - store := txParentLinkStore{tx: tx} - if err := resolveStoredParent(store, msgID, parentHash, requireParent); err != nil { - return err - } - return resolvePendingChildLinks(store, msgID, msgHash) -} - -func requiresStoredParent(msg *FMsgHeader) bool { - return len(msg.Pid) > 0 && msg.Flags&FlagHasAddTo == 0 -} - -func wirePidForLoadedMessage(storedParentHash []byte, msgHash []byte, hasAddTo bool) []byte { - if hasAddTo { - return msgHash - } - return storedParentHash -} - -// getMsgByID loads a message and all its recipients from the database by msg ID. -// Returns the full FMsgHeader or nil if the message doesn't exist. -func getMsgByID(msgID int64) (*FMsgHeader, error) { - db, err := sql.Open("postgres", "") - if err != nil { - return nil, err - } - defer db.Close() - - tx, err := db.Begin() - if err != nil { - return nil, err - } - defer tx.Rollback() - - h, err := loadMsg(tx, msgID) - if err != nil { - // If the message doesn't exist, loadMsg will return an error, - // but we want to distinguish "not found" from other errors - if err.Error() == "no rows in result set" || err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - - return h, nil -} - -func storeMsgDetail(msg *FMsgHeader) error { - - db, err := sql.Open("postgres", "") - if err != nil { - return err - } - defer db.Close() - - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - msgHash, err := msg.GetMessageHash() - if err != nil { - return err - } - - var addToFrom interface{} - if msg.AddToFrom != nil { - addToFrom = msg.AddToFrom.ToString() - } - - var msgID int64 - err = tx.QueryRow(`insert into msg (version - , no_reply - , is_important - , is_deflate - , time_sent - , from_addr - , add_to_from - , topic - , type - , sha256 - , psha256 - , size - , filepath) -values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) -returning id`, - msg.Version, - msg.Flags&FlagNoReply != 0, - msg.Flags&FlagImportant != 0, - msg.Flags&FlagDeflate != 0, - msg.Timestamp, - msg.From.ToString(), - addToFrom, - msg.Topic, - msg.Type, - msgHash, - msg.Pid, - int(msg.Size), - msg.Filepath).Scan(&msgID) - if err != nil { - return err - } - - stmt, err := tx.Prepare(`insert into msg_to (msg_id, addr, time_delivered) -values ($1, $2, $3)`) - if err != nil { - return err - } - defer stmt.Close() - - now := timeutil.TimestampNow().Float64() - for _, addr := range msg.To { - // recipients on our domain are already delivered; others are pending - var delivered interface{} - if addr.Domain == Domain { - delivered = now - } - if _, err := stmt.Exec(msgID, addr.ToString(), delivered); err != nil { - return err - } - } - - // insert add-to recipients into msg_add_to - if len(msg.AddTo) > 0 { - addToStmt, err := tx.Prepare(`insert into msg_add_to (msg_id, addr, time_delivered) -values ($1, $2, $3)`) - if err != nil { - return err - } - defer addToStmt.Close() - - for _, addr := range msg.AddTo { - var delivered interface{} - if addr.Domain == Domain { - delivered = now - } - if _, err := addToStmt.Exec(msgID, addr.ToString(), delivered); err != nil { - return err - } - } - } - - if len(msg.Attachments) > 0 { - attStmt, err := tx.Prepare(`insert into msg_attachment (msg_id, position, flags, type, filename, filesize, filepath) -values ($1, $2, $3, $4, $5, $6, $7)`) - if err != nil { - return err - } - defer attStmt.Close() - - for i := range msg.Attachments { - att := msg.Attachments[i] - if _, err := attStmt.Exec(msgID, i, int(att.Flags), att.Type, att.Filename, int(att.Size), att.Filepath); err != nil { - return err - } - } - } - - if err := resolveMsgParentLinks(tx, msgID, msgHash, msg.Pid, requiresStoredParent(msg)); err != nil { - return err - } - - return tx.Commit() - -} - -// storeMsgHeaderOnly stores just the message header for add-to notifications -// (spec code 11). Only the header is recorded so the header hash can be -// faithfully computed for subsequent messages referencing this one via pid. -func storeMsgHeaderOnly(msg *FMsgHeader) error { - db, err := sql.Open("postgres", "") - if err != nil { - return err - } - defer db.Close() - - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - msgHash, err := msg.GetMessageHash() - if err != nil { - return err - } - - var addToFrom interface{} - if msg.AddToFrom != nil { - addToFrom = msg.AddToFrom.ToString() - } - - var msgID int64 - err = tx.QueryRow(`insert into msg (version - , no_reply - , is_important - , is_deflate - , time_sent - , from_addr - , add_to_from - , topic - , type - , sha256 - , psha256 - , size - , filepath) -values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) -returning id`, - msg.Version, - msg.Flags&FlagNoReply != 0, - msg.Flags&FlagImportant != 0, - msg.Flags&FlagDeflate != 0, - msg.Timestamp, - msg.From.ToString(), - addToFrom, - msg.Topic, - msg.Type, - msgHash, - msg.Pid, - int(msg.Size), - "").Scan(&msgID) - if err != nil { - return err - } - - // insert to recipients (for record keeping) - toStmt, err := tx.Prepare(`insert into msg_to (msg_id, addr) values ($1, $2)`) - if err != nil { - return err - } - defer toStmt.Close() - for _, addr := range msg.To { - if _, err := toStmt.Exec(msgID, addr.ToString()); err != nil { - return err - } - } - - // insert add-to recipients - if len(msg.AddTo) > 0 { - addToStmt, err := tx.Prepare(`insert into msg_add_to (msg_id, addr) values ($1, $2)`) - if err != nil { - return err - } - defer addToStmt.Close() - for _, addr := range msg.AddTo { - if _, err := addToStmt.Exec(msgID, addr.ToString()); err != nil { - return err - } - } - } - - if len(msg.Attachments) > 0 { - attStmt, err := tx.Prepare(`insert into msg_attachment (msg_id, position, flags, type, filename, filesize, filepath) -values ($1, $2, $3, $4, $5, $6, $7)`) - if err != nil { - return err - } - defer attStmt.Close() - - for i := range msg.Attachments { - att := msg.Attachments[i] - if _, err := attStmt.Exec(msgID, i, int(att.Flags), att.Type, att.Filename, int(att.Size), att.Filepath); err != nil { - return err - } - } - } - - if err := resolveMsgParentLinks(tx, msgID, msgHash, msg.Pid, requiresStoredParent(msg)); err != nil { - return err - } - - return tx.Commit() -} - -// loadMsg loads a message and all its recipients from the database within the -// given transaction and returns a fully populated FMsgHeader. -func loadMsg(tx *sql.Tx, msgID int64) (*FMsgHeader, error) { - var version, size int - var noReply, isImportant, isDeflate bool - var pid, msgHash []byte - var fromAddr, topic, typ, filepath string - var addToFromAddr sql.NullString - var timeSent float64 - err := tx.QueryRow(` - SELECT version, no_reply, is_important, is_deflate, psha256, sha256, from_addr, add_to_from, topic, type, time_sent, size, filepath - FROM msg WHERE id = $1 - `, msgID).Scan(&version, &noReply, &isImportant, &isDeflate, &pid, &msgHash, &fromAddr, &addToFromAddr, &topic, &typ, &timeSent, &size, &filepath) - if err != nil { - return nil, fmt.Errorf("load msg %d: %w", msgID, err) - } - - recipRows, err := tx.Query(`SELECT addr FROM msg_to WHERE msg_id = $1 ORDER BY id`, msgID) - if err != nil { - return nil, fmt.Errorf("load recipients for msg %d: %w", msgID, err) - } - var allRecipientAddrs []string - for recipRows.Next() { - var a string - if err := recipRows.Scan(&a); err != nil { - recipRows.Close() - return nil, fmt.Errorf("scan recipient addr: %w", err) - } - allRecipientAddrs = append(allRecipientAddrs, a) - } - recipRows.Close() - if err := recipRows.Err(); err != nil { - return nil, fmt.Errorf("recipients query err for msg %d: %w", msgID, err) - } - - from, err := parseAddress([]byte(fromAddr)) - if err != nil { - return nil, fmt.Errorf("invalid from address %s: %w", fromAddr, err) - } - allTo := make([]FMsgAddress, 0, len(allRecipientAddrs)) - for _, a := range allRecipientAddrs { - addr, err := parseAddress([]byte(a)) - if err != nil { - return nil, fmt.Errorf("invalid to address %s: %w", a, err) - } - allTo = append(allTo, *addr) - } - - // load add-to recipients from msg_add_to - addToRows, err := tx.Query(`SELECT addr FROM msg_add_to WHERE msg_id = $1 ORDER BY id`, msgID) - if err != nil { - return nil, fmt.Errorf("load add-to recipients for msg %d: %w", msgID, err) - } - var allAddTo []FMsgAddress - for addToRows.Next() { - var a string - if err := addToRows.Scan(&a); err != nil { - addToRows.Close() - return nil, fmt.Errorf("scan add-to addr: %w", err) - } - addr, err := parseAddress([]byte(a)) - if err != nil { - addToRows.Close() - return nil, fmt.Errorf("invalid add-to address %s: %w", a, err) - } - allAddTo = append(allAddTo, *addr) - } - addToRows.Close() - if err := addToRows.Err(); err != nil { - return nil, fmt.Errorf("add-to recipients query err for msg %d: %w", msgID, err) - } - - attRows, err := tx.Query(` - SELECT flags, type, filename, filesize, filepath - FROM msg_attachment - WHERE msg_id = $1 - ORDER BY position, filename - `, msgID) - if err != nil { - return nil, fmt.Errorf("load attachments for msg %d: %w", msgID, err) - } - attachments := []FMsgAttachmentHeader{} - for attRows.Next() { - var flags, filesize int - var typ, filename, filepath string - if err := attRows.Scan(&flags, &typ, &filename, &filesize, &filepath); err != nil { - attRows.Close() - return nil, fmt.Errorf("scan attachment row: %w", err) - } - attachments = append(attachments, FMsgAttachmentHeader{ - Flags: uint8(flags), - Type: typ, - Filename: filename, - Size: uint32(filesize), - Filepath: filepath, - }) - } - attRows.Close() - if err := attRows.Err(); err != nil { - return nil, fmt.Errorf("attachments query err for msg %d: %w", msgID, err) - } - - // Compute flags bitfield from stored booleans and loaded data. - // has_pid and has_add_to are derived from actual data rather than stored, - // so add-to recipients added after the original message are included. - // - // When add-to recipients exist, the wire pid references the message being - // shared, not that message's parent. This keeps add-to on replies pointing - // at the reply payload rather than the root message. - pid = wirePidForLoadedMessage(pid, msgHash, len(allAddTo) > 0) - - var addToFrom *FMsgAddress - if addToFromAddr.Valid && addToFromAddr.String != "" { - addr, err := parseAddress([]byte(addToFromAddr.String)) - if err != nil { - return nil, fmt.Errorf("invalid add_to_from address %s: %w", addToFromAddr.String, err) - } - addToFrom = addr - } - if len(allAddTo) > 0 && addToFrom == nil { - // Backward-compatibility for older rows before add_to_from existed. - fallback := *from - addToFrom = &fallback - } - - var flags uint8 - if len(pid) > 0 { - flags |= FlagHasPid - } - if len(allAddTo) > 0 { - flags |= FlagHasAddTo - } - if noReply { - flags |= FlagNoReply - } - if isImportant { - flags |= FlagImportant - } - if isDeflate { - flags |= FlagDeflate - } - - return &FMsgHeader{ - Version: uint8(version), - Flags: flags, - Pid: pid, - From: *from, - To: allTo, - AddToFrom: addToFrom, - AddTo: allAddTo, - Timestamp: timeSent, - Topic: topic, - Type: typ, - Size: uint32(size), - Attachments: attachments, - Filepath: filepath, - }, nil -} diff --git a/src/store_test.go b/src/store_test.go deleted file mode 100644 index 6bfdf98..0000000 --- a/src/store_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package main - -import ( - "bytes" - "errors" - "testing" -) - -type fakeParentLinkStore struct { - parentID int64 - lookupErr error - - lookupHash []byte - setMsgID int64 - setParentIDValue int64 - setCalled bool - pendingParentID int64 - pendingParentHash []byte - pendingCalled bool -} - -func (s *fakeParentLinkStore) lookupParentID(parentHash []byte) (int64, error) { - s.lookupHash = append([]byte(nil), parentHash...) - return s.parentID, s.lookupErr -} - -func (s *fakeParentLinkStore) setParentID(msgID int64, parentID int64) error { - s.setCalled = true - s.setMsgID = msgID - s.setParentIDValue = parentID - return nil -} - -func (s *fakeParentLinkStore) setPendingChildrenParentID(parentID int64, parentHash []byte) error { - s.pendingCalled = true - s.pendingParentID = parentID - s.pendingParentHash = append([]byte(nil), parentHash...) - return nil -} - -func TestResolveStoredParentRequiresExistingParent(t *testing.T) { - store := &fakeParentLinkStore{} - parentHash := []byte{1, 2, 3} - - err := resolveStoredParent(store, 10, parentHash, true) - if err == nil { - t.Fatal("resolveStoredParent returned nil error for required missing parent") - } - if !bytes.Equal(store.lookupHash, parentHash) { - t.Fatalf("lookup hash = %v, want %v", store.lookupHash, parentHash) - } - if store.setCalled { - t.Fatal("setParentID was called for missing parent") - } -} - -func TestResolveStoredParentAllowsOptionalMissingParent(t *testing.T) { - store := &fakeParentLinkStore{} - - if err := resolveStoredParent(store, 10, []byte{1, 2, 3}, false); err != nil { - t.Fatalf("resolveStoredParent returned error for optional missing parent: %v", err) - } - if store.setCalled { - t.Fatal("setParentID was called for optional missing parent") - } -} - -func TestResolveStoredParentSetsPidWhenParentExists(t *testing.T) { - store := &fakeParentLinkStore{parentID: 42} - - if err := resolveStoredParent(store, 10, []byte{1, 2, 3}, true); err != nil { - t.Fatalf("resolveStoredParent returned error: %v", err) - } - if !store.setCalled { - t.Fatal("setParentID was not called") - } - if store.setMsgID != 10 || store.setParentIDValue != 42 { - t.Fatalf("setParentID called with msgID=%d parentID=%d, want msgID=10 parentID=42", store.setMsgID, store.setParentIDValue) - } -} - -func TestResolveStoredParentPropagatesLookupError(t *testing.T) { - lookupErr := errors.New("lookup failed") - store := &fakeParentLinkStore{lookupErr: lookupErr} - - err := resolveStoredParent(store, 10, []byte{1, 2, 3}, true) - if !errors.Is(err, lookupErr) { - t.Fatalf("resolveStoredParent error = %v, want %v", err, lookupErr) - } - if store.setCalled { - t.Fatal("setParentID was called after lookup error") - } -} - -func TestResolvePendingChildLinksBackfillsByParentHash(t *testing.T) { - store := &fakeParentLinkStore{} - parentHash := []byte{4, 5, 6} - - if err := resolvePendingChildLinks(store, 42, parentHash); err != nil { - t.Fatalf("resolvePendingChildLinks returned error: %v", err) - } - if !store.pendingCalled { - t.Fatal("setPendingChildrenParentID was not called") - } - if store.pendingParentID != 42 || !bytes.Equal(store.pendingParentHash, parentHash) { - t.Fatalf("pending update got parentID=%d hash=%v, want parentID=42 hash=%v", store.pendingParentID, store.pendingParentHash, parentHash) - } -} - -func TestRequiresStoredParentUsesAddToFlag(t *testing.T) { - parentHash := []byte{1, 2, 3} - - if !requiresStoredParent(&FMsgHeader{Flags: FlagHasPid, Pid: parentHash}) { - t.Fatal("normal reply did not require stored parent") - } - if requiresStoredParent(&FMsgHeader{Flags: FlagHasPid | FlagHasAddTo, Pid: parentHash}) { - t.Fatal("add-to message required stored parent") - } -} - -func TestWirePidForLoadedMessageAddToReferencesSharedMessage(t *testing.T) { - parentHash := []byte{1, 2, 3} - msgHash := []byte{4, 5, 6} - - got := wirePidForLoadedMessage(parentHash, msgHash, true) - if !bytes.Equal(got, msgHash) { - t.Fatalf("add-to wire pid = %v, want message hash %v", got, msgHash) - } -} - -func TestWirePidForLoadedMessageReplyKeepsParentHash(t *testing.T) { - parentHash := []byte{1, 2, 3} - msgHash := []byte{4, 5, 6} - - got := wirePidForLoadedMessage(parentHash, msgHash, false) - if !bytes.Equal(got, parentHash) { - t.Fatalf("reply wire pid = %v, want parent hash %v", got, parentHash) - } -} From 4862bd4f863540ee7a74569bb2780afb6f4dee34 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Sat, 16 May 2026 13:23:31 +0800 Subject: [PATCH 2/4] update readme and workflows to use cmd/ --- .github/workflows/go1.25.yml | 7 ++----- README.md | 3 +-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/go1.25.yml b/.github/workflows/go1.25.yml index 2d9436a..6ddf095 100644 --- a/.github/workflows/go1.25.yml +++ b/.github/workflows/go1.25.yml @@ -10,9 +10,6 @@ jobs: build: runs-on: ubuntu-latest - defaults: - run: - working-directory: ./src steps: - uses: actions/checkout@v3 @@ -22,7 +19,7 @@ jobs: go-version: '1.25' - name: Build - run: go build -v . + run: go build -v ./... - name: Test - run: go test -v . + run: go test -v ./... diff --git a/README.md b/README.md index 5c9019c..d56a3db 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,7 @@ Implementation of [fmsg](https://github.com/markmnl/fmsg) host written in Go! Us Tested with Go 1.25 on Linux and Windows, AMD64 and ARM 1. Clone this repository -2. Navigate to src/ -2. Run `go build .` +2. Run `go build ./cmd/fmsgd/` ## Environment From dfd94e28512461a0c91a8f4f3fdd479a512ec5a6 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Sat, 16 May 2026 13:44:25 +0800 Subject: [PATCH 3/4] pkg/protocol --> pkg/fmsg --- AGENTS.md | 2 +- pkg/fmsg/README.md | 57 +++ pkg/{protocol/protocol.go => fmsg/fmsg.go} | 44 +- pkg/fmsg/fmsg_test.go | 547 +++++++++++++++++++++ pkg/protocol/README.md | 40 -- 5 files changed, 626 insertions(+), 64 deletions(-) create mode 100644 pkg/fmsg/README.md rename pkg/{protocol/protocol.go => fmsg/fmsg.go} (90%) create mode 100644 pkg/fmsg/fmsg_test.go delete mode 100644 pkg/protocol/README.md diff --git a/AGENTS.md b/AGENTS.md index 1accc1e..4f4029e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -21,6 +21,6 @@ All code MUST conform to the specification. When in doubt, re-read SPEC.md and f - Language: Go - Module path: repo root (`go.mod` at `fmsgd/go.mod`) - Binary source: `cmd/fmsgd/` -- Shared protocol package: `pkg/protocol/` +- Shared protocol package: `pkg/fmsg/` - Build: `go build ./...` (from repo root) - Test: `go test ./...` (from repo root) diff --git a/pkg/fmsg/README.md b/pkg/fmsg/README.md new file mode 100644 index 0000000..3294d47 --- /dev/null +++ b/pkg/fmsg/README.md @@ -0,0 +1,57 @@ +# fmsg + +Go package implementing the [fmsg protocol](../../SPEC.md) message types, wire-format encoding, and SHA-256 hashing. + +``` +import "github.com/markmnl/fmsgd/pkg/fmsg" +``` + +## Types + +| Type | Description | +|---|---| +| `fmsg.Header` | All fields of an fmsg message header | +| `fmsg.Address` | fmsg address (`@user@domain`) | +| `fmsg.AttachmentHeader` | Wire-level metadata for a single attachment | + +## Flag constants + +```go +fmsg.FlagHasPid // bit 0: message is a reply (pid field present) +fmsg.FlagHasAddTo // bit 1: add-to addresses present +fmsg.FlagCommonType // bit 2: type encoded as common media type ID +fmsg.FlagImportant // bit 3: sender marks as important +fmsg.FlagNoReply // bit 4: sender will discard replies +fmsg.FlagDeflate // bit 5: body is zlib-deflate compressed +``` + +## Usage + +### Encode a header to wire format + +```go +h := &fmsg.Header{ + Version: 1, + From: fmsg.Address{User: "alice", Domain: "example.com"}, + To: []fmsg.Address{{User: "bob", Domain: "example.com"}}, + Timestamp: float64(time.Now().UnixMicro()) / 1e6, + Topic: "hello", + Type: "text/plain", + Size: uint32(len(body)), +} +wire := h.Encode() +``` + +### Hash a message (for verification or storage) + +```go +// Set h.Filepath (and each attachment's Filepath) before calling. +hash, err := h.GetMessageHash() +``` + +### Look up a common media type + +```go +mtype, ok := fmsg.GetCommonMediaType(id) // ID → "text/plain" +id, ok := fmsg.GetCommonMediaTypeID(mt) // "text/plain" → ID +``` diff --git a/pkg/protocol/protocol.go b/pkg/fmsg/fmsg.go similarity index 90% rename from pkg/protocol/protocol.go rename to pkg/fmsg/fmsg.go index 3ce3152..69c1ebc 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/fmsg/fmsg.go @@ -1,12 +1,10 @@ -// Package protocol implements the fmsg wire protocol encoding and hashing. +// Package fmsg defines the fmsg message types and implements wire-format +// encoding and hashing as specified in SPEC.md. // -// It provides the core types and methods needed to build, encode, and hash -// fmsg messages as defined in the fmsg protocol specification (SPEC.md). -// -// To compute a message hash from database fields, populate an [FMsgHeader] +// To compute a message hash from database fields, populate a [Header] // (including Filepath and per-attachment Filepath values), then call -// [FMsgHeader.GetMessageHash]. -package protocol +// [Header.GetMessageHash]. +package fmsg import ( "bytes" @@ -31,19 +29,19 @@ const ( FlagDeflate uint8 = 1 << 5 // bit 5: message body is zlib-deflate compressed ) -// FMsgAddress is an fmsg address of the form @user@domain. -type FMsgAddress struct { +// Address is an fmsg address of the form @user@domain. +type Address struct { User string Domain string } // ToString returns the address in @user@domain form. -func (addr *FMsgAddress) ToString() string { +func (addr *Address) ToString() string { return fmt.Sprintf("@%s@%s", addr.User, addr.Domain) } -// FMsgAttachmentHeader holds the wire-level metadata for a single attachment. -type FMsgAttachmentHeader struct { +// AttachmentHeader holds the wire-level metadata for a single attachment. +type AttachmentHeader struct { Flags uint8 TypeID uint8 Type string @@ -54,19 +52,19 @@ type FMsgAttachmentHeader struct { Filepath string // path to attachment data on disk } -// FMsgHeader holds all fields of an fmsg message header. +// Header holds all fields of an fmsg message header. // // Fields ChallengeHash, ChallengeCompleted, and InitialResponseCode are // fmsgd server-runtime state; they are not part of the wire format and can // be ignored by other consumers of this package. -type FMsgHeader struct { +type Header struct { Version uint8 Flags uint8 Pid []byte - From FMsgAddress - To []FMsgAddress - AddToFrom *FMsgAddress // present when FlagHasAddTo is set - AddTo []FMsgAddress + From Address + To []Address + AddToFrom *Address // present when FlagHasAddTo is set + AddTo []Address Timestamp float64 TypeID uint8 Topic string @@ -74,7 +72,7 @@ type FMsgHeader struct { Size uint32 // wire (possibly compressed) size of the message body ExpandedSize uint32 // decompressed size; present on wire iff FlagDeflate set - Attachments []FMsgAttachmentHeader + Attachments []AttachmentHeader HeaderHash []byte ChallengeHash [32]byte // fmsgd server field: challenge response hash @@ -88,7 +86,7 @@ type FMsgHeader struct { // The returned bytes cover all fields up to and including the attachment // headers (fields 1–12 per spec). This method panics on internal buffer errors // rather than returning an error. -func (h *FMsgHeader) Encode() []byte { +func (h *Header) Encode() []byte { var b bytes.Buffer b.WriteByte(h.Version) b.WriteByte(h.Flags) @@ -183,7 +181,7 @@ func (h *FMsgHeader) Encode() []byte { } // String returns a human-readable summary of the header fields. -func (h *FMsgHeader) String() string { +func (h *Header) String() string { var b strings.Builder fmt.Fprintf(&b, "v%d flags=%d", h.Version, h.Flags) if len(h.Pid) > 0 { @@ -212,7 +210,7 @@ func (h *FMsgHeader) String() string { // GetHeaderHash returns the SHA-256 hash of the encoded header (fields 1–12). // The result is cached after the first call. -func (h *FMsgHeader) GetHeaderHash() []byte { +func (h *Header) GetHeaderHash() []byte { if h.HeaderHash == nil { b := sha256.Sum256(h.Encode()) h.HeaderHash = b[:] @@ -223,7 +221,7 @@ func (h *FMsgHeader) GetHeaderHash() []byte { // GetMessageHash returns the SHA-256 hash of the full message: // encoded header + decompressed message body + decompressed attachment data. // The result is cached after the first call. -func (h *FMsgHeader) GetMessageHash() ([]byte, error) { +func (h *Header) GetMessageHash() ([]byte, error) { if h.messageHash == nil { hash := sha256.New() diff --git a/pkg/fmsg/fmsg_test.go b/pkg/fmsg/fmsg_test.go new file mode 100644 index 0000000..d263392 --- /dev/null +++ b/pkg/fmsg/fmsg_test.go @@ -0,0 +1,547 @@ +package fmsg_test + +import ( + "bytes" + "compress/zlib" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "os" + "testing" + + "github.com/markmnl/fmsgd/pkg/fmsg" +) + +// ── Flag constants ──────────────────────────────────────────────────────────── + +func TestFlagConstants(t *testing.T) { + tests := []struct { + name string + got uint8 + want uint8 + }{ + {"FlagHasPid", fmsg.FlagHasPid, 0x01}, + {"FlagHasAddTo", fmsg.FlagHasAddTo, 0x02}, + {"FlagCommonType", fmsg.FlagCommonType, 0x04}, + {"FlagImportant", fmsg.FlagImportant, 0x08}, + {"FlagNoReply", fmsg.FlagNoReply, 0x10}, + {"FlagDeflate", fmsg.FlagDeflate, 0x20}, + } + for _, tt := range tests { + if tt.got != tt.want { + t.Errorf("%s = 0x%02X, want 0x%02X", tt.name, tt.got, tt.want) + } + } +} + +// ── Address ─────────────────────────────────────────────────────────────────── + +func TestAddressToString(t *testing.T) { + tests := []struct { + addr fmsg.Address + want string + }{ + {fmsg.Address{User: "alice", Domain: "example.com"}, "@alice@example.com"}, + {fmsg.Address{User: "bob", Domain: "b.io"}, "@bob@b.io"}, + {fmsg.Address{User: "x", Domain: "y.z"}, "@x@y.z"}, + } + for _, tt := range tests { + if got := tt.addr.ToString(); got != tt.want { + t.Errorf("ToString() = %q, want %q", got, tt.want) + } + } +} + +// ── Encode ──────────────────────────────────────────────────────────────────── + +// buildExpectedWire manually constructs the expected wire bytes for a minimal +// header: version=1, flags=0, from=@alice@example.com, to=[@bob@example.com], +// timestamp=0, topic="hi", type="text/plain", size=12, no attachments. +// This is the ground-truth used by TestHeaderEncode. +func buildExpectedWire() []byte { + var b bytes.Buffer + b.WriteByte(1) // version + b.WriteByte(0) // flags + from := "@alice@example.com" + b.WriteByte(byte(len(from))) + b.WriteString(from) + b.WriteByte(1) // to count + to := "@bob@example.com" + b.WriteByte(byte(len(to))) + b.WriteString(to) + _ = binary.Write(&b, binary.LittleEndian, float64(0)) // timestamp + b.WriteByte(byte(len("hi"))) + b.WriteString("hi") + b.WriteByte(byte(len("text/plain"))) + b.WriteString("text/plain") + _ = binary.Write(&b, binary.LittleEndian, uint32(12)) // size + b.WriteByte(0) // attachment count + return b.Bytes() +} + +func TestHeaderEncode(t *testing.T) { + h := &fmsg.Header{ + Version: 1, + Flags: 0, + From: fmsg.Address{User: "alice", Domain: "example.com"}, + To: []fmsg.Address{{User: "bob", Domain: "example.com"}}, + Timestamp: 0, + Topic: "hi", + Type: "text/plain", + Size: 12, + } + got := h.Encode() + want := buildExpectedWire() + if !bytes.Equal(got, want) { + t.Errorf("Encode():\n got %x\n want %x", got, want) + } +} + +func TestHeaderEncodeHasPid(t *testing.T) { + // When FlagHasPid is set: 32 pid bytes written at offset 2; topic absent. + pid := make([]byte, 32) + for i := range pid { + pid[i] = byte(i) + } + h := &fmsg.Header{ + Version: 1, + Flags: fmsg.FlagHasPid, + Pid: pid, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Timestamp: 0, + Topic: "should-not-appear", + Type: "text/plain", + Size: 0, + } + wire := h.Encode() + if !bytes.Equal(wire[2:34], pid) { + t.Error("Encode() with FlagHasPid: pid bytes not at wire[2:34]") + } + if bytes.Contains(wire, []byte("should-not-appear")) { + t.Error("Encode() with FlagHasPid: topic must be absent from wire") + } +} + +func TestHeaderEncodeDeflate(t *testing.T) { + // When FlagDeflate is set, ExpandedSize uint32 follows Size. + h := &fmsg.Header{ + Version: 1, + Flags: fmsg.FlagDeflate, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Timestamp: 0, + Topic: "t", + Type: "application/octet-stream", + Size: 100, + ExpandedSize: 9999, + } + wire := h.Encode() + // Locate size bytes: offset = 1+1 + 1+7 + 1 + 1+7 + 8 + 1+1 + 1+24 = 55 + // version+flags = 2 + // fromlen+"@a@b.io" = 1+7 = 8; total 10 + // tocount = 1; total 11 + // tolen+"@c@d.io" = 1+7 = 8; total 19 + // timestamp = 8; total 27 + // topiclen+"t" = 1+1 = 2; total 29 + // typelen+"application/octet-stream" = 1+24 = 25; total 54 + // size uint32 = 4 bytes at [54:58] + // expanded size uint32 = 4 bytes at [58:62] + var size, expanded uint32 + _ = binary.Read(bytes.NewReader(wire[54:58]), binary.LittleEndian, &size) + _ = binary.Read(bytes.NewReader(wire[58:62]), binary.LittleEndian, &expanded) + if size != 100 { + t.Errorf("Encode() DeflateFlag: Size = %d, want 100", size) + } + if expanded != 9999 { + t.Errorf("Encode() DeflateFlag: ExpandedSize = %d, want 9999", expanded) + } +} + +func TestHeaderEncodeCommonType(t *testing.T) { + // When FlagCommonType is set, TypeID byte is written instead of type string. + // "text/csv" is ID 50. + h := &fmsg.Header{ + Version: 1, + Flags: fmsg.FlagCommonType, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Timestamp: 0, + Topic: "t", + TypeID: 50, + Type: "text/csv", + Size: 0, + } + wire := h.Encode() + if bytes.Contains(wire, []byte("text/csv")) { + t.Error("Encode() with FlagCommonType: type string must not appear in wire") + } + // TypeID byte is at offset: 2+8+1+8+8+2 = 29 + // (version+flags) + (fromlen+from) + tocount + (tolen+to) + timestamp + (topiclen+topic) + if wire[29] != 50 { + t.Errorf("Encode() with FlagCommonType: type byte at [29] = %d, want 50", wire[29]) + } +} + +func TestHeaderEncodeAttachment(t *testing.T) { + h := &fmsg.Header{ + Version: 1, + Flags: 0, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Timestamp: 0, + Topic: "t", + Type: "text/plain", + Size: 5, + Attachments: []fmsg.AttachmentHeader{ + { + Flags: 0, + Type: "image/png", + Filename: "pic.png", + Size: 1024, + }, + }, + } + wire := h.Encode() + if !bytes.Contains(wire, []byte("pic.png")) { + t.Error("Encode(): attachment filename not in wire") + } + if !bytes.Contains(wire, []byte("image/png")) { + t.Error("Encode(): attachment type not in wire") + } +} + +// ── GetHeaderHash ───────────────────────────────────────────────────────────── + +func TestGetHeaderHash(t *testing.T) { + h := &fmsg.Header{ + Version: 1, + Flags: 0, + From: fmsg.Address{User: "alice", Domain: "example.com"}, + To: []fmsg.Address{{User: "bob", Domain: "example.com"}}, + Timestamp: 0, + Topic: "hi", + Type: "text/plain", + Size: 12, + } + got := h.GetHeaderHash() + want := sha256.Sum256(buildExpectedWire()) + if !bytes.Equal(got, want[:]) { + t.Errorf("GetHeaderHash() = %x, want %x", got, want) + } +} + +func TestGetHeaderHashCached(t *testing.T) { + h := &fmsg.Header{ + Version: 1, Flags: 0, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Type: "text/plain", Size: 0, + } + h1 := h.GetHeaderHash() + h2 := h.GetHeaderHash() + if &h1[0] != &h2[0] { + t.Error("GetHeaderHash() should return the same slice on repeated calls (cached)") + } +} + +// ── GetMessageHash ──────────────────────────────────────────────────────────── + +// TestGetMessageHashSmall verifies a complete small-message hash against an +// independently computed expected value: sha256(encoded_header || body). +func TestGetMessageHashSmall(t *testing.T) { + const body = "Hello, fmsg!" + + f, err := os.CreateTemp(t.TempDir(), "fmsg-body-*") + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteString(body); err != nil { + t.Fatal(err) + } + f.Close() + + h := &fmsg.Header{ + Version: 1, + Flags: 0, + From: fmsg.Address{User: "alice", Domain: "example.com"}, + To: []fmsg.Address{{User: "bob", Domain: "example.com"}}, + Timestamp: 0, + Topic: "hi", + Type: "text/plain", + Size: uint32(len(body)), + Filepath: f.Name(), + } + + got, err := h.GetMessageHash() + if err != nil { + t.Fatalf("GetMessageHash() error: %v", err) + } + + // Expected: sha256(header_wire || body_bytes) computed independently. + wantHash := sha256.New() + wantHash.Write(buildExpectedWireWithSize(uint32(len(body)))) + wantHash.Write([]byte(body)) + want := wantHash.Sum(nil) + + if !bytes.Equal(got, want) { + t.Errorf("GetMessageHash():\n got %s\n want %s", hex.EncodeToString(got), hex.EncodeToString(want)) + } + + // Golden hash guards against regressions in Encode() or the hash algorithm. + const wantGolden = "eaefdd1cf1868078ff38ba882cbd31a2297af3db33cec888bd3e088bafdbcc3b" + if hex.EncodeToString(got) != wantGolden { + t.Errorf("GetMessageHash() golden:\n got %s\n want %s", hex.EncodeToString(got), wantGolden) + } +} + +// buildExpectedWireWithSize is like buildExpectedWire but uses a variable body size. +func buildExpectedWireWithSize(size uint32) []byte { + var b bytes.Buffer + b.WriteByte(1) // version + b.WriteByte(0) // flags + from := "@alice@example.com" + b.WriteByte(byte(len(from))) + b.WriteString(from) + b.WriteByte(1) + to := "@bob@example.com" + b.WriteByte(byte(len(to))) + b.WriteString(to) + _ = binary.Write(&b, binary.LittleEndian, float64(0)) + b.WriteByte(byte(len("hi"))) + b.WriteString("hi") + b.WriteByte(byte(len("text/plain"))) + b.WriteString("text/plain") + _ = binary.Write(&b, binary.LittleEndian, size) + b.WriteByte(0) + return b.Bytes() +} + +func TestGetMessageHashCached(t *testing.T) { + const body = "cached" + f, err := os.CreateTemp(t.TempDir(), "fmsg-body-*") + if err != nil { + t.Fatal(err) + } + f.WriteString(body) + f.Close() + + h := &fmsg.Header{ + Version: 1, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Type: "text/plain", + Size: uint32(len(body)), + Filepath: f.Name(), + } + h1, err := h.GetMessageHash() + if err != nil { + t.Fatal(err) + } + h2, err := h.GetMessageHash() + if err != nil { + t.Fatal(err) + } + if &h1[0] != &h2[0] { + t.Error("GetMessageHash() should return the same slice on repeated calls (cached)") + } +} + +func TestGetMessageHashWithAttachment(t *testing.T) { + const body = "body data" + const attData = "attachment data" + + bodyFile, err := os.CreateTemp(t.TempDir(), "fmsg-body-*") + if err != nil { + t.Fatal(err) + } + bodyFile.WriteString(body) + bodyFile.Close() + + attFile, err := os.CreateTemp(t.TempDir(), "fmsg-att-*") + if err != nil { + t.Fatal(err) + } + attFile.WriteString(attData) + attFile.Close() + + h := &fmsg.Header{ + Version: 1, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Type: "text/plain", + Size: uint32(len(body)), + Filepath: bodyFile.Name(), + Attachments: []fmsg.AttachmentHeader{ + { + Flags: 0, + Type: "image/png", + Filename: "pic.png", + Size: uint32(len(attData)), + Filepath: attFile.Name(), + }, + }, + } + got, err := h.GetMessageHash() + if err != nil { + t.Fatalf("GetMessageHash() with attachment error: %v", err) + } + + // Expected: sha256(header_wire || body || attachment_data) + hw := sha256.New() + hw.Write(h.Encode()) + hw.Write([]byte(body)) + hw.Write([]byte(attData)) + want := hw.Sum(nil) + + if !bytes.Equal(got, want) { + t.Errorf("GetMessageHash() with attachment:\n got %x\n want %x", got, want) + } +} + +// ── HashPayload ─────────────────────────────────────────────────────────────── + +func TestHashPayloadPlain(t *testing.T) { + content := []byte("payload bytes") + f, err := os.CreateTemp(t.TempDir(), "fmsg-payload-*") + if err != nil { + t.Fatal(err) + } + f.Write(content) + f.Close() + + var dst bytes.Buffer + if err := fmsg.HashPayload(&dst, f.Name(), int64(len(content)), false, 0); err != nil { + t.Fatalf("HashPayload() error: %v", err) + } + if !bytes.Equal(dst.Bytes(), content) { + t.Errorf("HashPayload() wrote %x, want %x", dst.Bytes(), content) + } +} + +func TestHashPayloadDeflated(t *testing.T) { + plain := []byte("hello compressed world") + + // Write zlib-compressed content to a temp file. + f, err := os.CreateTemp(t.TempDir(), "fmsg-deflated-*") + if err != nil { + t.Fatal(err) + } + zw := zlib.NewWriter(f) + zw.Write(plain) + zw.Close() + wireSize, _ := f.Seek(0, 1) // current offset = compressed size + f.Close() + + var dst bytes.Buffer + if err := fmsg.HashPayload(&dst, f.Name(), wireSize, true, uint32(len(plain))); err != nil { + t.Fatalf("HashPayload() deflated error: %v", err) + } + if !bytes.Equal(dst.Bytes(), plain) { + t.Errorf("HashPayload() deflated wrote %q, want %q", dst.Bytes(), plain) + } +} + +func TestHashPayloadDeflatedSizeMismatch(t *testing.T) { + plain := []byte("data") + f, err := os.CreateTemp(t.TempDir(), "fmsg-deflated-*") + if err != nil { + t.Fatal(err) + } + zw := zlib.NewWriter(f) + zw.Write(plain) + zw.Close() + wireSize, _ := f.Seek(0, 1) + f.Close() + + var dst bytes.Buffer + err = fmsg.HashPayload(&dst, f.Name(), wireSize, true, uint32(len(plain))+99) + if err == nil { + t.Error("HashPayload() should error when expanded size does not match") + } +} + +// ── Common media types ──────────────────────────────────────────────────────── + +func TestGetCommonMediaType(t *testing.T) { + tests := []struct { + id uint8 + want string + }{ + {3, "application/json"}, + {6, "application/pdf"}, + {38, "image/png"}, + {50, "text/csv"}, + {56, "text/plain;charset=UTF-8"}, + {64, "video/webm"}, + } + for _, tt := range tests { + got, ok := fmsg.GetCommonMediaType(tt.id) + if !ok { + t.Errorf("GetCommonMediaType(%d): not found", tt.id) + } + if got != tt.want { + t.Errorf("GetCommonMediaType(%d) = %q, want %q", tt.id, got, tt.want) + } + } +} + +func TestGetCommonMediaTypeUnknown(t *testing.T) { + _, ok := fmsg.GetCommonMediaType(0) + if ok { + t.Error("GetCommonMediaType(0) should return false") + } + _, ok = fmsg.GetCommonMediaType(65) + if ok { + t.Error("GetCommonMediaType(65) should return false") + } +} + +func TestGetCommonMediaTypeID(t *testing.T) { + tests := []struct { + mime string + want uint8 + }{ + {"application/json", 3}, + {"application/pdf", 6}, + {"image/png", 38}, + {"text/csv", 50}, + {"text/plain;charset=UTF-8", 56}, + {"video/webm", 64}, + } + for _, tt := range tests { + got, ok := fmsg.GetCommonMediaTypeID(tt.mime) + if !ok { + t.Errorf("GetCommonMediaTypeID(%q): not found", tt.mime) + } + if got != tt.want { + t.Errorf("GetCommonMediaTypeID(%q) = %d, want %d", tt.mime, got, tt.want) + } + } +} + +func TestGetCommonMediaTypeIDUnknown(t *testing.T) { + _, ok := fmsg.GetCommonMediaTypeID("application/unknown") + if ok { + t.Error("GetCommonMediaTypeID(unknown) should return false") + } +} + +func TestCommonMediaTypeRoundTrip(t *testing.T) { + // Every ID in the valid range 1–64 must round-trip: ID → string → ID. + for id := uint8(1); id <= 64; id++ { + mime, ok := fmsg.GetCommonMediaType(id) + if !ok { + t.Errorf("GetCommonMediaType(%d): not found", id) + continue + } + got, ok := fmsg.GetCommonMediaTypeID(mime) + if !ok { + t.Errorf("GetCommonMediaTypeID(%q): not found (from ID %d)", mime, id) + continue + } + if got != id { + t.Errorf("round-trip ID %d → %q → %d", id, mime, got) + } + } +} diff --git a/pkg/protocol/README.md b/pkg/protocol/README.md deleted file mode 100644 index 2219762..0000000 --- a/pkg/protocol/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# pkg/protocol - -This package implements the fmsg wire protocol encoding and hashing as defined in [SPEC.md](../../SPEC.md). - -## Import - -```go -import "github.com/markmnl/fmsgd/pkg/protocol" -``` - -## What it provides - -- **Types**: `FMsgHeader`, `FMsgAddress`, `FMsgAttachmentHeader` -- **Flag constants**: `FlagHasPid`, `FlagHasAddTo`, `FlagCommonType`, `FlagImportant`, `FlagNoReply`, `FlagDeflate` -- **Wire encoding**: `FMsgHeader.Encode()` — serialises a header to the exact byte sequence defined in SPEC.md -- **Hashing**: `FMsgHeader.GetHeaderHash()` — SHA-256 of the encoded header; `FMsgHeader.GetMessageHash()` — SHA-256 of header + decompressed body + decompressed attachments -- **Common media type lookup**: `GetCommonMediaType(id)`, `GetCommonMediaTypeID(mimeType)` - -## Example: compute a message hash - -```go -h := &protocol.FMsgHeader{ - Version: 1, - Flags: 0, - From: protocol.FMsgAddress{User: "alice", Domain: "example.com"}, - To: []protocol.FMsgAddress{{User: "bob", Domain: "other.com"}}, - Timestamp: 1700000000.0, - Topic: "hello", - Type: "text/plain", - Size: 5, - Filepath: "/path/to/body/file", -} -hash, err := h.GetMessageHash() -``` - -## Notes - -- `Encode()` produces fields 1–12 of the fmsg wire format (header through attachment headers). Message data and attachment data follow separately on the wire. -- `GetMessageHash()` hashes over **decompressed** data even when the stored file is zlib-compressed (`FlagDeflate` set). Set `ExpandedSize` accordingly. -- `Filepath` and `FMsgAttachmentHeader.Filepath` must point to readable files on disk for `GetMessageHash()` to succeed. From 67388ac7e9c5ad999e11554733e85f26e85007eb Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Sat, 16 May 2026 13:51:14 +0800 Subject: [PATCH 4/4] unignore fmsgd --- .gitignore | 2 +- cmd/fmsgd/deflate.go | 157 ++++ cmd/fmsgd/deflate_test.go | 549 ++++++++++++ cmd/fmsgd/defs.go | 18 + cmd/fmsgd/defs_test.go | 856 ++++++++++++++++++ cmd/fmsgd/dns.go | 148 ++++ cmd/fmsgd/host.go | 1688 ++++++++++++++++++++++++++++++++++++ cmd/fmsgd/host_test.go | 622 +++++++++++++ cmd/fmsgd/id.go | 92 ++ cmd/fmsgd/outgoing.go | 66 ++ cmd/fmsgd/outgoing_test.go | 142 +++ cmd/fmsgd/sender.go | 628 ++++++++++++++ cmd/fmsgd/store.go | 592 +++++++++++++ cmd/fmsgd/store_test.go | 139 +++ 14 files changed, 5698 insertions(+), 1 deletion(-) create mode 100644 cmd/fmsgd/deflate.go create mode 100644 cmd/fmsgd/deflate_test.go create mode 100644 cmd/fmsgd/defs.go create mode 100644 cmd/fmsgd/defs_test.go create mode 100644 cmd/fmsgd/dns.go create mode 100644 cmd/fmsgd/host.go create mode 100644 cmd/fmsgd/host_test.go create mode 100644 cmd/fmsgd/id.go create mode 100644 cmd/fmsgd/outgoing.go create mode 100644 cmd/fmsgd/outgoing_test.go create mode 100644 cmd/fmsgd/sender.go create mode 100644 cmd/fmsgd/store.go create mode 100644 cmd/fmsgd/store_test.go diff --git a/.gitignore b/.gitignore index 0077797..3c19fc4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ .claude/ *.exe .env -fmsgd \ No newline at end of file +/cmd/fmsgd/fmsgd \ No newline at end of file diff --git a/cmd/fmsgd/deflate.go b/cmd/fmsgd/deflate.go new file mode 100644 index 0000000..68b85a7 --- /dev/null +++ b/cmd/fmsgd/deflate.go @@ -0,0 +1,157 @@ +package main + +import ( + "bytes" + "compress/zlib" + "io" + "os" + "strings" +) + +// minDeflateSize is the minimum payload size in bytes before compression is +// attempted. +const minDeflateSize uint32 = 512 + +// incompressibleTypes lists media types (lowercased, without parameters) that +// are already compressed or otherwise unlikely to benefit from zlib-deflate. +var incompressibleTypes = map[string]bool{ + // images + "image/jpeg": true, "image/png": true, "image/gif": true, + "image/webp": true, "image/heic": true, "image/avif": true, + "image/apng": true, + // audio + "audio/aac": true, "audio/mpeg": true, "audio/ogg": true, + "audio/opus": true, "audio/webm": true, + // video + "video/h264": true, "video/h265": true, "video/h266": true, + "video/ogg": true, "video/vp8": true, "video/vp9": true, + "video/webm": true, + // archives / compressed containers + "application/gzip": true, "application/zip": true, + "application/epub+zip": true, + "application/octet-stream": true, + // zip-based office formats + "application/vnd.oasis.opendocument.presentation": true, + "application/vnd.oasis.opendocument.spreadsheet": true, + "application/vnd.oasis.opendocument.text": true, + "application/vnd.openxmlformats-officedocument.presentationml.presentation": true, + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true, + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": true, + "application/vnd.amazon.ebook": true, + // fonts (compressed) + "font/woff": true, "font/woff2": true, + // pdf (internally compressed) + "application/pdf": true, + // 3d models (compressed containers) + "model/3mf": true, "model/gltf-binary": true, + "model/vnd.usdz+zip": true, +} + +// shouldCompress reports whether compression should be attempted for a payload +// with the given media type and size. It returns false for payloads that are +// too small or use a media type known to be already compressed. +func shouldCompress(mediaType string, dataSize uint32) bool { + if dataSize < minDeflateSize { + return false + } + t := strings.ToLower(mediaType) + if i := strings.IndexByte(t, ';'); i >= 0 { + t = strings.TrimRight(t[:i], " ") + } + return !incompressibleTypes[t] +} + +// deflateSampleSize is the number of bytes sampled from the start of a file +// to estimate compressibility before committing to a full-file compression +// pass. Chosen large enough for zlib to find patterns but small enough to be +// fast even on very large files. +const deflateSampleSize = 8192 + +// probeSample compresses up to deflateSampleSize bytes from the start of src +// and reports whether the ratio looks promising (compressed < 80% of input). +// src is seeked back to the start on return. +func probeSample(src *os.File, srcSize uint32) (bool, error) { + sampleLen := int64(deflateSampleSize) + if int64(srcSize) < sampleLen { + sampleLen = int64(srcSize) + } + + var buf bytes.Buffer + zw := zlib.NewWriter(&buf) + if _, err := io.CopyN(zw, src, sampleLen); err != nil { + _ = zw.Close() + return false, err + } + if err := zw.Close(); err != nil { + return false, err + } + + if _, err := src.Seek(0, io.SeekStart); err != nil { + return false, err + } + + return int64(buf.Len()) < sampleLen*8/10, nil +} + +// tryCompress compresses the file at srcPath using zlib-deflate and writes the +// result to a temporary file. For files larger than deflateSampleSize it first +// compresses a prefix sample to estimate compressibility, avoiding a full pass +// over files that won't compress well. It returns worthwhile=true only when +// the compressed output is less than 80% of the original size (at least a 20% +// reduction). When not worthwhile the temporary file is removed. When +// worthwhile the caller is responsible for removing the file at dstPath. +func tryCompress(srcPath string, srcSize uint32) (dstPath string, compressedSize uint32, worthwhile bool, err error) { + src, err := os.Open(srcPath) + if err != nil { + return "", 0, false, err + } + defer src.Close() + + // For files larger than the sample size, probe a prefix first. + if srcSize > deflateSampleSize { + promising, err := probeSample(src, srcSize) + if err != nil { + return "", 0, false, err + } + if !promising { + return "", 0, false, nil + } + } + + dst, err := os.CreateTemp("", "fmsg-deflate-*") + if err != nil { + return "", 0, false, err + } + dstName := dst.Name() + + zw := zlib.NewWriter(dst) + if _, err := io.Copy(zw, src); err != nil { + _ = zw.Close() + _ = dst.Close() + _ = os.Remove(dstName) + return "", 0, false, err + } + if err := zw.Close(); err != nil { + _ = dst.Close() + _ = os.Remove(dstName) + return "", 0, false, err + } + if err := dst.Close(); err != nil { + _ = os.Remove(dstName) + return "", 0, false, err + } + + fi, err := os.Stat(dstName) + if err != nil { + _ = os.Remove(dstName) + return "", 0, false, err + } + + cSize := uint32(fi.Size()) + if cSize >= srcSize*8/10 { + _ = os.Remove(dstName) + return "", 0, false, nil + } + + return dstName, cSize, true, nil +} diff --git a/cmd/fmsgd/deflate_test.go b/cmd/fmsgd/deflate_test.go new file mode 100644 index 0000000..7c10c34 --- /dev/null +++ b/cmd/fmsgd/deflate_test.go @@ -0,0 +1,549 @@ +package main + +import ( + "bytes" + "compress/zlib" + "crypto/rand" + "crypto/sha256" + "io" + "os" + "strings" + "testing" +) + +// --- shouldDeflate tests --- + +func TestShouldDeflate_TextTypes(t *testing.T) { + compressible := []string{ + "text/plain;charset=UTF-8", + "text/html", + "text/markdown", + "text/csv", + "text/css", + "text/javascript", + "text/calendar", + "text/vcard", + "text/plain;charset=US-ASCII", + "text/plain;charset=UTF-16", + "application/json", + "application/xml", + "application/xhtml+xml", + "application/rtf", + "application/x-tar", + "application/msword", + "application/vnd.ms-excel", + "application/vnd.ms-powerpoint", + "image/svg+xml", + "audio/midi", + "model/obj", + "model/step", + "model/stl", + } + for _, mt := range compressible { + if !shouldCompress(mt, 1024) { + t.Errorf("shouldCompress(%q, 1024) = false, want true", mt) + } + } +} + +func TestShouldDeflate_IncompressibleTypes(t *testing.T) { + skip := []string{ + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + "image/heic", + "image/avif", + "image/apng", + "audio/aac", + "audio/mpeg", + "audio/ogg", + "audio/opus", + "audio/webm", + "video/H264", + "video/H265", + "video/H266", + "video/ogg", + "video/VP8", + "video/VP9", + "video/webm", + "application/gzip", + "application/zip", + "application/epub+zip", + "application/octet-stream", + "application/pdf", + "application/vnd.oasis.opendocument.presentation", + "application/vnd.oasis.opendocument.spreadsheet", + "application/vnd.oasis.opendocument.text", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.amazon.ebook", + "font/woff", + "font/woff2", + "model/3mf", + "model/gltf-binary", + "model/vnd.usdz+zip", + } + for _, mt := range skip { + if shouldCompress(mt, 1024) { + t.Errorf("shouldCompress(%q, 1024) = true, want false", mt) + } + } +} + +func TestShouldDeflate_SmallPayload(t *testing.T) { + sizes := []uint32{0, 1, 100, 511} + for _, sz := range sizes { + if shouldCompress("text/plain;charset=UTF-8", sz) { + t.Errorf("shouldCompress(text/plain, %d) = true, want false", sz) + } + } +} + +func TestShouldDeflate_EdgeCases(t *testing.T) { + // Exactly at threshold: should attempt + if !shouldCompress("text/plain;charset=UTF-8", 512) { + t.Error("shouldDeflate at threshold 512 should return true") + } + // Unknown type: default to try compression + if !shouldCompress("application/x-custom", 1024) { + t.Error("shouldDeflate for unknown type should return true") + } + // Type with parameters should match base type + if shouldCompress("application/pdf; charset=utf-8", 1024) { + t.Error("shouldDeflate should strip parameters and match application/pdf") + } + // Case insensitive + if shouldCompress("VIDEO/H264", 1024) { + t.Error("shouldDeflate should be case-insensitive") + } +} + +// --- tryDeflate tests --- + +func writeTempFile(t *testing.T, data []byte) string { + t.Helper() + f, err := os.CreateTemp("", "deflate-test-*") + if err != nil { + t.Fatal(err) + } + if _, err := f.Write(data); err != nil { + f.Close() + os.Remove(f.Name()) + t.Fatal(err) + } + f.Close() + return f.Name() +} + +func TestTryDeflate_CompressibleData(t *testing.T) { + original := []byte(strings.Repeat("hello world, this is compressible text data! ", 100)) + srcPath := writeTempFile(t, original) + defer os.Remove(srcPath) + + dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(original))) + if err != nil { + t.Fatal(err) + } + if !worthwhile { + t.Fatal("expected compression to be worthwhile for repetitive text") + } + defer os.Remove(dstPath) + + if cSize >= uint32(len(original))*8/10 { + t.Errorf("compressed size %d not < 80%% of original %d", cSize, len(original)) + } + + // Verify the compressed file decompresses to the original data + f, err := os.Open(dstPath) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + zr, err := zlib.NewReader(f) + if err != nil { + t.Fatal(err) + } + decompressed, err := io.ReadAll(zr) + zr.Close() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decompressed, original) { + t.Error("decompressed data does not match original") + } +} + +func TestTryDeflate_IncompressibleData(t *testing.T) { + // Random bytes are effectively incompressible + data := make([]byte, 2048) + if _, err := rand.Read(data); err != nil { + t.Fatal(err) + } + srcPath := writeTempFile(t, data) + defer os.Remove(srcPath) + + _, _, worthwhile, err := tryCompress(srcPath, uint32(len(data))) + if err != nil { + t.Fatal(err) + } + if worthwhile { + t.Error("expected compression of random data to not be worthwhile") + } +} + +func TestTryDeflate_RoundTrip(t *testing.T) { + original := []byte(strings.Repeat("Round-trip test data with enough repetition to compress well. ", 50)) + srcPath := writeTempFile(t, original) + defer os.Remove(srcPath) + + dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(original))) + if err != nil { + t.Fatal(err) + } + if !worthwhile { + t.Fatal("expected compression to be worthwhile") + } + defer os.Remove(dstPath) + + // Read compressed file + compressed, err := os.ReadFile(dstPath) + if err != nil { + t.Fatal(err) + } + if uint32(len(compressed)) != cSize { + t.Errorf("compressed file size %d != reported size %d", len(compressed), cSize) + } + + // Decompress and verify + zr, err := zlib.NewReader(bytes.NewReader(compressed)) + if err != nil { + t.Fatal(err) + } + decompressed, err := io.ReadAll(zr) + zr.Close() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decompressed, original) { + t.Errorf("round-trip mismatch: got %d bytes, want %d bytes", len(decompressed), len(original)) + } +} + +func TestTryDeflate_CleanupOnNotWorthwhile(t *testing.T) { + // Random data won't compress well — the temp file should be removed + data := make([]byte, 2048) + if _, err := rand.Read(data); err != nil { + t.Fatal(err) + } + srcPath := writeTempFile(t, data) + defer os.Remove(srcPath) + + dstPath, _, worthwhile, err := tryCompress(srcPath, uint32(len(data))) + if err != nil { + t.Fatal(err) + } + if worthwhile { + defer os.Remove(dstPath) + t.Fatal("expected not worthwhile for random data") + } + // dstPath should be empty and no leaked temp file + if dstPath != "" { + t.Errorf("expected empty dstPath when not worthwhile, got %q", dstPath) + } +} + +func TestTryDeflate_ProbeRejectsLargeIncompressible(t *testing.T) { + // A file larger than deflateSampleSize filled with random bytes should be + // rejected by the sample probe without writing a full compressed file. + data := make([]byte, deflateSampleSize+4096) + if _, err := rand.Read(data); err != nil { + t.Fatal(err) + } + srcPath := writeTempFile(t, data) + defer os.Remove(srcPath) + + _, _, worthwhile, err := tryCompress(srcPath, uint32(len(data))) + if err != nil { + t.Fatal(err) + } + if worthwhile { + t.Error("expected probe to reject large random data") + } +} + +func TestTryDeflate_ProbeAcceptsLargeCompressible(t *testing.T) { + // A file larger than deflateSampleSize filled with repetitive text should + // pass the probe and compress the full file successfully. + data := []byte(strings.Repeat("probe compressible test data! ", 1000)) + if len(data) <= deflateSampleSize { + t.Fatalf("test data %d bytes not larger than sample size %d", len(data), deflateSampleSize) + } + srcPath := writeTempFile(t, data) + defer os.Remove(srcPath) + + dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(data))) + if err != nil { + t.Fatal(err) + } + if !worthwhile { + t.Fatal("expected large compressible data to be worthwhile") + } + defer os.Remove(dstPath) + + if cSize >= uint32(len(data))*8/10 { + t.Errorf("compressed size %d not < 80%% of original %d", cSize, len(data)) + } + + // Verify round-trip + f, err := os.Open(dstPath) + if err != nil { + t.Fatal(err) + } + defer f.Close() + zr, err := zlib.NewReader(f) + if err != nil { + t.Fatal(err) + } + decompressed, err := io.ReadAll(zr) + zr.Close() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decompressed, data) { + t.Error("decompressed data does not match original") + } +} + +// --- Hash determinism tests --- + +func TestGetMessageHash_WithDeflate(t *testing.T) { + // Create repetitive data that compresses well + original := []byte(strings.Repeat("deflate hash test data ", 100)) + srcPath := writeTempFile(t, original) + defer os.Remove(srcPath) + + // Compress it + dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(original))) + if err != nil { + t.Fatal(err) + } + if !worthwhile { + t.Fatal("expected compression to be worthwhile") + } + defer os.Remove(dstPath) + + // Build header with deflate flag pointing at compressed file + h := &FMsgHeader{ + Version: 1, + Flags: FlagDeflate, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, + Topic: "test", + Type: "text/plain;charset=UTF-8", + Size: cSize, + ExpandedSize: uint32(len(original)), + Filepath: dstPath, + } + + msgHash, err := h.GetMessageHash() + if err != nil { + t.Fatal(err) + } + + // Manually compute expected: SHA-256(encoded header + decompressed data) + expected := sha256.New() + expected.Write(h.Encode()) + expected.Write(original) + expectedHash := expected.Sum(nil) + + if !bytes.Equal(msgHash, expectedHash) { + t.Errorf("hash mismatch:\n got %x\n want %x", msgHash, expectedHash) + } +} + +func TestGetMessageHash_WithoutDeflate(t *testing.T) { + original := []byte(strings.Repeat("no deflate hash test ", 100)) + srcPath := writeTempFile(t, original) + defer os.Remove(srcPath) + + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, + Topic: "test", + Type: "text/plain;charset=UTF-8", + Size: uint32(len(original)), + Filepath: srcPath, + } + + msgHash, err := h.GetMessageHash() + if err != nil { + t.Fatal(err) + } + + expected := sha256.New() + expected.Write(h.Encode()) + expected.Write(original) + expectedHash := expected.Sum(nil) + + if !bytes.Equal(msgHash, expectedHash) { + t.Errorf("hash mismatch:\n got %x\n want %x", msgHash, expectedHash) + } +} + +func TestGetMessageHash_DeflateChangesHash(t *testing.T) { + // The same data produces different message hashes depending on whether + // it is deflated, because the header bytes differ (flags and size fields). + original := []byte(strings.Repeat("deflate vs plain ", 100)) + srcPath := writeTempFile(t, original) + defer os.Remove(srcPath) + + dstPath, cSize, worthwhile, err := tryCompress(srcPath, uint32(len(original))) + if err != nil { + t.Fatal(err) + } + if !worthwhile { + t.Fatal("expected compression to be worthwhile") + } + defer os.Remove(dstPath) + + base := FMsgHeader{ + Version: 1, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, + Topic: "test", + Type: "text/plain;charset=UTF-8", + } + + // Hash without deflate + plain := base + plain.Flags = 0 + plain.Size = uint32(len(original)) + plain.Filepath = srcPath + hashPlain, err := plain.GetMessageHash() + if err != nil { + t.Fatal(err) + } + + // Hash with deflate + deflated := base + deflated.Flags = FlagDeflate + deflated.Size = cSize + deflated.ExpandedSize = uint32(len(original)) + deflated.Filepath = dstPath + hashDeflated, err := deflated.GetMessageHash() + if err != nil { + t.Fatal(err) + } + + if bytes.Equal(hashPlain, hashDeflated) { + t.Error("expected different hashes for deflated vs non-deflated wire representations") + } +} + +func TestGetMessageHash_AttachmentDeflate(t *testing.T) { + msgData := []byte("short message body that fits in a file") + msgPath := writeTempFile(t, msgData) + defer os.Remove(msgPath) + + attOriginal := []byte(strings.Repeat("attachment data for compression test ", 100)) + attSrcPath := writeTempFile(t, attOriginal) + defer os.Remove(attSrcPath) + + attDstPath, attCSize, worthwhile, err := tryCompress(attSrcPath, uint32(len(attOriginal))) + if err != nil { + t.Fatal(err) + } + if !worthwhile { + t.Fatal("expected attachment compression to be worthwhile") + } + defer os.Remove(attDstPath) + + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, + Topic: "test", + Type: "text/plain;charset=UTF-8", + Size: uint32(len(msgData)), + Filepath: msgPath, + Attachments: []FMsgAttachmentHeader{ + { + Flags: 1 << 1, // attachment deflate bit + Type: "text/csv", + Filename: "data.csv", + Size: attCSize, + ExpandedSize: uint32(len(attOriginal)), + Filepath: attDstPath, + }, + }, + } + + msgHash, err := h.GetMessageHash() + if err != nil { + t.Fatal(err) + } + + // Manually compute: SHA-256(header + msg data + decompressed attachment) + expected := sha256.New() + expected.Write(h.Encode()) + expected.Write(msgData) + expected.Write(attOriginal) + expectedHash := expected.Sum(nil) + + if !bytes.Equal(msgHash, expectedHash) { + t.Errorf("attachment hash mismatch:\n got %x\n want %x", msgHash, expectedHash) + } +} + +// --- Encode flag tests --- + +func TestEncode_DeflateFlag(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: FlagDeflate, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, + Topic: "test", + Type: "text/plain;charset=UTF-8", + } + b := h.Encode() + if b[1]&FlagDeflate == 0 { + t.Error("deflate flag bit (5) not set in encoded header flags byte") + } +} + +func TestEncode_AttachmentDeflateFlag(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, + Topic: "test", + Type: "text/plain;charset=UTF-8", + Attachments: []FMsgAttachmentHeader{ + {Flags: 1 << 1, Type: "text/plain", Filename: "test.txt", Size: 100}, + }, + } + b := h.Encode() + // The encoded header ends with attachment headers. Find the attachment + // flags byte: it's the first byte after the attachment count byte. + // The attachment count is at len(b) - (1 + 1 + len("text/plain") + 1 + len("test.txt") + 4) - 1 + // Simpler: just verify the flags byte value appears in the output. + // The attachment count byte (1) followed by attachment flags byte (0x02). + found := false + for i := 0; i < len(b)-1; i++ { + if b[i] == 1 && b[i+1] == (1<<1) { // count=1, flags=0x02 + found = true + break + } + } + if !found { + t.Error("attachment deflate flag bit (1) not found in encoded header") + } +} diff --git a/cmd/fmsgd/defs.go b/cmd/fmsgd/defs.go new file mode 100644 index 0000000..cf18055 --- /dev/null +++ b/cmd/fmsgd/defs.go @@ -0,0 +1,18 @@ +package main + +import "github.com/markmnl/fmsgd/pkg/fmsg" + +// Type aliases so all internal code continues to use unqualified names. +type FMsgAddress = fmsg.Address +type FMsgAttachmentHeader = fmsg.AttachmentHeader +type FMsgHeader = fmsg.Header + +// Flag constants forwarded from the fmsg package. +const ( + FlagHasPid = fmsg.FlagHasPid + FlagHasAddTo = fmsg.FlagHasAddTo + FlagCommonType = fmsg.FlagCommonType + FlagImportant = fmsg.FlagImportant + FlagNoReply = fmsg.FlagNoReply + FlagDeflate = fmsg.FlagDeflate +) diff --git a/cmd/fmsgd/defs_test.go b/cmd/fmsgd/defs_test.go new file mode 100644 index 0000000..52a1a3a --- /dev/null +++ b/cmd/fmsgd/defs_test.go @@ -0,0 +1,856 @@ +package main + +import ( + "bytes" + "compress/zlib" + "crypto/sha256" + "encoding/binary" + "io" + "math" + "os" + "path/filepath" + "testing" + + "github.com/markmnl/fmsgd/pkg/fmsg" +) + +func TestAddressToString(t *testing.T) { + tests := []struct { + addr FMsgAddress + want string + }{ + {FMsgAddress{User: "alice", Domain: "example.com"}, "@alice@example.com"}, + {FMsgAddress{User: "Bob", Domain: "EXAMPLE.COM"}, "@Bob@EXAMPLE.COM"}, + {FMsgAddress{User: "a-b.c", Domain: "x.y.z"}, "@a-b.c@x.y.z"}, + } + for _, tt := range tests { + got := tt.addr.ToString() + if got != tt.want { + t.Errorf("FMsgAddress{%q, %q}.ToString() = %q, want %q", tt.addr.User, tt.addr.Domain, got, tt.want) + } + } +} + +func TestEncodeMinimalHeader(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "alice", Domain: "a.com"}, + To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, + Timestamp: 1700000000.0, + Topic: "hello", + Type: "text/plain", + } + b := h.Encode() + + r := bytes.NewReader(b) + + // version + ver, _ := r.ReadByte() + if ver != 1 { + t.Fatalf("version = %d, want 1", ver) + } + + // flags + flags, _ := r.ReadByte() + if flags != 0 { + t.Fatalf("flags = %d, want 0", flags) + } + + // from address + fromLen, _ := r.ReadByte() + fromBytes := make([]byte, fromLen) + r.Read(fromBytes) + if string(fromBytes) != "@alice@a.com" { + t.Fatalf("from = %q, want %q", string(fromBytes), "@alice@a.com") + } + + // to count + toCount, _ := r.ReadByte() + if toCount != 1 { + t.Fatalf("to count = %d, want 1", toCount) + } + + // to[0] + toLen, _ := r.ReadByte() + toBytes := make([]byte, toLen) + r.Read(toBytes) + if string(toBytes) != "@bob@b.com" { + t.Fatalf("to[0] = %q, want %q", string(toBytes), "@bob@b.com") + } + + // timestamp + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + if ts != 1700000000.0 { + t.Fatalf("timestamp = %f, want 1700000000.0", ts) + } + + // topic + topicLen, _ := r.ReadByte() + topicBytes := make([]byte, topicLen) + r.Read(topicBytes) + if string(topicBytes) != "hello" { + t.Fatalf("topic = %q, want %q", string(topicBytes), "hello") + } + + // type + typeLen, _ := r.ReadByte() + typeBytes := make([]byte, typeLen) + r.Read(typeBytes) + if string(typeBytes) != "text/plain" { + t.Fatalf("type = %q, want %q", string(typeBytes), "text/plain") + } + + // size (uint32 LE) + var size uint32 + binary.Read(r, binary.LittleEndian, &size) + if size != 0 { + t.Fatalf("size = %d, want 0", size) + } + + // attachment count + attachCount, _ := r.ReadByte() + if attachCount != 0 { + t.Fatalf("attach count = %d, want 0", attachCount) + } + + // should have consumed entire buffer + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes", r.Len()) + } +} + +func TestEncodeWithPid(t *testing.T) { + pid := make([]byte, 32) + for i := range pid { + pid[i] = byte(i) + } + h := &FMsgHeader{ + Version: 1, + Flags: FlagHasPid, + Pid: pid, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "should be omitted", + Type: "text/plain", + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + + // pid should be next 32 bytes + pidOut := make([]byte, 32) + n, _ := r.Read(pidOut) + if n != 32 { + t.Fatalf("pid bytes read = %d, want 32", n) + } + if !bytes.Equal(pidOut, pid) { + t.Fatalf("pid mismatch") + } + + // skip from + fLen, _ := r.ReadByte() + fBuf := make([]byte, fLen) + r.Read(fBuf) + + // skip to count + to[0] + toCount, _ := r.ReadByte() + if toCount != 1 { + t.Fatalf("to count = %d, want 1", toCount) + } + tLen, _ := r.ReadByte() + tBuf := make([]byte, tLen) + r.Read(tBuf) + + // skip timestamp + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + + // topic must NOT be present when pid is set — next byte should be type length + typeLen, _ := r.ReadByte() + typeBytes := make([]byte, typeLen) + r.Read(typeBytes) + if string(typeBytes) != "text/plain" { + t.Fatalf("expected type field directly after timestamp, got %q", string(typeBytes)) + } + + // size + attachment count + var size uint32 + binary.Read(r, binary.LittleEndian, &size) + attachCount, _ := r.ReadByte() + if attachCount != 0 { + t.Fatalf("attach count = %d, want 0", attachCount) + } + + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes", r.Len()) + } +} + +func TestEncodeWithAddTo(t *testing.T) { + pid := make([]byte, 32) + h := &FMsgHeader{ + Version: 1, + Flags: FlagHasPid | FlagHasAddTo, + Pid: pid, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + AddToFrom: &FMsgAddress{User: "a", Domain: "b.com"}, + AddTo: []FMsgAddress{{User: "e", Domain: "f.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + + // skip pid (32 bytes) + pidBuf := make([]byte, 32) + r.Read(pidBuf) + + // skip from + fLen, _ := r.ReadByte() + fBuf := make([]byte, fLen) + r.Read(fBuf) + + // skip to count + to[0] + toCount, _ := r.ReadByte() + if toCount != 1 { + t.Fatalf("to count = %d, want 1", toCount) + } + tLen, _ := r.ReadByte() + tBuf := make([]byte, tLen) + r.Read(tBuf) + + // add-to-from + addToFromLen, _ := r.ReadByte() + addToFrom := make([]byte, addToFromLen) + r.Read(addToFrom) + if string(addToFrom) != "@a@b.com" { + t.Fatalf("add-to-from = %q, want %q", string(addToFrom), "@a@b.com") + } + + // add to count + addToCount, _ := r.ReadByte() + if addToCount != 1 { + t.Fatalf("add to count = %d, want 1", addToCount) + } + + // add to[0] + atLen, _ := r.ReadByte() + atBuf := make([]byte, atLen) + r.Read(atBuf) + if string(atBuf) != "@e@f.com" { + t.Fatalf("add to[0] = %q, want %q", string(atBuf), "@e@f.com") + } +} + +func TestEncodeWithAddToDefaultsAddToFromToFromAddress(t *testing.T) { + pid := make([]byte, 32) + h := &FMsgHeader{ + Version: 1, + Flags: FlagHasPid | FlagHasAddTo, + Pid: pid, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + AddTo: []FMsgAddress{{User: "e", Domain: "f.com"}}, + Timestamp: 0, + Type: "text/plain", + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + r.Read(make([]byte, 32)) + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + r.ReadByte() // to count + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + + addToFromLen, _ := r.ReadByte() + addToFrom := make([]byte, addToFromLen) + r.Read(addToFrom) + if string(addToFrom) != "@a@b.com" { + t.Fatalf("default add-to-from = %q, want %q", string(addToFrom), "@a@b.com") + } +} + +func TestEncodeNoAddToWhenFlagUnset(t *testing.T) { + // When FlagHasAddTo is NOT set, add-to addresses should not appear on the wire + // even if the AddTo slice is populated. + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + AddTo: []FMsgAddress{{User: "e", Domain: "f.com"}}, // should be ignored + Timestamp: 0, + Topic: "", + Type: "text/plain", + } + withAddTo := h.Encode() + + h2 := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + } + withoutAddTo := h2.Encode() + + if !bytes.Equal(withAddTo, withoutAddTo) { + t.Fatalf("encoded bytes differ when AddTo populated but flag unset") + } +} + +func TestGetHeaderHash(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "alice", Domain: "a.com"}, + To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, + Timestamp: 1700000000.0, + Topic: "test", + Type: "text/plain", + } + hash := h.GetHeaderHash() + if len(hash) != 32 { + t.Fatalf("hash length = %d, want 32", len(hash)) + } + + // Must be deterministic + hash2 := h.GetHeaderHash() + if !bytes.Equal(hash, hash2) { + t.Fatal("GetHeaderHash not deterministic") + } + + // Must match manual SHA-256 of Encode() + expected := sha256.Sum256(h.Encode()) + if !bytes.Equal(hash, expected[:]) { + t.Fatal("GetHeaderHash does not match sha256(Encode())") + } +} + +func TestGetHeaderHashCommonTypeMatchesWireIDEncoding(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: FlagCommonType, + From: FMsgAddress{User: "alice", Domain: "a.com"}, + To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, + Timestamp: 1700000000, + Topic: "x", + TypeID: 3, + Type: "application/json", + } + expected := sha256.Sum256(h.Encode()) + got := h.GetHeaderHash() + if !bytes.Equal(got, expected[:]) { + t.Fatalf("GetHeaderHash mismatch for common type ID") + } +} + +func TestStringOutput(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "alice", Domain: "a.com"}, + To: []FMsgAddress{{User: "bob", Domain: "b.com"}, {User: "carol", Domain: "c.com"}}, + Timestamp: 0, + Topic: "greetings", + Type: "text/plain", + Size: 42, + } + s := h.String() + + // Check key substrings are present + for _, want := range []string{ + "v1", + "@alice@a.com", + "@bob@b.com", + "@carol@c.com", + "greetings", + "text/plain", + "42", + } { + if !bytes.Contains([]byte(s), []byte(want)) { + t.Errorf("String() missing %q", want) + } + } +} + +func TestStringWithAddTo(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: FlagHasAddTo, + From: FMsgAddress{User: "alice", Domain: "a.com"}, + To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, + AddTo: []FMsgAddress{{User: "dave", Domain: "d.com"}}, + Topic: "t", + Type: "text/plain", + } + s := h.String() + if !bytes.Contains([]byte(s), []byte("add to:")) { + t.Error("String() missing 'add to:' label") + } + if !bytes.Contains([]byte(s), []byte("@dave@d.com")) { + t.Error("String() missing add-to address") + } +} + +func TestEncodeTimestampEncoding(t *testing.T) { + // Verify the timestamp is encoded as little-endian float64 + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 1700000000.5, + Topic: "", + Type: "", + } + b := h.Encode() + + // Find timestamp position: version(1) + flags(1) + from(1+len) + to_count(1) + to[0](1+len) + fromStr := "@a@b.com" + toStr := "@c@d.com" + offset := 1 + 1 + 1 + len(fromStr) + 1 + 1 + len(toStr) // = 2 + 9 + 10 = 21 + tsBytes := b[offset : offset+8] + + bits := binary.LittleEndian.Uint64(tsBytes) + ts := math.Float64frombits(bits) + if ts != 1700000000.5 { + t.Fatalf("timestamp = %f, want 1700000000.5", ts) + } + + // After timestamp: topic(1+0) + type(1+0) + size(4) + attach_count(1) = 7 bytes + if r := bytes.NewReader(b[offset+8:]); r.Len() != 7 { + t.Fatalf("trailing bytes after timestamp = %d, want 7", r.Len()) + } +} + +func TestEncodeWithAttachments(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + Size: 100, + Attachments: []FMsgAttachmentHeader{ + {Flags: 0, Type: "image/png", Filename: "pic.png", Size: 2048}, + {Flags: 1, TypeID: 38, Type: "image/png", Filename: "doc.txt", Size: 512}, + }, + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + + // skip from + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + // skip to count + to[0] + r.ReadByte() + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + // skip timestamp + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + // skip topic + topicLen, _ := r.ReadByte() + r.Read(make([]byte, topicLen)) + // skip type + typeLen, _ := r.ReadByte() + r.Read(make([]byte, typeLen)) + + // size + var size uint32 + binary.Read(r, binary.LittleEndian, &size) + if size != 100 { + t.Fatalf("size = %d, want 100", size) + } + + // attachment count + attachCount, _ := r.ReadByte() + if attachCount != 2 { + t.Fatalf("attach count = %d, want 2", attachCount) + } + + // attachment 0 + att0Flags, _ := r.ReadByte() + if att0Flags != 0 { + t.Fatalf("att[0] flags = %d, want 0", att0Flags) + } + att0TypeLen, _ := r.ReadByte() + att0Type := make([]byte, att0TypeLen) + r.Read(att0Type) + if string(att0Type) != "image/png" { + t.Fatalf("att[0] type = %q, want %q", string(att0Type), "image/png") + } + att0FnLen, _ := r.ReadByte() + att0Fn := make([]byte, att0FnLen) + r.Read(att0Fn) + if string(att0Fn) != "pic.png" { + t.Fatalf("att[0] filename = %q, want %q", string(att0Fn), "pic.png") + } + var att0Size uint32 + binary.Read(r, binary.LittleEndian, &att0Size) + if att0Size != 2048 { + t.Fatalf("att[0] size = %d, want 2048", att0Size) + } + + // attachment 1 + att1Flags, _ := r.ReadByte() + if att1Flags != 1 { + t.Fatalf("att[1] flags = %d, want 1", att1Flags) + } + att1TypeID, _ := r.ReadByte() + if att1TypeID != 38 { + t.Fatalf("att[1] type ID = %d, want 38", att1TypeID) + } + att1FnLen, _ := r.ReadByte() + att1Fn := make([]byte, att1FnLen) + r.Read(att1Fn) + if string(att1Fn) != "doc.txt" { + t.Fatalf("att[1] filename = %q, want %q", string(att1Fn), "doc.txt") + } + var att1Size uint32 + binary.Read(r, binary.LittleEndian, &att1Size) + if att1Size != 512 { + t.Fatalf("att[1] size = %d, want 512", att1Size) + } + + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes", r.Len()) + } +} + +func TestEncodeWithCommonMessageType(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: FlagCommonType, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + TypeID: 3, + Type: "application/json", + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + + // skip from + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + // skip to count + to[0] + r.ReadByte() + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + // skip timestamp + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + // skip topic + topicLen, _ := r.ReadByte() + r.Read(make([]byte, topicLen)) + + typeID, _ := r.ReadByte() + if typeID != 3 { + t.Fatalf("type ID = %d, want 3", typeID) + } +} + +func TestGetMessageHashUsesDecompressedPayloads(t *testing.T) { + compress := func(data []byte) []byte { + var b bytes.Buffer + w := zlib.NewWriter(&b) + if _, err := w.Write(data); err != nil { + t.Fatalf("zlib write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("zlib close: %v", err) + } + return b.Bytes() + } + + msgPlain := []byte("hello compressed body") + attPlain := []byte("hello compressed attachment") + msgWire := compress(msgPlain) + attWire := compress(attPlain) + + tmpDir := t.TempDir() + msgPath := filepath.Join(tmpDir, "msg.bin") + if err := os.WriteFile(msgPath, msgWire, 0600); err != nil { + t.Fatalf("write msg file: %v", err) + } + attPath := filepath.Join(tmpDir, "att.bin") + if err := os.WriteFile(attPath, attWire, 0600); err != nil { + t.Fatalf("write attachment file: %v", err) + } + + h := &FMsgHeader{ + Version: 1, + Flags: FlagDeflate, + From: FMsgAddress{User: "alice", Domain: "a.com"}, + To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, + Timestamp: 1700000000, + Topic: "t", + Type: "text/plain", + Size: uint32(len(msgWire)), + ExpandedSize: uint32(len(msgPlain)), + Attachments: []FMsgAttachmentHeader{ + {Flags: 1 << 1, Type: "application/octet-stream", Filename: "a.bin", Size: uint32(len(attWire)), ExpandedSize: uint32(len(attPlain)), Filepath: attPath}, + }, + Filepath: msgPath, + } + + got, err := h.GetMessageHash() + if err != nil { + t.Fatalf("GetMessageHash() error: %v", err) + } + + manual := sha256.New() + if _, err := io.Copy(manual, bytes.NewReader(h.Encode())); err != nil { + t.Fatalf("manual header copy: %v", err) + } + if _, err := manual.Write(msgPlain); err != nil { + t.Fatalf("manual msg write: %v", err) + } + if _, err := manual.Write(attPlain); err != nil { + t.Fatalf("manual att write: %v", err) + } + want := manual.Sum(nil) + + if !bytes.Equal(got, want) { + t.Fatalf("message hash mismatch: got %x want %x", got, want) + } +} + +func TestEncodeExpandedSizePresentWhenDeflateSet(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: FlagDeflate, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + Size: 50, + ExpandedSize: 200, + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + + // skip from + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + // skip to count + to[0] + r.ReadByte() + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + // skip timestamp + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + // skip topic + topicLen, _ := r.ReadByte() + r.Read(make([]byte, topicLen)) + // skip type + typeLen, _ := r.ReadByte() + r.Read(make([]byte, typeLen)) + + // size + var size uint32 + binary.Read(r, binary.LittleEndian, &size) + if size != 50 { + t.Fatalf("size = %d, want 50", size) + } + + // expanded size must be present because FlagDeflate is set + var expandedSize uint32 + if err := binary.Read(r, binary.LittleEndian, &expandedSize); err != nil { + t.Fatalf("reading expanded size: %v", err) + } + if expandedSize != 200 { + t.Fatalf("expanded size = %d, want 200", expandedSize) + } + + // attachment count + attachCount, _ := r.ReadByte() + if attachCount != 0 { + t.Fatalf("attach count = %d, want 0", attachCount) + } + + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes", r.Len()) + } +} + +func TestEncodeNoExpandedSizeWhenDeflateUnset(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + Size: 100, + ExpandedSize: 999, // must NOT appear on wire + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + r.ReadByte() + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + topicLen, _ := r.ReadByte() + r.Read(make([]byte, topicLen)) + typeLen, _ := r.ReadByte() + r.Read(make([]byte, typeLen)) + + var size uint32 + binary.Read(r, binary.LittleEndian, &size) + if size != 100 { + t.Fatalf("size = %d, want 100", size) + } + + // No expanded size field; next byte should be attachment count = 0 + attachCount, _ := r.ReadByte() + if attachCount != 0 { + t.Fatalf("attach count = %d, want 0", attachCount) + } + + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes (expanded size should not be present when deflate unset)", r.Len()) + } +} + +func TestEncodeAttachmentExpandedSizePresentWhenDeflateSet(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + Size: 0, + Attachments: []FMsgAttachmentHeader{ + // attachment with zlib-deflate flag (bit 1 = 0b00000010) + {Flags: 1 << 1, Type: "text/plain", Filename: "doc.txt", Size: 60, ExpandedSize: 300}, + }, + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + r.ReadByte() + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + topicLen, _ := r.ReadByte() + r.Read(make([]byte, topicLen)) + typeLen, _ := r.ReadByte() + r.Read(make([]byte, typeLen)) + var msgSize uint32 + binary.Read(r, binary.LittleEndian, &msgSize) + + // attachment count + attachCount, _ := r.ReadByte() + if attachCount != 1 { + t.Fatalf("attach count = %d, want 1", attachCount) + } + + // attachment flags + attFlags, _ := r.ReadByte() + if attFlags != 1<<1 { + t.Fatalf("att flags = %d, want %d", attFlags, 1<<1) + } + // type (length-prefixed) + attTypeLen, _ := r.ReadByte() + r.Read(make([]byte, attTypeLen)) + // filename + attFnLen, _ := r.ReadByte() + r.Read(make([]byte, attFnLen)) + // wire size + var attSize uint32 + binary.Read(r, binary.LittleEndian, &attSize) + if attSize != 60 { + t.Fatalf("att size = %d, want 60", attSize) + } + // expanded size must be present + var attExpandedSize uint32 + if err := binary.Read(r, binary.LittleEndian, &attExpandedSize); err != nil { + t.Fatalf("reading att expanded size: %v", err) + } + if attExpandedSize != 300 { + t.Fatalf("att expanded size = %d, want 300", attExpandedSize) + } + + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes", r.Len()) + } +} + +func TestHashPayloadRejectsExpandedSizeMismatch(t *testing.T) { + compress := func(data []byte) []byte { + var b bytes.Buffer + w := zlib.NewWriter(&b) + w.Write(data) + w.Close() + return b.Bytes() + } + + plain := []byte("hello world this is test data") + wire := compress(plain) + + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "data.bin") + if err := os.WriteFile(p, wire, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + // Correct expanded size should succeed + var dst bytes.Buffer + if err := fmsg.HashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))); err != nil { + t.Fatalf("hashPayload with correct expanded size: %v", err) + } + + // Wrong expanded size should fail + dst.Reset() + err := fmsg.HashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))+1) + if err == nil { + t.Fatal("hashPayload with wrong expanded size: expected error, got nil") + } +} diff --git a/cmd/fmsgd/dns.go b/cmd/fmsgd/dns.go new file mode 100644 index 0000000..a5e6daa --- /dev/null +++ b/cmd/fmsgd/dns.go @@ -0,0 +1,148 @@ +package main + +import ( + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "strings" + "time" + + "github.com/miekg/dns" +) + +func dnssecRequired() bool { + return os.Getenv("FMSG_REQUIRE_DNSSEC") == "true" +} + +func resolverAuthenticatedData(name string, qtype uint16) (bool, error) { + cfg, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + return false, err + } + + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(name), qtype) + msg.SetEdns0(4096, true) + + client := &dns.Client{Timeout: 5 * time.Second} + var lastErr error + for _, server := range cfg.Servers { + addr := net.JoinHostPort(server, cfg.Port) + resp, _, err := client.Exchange(msg, addr) + if err != nil { + lastErr = err + continue + } + if resp == nil { + lastErr = fmt.Errorf("nil DNS response from %s", addr) + continue + } + if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError { + lastErr = fmt.Errorf("dns rcode %d from %s", resp.Rcode, addr) + continue + } + return resp.AuthenticatedData, nil + } + + if lastErr == nil { + lastErr = fmt.Errorf("no DNS resolvers configured") + } + return false, lastErr +} + +// lookupAuthorisedIPs resolves fmsg. for A and AAAA records +func lookupAuthorisedIPs(domain string) ([]net.IP, error) { + fmsgDomain := "fmsg." + domain + ips, err := net.LookupIP(fmsgDomain) + if err != nil { + return nil, fmt.Errorf("DNS lookup for %s failed: %w", fmsgDomain, err) + } + if len(ips) == 0 { + return nil, fmt.Errorf("no A/AAAA records found for %s", fmsgDomain) + } + + if dnssecRequired() { + adA, errA := resolverAuthenticatedData(fmsgDomain, dns.TypeA) + adAAAA, errAAAA := resolverAuthenticatedData(fmsgDomain, dns.TypeAAAA) + if !adA && !adAAAA { + if errA != nil && errAAAA != nil { + return nil, fmt.Errorf("dnssec validation failed for %s: A=%v AAAA=%v", fmsgDomain, errA, errAAAA) + } + return nil, fmt.Errorf("dnssec validation failed for %s: resolver did not set AD bit", fmsgDomain) + } + } + + return ips, nil +} + +// getExternalIP discovers this host's external IP address +func getExternalIP() (net.IP, error) { + services := []string{ + "https://api.ipify.org", + "https://checkip.amazonaws.com", + "https://icanhazip.com", + } + client := &http.Client{Timeout: 10 * time.Second} + var lastErr error + for _, svc := range services { + resp, err := client.Get(svc) + if err != nil { + lastErr = fmt.Errorf("%s: %w", svc, err) + continue + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + lastErr = fmt.Errorf("%s: failed to read response: %w", svc, err) + continue + } + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("%s: unexpected status %d", svc, resp.StatusCode) + continue + } + ip := net.ParseIP(strings.TrimSpace(string(body))) + if ip == nil { + lastErr = fmt.Errorf("%s: failed to parse IP from response: %s", svc, string(body)) + continue + } + return ip, nil + } + return nil, fmt.Errorf("all external IP services failed, last error: %w", lastErr) +} + +// verifyDomainIP checks that this host's external IP is present in the +// fmsg. authorised IP set. Panics if not found. +func verifyDomainIP(domain string) { + externalIP, err := getExternalIP() + if err != nil { + log.Panicf("ERROR: failed to get external IP: %s", err) + } + log.Printf("INFO: external IP: %s", externalIP) + + authorisedIPs, err := lookupAuthorisedIPs(domain) + if err != nil { + log.Panicf("ERROR: failed to lookup fmsg.%s: %s", domain, err) + } + + for _, ip := range authorisedIPs { + if externalIP.Equal(ip) { + log.Printf("INFO: external IP %s found in fmsg.%s authorised IPs", externalIP, domain) + return + } + } + + log.Panicf("ERROR: external IP %s not found in fmsg.%s authorised IPs %v", externalIP, domain, authorisedIPs) +} + +// checkDomainIP verifies the external IP is authorised unless +// FMSG_SKIP_DOMAIN_IP_CHECK is set to "true". +func checkDomainIP(domain string) { + if os.Getenv("FMSG_SKIP_DOMAIN_IP_CHECK") == "true" { + log.Println("INFO: skipping domain IP verification (FMSG_SKIP_DOMAIN_IP_CHECK=true)") + return + } + verifyDomainIP(domain) +} diff --git a/cmd/fmsgd/host.go b/cmd/fmsgd/host.go new file mode 100644 index 0000000..194fed9 --- /dev/null +++ b/cmd/fmsgd/host.go @@ -0,0 +1,1688 @@ +package main + +import ( + "bufio" + "bytes" + "compress/zlib" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "math" + "mime" + "net" + "net/url" + "os" + "path/filepath" + "strings" + "time" + "unicode" + "unicode/utf8" + + env "github.com/caitlinelfring/go-env-default" + "github.com/joho/godotenv" + "github.com/levenlabs/golib/timeutil" + + "github.com/markmnl/fmsgd/pkg/fmsg" +) + +const ( + InboxDirName = "in" + OutboxDirName = "out" + + RejectCodeInvalid uint8 = 1 + RejectCodeUnsupportedVersion uint8 = 2 + RejectCodeUndisclosed uint8 = 3 + RejectCodeTooBig uint8 = 4 + RejectCodeInsufficentResources uint8 = 5 + RejectCodeParentNotFound uint8 = 6 + RejectCodePastTime uint8 = 7 + RejectCodeFutureTime uint8 = 8 + RejectCodeTimeTravel uint8 = 9 + RejectCodeDuplicate uint8 = 10 + AcceptCodeAddTo uint8 = 11 + AcceptCodeContinue uint8 = 64 + AcceptCodeSkipData uint8 = 65 + + RejectCodeUserUnknown uint8 = 100 + RejectCodeUserFull uint8 = 101 + RejectCodeUserNotAccepting uint8 = 102 + RejectCodeUserDuplicate uint8 = 103 + RejectCodeUserUndisclosed uint8 = 105 + + RejectCodeAccept uint8 = 200 + + messageReservedBitsMask uint8 = 0b11000000 + attachmentReservedBitsMask uint8 = 0b11111100 +) + +// responseCodeName returns the human-friendly name for a response code. +func responseCodeName(code uint8) string { + switch code { + case RejectCodeInvalid: + return "invalid" + case RejectCodeUnsupportedVersion: + return "unsupported version" + case RejectCodeUndisclosed: + return "undisclosed" + case RejectCodeTooBig: + return "too big" + case RejectCodeInsufficentResources: + return "insufficient resources" + case RejectCodeParentNotFound: + return "parent not found" + case RejectCodePastTime: + return "past time" + case RejectCodeFutureTime: + return "future time" + case RejectCodeTimeTravel: + return "time travel" + case RejectCodeDuplicate: + return "duplicate" + case AcceptCodeAddTo: + return "accept add to" + case RejectCodeUserUnknown: + return "user unknown" + case RejectCodeUserFull: + return "user full" + case RejectCodeUserNotAccepting: + return "user not accepting" + case RejectCodeUserDuplicate: + return "user duplicate" + case RejectCodeUserUndisclosed: + return "user undisclosed" + case RejectCodeAccept: + return "accept" + default: + return fmt.Sprintf("unknown(%d)", code) + } +} + +var ErrProtocolViolation = errors.New("protocol violation") + +var Port = 4930 + +// The only reason RemotePort would ever be different from Port is when running two fmsg hosts on the same machine so the same port is unavaliable. +var RemotePort = 4930 +var PastTimeDelta float64 = 7 * 24 * 60 * 60 +var FutureTimeDelta float64 = 300 +var MinDownloadRate float64 = 5000 +var MinUploadRate float64 = 5000 +var ReadBufferSize = 1600 +var MaxMessageSize = uint32(1024 * 10) +var MaxExpandedSize = uint32(1024 * 10) +var SkipAuthorisedIPs = false +var TLSInsecureSkipVerify = false +var DataDir = "got on startup" +var Domain = "got on startup" +var IDURI = "got on startup" +var AtRune, _ = utf8.DecodeRuneInString("@") +var MinNetIODeadline = 6 * time.Second + +var serverTLSConfig *tls.Config + +func buildServerTLSConfig() *tls.Config { + certFile := os.Getenv("FMSG_TLS_CERT") + keyFile := os.Getenv("FMSG_TLS_KEY") + if certFile == "" || keyFile == "" { + log.Fatalf("ERROR: FMSG_TLS_CERT and FMSG_TLS_KEY must be set") + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalf("ERROR: loading TLS certificate: %s", err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + }, + NextProtos: []string{"fmsg/1"}, + } +} + +func buildClientTLSConfig(serverName string) *tls.Config { + return &tls.Config{ + ServerName: serverName, + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: TLSInsecureSkipVerify, + NextProtos: []string{"fmsg/1"}, + } +} + +// loadEnvConfig reads env vars (after godotenv.Load so .env is picked up). +func loadEnvConfig() { + Port = env.GetIntDefault("FMSG_PORT", 4930) + RemotePort = env.GetIntDefault("FMSG_REMOTE_PORT", 4930) + PastTimeDelta = env.GetFloatDefault("FMSG_MAX_PAST_TIME_DELTA", 7*24*60*60) + FutureTimeDelta = env.GetFloatDefault("FMSG_MAX_FUTURE_TIME_DELTA", 300) + MinDownloadRate = env.GetFloatDefault("FMSG_MIN_DOWNLOAD_RATE", 5000) + MinUploadRate = env.GetFloatDefault("FMSG_MIN_UPLOAD_RATE", 5000) + ReadBufferSize = env.GetIntDefault("FMSG_READ_BUFFER_SIZE", 1600) + MaxMessageSize = uint32(env.GetIntDefault("FMSG_MAX_MSG_SIZE", 1024*10)) + MaxExpandedSize = uint32(env.GetIntDefault("FMSG_MAX_EXPANDED_SIZE", int(MaxMessageSize))) + SkipAuthorisedIPs = os.Getenv("FMSG_SKIP_AUTHORISED_IPS") == "true" + TLSInsecureSkipVerify = os.Getenv("FMSG_TLS_INSECURE_SKIP_VERIFY") == "true" +} + +// Updates DataDir from environment, panics if not a valid directory. +func setDataDir() { + value, hasValue := os.LookupEnv("FMSG_DATA_DIR") + if !hasValue { + log.Panic("ERROR: FMSG_DATA_DIR not set") + } + _, err := os.Stat(value) + if err != nil { + log.Panicf("ERROR: FMSG_DATA_DIR, %s: %s", value, err) + } + DataDir = value +} + +// Updates Domain from environment, panics if not a valid domain. +func setDomain() { + domain, hasValue := os.LookupEnv("FMSG_DOMAIN") + if !hasValue { + log.Panicln("ERROR: FMSG_DOMAIN not set") + } + _, err := net.LookupHost("fmsg." + domain) + if err != nil { + log.Panicf("ERROR: FMSG_DOMAIN, %s: %s\n", domain, err) + } + Domain = domain + + // verify our external IP is in the fmsg authorised IP set + checkDomainIP(domain) +} + +// Updates IDURL from environment, panics if not valid. +func setIDURL() { + rawUrl, hasValue := os.LookupEnv("FMSG_ID_URL") + if !hasValue { + log.Panicln("ERROR: FMSG_ID_URL not set") + } + url, err := url.Parse(rawUrl) + if err != nil { + log.Panicf("ERROR: FMSG_ID_URL not valid, %s: %s", rawUrl, err) + } + _, err = net.LookupHost(url.Hostname()) + if err != nil { + log.Panicf("ERROR: FMSG_ID_URL lookup failed, %s: %s", url, err) + } + // TODO ping URL to verify its up and responding in a timely manner + IDURI = rawUrl + log.Printf("INFO: ID URL: %s", IDURI) +} + +func calcNetIODuration(sizeInBytes int, bytesPerSecond float64) time.Duration { + rate := float64(sizeInBytes) / bytesPerSecond + d := time.Duration(rate * float64(time.Second)) + if d < MinNetIODeadline { + return MinNetIODeadline + } + return d +} + +func isValidUser(s string) bool { + if !utf8.ValidString(s) || len(s) == 0 || len(s) > 64 { + return false + } + + isSpecial := func(r rune) bool { + return r == '-' || r == '_' || r == '.' + } + + runes := []rune(s) + if isSpecial(runes[0]) || isSpecial(runes[len(runes)-1]) { + return false + } + + lastWasSpecial := false + for _, c := range runes { + if unicode.IsLetter(c) || unicode.IsNumber(c) { + lastWasSpecial = false + continue + } + if !isSpecial(c) { + return false + } + if lastWasSpecial { + return false + } + lastWasSpecial = true + } + return true +} + +func isASCIIBytes(b []byte) bool { + for _, c := range b { + if c > 127 { + return false + } + } + return true +} + +func isValidAttachmentFilename(name string) bool { + if !utf8.ValidString(name) || len(name) == 0 || len(name) >= 256 { + return false + } + + isSpecial := func(r rune) bool { + return r == '-' || r == '_' || r == ' ' || r == '.' + } + + runes := []rune(name) + if isSpecial(runes[0]) || isSpecial(runes[len(runes)-1]) { + return false + } + + lastWasSpecial := false + for _, r := range runes { + if unicode.IsLetter(r) || unicode.IsNumber(r) { + lastWasSpecial = false + continue + } + if !isSpecial(r) { + return false + } + if lastWasSpecial { + return false + } + lastWasSpecial = true + } + + return true +} + +func isMessageRetrievable(msg *FMsgHeader) bool { + if msg == nil { + return false + } + if msg.Filepath != "" { + st, err := os.Stat(msg.Filepath) + if err == nil && !st.IsDir() { + return true + } + } + if len(msg.Pid) == 0 { + return false + } + parentID, err := lookupMsgIdByHash(msg.Pid) + if err != nil || parentID == 0 { + return false + } + parentMsg, err := getMsgByID(parentID) + if err != nil { + return false + } + if parentMsg == nil { + return false + } + return isMessageRetrievable(parentMsg) +} + +func isParentParticipant(parent *FMsgHeader, addr *FMsgAddress) bool { + if parent == nil || addr == nil { + return false + } + target := strings.ToLower(addr.ToString()) + if strings.ToLower(parent.From.ToString()) == target { + return true + } + for i := range parent.To { + if strings.ToLower(parent.To[i].ToString()) == target { + return true + } + } + if parent.AddToFrom != nil && strings.ToLower(parent.AddToFrom.ToString()) == target { + return true + } + for i := range parent.AddTo { + if strings.ToLower(parent.AddTo[i].ToString()) == target { + return true + } + } + return false +} + +func isValidDomain(s string) bool { + if len(s) == 0 || len(s) > 253 { + return false + } + if s == "localhost" { + return true + } + labels := strings.Split(s, ".") + if len(labels) < 2 { + return false + } + for _, label := range labels { + if len(label) == 0 || len(label) > 63 { + return false + } + if label[0] == '-' || label[len(label)-1] == '-' { + return false + } + for _, c := range label { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-') { + return false + } + } + } + return true +} + +func parseAddress(b []byte) (*FMsgAddress, error) { + if len(b) < 4 { + return nil, fmt.Errorf("invalid address: too short (%d bytes)", len(b)) + } + var addr = &FMsgAddress{} + addrStr := string(b) + firstAt := strings.IndexRune(addrStr, AtRune) + if firstAt == -1 || firstAt != 0 { + return addr, fmt.Errorf("invalid address, must start with @ %s", addr) + } + lastAt := strings.LastIndex(addrStr, "@") + if lastAt == firstAt { + return addr, fmt.Errorf("invalid address, must have second @ %s", addr) + } + addr.User = addrStr[1:lastAt] + if !isValidUser(addr.User) { + return addr, fmt.Errorf("invalid user in address: %s", addr.User) + } + addr.Domain = addrStr[lastAt+1:] + if !isValidDomain(addr.Domain) { + return addr, fmt.Errorf("invalid domain in address: %s", addr.Domain) + } + return addr, nil +} + +// Reads byte slice prefixed with uint8 size from reader supplied +func ReadUInt8Slice(r io.Reader) ([]byte, error) { + var size byte + err := binary.Read(r, binary.LittleEndian, &size) + if err != nil { + return nil, err + } + return io.ReadAll(io.LimitReader(r, int64(size))) +} + +func readAddress(r io.Reader) (*FMsgAddress, error) { + slice, err := ReadUInt8Slice(r) + if err != nil { + return nil, err + } + return parseAddress(slice) +} + +func handleChallenge(c net.Conn, r *bufio.Reader) error { + hashSlice, err := io.ReadAll(io.LimitReader(r, 32)) + if err != nil { + return err + } + hash := *(*[32]byte)(hashSlice) + log.Printf("INFO: CHALLENGE <-- %s", hex.EncodeToString(hashSlice)) + + // Verify the challenger's IP is the Host-B IP registered for this message + // (§10.5 step 2). An unrecognised hash OR a mismatched IP both → TERMINATE. + remoteIP, _, _ := net.SplitHostPort(c.RemoteAddr().String()) + header, exists := lookupOutgoing(hash, remoteIP) + if !exists { + return fmt.Errorf("challenge for unknown message: %s, from: %s", hex.EncodeToString(hashSlice), c.RemoteAddr().String()) + } + msgHash, err := header.GetMessageHash() + if err != nil { + return err + } + if _, err := c.Write(msgHash); err != nil { + return err + } + return nil +} + +func rejectAccept(c net.Conn, codes []byte) error { + _, err := c.Write(codes) + return err +} + +func sendCode(c net.Conn, code uint8) error { + return rejectAccept(c, []byte{code}) +} + +func validateMessageFlags(c net.Conn, flags uint8) error { + if flags&messageReservedBitsMask != 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("reserved message flag bits set: %#08b", flags) + } + return nil +} + +func validateAttachmentFlags(c net.Conn, flags uint8) error { + if flags&attachmentReservedBitsMask != 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("reserved attachment flag bits set: %#08b", flags) + } + return nil +} + +func hasDomainRecipient(addrs []FMsgAddress, domain string) bool { + for _, addr := range addrs { + if strings.EqualFold(addr.Domain, domain) { + return true + } + } + return false +} + +func determineSenderDomain(h *FMsgHeader) string { + if len(h.AddTo) > 0 && h.AddToFrom != nil { + return h.AddToFrom.Domain + } + return h.From.Domain +} + +func verifySenderIP(c net.Conn, senderDomain string) error { + if SkipAuthorisedIPs { + return nil + } + + remoteHost, _, err := net.SplitHostPort(c.RemoteAddr().String()) + if err != nil { + log.Printf("WARN: failed to parse remote address for DNS check: %s", err) + return fmt.Errorf("DNS verification failed") + } + + remoteIP := net.ParseIP(remoteHost) + if remoteIP == nil { + log.Printf("WARN: failed to parse remote IP: %s", remoteHost) + return fmt.Errorf("DNS verification failed") + } + + authorisedIPs, err := lookupAuthorisedIPs(senderDomain) + if err != nil { + log.Printf("WARN: DNS lookup failed for fmsg.%s: %s", senderDomain, err) + return fmt.Errorf("DNS verification failed") + } + + for _, ip := range authorisedIPs { + if remoteIP.Equal(ip) { + return nil + } + } + + log.Printf("WARN: remote IP %s not in authorised IPs for fmsg.%s", remoteIP.String(), senderDomain) + return fmt.Errorf("DNS verification failed") +} + +func handleAddToPath(c net.Conn, h *FMsgHeader) (*FMsgHeader, error) { + if len(h.AddTo) == 0 { + return h, nil + } + + addToHasOurDomain := hasDomainRecipient(h.AddTo, Domain) + + parentID, err := lookupMsgIdByHash(h.Pid) + if err != nil { + return h, err + } + + if parentID == 0 { + h.InitialResponseCode = AcceptCodeContinue + return h, nil + } + + parentMsg, err := getMsgByID(parentID) + if err != nil { + return h, err + } + if parentMsg == nil || !isMessageRetrievable(parentMsg) { + h.InitialResponseCode = AcceptCodeContinue + return h, nil + } + + if parentMsg.Timestamp-FutureTimeDelta > h.Timestamp { + if err := sendCode(c, RejectCodeTimeTravel); err != nil { + return h, err + } + return h, fmt.Errorf("add-to: time travel detected (parent time %f, current %f)", parentMsg.Timestamp, h.Timestamp) + } + + if addToHasOurDomain { + h.InitialResponseCode = AcceptCodeSkipData + return h, nil + } + + h.Filepath = parentMsg.Filepath + for i := range h.Attachments { + if i < len(parentMsg.Attachments) { + h.Attachments[i].Filepath = parentMsg.Attachments[i].Filepath + } + } + h.InitialResponseCode = AcceptCodeAddTo + return h, nil +} + +func validatePidReplyPath(c net.Conn, h *FMsgHeader) error { + if len(h.AddTo) != 0 || h.Flags&FlagHasPid == 0 { + return nil + } + + parentID, err := lookupMsgIdByHash(h.Pid) + if err != nil { + return err + } + if parentID == 0 { + if err := sendCode(c, RejectCodeParentNotFound); err != nil { + return err + } + return fmt.Errorf("pid reply: parent not found for pid %s", hex.EncodeToString(h.Pid)) + } + + parentMsg, err := getMsgByID(parentID) + if err != nil { + return err + } + if parentMsg == nil { + if err := sendCode(c, RejectCodeParentNotFound); err != nil { + return err + } + return fmt.Errorf("pid reply: parent message not found by ID %d", parentID) + } + if !isMessageRetrievable(parentMsg) { + if err := sendCode(c, RejectCodeParentNotFound); err != nil { + return err + } + return fmt.Errorf("pid reply: parent is not retrievable for msg %d", parentID) + } + + if parentMsg.Timestamp-FutureTimeDelta > h.Timestamp { + if err := sendCode(c, RejectCodeTimeTravel); err != nil { + return err + } + return fmt.Errorf("pid reply: time travel detected (parent time %f, current %f)", parentMsg.Timestamp, h.Timestamp) + } + if !isParentParticipant(parentMsg, &h.From) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("pid reply: sender %s was not a participant of parent", h.From.ToString()) + } + + return nil +} + +func readVersionOrChallenge(c net.Conn, r *bufio.Reader, h *FMsgHeader) (bool, error) { + v, err := r.ReadByte() + if err != nil { + return false, err + } + if v >= 129 { + challengeVersion := 256 - int(v) + if challengeVersion == 1 { + return true, handleChallenge(c, r) + } + if err := sendCode(c, RejectCodeUnsupportedVersion); err != nil { + log.Printf("WARN: failed to send unsupported version response: %s", err) + } + return false, fmt.Errorf("unsupported challenge version: %d", challengeVersion) + } + if v != 1 { + if err := sendCode(c, RejectCodeUnsupportedVersion); err != nil { + log.Printf("WARN: failed to send unsupported version response: %s", err) + } + return false, fmt.Errorf("unsupported message version: %d", v) + } + h.Version = v + return false, nil +} + +func readToRecipients(c net.Conn, r *bufio.Reader, h *FMsgHeader) (map[string]bool, error) { + num, err := r.ReadByte() + if err != nil { + return nil, err + } + if num == 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return nil, err + } + return nil, fmt.Errorf("to count must be >= 1") + } + seen := make(map[string]bool) + for num > 0 { + addr, err := readAddress(r) + if err != nil { + return nil, err + } + key := strings.ToLower(addr.ToString()) + if seen[key] { + return nil, fmt.Errorf("duplicate recipient address: %s", addr.ToString()) + } + seen[key] = true + h.To = append(h.To, *addr) + num-- + } + return seen, nil +} + +func readAddToRecipients(c net.Conn, r *bufio.Reader, h *FMsgHeader, seen map[string]bool) error { + if h.Flags&FlagHasAddTo == 0 { + return nil + } + if h.Flags&FlagHasPid == 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("add to exists but pid does not") + } + + addToFrom, err := readAddress(r) + if err != nil { + if err2 := sendCode(c, RejectCodeInvalid); err2 != nil { + return err2 + } + return fmt.Errorf("reading add-to-from address: %w", err) + } + + addToFromKey := strings.ToLower(addToFrom.ToString()) + fromKey := strings.ToLower(h.From.ToString()) + inFromOrTo := fromKey == addToFromKey + if !inFromOrTo { + for _, toAddr := range h.To { + if strings.ToLower(toAddr.ToString()) == addToFromKey { + inFromOrTo = true + break + } + } + } + if !inFromOrTo { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("add-to-from (%s) not in from or to", addToFrom.ToString()) + } + h.AddToFrom = addToFrom + + addToCount, err := r.ReadByte() + if err != nil { + return err + } + if addToCount == 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("add to flag set but count is 0") + } + + addToSeen := make(map[string]bool) + for addToCount > 0 { + addr, err := readAddress(r) + if err != nil { + return err + } + key := strings.ToLower(addr.ToString()) + if addToSeen[key] { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("duplicate recipient address in add to: %s", addr.ToString()) + } + addToSeen[key] = true + if seen[key] { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("add-to address already in to: %s", addr.ToString()) + } + h.AddTo = append(h.AddTo, *addr) + addToCount-- + } + + return nil +} + +func readAndValidateTimestamp(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { + if err := binary.Read(r, binary.LittleEndian, &h.Timestamp); err != nil { + return err + } + now := timeutil.TimestampNow().Float64() + delta := now - h.Timestamp + if PastTimeDelta > 0 && delta > PastTimeDelta { + if err := sendCode(c, RejectCodePastTime); err != nil { + return err + } + return fmt.Errorf("message timestamp: %f too far in past, delta: %fs", h.Timestamp, delta) + } + if FutureTimeDelta > 0 && delta < 0 && math.Abs(delta) > FutureTimeDelta { + if err := sendCode(c, RejectCodeFutureTime); err != nil { + return err + } + return fmt.Errorf("message timestamp: %f too far in future, delta: %fs", h.Timestamp, delta) + } + return nil +} + +func readMessageType(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { + if h.Flags&FlagCommonType != 0 { + typeID, err := r.ReadByte() + if err != nil { + return err + } + mtype, ok := fmsg.GetCommonMediaType(typeID) + if !ok { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("unmapped common type ID: %d", typeID) + } + h.TypeID = typeID + h.Type = mtype + return nil + } + + mime, err := ReadUInt8Slice(r) + if err != nil { + return err + } + if !isASCIIBytes(mime) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("message media type must be US-ASCII") + } + h.Type = string(mime) + return nil +} + +func readAttachmentType(c net.Conn, r *bufio.Reader, flags uint8) (string, uint8, error) { + if flags&(1<<0) != 0 { + typeID, err := r.ReadByte() + if err != nil { + return "", 0, err + } + mtype, ok := fmsg.GetCommonMediaType(typeID) + if !ok { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return "", 0, err + } + return "", 0, fmt.Errorf("unmapped attachment common type ID: %d", typeID) + } + return mtype, typeID, nil + } + + typeBytes, err := ReadUInt8Slice(r) + if err != nil { + return "", 0, err + } + if !isASCIIBytes(typeBytes) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return "", 0, err + } + return "", 0, fmt.Errorf("attachment media type must be US-ASCII") + } + return string(typeBytes), 0, nil +} + +func readAttachmentHeaders(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { + var attachCount uint8 + if err := binary.Read(r, binary.LittleEndian, &attachCount); err != nil { + return err + } + + totalSize := h.Size + // When message is compressed, expanded size comes from the header field. + // When uncompressed, the wire size IS the expanded size. + var totalExpandedSize uint32 + if h.Flags&FlagDeflate != 0 { + totalExpandedSize = h.ExpandedSize + } else { + totalExpandedSize = h.Size + } + filenameSeen := make(map[string]bool) + for i := uint8(0); i < attachCount; i++ { + attFlags, err := r.ReadByte() + if err != nil { + return err + } + if err := validateAttachmentFlags(c, attFlags); err != nil { + return err + } + + attType, attTypeID, err := readAttachmentType(c, r, attFlags) + if err != nil { + return err + } + + filenameBytes, err := ReadUInt8Slice(r) + if err != nil { + return err + } + filename := string(filenameBytes) + if !isValidAttachmentFilename(filename) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("invalid attachment filename: %s", filename) + } + filenameKey := strings.ToLower(filename) + if filenameSeen[filenameKey] { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("duplicate attachment filename: %s", filename) + } + filenameSeen[filenameKey] = true + + var attSize uint32 + if err := binary.Read(r, binary.LittleEndian, &attSize); err != nil { + return err + } + + // read attachment expanded size — present iff attachment zlib-deflate flag set (§5) + var attExpandedSize uint32 + if attFlags&(1<<1) != 0 { + if err := binary.Read(r, binary.LittleEndian, &attExpandedSize); err != nil { + return err + } + totalExpandedSize += attExpandedSize + } else { + // uncompressed: expanded size equals wire size + totalExpandedSize += attSize + } + + h.Attachments = append(h.Attachments, FMsgAttachmentHeader{ + Flags: attFlags, + TypeID: attTypeID, + Type: attType, + Filename: filename, + Size: attSize, + ExpandedSize: attExpandedSize, + }) + totalSize += attSize + } + + if totalSize > MaxMessageSize { + if err := sendCode(c, RejectCodeTooBig); err != nil { + return err + } + return fmt.Errorf("total message size %d exceeds max %d", totalSize, MaxMessageSize) + } + + if totalExpandedSize > MaxExpandedSize { + if err := sendCode(c, RejectCodeTooBig); err != nil { + return err + } + return fmt.Errorf("total expanded size %d exceeds MAX_EXPANDED_SIZE %d", totalExpandedSize, MaxExpandedSize) + } + + return nil +} + +func readHeader(c net.Conn) (*FMsgHeader, *bufio.Reader, error) { + r := bufio.NewReaderSize(c, ReadBufferSize) + var h = &FMsgHeader{InitialResponseCode: AcceptCodeContinue} + + d := calcNetIODuration(66000, MinDownloadRate) // max possible header size + c.SetReadDeadline(time.Now().Add(d)) + + handled, err := readVersionOrChallenge(c, r, h) + if err != nil { + if handled { + return nil, r, err + } + return h, r, err + } + if handled { + return nil, r, nil + } + + // read flags + flags, err := r.ReadByte() + if err != nil { + return h, r, err + } + h.Flags = flags + if err := validateMessageFlags(c, flags); err != nil { + return h, r, err + } + + // read pid if any + if flags&FlagHasPid == 1 { + pid, err := io.ReadAll(io.LimitReader(r, 32)) + if err != nil { + return h, r, err + } + h.Pid = make([]byte, 32) + copy(h.Pid, pid) + } + + // read from address + from, err := readAddress(r) + if err != nil { + return h, r, err + } + + h.From = *from + + seen, err := readToRecipients(c, r, h) + if err != nil { + return h, r, err + } + + if err := readAddToRecipients(c, r, h, seen); err != nil { + return h, r, err + } + + if err := readAndValidateTimestamp(c, r, h); err != nil { + return h, r, err + } + + // read topic — only present when pid is NOT set (first message in a thread) + if flags&FlagHasPid == 0 { + topic, err := ReadUInt8Slice(r) + if err != nil { + return h, r, err + } + h.Topic = string(topic) + } + + if err := readMessageType(c, r, h); err != nil { + return h, r, err + } + + // read message size + if err := binary.Read(r, binary.LittleEndian, &h.Size); err != nil { + return h, r, err + } + // read expanded size — present iff zlib-deflate flag is set (§2 field 12) + if h.Flags&FlagDeflate != 0 { + if err := binary.Read(r, binary.LittleEndian, &h.ExpandedSize); err != nil { + return h, r, err + } + if h.ExpandedSize > MaxExpandedSize { + if err := sendCode(c, RejectCodeTooBig); err != nil { + return h, r, err + } + return h, r, fmt.Errorf("expanded size %d exceeds MAX_EXPANDED_SIZE %d", h.ExpandedSize, MaxExpandedSize) + } + } + // Size check is deferred until attachment headers are parsed (see below) + + if err := readAttachmentHeaders(c, r, h); err != nil { + return h, r, err + } + + log.Printf("INFO: <-- MSG\n%s", h) + + if !hasDomainRecipient(h.To, Domain) && !hasDomainRecipient(h.AddTo, Domain) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return h, r, err + } + return h, r, fmt.Errorf("no recipients for domain %s", Domain) + } + + if err := verifySenderIP(c, determineSenderDomain(h)); err != nil { + return nil, r, err + } + + h, err = handleAddToPath(c, h) + if err != nil { + return h, r, err + } + if h == nil { + return nil, r, nil + } + + if err := validatePidReplyPath(c, h); err != nil { + return h, r, err + } + + return h, r, nil +} + +// Sends CHALLENGE request to sender, receiving and storing the challenge hash. +// DNS verification of the remote IP is performed during header exchange (readHeader). +// TODO [Spec step 2]: The spec defines challenge modes (NEVER, ALWAYS, +// HAS_NOT_PARTICIPATED, DIFFERENT_DOMAIN) as implementation choices. +// Currently defaults to ALWAYS. Implement configurable challenge mode. +func challenge(conn net.Conn, h *FMsgHeader, senderDomain string) error { + + // Connection 2 MUST target the same IP as Connection 1 (spec 2.1). + remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + return fmt.Errorf("failed to parse remote address for challenge: %w", err) + } + conn2, err := tls.Dial("tcp", net.JoinHostPort(remoteHost, fmt.Sprintf("%d", RemotePort)), buildClientTLSConfig("fmsg."+senderDomain)) + if err != nil { + return err + } + version := uint8(255) + if err := binary.Write(conn2, binary.LittleEndian, version); err != nil { + return err + } + hash := h.GetHeaderHash() + log.Printf("INFO: --> CHALLENGE\t%s\n", hex.EncodeToString(hash)) + if _, err := conn2.Write(hash); err != nil { + return err + } + + // read challenge response + resp, err := io.ReadAll(io.LimitReader(conn2, 32)) + if err != nil { + return err + } + if len(resp) != 32 { + return fmt.Errorf("challenge response size %d, expected 32", len(resp)) + } + copy(h.ChallengeHash[:], resp) + h.ChallengeCompleted = true + log.Printf("INFO: <-- CHALLENGE RESP\t%s\n", hex.EncodeToString(resp)) + + // gracefully close 2nd connection + if err := conn2.Close(); err != nil { + return err + } + + return nil +} + +func validateMsgRecvForAddr(h *FMsgHeader, addr *FMsgAddress, msgHash []byte) (code uint8, err error) { + duplicate, err := hasAddrReceivedMsgHash(msgHash, addr) + if err != nil { + return RejectCodeUserUndisclosed, err + } + if duplicate { + return RejectCodeUserDuplicate, nil + } + + detail, err := getAddressDetail(addr) + if err != nil { + return RejectCodeUserUndisclosed, err + } + if detail == nil { + return RejectCodeUserUnknown, nil + } + + // check user accepting new + if !detail.AcceptingNew { + return RejectCodeUserNotAccepting, nil + } + + // check user limits + if detail.LimitRecvCountPer1d > -1 && detail.RecvCountPer1d+1 > detail.LimitRecvCountPer1d { + log.Printf("WARN: Message rejected: RecvCountPer1d would exceed LimitRecvCountPer1d %d", detail.LimitRecvCountPer1d) + return RejectCodeUserFull, nil + } + if detail.LimitRecvSizePer1d > -1 && detail.RecvSizePer1d+int64(h.Size) > detail.LimitRecvSizePer1d { + log.Printf("WARN: Message rejected: RecvSizePer1d would exceed LimitRecvSizePer1d %d", detail.LimitRecvSizePer1d) + return RejectCodeUserFull, nil + } + if detail.LimitRecvSizeTotal > -1 && detail.RecvSizeTotal+int64(h.Size) > detail.LimitRecvSizeTotal { + log.Printf("WARN: Message rejected: RecvSizeTotal would exceed LimitRecvSizeTotal %d", detail.LimitRecvSizeTotal) + return RejectCodeUserFull, nil + } + + return RejectCodeAccept, nil +} + +// uniqueFilepath generates a unique file path in the given directory, +// appending a counter suffix if the base name already exists. +func uniqueFilepath(dir string, timestamp uint32, ext string) string { + base := fmt.Sprintf("%d", timestamp) + fp := filepath.Join(dir, base+ext) + if _, err := os.Stat(fp); os.IsNotExist(err) { + return fp + } + for i := 1; ; i++ { + fp = filepath.Join(dir, fmt.Sprintf("%s_%d%s", base, i, ext)) + if _, err := os.Stat(fp); os.IsNotExist(err) { + return fp + } + } +} + +func localRecipients(h *FMsgHeader) []FMsgAddress { + addrs := make([]FMsgAddress, 0, len(h.To)+len(h.AddTo)) + for _, addr := range h.To { + if strings.EqualFold(addr.Domain, Domain) { + addrs = append(addrs, addr) + } + } + for _, addr := range h.AddTo { + if strings.EqualFold(addr.Domain, Domain) { + addrs = append(addrs, addr) + } + } + return addrs +} + +func allLocalRecipientsHaveMessageHash(msgHash []byte, addrs []FMsgAddress) (bool, error) { + if len(addrs) == 0 { + return false, nil + } + for i := range addrs { + duplicate, err := hasAddrReceivedMsgHash(msgHash, &addrs[i]) + if err != nil { + return false, err + } + if !duplicate { + return false, nil + } + } + return true, nil +} + +func markAllCodes(codes []byte, code uint8) { + for i := range codes { + codes[i] = code + } +} + +func prepareMessageData(r io.Reader, h *FMsgHeader, skipData bool) ([]string, error) { + if skipData { + parentID, err := lookupMsgIdByHash(h.Pid) + if err != nil { + return nil, err + } + if parentID == 0 { + return nil, fmt.Errorf("%w code 65 requires stored parent for pid %s", ErrProtocolViolation, hex.EncodeToString(h.Pid)) + } + parentMsg, err := getMsgByID(parentID) + if err != nil { + return nil, err + } + if parentMsg == nil || parentMsg.Filepath == "" { + return nil, fmt.Errorf("%w code 65 parent data unavailable for msg %d", ErrProtocolViolation, parentID) + } + h.Filepath = parentMsg.Filepath + return nil, nil + } + + createdPaths := make([]string, 0, 1+len(h.Attachments)) + + fd, err := os.CreateTemp("", "fmsg-download-*") + if err != nil { + return nil, err + } + + if _, err := io.CopyN(fd, r, int64(h.Size)); err != nil { + fd.Close() + _ = os.Remove(fd.Name()) + return nil, err + } + if err := fd.Close(); err != nil { + _ = os.Remove(fd.Name()) + return nil, err + } + + h.Filepath = fd.Name() + createdPaths = append(createdPaths, fd.Name()) + + for i := range h.Attachments { + afd, err := os.CreateTemp("", "fmsg-attachment-*") + if err != nil { + for _, path := range createdPaths { + _ = os.Remove(path) + } + return nil, err + } + + if _, err := io.CopyN(afd, r, int64(h.Attachments[i].Size)); err != nil { + afd.Close() + _ = os.Remove(afd.Name()) + for _, path := range createdPaths { + _ = os.Remove(path) + } + return nil, err + } + if err := afd.Close(); err != nil { + _ = os.Remove(afd.Name()) + for _, path := range createdPaths { + _ = os.Remove(path) + } + return nil, err + } + h.Attachments[i].Filepath = afd.Name() + createdPaths = append(createdPaths, afd.Name()) + } + + return createdPaths, nil +} + +func cleanupFiles(paths []string) { + for _, path := range paths { + if path == "" { + continue + } + _ = os.Remove(path) + } +} + +func copyMessagePayload(src *os.File, dstPath string, compressed bool, wireSize uint32) error { + if _, err := src.Seek(0, io.SeekStart); err != nil { + return err + } + + fd2, err := os.Create(dstPath) + if err != nil { + return err + } + + var copyErr error + if compressed { + lr := io.LimitReader(src, int64(wireSize)) + zr, err := zlib.NewReader(lr) + if err != nil { + fd2.Close() + _ = os.Remove(dstPath) + return err + } + _, copyErr = io.Copy(fd2, zr) + _ = zr.Close() + } else { + _, copyErr = io.CopyN(fd2, src, int64(wireSize)) + } + if err := fd2.Close(); err != nil { + return err + } + + if copyErr != nil { + _ = os.Remove(dstPath) + return copyErr + } + return nil +} + +func uniqueAttachmentPath(dir string, timestamp uint32, idx int, filename string) string { + ext := filepath.Ext(filename) + base := fmt.Sprintf("%d_att_%d", timestamp, idx) + p := filepath.Join(dir, base+ext) + if _, err := os.Stat(p); os.IsNotExist(err) { + return p + } + for n := 1; ; n++ { + p = filepath.Join(dir, fmt.Sprintf("%s_%d%s", base, n, ext)) + if _, err := os.Stat(p); os.IsNotExist(err) { + return p + } + } +} + +func persistAttachmentPayloads(h *FMsgHeader, dirpath string) error { + for i := range h.Attachments { + a := &h.Attachments[i] + src, err := os.Open(a.Filepath) + if err != nil { + return err + } + dstPath := uniqueAttachmentPath(dirpath, uint32(h.Timestamp), i, a.Filename) + compressed := a.Flags&(1<<1) != 0 + err = copyMessagePayload(src, dstPath, compressed, a.Size) + src.Close() + if err != nil { + return err + } + a.Filepath = dstPath + } + return nil +} + +func storeAcceptedMessage(h *FMsgHeader, codes []byte, acceptedTo []FMsgAddress, acceptedAddTo []FMsgAddress, primaryFilepath string) bool { + if len(acceptedTo) == 0 && len(acceptedAddTo) == 0 { + return false + } + + origTo := h.To + origAddTo := h.AddTo + h.To = acceptedTo + h.AddTo = acceptedAddTo + h.Filepath = primaryFilepath + if err := storeMsgDetail(h); err != nil { + log.Printf("ERROR: storing message: %s", err) + h.To = origTo + h.AddTo = origAddTo + for i := range codes { + if codes[i] == RejectCodeAccept { + codes[i] = RejectCodeUndisclosed + } + } + return false + } + + h.To = origTo + h.AddTo = origAddTo + allAccepted := append(acceptedTo, acceptedAddTo...) + for i := range allAccepted { + if err := postMsgStatRecv(&allAccepted[i], h.Timestamp, int(h.Size)); err != nil { + log.Printf("WARN: Failed to post msg recv stat: %s", err) + } + } + return true +} + +func downloadMessage(c net.Conn, r io.Reader, h *FMsgHeader, skipData bool) error { + addrs := localRecipients(h) + if len(addrs) == 0 { + return fmt.Errorf("%w our domain: %s, not in recipient list", ErrProtocolViolation, Domain) + } + codes := make([]byte, len(addrs)) + + createdPaths, err := prepareMessageData(r, h, skipData) + if err != nil { + return err + } + cleanupOnReturn := !skipData + defer func() { + if cleanupOnReturn { + cleanupFiles(createdPaths) + } + }() + + // verify hash matches challenge response when challenge was completed + msgHash, err := h.GetMessageHash() + if err != nil { + return err + } + if h.ChallengeCompleted && !bytes.Equal(h.ChallengeHash[:], msgHash) { + challengeHashStr := hex.EncodeToString(h.ChallengeHash[:]) + actualHashStr := hex.EncodeToString(msgHash) + return fmt.Errorf("%w actual hash: %s mismatch challenge response: %s", ErrProtocolViolation, actualHashStr, challengeHashStr) + } + + // pid/add-to validation is handled during header exchange in readHeader(). + + // determine file extension from MIME type + exts, _ := mime.ExtensionsByType(h.Type) + var ext string + if exts == nil { + ext = ".unknown" + } else { + ext = exts[0] + } + + src, err := os.Open(h.Filepath) + if err != nil { + return err + } + defer src.Close() + + // validate each recipient and copy message for accepted ones + // Build a set of add-to addresses for later classification + addToSet := make(map[string]bool) + for _, addr := range h.AddTo { + addToSet[strings.ToLower(addr.ToString())] = true + } + acceptedTo := []FMsgAddress{} + acceptedAddTo := []FMsgAddress{} + var primaryFilepath string + for i, addr := range addrs { + code, err := validateMsgRecvForAddr(h, &addr, msgHash) + if err != nil { + return err + } + if code != RejectCodeAccept { + log.Printf("WARN: Rejected message to: %s: %s (%d)", addr.ToString(), responseCodeName(code), code) + codes[i] = code + continue + } + + // copy to recipient's directory + dirpath := filepath.Join(DataDir, addr.Domain, addr.User, InboxDirName) + if err := os.MkdirAll(dirpath, 0750); err != nil { + return err + } + + fp := uniqueFilepath(dirpath, uint32(h.Timestamp), ext) + if err := copyMessagePayload(src, fp, h.Flags&FlagDeflate != 0, h.Size); err != nil { + log.Printf("ERROR: copying downloaded message from: %s, to: %s", h.Filepath, fp) + codes[i] = RejectCodeUndisclosed + continue + } + + codes[i] = RejectCodeAccept + if addToSet[strings.ToLower(addr.ToString())] { + acceptedAddTo = append(acceptedAddTo, addr) + } else { + acceptedTo = append(acceptedTo, addr) + } + if primaryFilepath == "" { + primaryFilepath = fp + if err := persistAttachmentPayloads(h, filepath.Dir(primaryFilepath)); err != nil { + log.Printf("ERROR: copying attachment payloads for message storage: %s", err) + codes[i] = RejectCodeUndisclosed + primaryFilepath = "" + acceptedTo = acceptedTo[:0] + acceptedAddTo = acceptedAddTo[:0] + continue + } + } + } + + stored := storeAcceptedMessage(h, codes, acceptedTo, acceptedAddTo, primaryFilepath) + if stored { + cleanupOnReturn = false + } + + return rejectAccept(c, codes) +} + +// resolvePostChallengeCode determines the initial response code to send after +// the optional challenge (§10.4). Code 11 (accept add-to) is returned as-is +// since it has no local recipients to duplicate-check. For the skip-data (65) +// and continue (64) paths, a completed challenge with all-local-duplicate +// produces code 10 (duplicate) instead. +func resolvePostChallengeCode(initialCode uint8, challengeCompleted bool, allLocalDup bool) uint8 { + if initialCode == AcceptCodeAddTo { + return AcceptCodeAddTo + } + if challengeCompleted && allLocalDup { + return RejectCodeDuplicate + } + if initialCode == AcceptCodeSkipData { + return AcceptCodeSkipData + } + return AcceptCodeContinue +} + +func abortConn(c net.Conn) { + if tcp, ok := c.(*net.TCPConn); ok { + tcp.SetLinger(0) + } + _ = c.Close() +} + +type responseTrackingConn struct { + net.Conn + wroteResponse bool +} + +func (c *responseTrackingConn) Write(b []byte) (int, error) { + n, err := c.Conn.Write(b) + if n > 0 { + c.wroteResponse = true + } + return n, err +} + +func handleConn(c net.Conn) { + defer func() { + if r := recover(); r != nil { + log.Printf("ERROR: Recovered in handleConn: %v", r) + } + }() + + log.Printf("INFO: Connection from: %s\n", c.RemoteAddr().String()) + tc := &responseTrackingConn{Conn: c} + + // read header + header, r, err := readHeader(tc) + if err != nil { + log.Printf("WARN: reading header from, %s: %s", c.RemoteAddr().String(), err) + if tc.wroteResponse { + _ = c.Close() + return + } + abortConn(c) + return + } + + // if no header AND no error this was a challenge thats been handeled + if header == nil { + c.Close() + return + } + + if err := challenge(tc, header, determineSenderDomain(header)); err != nil { + log.Printf("ERROR: Challenge failed to, %s: %s", c.RemoteAddr().String(), err) + abortConn(c) + return + } + + // §10.4: Determine initial response code after optional challenge. + // Code 11 (add-to, no local recipients) does not need a dup check. + // Codes 65 and 64 both require a dup check when challenge was completed. + allLocalDup := false + if header.ChallengeCompleted && header.InitialResponseCode != AcceptCodeAddTo { + addrs := localRecipients(header) + var err error + allLocalDup, err = allLocalRecipientsHaveMessageHash(header.ChallengeHash[:], addrs) + if err != nil { + log.Printf("ERROR: duplicate check failed for %s: %s", c.RemoteAddr().String(), err) + if err := sendCode(c, RejectCodeUndisclosed); err != nil { + abortConn(c) + return + } + _ = c.Close() + return + } + } + + code := resolvePostChallengeCode(header.InitialResponseCode, header.ChallengeCompleted, allLocalDup) + skipData := false + + switch code { + case AcceptCodeAddTo: + // No local add-to recipients; store header and respond code 11, close. + if err := storeMsgHeaderOnly(header); err != nil { + log.Printf("ERROR: storing add-to header: %s", err) + if err := sendCode(c, RejectCodeUndisclosed); err != nil { + abortConn(c) + return + } + _ = c.Close() + return + } + if err := sendCode(c, AcceptCodeAddTo); err != nil { + log.Printf("ERROR: failed sending code 11 to %s: %s", c.RemoteAddr().String(), err) + abortConn(c) + return + } + log.Printf("INFO: additional recipients received (code 11) for pid %s", hex.EncodeToString(header.Pid)) + c.Close() + return + case RejectCodeDuplicate: + if err := sendCode(c, RejectCodeDuplicate); err != nil { + log.Printf("ERROR: failed sending code 10 to %s: %s", c.RemoteAddr().String(), err) + } + c.Close() + return + case AcceptCodeSkipData: + if err := sendCode(c, AcceptCodeSkipData); err != nil { + log.Printf("ERROR: failed sending code 65 to %s: %s", c.RemoteAddr().String(), err) + abortConn(c) + return + } + skipData = true + log.Printf("INFO: sent code 65 (skip data) to %s", c.RemoteAddr().String()) + default: + if err := sendCode(c, AcceptCodeContinue); err != nil { + log.Printf("ERROR: failed sending code 64 to %s: %s", c.RemoteAddr().String(), err) + abortConn(c) + return + } + log.Printf("INFO: sent code 64 (continue) to %s", c.RemoteAddr().String()) + } + + // store message + deadlineBytes := int(header.Size) + if skipData { + deadlineBytes = 1 + } + c.SetReadDeadline(time.Now().Add(calcNetIODuration(deadlineBytes, MinDownloadRate))) + if err := downloadMessage(c, r, header, skipData); err != nil { + // if error was a protocal violation, abort; otherise let sender know there was an internal error + log.Printf("ERROR: Download failed from, %s: %s", c.RemoteAddr().String(), err) + if errors.Is(err, ErrProtocolViolation) { + abortConn(c) + return + } else { + _ = sendCode(c, RejectCodeUndisclosed) + } + } + + // gracefully close 1st connection + c.Close() +} + +func main() { + + initOutgoing() + + // load environment variables from .env file if present + if err := godotenv.Load(); err != nil { + log.Printf("INFO: Could not load .env file: %v", err) + } + + // read env config (must be after godotenv.Load) + loadEnvConfig() + + // determine listen address from args + listenAddress := "127.0.0.1" + for _, arg := range os.Args[1:] { + listenAddress = arg + } + + // initalize database + err := testDb() + if err != nil { + log.Fatalf("ERROR: connecting to database: %s\n", err) + } + + // set DataDir, Domain and IDURL from env + setDataDir() + setDomain() + setIDURL() + + // load TLS configuration (must be after loadEnvConfig for FMSG_TLS_INSECURE_SKIP_VERIFY) + serverTLSConfig = buildServerTLSConfig() + + // start sender in background (small delay so listener is ready first) + go func() { + time.Sleep(1 * time.Second) + startSender() + }() + + // start listening + addr := fmt.Sprintf("%s:%d", listenAddress, Port) + ln, err := tls.Listen("tcp", addr, serverTLSConfig) + if err != nil { + log.Fatal(err) + } + log.Printf("INFO: Ready to receive on %s\n", addr) + for { + conn, err := ln.Accept() + if err != nil { + log.Printf("ERROR: Accept connection from %s returned: %s\n", ln.Addr().String(), err) + } else { + go handleConn(conn) + } + } + +} diff --git a/cmd/fmsgd/host_test.go b/cmd/fmsgd/host_test.go new file mode 100644 index 0000000..48c54be --- /dev/null +++ b/cmd/fmsgd/host_test.go @@ -0,0 +1,622 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/binary" + "net" + "testing" + "time" +) + +type testAddr string + +func (a testAddr) Network() string { return "tcp" } +func (a testAddr) String() string { return string(a) } + +type testConn struct { + bytes.Buffer +} + +func (c *testConn) Read(b []byte) (int, error) { return 0, nil } +func (c *testConn) Write(b []byte) (int, error) { return c.Buffer.Write(b) } +func (c *testConn) Close() error { return nil } +func (c *testConn) LocalAddr() net.Addr { return testAddr("127.0.0.1:1000") } +func (c *testConn) RemoteAddr() net.Addr { return testAddr("127.0.0.1:2000") } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestIsValidUser(t *testing.T) { + valid := []string{"alice", "Bob", "a-b", "a_b", "a.b", "user123", "A", "u\u00f1icode", "\u7528\u62371"} + for _, u := range valid { + if !isValidUser(u) { + t.Errorf("isValidUser(%q) = false, want true", u) + } + } + + invalid := []string{"", " ", "a b", "a@b", "a/b", string(make([]byte, 65)), "-alice", "alice-", "a..b", "a-_b"} + for _, u := range invalid { + if isValidUser(u) { + t.Errorf("isValidUser(%q) = true, want false", u) + } + } +} + +func TestIsValidDomain(t *testing.T) { + valid := []string{"example.com", "a.b.c", "foo-bar.com", "localhost"} + for _, d := range valid { + if !isValidDomain(d) { + t.Errorf("isValidDomain(%q) = false, want true", d) + } + } + + invalid := []string{ + "", + "nodot", // no dots and not localhost + ".leading.dot", // empty label + "trailing.", // empty label + "-start.com", // label starts with hyphen + "end-.com", // label ends with hyphen + "has space.com", + } + for _, d := range invalid { + if isValidDomain(d) { + t.Errorf("isValidDomain(%q) = true, want false", d) + } + } +} + +func TestParseAddress(t *testing.T) { + tests := []struct { + input string + wantErr bool + user string + domain string + }{ + {"@alice@example.com", false, "alice", "example.com"}, + {"@Bob@EXAMPLE.COM", false, "Bob", "EXAMPLE.COM"}, + {"@a-b.c@x.y.z", false, "a-b.c", "x.y.z"}, + // errors + {"alice@example.com", true, "", ""}, // missing leading @ + {"@alice", true, "", ""}, // missing second @ + {"@", true, "", ""}, // too short + {"ab", true, "", ""}, // too short + {"@@example.com", true, "", ""}, // empty user + {"@alice@", true, "", ""}, // empty domain (not valid) + {"@alice@nodot", true, "", ""}, // domain with no dot (not localhost) + } + for _, tt := range tests { + addr, err := parseAddress([]byte(tt.input)) + if tt.wantErr { + if err == nil { + t.Errorf("parseAddress(%q) = nil error, want error", tt.input) + } + continue + } + if err != nil { + t.Errorf("parseAddress(%q) error: %v", tt.input, err) + continue + } + if addr.User != tt.user || addr.Domain != tt.domain { + t.Errorf("parseAddress(%q) = {%q, %q}, want {%q, %q}", tt.input, addr.User, addr.Domain, tt.user, tt.domain) + } + } +} + +func TestReadUInt8Slice(t *testing.T) { + // Build a buffer: uint8 length = 5, then "hello" + var buf bytes.Buffer + buf.WriteByte(5) + buf.WriteString("hello") + // Extra trailing bytes should not be consumed + buf.WriteString("extra") + + slice, err := ReadUInt8Slice(&buf) + if err != nil { + t.Fatalf("ReadUInt8Slice error: %v", err) + } + if string(slice) != "hello" { + t.Fatalf("ReadUInt8Slice = %q, want %q", string(slice), "hello") + } + // "extra" should remain + rest := make([]byte, 5) + n, _ := buf.Read(rest) + if string(rest[:n]) != "extra" { + t.Fatalf("remaining bytes = %q, want %q", string(rest[:n]), "extra") + } +} + +func TestReadUInt8SliceEmpty(t *testing.T) { + var buf bytes.Buffer + buf.WriteByte(0) // zero-length slice + + slice, err := ReadUInt8Slice(&buf) + if err != nil { + t.Fatalf("ReadUInt8Slice error: %v", err) + } + if len(slice) != 0 { + t.Fatalf("expected empty slice, got len %d", len(slice)) + } +} + +func TestCalcNetIODuration(t *testing.T) { + // Small sizes should return MinNetIODeadline + d := calcNetIODuration(100, 5000) + if d < MinNetIODeadline { + t.Fatalf("calcNetIODuration(100, 5000) = %v, want >= %v", d, MinNetIODeadline) + } + + // Large sizes should exceed MinNetIODeadline + d = calcNetIODuration(1_000_000, 5000) + expected := time.Duration(float64(1_000_000) / 5000 * float64(time.Second)) // 200s + if d != expected { + t.Fatalf("calcNetIODuration(1000000, 5000) = %v, want %v", d, expected) + } +} + +func TestResponseCodeName(t *testing.T) { + tests := []struct { + code uint8 + want string + }{ + {RejectCodeInvalid, "invalid"}, + {RejectCodeUnsupportedVersion, "unsupported version"}, + {RejectCodeUndisclosed, "undisclosed"}, + {RejectCodeTooBig, "too big"}, + {RejectCodeInsufficentResources, "insufficient resources"}, + {RejectCodeParentNotFound, "parent not found"}, + {RejectCodePastTime, "past time"}, + {RejectCodeFutureTime, "future time"}, + {RejectCodeTimeTravel, "time travel"}, + {RejectCodeDuplicate, "duplicate"}, + {AcceptCodeAddTo, "accept add to"}, + {RejectCodeUserUnknown, "user unknown"}, + {RejectCodeUserFull, "user full"}, + {RejectCodeUserNotAccepting, "user not accepting"}, + {RejectCodeUserDuplicate, "user duplicate"}, + {RejectCodeUserUndisclosed, "user undisclosed"}, + {RejectCodeAccept, "accept"}, + {99, "unknown(99)"}, + } + for _, tt := range tests { + got := responseCodeName(tt.code) + if got != tt.want { + t.Errorf("responseCodeName(%d) = %q, want %q", tt.code, got, tt.want) + } + } +} + +func TestPerRecipientDuplicateAndUndisclosedCodeValues(t *testing.T) { + if RejectCodeUserDuplicate != 103 { + t.Fatalf("RejectCodeUserDuplicate = %d, want 103", RejectCodeUserDuplicate) + } + if RejectCodeUserUndisclosed != 105 { + t.Fatalf("RejectCodeUserUndisclosed = %d, want 105", RejectCodeUserUndisclosed) + } +} + +func TestFlagConstants(t *testing.T) { + // Verify flag bit assignments match SPEC.md + if FlagHasPid != 1 { + t.Errorf("FlagHasPid = %d, want 1 (bit 0)", FlagHasPid) + } + if FlagHasAddTo != 2 { + t.Errorf("FlagHasAddTo = %d, want 2 (bit 1)", FlagHasAddTo) + } + if FlagCommonType != 4 { + t.Errorf("FlagCommonType = %d, want 4 (bit 2)", FlagCommonType) + } + if FlagImportant != 8 { + t.Errorf("FlagImportant = %d, want 8 (bit 3)", FlagImportant) + } + if FlagNoReply != 16 { + t.Errorf("FlagNoReply = %d, want 16 (bit 4)", FlagNoReply) + } + if FlagDeflate != 32 { + t.Errorf("FlagDeflate = %d, want 32 (bit 5)", FlagDeflate) + } +} + +func encodeUInt8String(t *testing.T, s string) []byte { + t.Helper() + if len(s) > 255 { + t.Fatalf("string too long for uint8 prefix: %d", len(s)) + } + b := []byte{byte(len(s))} + b = append(b, []byte(s)...) + return b +} + +func TestHasDomainRecipient(t *testing.T) { + addrs := []FMsgAddress{ + {User: "alice", Domain: "example.com"}, + {User: "bob", Domain: "other.org"}, + } + if !hasDomainRecipient(addrs, "EXAMPLE.COM") { + t.Fatalf("expected domain match") + } + if hasDomainRecipient(addrs, "missing.test") { + t.Fatalf("did not expect domain match") + } +} + +func TestDetermineSenderDomain(t *testing.T) { + h := &FMsgHeader{ + From: FMsgAddress{User: "alice", Domain: "from.example"}, + } + if got := determineSenderDomain(h); got != "from.example" { + t.Fatalf("determineSenderDomain() = %q, want %q", got, "from.example") + } + + h.AddTo = []FMsgAddress{{User: "new", Domain: "to.example"}} + h.AddToFrom = &FMsgAddress{User: "bob", Domain: "sender.example"} + if got := determineSenderDomain(h); got != "sender.example" { + t.Fatalf("determineSenderDomain() = %q, want %q", got, "sender.example") + } +} + +func TestReadToRecipients(t *testing.T) { + b := []byte{2} + b = append(b, encodeUInt8String(t, "@alice@example.com")...) + b = append(b, encodeUInt8String(t, "@bob@example.com")...) + + h := &FMsgHeader{} + seen, err := readToRecipients(nil, bufio.NewReader(bytes.NewReader(b)), h) + if err != nil { + t.Fatalf("readToRecipients returned error: %v", err) + } + if len(h.To) != 2 { + t.Fatalf("len(h.To) = %d, want 2", len(h.To)) + } + if !seen["@alice@example.com"] || !seen["@bob@example.com"] { + t.Fatalf("seen map missing expected recipients: %#v", seen) + } +} + +func TestReadAddToRecipients(t *testing.T) { + h := &FMsgHeader{ + Flags: FlagHasPid | FlagHasAddTo, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "example.com"}}, + } + seen := map[string]bool{"@bob@example.com": true} + + b := []byte{} + b = append(b, encodeUInt8String(t, "@alice@example.com")...) // add-to-from + b = append(b, 1) // add-to count + b = append(b, encodeUInt8String(t, "@carol@example.com")...) + + err := readAddToRecipients(nil, bufio.NewReader(bytes.NewReader(b)), h, seen) + if err != nil { + t.Fatalf("readAddToRecipients returned error: %v", err) + } + if h.AddToFrom == nil || h.AddToFrom.ToString() != "@alice@example.com" { + t.Fatalf("unexpected AddToFrom: %+v", h.AddToFrom) + } + if len(h.AddTo) != 1 || h.AddTo[0].ToString() != "@carol@example.com" { + t.Fatalf("unexpected AddTo: %+v", h.AddTo) + } +} + +func TestReadMessageType(t *testing.T) { + hCommon := &FMsgHeader{Flags: FlagCommonType} + if err := readMessageType(nil, bufio.NewReader(bytes.NewReader([]byte{3})), hCommon); err != nil { + t.Fatalf("readMessageType(common) error: %v", err) + } + if hCommon.TypeID != 3 { + t.Fatalf("common type ID = %d, want 3", hCommon.TypeID) + } + if hCommon.Type != "application/json" { + t.Fatalf("common type = %q, want %q", hCommon.Type, "application/json") + } + + hText := &FMsgHeader{Flags: 0} + b := encodeUInt8String(t, "text/plain") + if err := readMessageType(nil, bufio.NewReader(bytes.NewReader(b)), hText); err != nil { + t.Fatalf("readMessageType(string) error: %v", err) + } + if hText.Type != "text/plain" { + t.Fatalf("string type = %q, want %q", hText.Type, "text/plain") + } +} + +func TestReadAttachmentHeaders(t *testing.T) { + origMax := MaxMessageSize + MaxMessageSize = 1024 + t.Cleanup(func() { + MaxMessageSize = origMax + }) + + h := &FMsgHeader{Size: 10} + b := []byte{1} // attachment count + b = append(b, 0) // attachment flags (no common type) + b = append(b, encodeUInt8String(t, "text/plain")...) + b = append(b, encodeUInt8String(t, "file.txt")...) + + var sz [4]byte + binary.LittleEndian.PutUint32(sz[:], 12) + b = append(b, sz[:]...) + + err := readAttachmentHeaders(nil, bufio.NewReader(bytes.NewReader(b)), h) + if err != nil { + t.Fatalf("readAttachmentHeaders returned error: %v", err) + } + if len(h.Attachments) != 1 { + t.Fatalf("len(h.Attachments) = %d, want 1", len(h.Attachments)) + } + att := h.Attachments[0] + if att.TypeID != 0 { + t.Fatalf("attachment type ID = %d, want 0 for non-common", att.TypeID) + } + if att.Type != "text/plain" || att.Filename != "file.txt" || att.Size != 12 { + t.Fatalf("unexpected attachment parsed: %+v", att) + } +} + +func TestReadAddToRecipientsRejectsWhenPidMissing(t *testing.T) { + h := &FMsgHeader{Flags: FlagHasAddTo} + c := &testConn{} + + err := readAddToRecipients(c, bufio.NewReader(bytes.NewReader(nil)), h, map[string]bool{}) + if err == nil { + t.Fatalf("expected error when add-to flag is set without pid") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { + t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) + } +} + +func TestReadAddToRecipientsRejectsDuplicateAddTo(t *testing.T) { + h := &FMsgHeader{ + Flags: FlagHasPid | FlagHasAddTo, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "example.com"}}, + } + c := &testConn{} + seen := map[string]bool{"@bob@example.com": true} + + b := []byte{} + b = append(b, encodeUInt8String(t, "@alice@example.com")...) // add-to-from + b = append(b, 2) // add-to count + b = append(b, encodeUInt8String(t, "@carol@example.com")...) + b = append(b, encodeUInt8String(t, "@carol@example.com")...) + + err := readAddToRecipients(c, bufio.NewReader(bytes.NewReader(b)), h, seen) + if err == nil { + t.Fatalf("expected duplicate add-to error") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { + t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) + } +} + +func TestReadMessageTypeRejectsUnknownCommonType(t *testing.T) { + h := &FMsgHeader{Flags: FlagCommonType} + c := &testConn{} + + err := readMessageType(c, bufio.NewReader(bytes.NewReader([]byte{200})), h) + if err == nil { + t.Fatalf("expected error for unknown common type") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { + t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) + } +} + +func TestReadMessageTypeRejectsNonASCIIStringType(t *testing.T) { + h := &FMsgHeader{Flags: 0} + c := &testConn{} + + b := encodeUInt8String(t, "text/\u03c0lain") + err := readMessageType(c, bufio.NewReader(bytes.NewReader(b)), h) + if err == nil { + t.Fatalf("expected error for non-ASCII message type") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { + t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) + } +} + +func TestReadAttachmentTypeRejectsNonASCIIStringType(t *testing.T) { + c := &testConn{} + b := encodeUInt8String(t, "text/\u03c0lain") + + _, _, err := readAttachmentType(c, bufio.NewReader(bytes.NewReader(b)), 0) + if err == nil { + t.Fatalf("expected error for non-ASCII attachment type") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { + t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) + } +} + +func TestReadAttachmentHeadersRejectsInvalidFilename(t *testing.T) { + origMax := MaxMessageSize + MaxMessageSize = 1024 + t.Cleanup(func() { + MaxMessageSize = origMax + }) + + h := &FMsgHeader{Size: 10} + c := &testConn{} + b := []byte{1} + b = append(b, 0) + b = append(b, encodeUInt8String(t, "text/plain")...) + b = append(b, encodeUInt8String(t, "bad..name")...) // invalid: consecutive special chars + + var sz [4]byte + binary.LittleEndian.PutUint32(sz[:], 12) + b = append(b, sz[:]...) + + err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) + if err == nil { + t.Fatalf("expected error for invalid attachment filename") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { + t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) + } +} + +func TestReadAttachmentHeadersRejectsTooBig(t *testing.T) { + origMax := MaxMessageSize + MaxMessageSize = 20 + t.Cleanup(func() { + MaxMessageSize = origMax + }) + + h := &FMsgHeader{Size: 15} + c := &testConn{} + b := []byte{1} + b = append(b, 0) + b = append(b, encodeUInt8String(t, "text/plain")...) + b = append(b, encodeUInt8String(t, "file.txt")...) + + var sz [4]byte + binary.LittleEndian.PutUint32(sz[:], 10) + b = append(b, sz[:]...) + + err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) + if err == nil { + t.Fatalf("expected size overflow error") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeTooBig { + t.Fatalf("expected reject code %d, got %v", RejectCodeTooBig, got) + } +} + +func TestValidateMessageFlagsRejectsReservedBits(t *testing.T) { + c := &testConn{} + err := validateMessageFlags(c, 1<<6) + if err == nil { + t.Fatalf("expected error for reserved message flag bit") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { + t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) + } +} + +func TestReadAttachmentHeadersRejectsReservedAttachmentBits(t *testing.T) { + h := &FMsgHeader{Size: 0} + c := &testConn{} + + // attachment count=1, then attachment flags with reserved bit 2 set + b := []byte{1, 1 << 2} + err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) + if err == nil { + t.Fatalf("expected error for reserved attachment flag bits") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeInvalid { + t.Fatalf("expected reject code %d, got %v", RejectCodeInvalid, got) + } +} + +func TestResolvePostChallengeCode(t *testing.T) { + tests := []struct { + name string + initialCode uint8 + challengeCompleted bool + allLocalDup bool + want uint8 + }{ + // Add-to (code 11) path — never overridden by dup check. + {"add-to no challenge", AcceptCodeAddTo, false, false, AcceptCodeAddTo}, + {"add-to challenge no dup", AcceptCodeAddTo, true, false, AcceptCodeAddTo}, + {"add-to challenge all dup", AcceptCodeAddTo, true, true, AcceptCodeAddTo}, + + // Continue (code 64) path — dup check yields code 10 when all dup. + {"continue no challenge", AcceptCodeContinue, false, false, AcceptCodeContinue}, + {"continue challenge no dup", AcceptCodeContinue, true, false, AcceptCodeContinue}, + {"continue challenge all dup", AcceptCodeContinue, true, true, RejectCodeDuplicate}, + + // Skip-data (code 65) path — dup check yields code 10 when all dup. + {"skip-data no challenge", AcceptCodeSkipData, false, false, AcceptCodeSkipData}, + {"skip-data challenge no dup", AcceptCodeSkipData, true, false, AcceptCodeSkipData}, + {"skip-data challenge all dup", AcceptCodeSkipData, true, true, RejectCodeDuplicate}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resolvePostChallengeCode(tt.initialCode, tt.challengeCompleted, tt.allLocalDup) + if got != tt.want { + t.Errorf("resolvePostChallengeCode(%d, %v, %v) = %d (%s), want %d (%s)", + tt.initialCode, tt.challengeCompleted, tt.allLocalDup, + got, responseCodeName(got), tt.want, responseCodeName(tt.want)) + } + }) + } +} + +func TestReadAttachmentHeadersReadsExpandedSizeForCompressedAttachment(t *testing.T) { + origMax := MaxMessageSize + origExpanded := MaxExpandedSize + MaxMessageSize = 1024 + MaxExpandedSize = 1024 + t.Cleanup(func() { + MaxMessageSize = origMax + MaxExpandedSize = origExpanded + }) + + h := &FMsgHeader{Size: 0} + b := []byte{1} // 1 attachment + b = append(b, 1<<1) // attachment flags: zlib-deflate (bit 1) + b = append(b, encodeUInt8String(t, "text/plain")...) + b = append(b, encodeUInt8String(t, "file.txt")...) + + var wireSize [4]byte + binary.LittleEndian.PutUint32(wireSize[:], 50) + b = append(b, wireSize[:]...) + + var expandedSize [4]byte + binary.LittleEndian.PutUint32(expandedSize[:], 200) + b = append(b, expandedSize[:]...) + + err := readAttachmentHeaders(nil, bufio.NewReader(bytes.NewReader(b)), h) + if err != nil { + t.Fatalf("readAttachmentHeaders returned error: %v", err) + } + if len(h.Attachments) != 1 { + t.Fatalf("len(h.Attachments) = %d, want 1", len(h.Attachments)) + } + att := h.Attachments[0] + if att.Size != 50 { + t.Fatalf("att.Size = %d, want 50", att.Size) + } + if att.ExpandedSize != 200 { + t.Fatalf("att.ExpandedSize = %d, want 200", att.ExpandedSize) + } +} + +func TestReadAttachmentHeadersRejectsExpandedSizeExceedsMax(t *testing.T) { + origMax := MaxMessageSize + origExpanded := MaxExpandedSize + MaxMessageSize = 1024 + MaxExpandedSize = 100 + t.Cleanup(func() { + MaxMessageSize = origMax + MaxExpandedSize = origExpanded + }) + + h := &FMsgHeader{Size: 0} + c := &testConn{} + b := []byte{1} // 1 attachment + b = append(b, 1<<1) // attachment flags: zlib-deflate (bit 1) + b = append(b, encodeUInt8String(t, "text/plain")...) + b = append(b, encodeUInt8String(t, "file.txt")...) + + var wireSize [4]byte + binary.LittleEndian.PutUint32(wireSize[:], 50) + b = append(b, wireSize[:]...) + + // expanded size exceeds MaxExpandedSize=100 + var expandedSize [4]byte + binary.LittleEndian.PutUint32(expandedSize[:], 200) + b = append(b, expandedSize[:]...) + + err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) + if err == nil { + t.Fatalf("expected error when expanded size exceeds max") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeTooBig { + t.Fatalf("expected reject code %d, got %v", RejectCodeTooBig, got) + } +} diff --git a/cmd/fmsgd/id.go b/cmd/fmsgd/id.go new file mode 100644 index 0000000..49c22e2 --- /dev/null +++ b/cmd/fmsgd/id.go @@ -0,0 +1,92 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +type AddressDetail struct { + Address string `json:"address"` + DisplayName string `json:"displayName"` + AcceptingNew bool `json:"acceptingNew"` + LimitRecvSizeTotal int64 `json:"limitRecvSizeTotal"` + LimitRecvSizePerMsg int64 `json:"limitRecvSizePerMsg"` + LimitRecvSizePer1d int64 `json:"limitRecvSizePer1d"` + LimitRecvCountPer1d int64 `json:"limitRecvCountPer1d"` + LimitSendSizeTotal int64 `json:"limitSendSizeTotal"` + LimitSendSizePerMsg int64 `json:"limitSendSizePerMsg"` + LimitSendSizePer1d int64 `json:"limitSendSizePer1d"` + LimitSendCountPer1d int64 `json:"limitSendCountPer1d"` + RecvSizeTotal int64 `json:"recvSizeTotal"` + RecvSizePer1d int64 `json:"recvSizePer1d"` + RecvCountPer1d int64 `json:"recvCountPer1d"` + SendSizeTotal int64 `json:"sendSizeTotal"` + SendSizePer1d int64 `json:"sendSizePer1d"` + SendCountPer1d int64 `json:"sendCountPer1d"` + Tags []string `json:"tags"` +} + +// Returns pointer to an AddressDetail populated by querying fmsg Id standard at FMSG_ID_URL for +// address supplied. If the address is not found returns nil, nil. +func getAddressDetail(addr *FMsgAddress) (*AddressDetail, error) { + uri := IDURI + "/fmsgid/" + url.PathEscape(addr.ToString()) + resp, err := http.Get(uri) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, nil + } + + var detail AddressDetail + err = json.NewDecoder(resp.Body).Decode(&detail) + if err != nil { + return nil, err + } + + return &detail, nil +} + +func postMsgStatSend(addr *FMsgAddress, timestamp float64, size int) error { + return postMsgStat(addr, timestamp, size, true) +} + +func postMsgStatRecv(addr *FMsgAddress, timestamp float64, size int) error { + return postMsgStat(addr, timestamp, size, false) +} + +func postMsgStat(addr *FMsgAddress, timestamp float64, size int, isSending bool) error { + var part string + if isSending { + part = "send" + } else { + part = "recv" + } + uri := fmt.Sprintf("%s/fmsgid/%s", IDURI, part) + + payload := map[string]interface{}{ + "address": addr.ToString(), + "ts": timestamp, + "size": size} + jsonPayload, err := json.Marshal(payload) + if err != nil { + return err + } + + resp, err := http.Post(uri, "application/json", bytes.NewBuffer(jsonPayload)) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("POST %s returned %d", uri, resp.StatusCode) + } + + return nil +} diff --git a/cmd/fmsgd/outgoing.go b/cmd/fmsgd/outgoing.go new file mode 100644 index 0000000..242811f --- /dev/null +++ b/cmd/fmsgd/outgoing.go @@ -0,0 +1,66 @@ +package main + +import "sync" + +// outgoingEntry tracks an in-flight outgoing message header together with +// the set of Host-B IPs currently being serviced for that message. +// The IP set is used to validate incoming challenges (§10.5 step 2). +type outgoingEntry struct { + header *FMsgHeader + ips map[string]struct{} +} + +// outgoingMap indexes in-flight message headers by their header hash. +// All access is synchronised via outgoingMu. +var outgoingMap map[[32]byte]*outgoingEntry +var outgoingMu sync.RWMutex + +func initOutgoing() { + outgoingMap = make(map[[32]byte]*outgoingEntry) +} + +// registerOutgoing records hash → (header, ip) so challenge handlers can look +// it up. Multiple IPs may be registered for the same hash when the same message +// is being concurrently delivered to different domains (§10.2 step 2). +func registerOutgoing(hash [32]byte, h *FMsgHeader, ip string) { + outgoingMu.Lock() + e, ok := outgoingMap[hash] + if !ok { + e = &outgoingEntry{header: h, ips: make(map[string]struct{})} + outgoingMap[hash] = e + } + e.ips[ip] = struct{}{} + outgoingMu.Unlock() +} + +// lookupOutgoing returns the header for hash iff ip is a registered Host-B IP +// for that entry. Returns (nil, false) if the hash is unknown or ip is not in +// the registered set (§10.5 step 2). +func lookupOutgoing(hash [32]byte, ip string) (*FMsgHeader, bool) { + outgoingMu.RLock() + e, ok := outgoingMap[hash] + if !ok { + outgoingMu.RUnlock() + return nil, false + } + _, ipOK := e.ips[ip] + h := e.header + outgoingMu.RUnlock() + if !ipOK { + return nil, false + } + return h, true +} + +// removeOutgoingIP removes ip from the entry's IP set. When the set becomes +// empty the map entry is deleted entirely (§10.2 step 7). +func removeOutgoingIP(hash [32]byte, ip string) { + outgoingMu.Lock() + if e, ok := outgoingMap[hash]; ok { + delete(e.ips, ip) + if len(e.ips) == 0 { + delete(outgoingMap, hash) + } + } + outgoingMu.Unlock() +} diff --git a/cmd/fmsgd/outgoing_test.go b/cmd/fmsgd/outgoing_test.go new file mode 100644 index 0000000..af75412 --- /dev/null +++ b/cmd/fmsgd/outgoing_test.go @@ -0,0 +1,142 @@ +package main + +import ( + "testing" +) + +func TestOutgoingMapOperations(t *testing.T) { + initOutgoing() + + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "alice", Domain: "a.com"}, + To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, + Topic: "test", + Type: "text/plain", + } + + var hash [32]byte + copy(hash[:], h.GetHeaderHash()) + + const ip = "1.2.3.4" + + // Lookup before register should fail + _, ok := lookupOutgoing(hash, ip) + if ok { + t.Fatal("lookupOutgoing found entry before register") + } + + // Register + registerOutgoing(hash, h, ip) + + // Lookup with correct IP should succeed + got, ok := lookupOutgoing(hash, ip) + if !ok { + t.Fatal("lookupOutgoing failed after register") + } + if got != h { + t.Fatal("lookupOutgoing returned different pointer") + } + + // Lookup with wrong IP should fail + _, ok = lookupOutgoing(hash, "9.9.9.9") + if ok { + t.Fatal("lookupOutgoing succeeded with wrong IP") + } + + // Remove IP — entry should be gone + removeOutgoingIP(hash, ip) + + _, ok = lookupOutgoing(hash, ip) + if ok { + t.Fatal("lookupOutgoing found entry after removeOutgoingIP") + } +} + +func TestOutgoingMapMultipleIPs(t *testing.T) { + initOutgoing() + + h := &FMsgHeader{ + Version: 1, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 1.0, + Type: "text/plain", + } + + var hash [32]byte + copy(hash[:], h.GetHeaderHash()) + + registerOutgoing(hash, h, "1.1.1.1") + registerOutgoing(hash, h, "2.2.2.2") + + // Both IPs should resolve + for _, ip := range []string{"1.1.1.1", "2.2.2.2"} { + if _, ok := lookupOutgoing(hash, ip); !ok { + t.Errorf("expected lookup to succeed for IP %s", ip) + } + } + + // Removing first IP still leaves entry for second + removeOutgoingIP(hash, "1.1.1.1") + if _, ok := lookupOutgoing(hash, "1.1.1.1"); ok { + t.Error("1.1.1.1 still present after remove") + } + if _, ok := lookupOutgoing(hash, "2.2.2.2"); !ok { + t.Error("2.2.2.2 missing after removing 1.1.1.1") + } + + // Removing last IP deletes the entry + removeOutgoingIP(hash, "2.2.2.2") + if _, ok := lookupOutgoing(hash, "2.2.2.2"); ok { + t.Error("entry still present after removing last IP") + } +} + +func TestOutgoingMapMultipleEntries(t *testing.T) { + initOutgoing() + + h1 := &FMsgHeader{ + Version: 1, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 1.0, + Type: "text/plain", + } + h2 := &FMsgHeader{ + Version: 1, + From: FMsgAddress{User: "x", Domain: "y.com"}, + To: []FMsgAddress{{User: "z", Domain: "w.com"}}, + Timestamp: 2.0, + Type: "text/plain", + } + + var hash1, hash2 [32]byte + copy(hash1[:], h1.GetHeaderHash()) + copy(hash2[:], h2.GetHeaderHash()) + + registerOutgoing(hash1, h1, "1.1.1.1") + registerOutgoing(hash2, h2, "2.2.2.2") + + got1, ok1 := lookupOutgoing(hash1, "1.1.1.1") + got2, ok2 := lookupOutgoing(hash2, "2.2.2.2") + + if !ok1 || got1 != h1 { + t.Error("failed to look up h1") + } + if !ok2 || got2 != h2 { + t.Error("failed to look up h2") + } + + // Remove one, other should remain + removeOutgoingIP(hash1, "1.1.1.1") + _, ok1 = lookupOutgoing(hash1, "1.1.1.1") + _, ok2 = lookupOutgoing(hash2, "2.2.2.2") + if ok1 { + t.Error("h1 still present after remove") + } + if !ok2 { + t.Error("h2 missing after removing h1") + } +} diff --git a/cmd/fmsgd/sender.go b/cmd/fmsgd/sender.go new file mode 100644 index 0000000..a33a1e8 --- /dev/null +++ b/cmd/fmsgd/sender.go @@ -0,0 +1,628 @@ +package main + +import ( + "crypto/tls" + "database/sql" + "encoding/hex" + "fmt" + "io" + "log" + "net" + "os" + "strings" + "time" + + env "github.com/caitlinelfring/go-env-default" + "github.com/levenlabs/golib/timeutil" + "github.com/lib/pq" +) + +var RetryInterval float64 = 20 +var RetryMaxAge float64 = 86400 +var PollInterval = 10 +var MaxConcurrentSend = 1024 + +// localResponseCodeNoResponse is stored only in the database; it is not an fmsg protocol response code. +const localResponseCodeNoResponse = -1 + +var retryableResponseCodes = []int16{ + int16(localResponseCodeNoResponse), + int16(RejectCodeUndisclosed), + int16(RejectCodeInsufficentResources), + int16(RejectCodeUserFull), +} + +func loadSenderEnvConfig() { + RetryInterval = env.GetFloatDefault("FMSG_RETRY_INTERVAL", 20) + RetryMaxAge = env.GetFloatDefault("FMSG_RETRY_MAX_AGE", 86400) + PollInterval = env.GetIntDefault("FMSG_POLL_INTERVAL", 10) + MaxConcurrentSend = env.GetIntDefault("FMSG_MAX_CONCURRENT_SEND", 1024) +} + +// pendingTarget identifies a (message, domain) pair that needs delivery. +type pendingTarget struct { + MsgID int64 + Domain string +} + +// findPendingTargets discovers (msg_id, domain) pairs with undelivered, +// retryable recipients. This is a lightweight read-only query — row-level +// locks are acquired per-delivery in deliverMessage. +func findPendingTargets() ([]pendingTarget, error) { + db, err := sql.Open("postgres", "") + if err != nil { + return nil, err + } + defer db.Close() + + now := timeutil.TimestampNow().Float64() + + // query both msg_to and msg_add_to for pending targets + rows, err := db.Query(` + SELECT mt.msg_id, mt.addr + FROM msg_to mt + INNER JOIN msg m ON m.id = mt.msg_id + WHERE mt.time_delivered IS NULL + AND m.time_sent IS NOT NULL + AND (mt.response_code IS NULL OR mt.response_code = ANY($4)) + AND (mt.time_last_attempt IS NULL OR ($1 - mt.time_last_attempt) > LEAST($2 * POWER(2.0, GREATEST(mt.attempt_count - 1, 0)::float), $3)) + AND ($1 - m.time_sent) < $3 + UNION ALL + SELECT mat.msg_id, mat.addr + FROM msg_add_to mat + INNER JOIN msg m ON m.id = mat.msg_id + WHERE mat.time_delivered IS NULL + AND m.time_sent IS NOT NULL + AND (mat.response_code IS NULL OR mat.response_code = ANY($4)) + AND (mat.time_last_attempt IS NULL OR ($1 - mat.time_last_attempt) > LEAST($2 * POWER(2.0, GREATEST(mat.attempt_count - 1, 0)::float), $3)) + AND ($1 - m.time_sent) < $3 + `, now, RetryInterval, RetryMaxAge, pq.Array(retryableResponseCodes)) + if err != nil { + return nil, err + } + defer rows.Close() + + type key struct { + msgID int64 + domain string + } + seen := make(map[key]bool) + var targets []pendingTarget + + for rows.Next() { + var msgID int64 + var addr string + if err := rows.Scan(&msgID, &addr); err != nil { + return nil, err + } + lastAt := strings.LastIndex(addr, "@") + if lastAt == -1 { + continue + } + domain := addr[lastAt+1:] + if strings.EqualFold(domain, Domain) { + continue // local domain — no remote delivery needed + } + k := key{msgID, domain} + if !seen[k] { + seen[k] = true + targets = append(targets, pendingTarget{MsgID: msgID, Domain: domain}) + } + } + return targets, rows.Err() +} + +// sendMsgData transmits the message body then all attachment payloads on conn. +func sendMsgData(conn net.Conn, h *FMsgHeader) error { + fd, err := os.Open(h.Filepath) + if err != nil { + return fmt.Errorf("opening data file %s: %w", h.Filepath, err) + } + defer fd.Close() + + conn.SetWriteDeadline(time.Now().Add(calcNetIODuration(int(h.Size), MinUploadRate))) + if _, err := io.CopyN(conn, fd, int64(h.Size)); err != nil { + return fmt.Errorf("sending data: %w", err) + } + for _, att := range h.Attachments { + af, err := os.Open(att.Filepath) + if err != nil { + return fmt.Errorf("opening attachment %s: %w", att.Filename, err) + } + _, copyErr := io.CopyN(conn, af, int64(att.Size)) + af.Close() + if copyErr != nil { + return fmt.Errorf("sending attachment %s: %w", att.Filename, copyErr) + } + } + return nil +} + +// updateRecipient records a delivery outcome for one address in table. +// Deliveries set time_delivered; failures set time_last_attempt and increment +// attempt_count to drive exponential back-off on subsequent retries. +func updateRecipient(tx *sql.Tx, table, addr string, msgID int64, now float64, code int, delivered bool) { + var err error + if delivered { + _, err = tx.Exec(fmt.Sprintf(` + UPDATE %s SET time_delivered = $1, response_code = $2 + WHERE msg_id = $3 AND addr = $4 + `, table), now, code, msgID, addr) + } else { + _, err = tx.Exec(fmt.Sprintf(` + UPDATE %s SET time_last_attempt = $1, response_code = $2, + attempt_count = attempt_count + 1 + WHERE msg_id = $3 AND addr = $4 + `, table), now, code, msgID, addr) + } + if err != nil { + log.Printf("ERROR: sender: update recipient %s: %s", addr, err) + } +} + +// updateAllLocked applies the same outcome to every locked to and add-to address. +func updateAllLocked(tx *sql.Tx, lockedAddrs, lockedAddToAddrs []string, msgID int64, now float64, code int, delivered bool) { + for _, a := range lockedAddrs { + updateRecipient(tx, "msg_to", a, msgID, now, code, delivered) + } + for _, a := range lockedAddToAddrs { + updateRecipient(tx, "msg_add_to", a, msgID, now, code, delivered) + } +} + +// commitOrLog commits the transaction and marks it as committed. +func commitOrLog(tx *sql.Tx, committed *bool, msgID int64) { + if err := tx.Commit(); err != nil { + log.Printf("ERROR: sender: commit tx for msg %d: %s", msgID, err) + } else { + *committed = true + } +} + +func recordRetryableFailure(tx *sql.Tx, committed *bool, lockedAddrs, lockedAddToAddrs []string, msgID int64) { + now := timeutil.TimestampNow().Float64() + updateAllLocked(tx, lockedAddrs, lockedAddToAddrs, msgID, now, localResponseCodeNoResponse, false) + commitOrLog(tx, committed, msgID) +} + +// deliverMessage handles delivery of a single message to a single remote domain. +// +// It manages its own database transaction with the following lifecycle: +// - Locks the pending msg_to rows for this (message, domain) via FOR UPDATE SKIP LOCKED. +// - Loads the full message including ALL recipients (for the original wire header). +// - Sends the complete original message to the remote host. +// - On success: updates time_delivered + response_code, commits. +// - On rejection (got response code): updates response_code + time_last_attempt, commits. +// - On early delivery error: records a retryable failure, commits, and backs off. +func deliverMessage(target pendingTarget) { + if strings.EqualFold(target.Domain, Domain) { + // local domain — mark as delivered rather than sending remotely + db, err := sql.Open("postgres", "") + if err != nil { + log.Printf("ERROR: sender: db open for local delivery: %s", err) + return + } + defer db.Close() + now := timeutil.TimestampNow().Float64() + if _, err := db.Exec(` + UPDATE msg_to SET time_delivered = $1, response_code = 200 + WHERE msg_id = $2 AND time_delivered IS NULL + AND lower(split_part(addr, '@', 3)) = lower($3) + `, now, target.MsgID, target.Domain); err != nil { + log.Printf("ERROR: sender: marking local recipients delivered for msg %d: %s", target.MsgID, err) + } + return + } + + db, err := sql.Open("postgres", "") + if err != nil { + log.Printf("ERROR: sender: db open: %s", err) + return + } + defer db.Close() + + tx, err := db.Begin() + if err != nil { + log.Printf("ERROR: sender: begin tx: %s", err) + return + } + committed := false + defer func() { + if !committed { + tx.Rollback() + } + }() + + now := timeutil.TimestampNow().Float64() + + // Lock pending (undelivered, retryable) msg_to rows for this message + // on the target domain. SKIP LOCKED avoids blocking concurrent senders. + lockRows, err := tx.Query(` + SELECT mt.addr + FROM msg_to mt + INNER JOIN msg m ON m.id = mt.msg_id + WHERE mt.msg_id = $1 + AND mt.time_delivered IS NULL + AND m.time_sent IS NOT NULL + AND (mt.response_code IS NULL OR mt.response_code = ANY($5)) + AND (mt.time_last_attempt IS NULL OR ($2 - mt.time_last_attempt) > LEAST($3 * POWER(2.0, GREATEST(mt.attempt_count - 1, 0)::float), $4)) + AND ($2 - m.time_sent) < $4 + FOR UPDATE OF mt SKIP LOCKED + `, target.MsgID, now, RetryInterval, RetryMaxAge, pq.Array(retryableResponseCodes)) + if err != nil { + log.Printf("ERROR: sender: lock rows for msg %d: %s", target.MsgID, err) + return + } + + var lockedAddrs []string + for lockRows.Next() { + var addr string + if err := lockRows.Scan(&addr); err != nil { + lockRows.Close() + log.Printf("ERROR: sender: scan locked addr: %s", err) + return + } + lastAt := strings.LastIndex(addr, "@") + if lastAt != -1 && strings.EqualFold(addr[lastAt+1:], target.Domain) { + lockedAddrs = append(lockedAddrs, addr) + } + } + lockRows.Close() + if err := lockRows.Err(); err != nil { + log.Printf("ERROR: sender: lock rows err for msg %d: %s", target.MsgID, err) + return + } + + // Also lock pending msg_add_to rows for this message on the target domain. + lockAddToRows, err := tx.Query(` + SELECT mat.addr + FROM msg_add_to mat + INNER JOIN msg m ON m.id = mat.msg_id + WHERE mat.msg_id = $1 + AND mat.time_delivered IS NULL + AND m.time_sent IS NOT NULL + AND (mat.response_code IS NULL OR mat.response_code = ANY($5)) + AND (mat.time_last_attempt IS NULL OR ($2 - mat.time_last_attempt) > LEAST($3 * POWER(2.0, GREATEST(mat.attempt_count - 1, 0)::float), $4)) + AND ($2 - m.time_sent) < $4 + FOR UPDATE OF mat SKIP LOCKED + `, target.MsgID, now, RetryInterval, RetryMaxAge, pq.Array(retryableResponseCodes)) + if err != nil { + log.Printf("ERROR: sender: lock add-to rows for msg %d: %s", target.MsgID, err) + return + } + + var lockedAddToAddrs []string + for lockAddToRows.Next() { + var addr string + if err := lockAddToRows.Scan(&addr); err != nil { + lockAddToRows.Close() + log.Printf("ERROR: sender: scan locked add-to addr: %s", err) + return + } + lastAt := strings.LastIndex(addr, "@") + if lastAt != -1 && strings.EqualFold(addr[lastAt+1:], target.Domain) { + lockedAddToAddrs = append(lockedAddToAddrs, addr) + } + } + lockAddToRows.Close() + if err := lockAddToRows.Err(); err != nil { + log.Printf("ERROR: sender: lock add-to rows err for msg %d: %s", target.MsgID, err) + return + } + + if len(lockedAddrs) == 0 && len(lockedAddToAddrs) == 0 { + return // already locked by another sender or no longer eligible + } + + deferRetry := true + defer func() { + if deferRetry && !committed { + recordRetryableFailure(tx, &committed, lockedAddrs, lockedAddToAddrs, target.MsgID) + } + }() + + // Load the full message from msg table + h, err := loadMsg(tx, target.MsgID) + if err != nil { + log.Printf("ERROR: sender: %s", err) + return + } + + // Try zlib-deflate compression for message data and attachment data. + // Compressed temp files are cleaned up after delivery completes. + var deflateCleanup []string + defer func() { + for _, p := range deflateCleanup { + _ = os.Remove(p) + } + }() + if shouldCompress(h.Type, h.Size) { + dp, cs, ok, derr := tryCompress(h.Filepath, h.Size) + if derr != nil { + log.Printf("WARN: sender: compress msg data for msg %d: %s", target.MsgID, derr) + } else if ok { + log.Printf("INFO: sender: compressed msg %d data: %d -> %d bytes", target.MsgID, h.Size, cs) + deflateCleanup = append(deflateCleanup, dp) + h.Filepath = dp + h.ExpandedSize = h.Size + h.Size = cs + h.Flags |= FlagDeflate + } + } + for i := range h.Attachments { + att := &h.Attachments[i] + if shouldCompress(att.Type, att.Size) { + dp, cs, ok, derr := tryCompress(att.Filepath, att.Size) + if derr != nil { + log.Printf("WARN: sender: compress attachment %s for msg %d: %s", att.Filename, target.MsgID, derr) + } else if ok { + log.Printf("INFO: sender: compressed msg %d attachment %s: %d -> %d bytes", target.MsgID, att.Filename, att.Size, cs) + deflateCleanup = append(deflateCleanup, dp) + att.Filepath = dp + att.ExpandedSize = att.Size + att.Size = cs + att.Flags |= 1 << 1 + } + } + } + + // Ensure sha256 is populated for outgoing messages so future pid lookups + // (e.g. add-to notifications referencing this message) can find it. + msgHash, err := h.GetMessageHash() + if err != nil { + log.Printf("ERROR: sender: computing message hash for msg %d: %s", target.MsgID, err) + return + } + if _, err := tx.Exec(`UPDATE msg SET sha256 = $1 WHERE id = $2 AND sha256 IS NULL`, + msgHash, target.MsgID); err != nil { + log.Printf("ERROR: sender: storing sha256 for msg %d: %s", target.MsgID, err) + return + } + if err := resolvePendingChildLinks(txParentLinkStore{tx: tx}, target.MsgID, msgHash); err != nil { + log.Printf("ERROR: sender: resolving child pids for msg %d: %s", target.MsgID, err) + return + } + + // Compute header hash now; registerOutgoing with Host B's IP happens after + // the connection is established (IP needed for challenge validation §10.5). + hash := h.GetHeaderHash() + hashArr := *(*[32]byte)(hash) + + // Build the list of recipients on the target domain in to then add-to order. + // Per spec, per-recipient response codes follow the same order. + lockedSet := make(map[string]bool) + for _, a := range lockedAddrs { + lockedSet[strings.ToLower(a)] = true + } + for _, a := range lockedAddToAddrs { + lockedSet[strings.ToLower(a)] = true + } + type domainRecip struct { + addr string + isLocked bool + isAddTo bool + } + var domainRecips []domainRecip + for _, addr := range h.To { + if strings.EqualFold(addr.Domain, target.Domain) { + s := addr.ToString() + domainRecips = append(domainRecips, domainRecip{ + addr: s, + isLocked: lockedSet[strings.ToLower(s)], + isAddTo: false, + }) + } + } + for _, addr := range h.AddTo { + if strings.EqualFold(addr.Domain, target.Domain) { + s := addr.ToString() + domainRecips = append(domainRecips, domainRecip{ + addr: s, + isLocked: lockedSet[strings.ToLower(s)], + isAddTo: true, + }) + } + } + + // --- network delivery --- + + targetIPs, err := lookupAuthorisedIPs(target.Domain) + if err != nil { + log.Printf("ERROR: sender: DNS lookup for fmsg.%s failed: %s", target.Domain, err) + return + } + + var conn net.Conn + dialer := &net.Dialer{Timeout: 10 * time.Second} + tlsConf := buildClientTLSConfig("fmsg." + target.Domain) + for _, ip := range targetIPs { + addr := net.JoinHostPort(ip.String(), fmt.Sprintf("%d", RemotePort)) + conn, err = tls.DialWithDialer(dialer, "tcp", addr, tlsConf) + if err == nil { + break + } + log.Printf("WARN: sender: connect to %s failed: %s", addr, err) + } + if conn == nil { + log.Printf("ERROR: sender: could not connect to any IP for fmsg.%s", target.Domain) + return + } + defer conn.Close() + + // Register in outgoing map with Host B's IP before sending the header so + // any incoming challenge can be matched by hash AND IP (§10.2 step 2). + connectedIP := conn.RemoteAddr().(*net.TCPAddr).IP.String() + log.Printf("INFO: sender: registering outgoing message %s (%s)", hex.EncodeToString(hashArr[:]), connectedIP) + registerOutgoing(hashArr, h, connectedIP) + defer removeOutgoingIP(hashArr, connectedIP) + + // Step 3: Transmit message header. + if _, err := conn.Write(h.Encode()); err != nil { + log.Printf("ERROR: sender: writing header for msg %d: %s", target.MsgID, err) + return + } + + // Step 5: Read the initial response byte before sending any data (§10.2 step 5). + // The challenge handler may fire on a separate goroutine during this wait. + conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + initCode := make([]byte, 1) + if _, err := io.ReadFull(conn, initCode); err != nil { + log.Printf("ERROR: sender: reading initial response for msg %d: %s", target.MsgID, err) + return + } + now = timeutil.TimestampNow().Float64() + isAddToMsg := h.Flags&FlagHasAddTo != 0 + + switch initCode[0] { + case AcceptCodeContinue: // 64 — send data + attachments, then per-recipient codes + if err := sendMsgData(conn, h); err != nil { + log.Printf("ERROR: sender: %s (msg %d)", err, target.MsgID) + return + } + case AcceptCodeSkipData: // 65 — add-to, parent stored, recipients on this host; skip data + if !isAddToMsg { + log.Printf("WARN: sender: msg %d received protocol-invalid code 65 from %s for non-add-to message, terminating", + target.MsgID, target.Domain) + return + } + // do not transmit data; per-recipient codes follow below + case AcceptCodeAddTo: // 11 — add-to accepted, no recipients on this host + if !isAddToMsg { + log.Printf("WARN: sender: msg %d received protocol-invalid code 11 from %s for non-add-to message, terminating", + target.MsgID, target.Domain) + return + } + log.Printf("INFO: sender: msg %d add-to accepted by %s (code 11)", target.MsgID, target.Domain) + updateAllLocked(tx, lockedAddrs, lockedAddToAddrs, target.MsgID, now, int(initCode[0]), true) + deferRetry = false + commitOrLog(tx, &committed, target.MsgID) + return + default: + if initCode[0] >= 1 && initCode[0] <= 10 { + // global rejection + log.Printf("WARN: sender: msg %d rejected by %s: %s (%d)", + target.MsgID, target.Domain, responseCodeName(initCode[0]), initCode[0]) + updateAllLocked(tx, lockedAddrs, lockedAddToAddrs, target.MsgID, now, int(initCode[0]), false) + deferRetry = false + commitOrLog(tx, &committed, target.MsgID) + } else { + // unexpected code — TERMINATE + log.Printf("WARN: sender: msg %d unexpected response code %d from %s, terminating", + target.MsgID, initCode[0], target.Domain) + } + return + } + + // Step 6: Read one per-recipient code per recipient on this host, in + // to-field order then add-to order (§10.2 step 6). + conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + codes := make([]byte, len(domainRecips)) + if _, err := io.ReadFull(conn, codes); err != nil { + log.Printf("ERROR: sender: reading per-recipient codes for msg %d: %s", target.MsgID, err) + return + } + now = timeutil.TimestampNow().Float64() + + for i, dr := range domainRecips { + if !dr.isLocked { + continue + } + c := codes[i] + table := "msg_to" + if dr.isAddTo { + table = "msg_add_to" + } + delivered := c == RejectCodeAccept + if delivered { + log.Printf("INFO: sender: delivered msg %d to %s", target.MsgID, dr.addr) + } else { + log.Printf("WARN: sender: msg %d to %s: %s (%d)", target.MsgID, dr.addr, responseCodeName(c), c) + } + updateRecipient(tx, table, dr.addr, target.MsgID, now, int(c), delivered) + } + + deferRetry = false + commitOrLog(tx, &committed, target.MsgID) +} + +// processPendingMessages finds messages needing delivery and dispatches a +// goroutine per (message, domain) pair, bounded by the semaphore. +func processPendingMessages(sem chan struct{}) { + targets, err := findPendingTargets() + if err != nil { + log.Printf("ERROR: sender: finding pending targets: %s", err) + return + } + if len(targets) == 0 { + return + } + log.Printf("INFO: sender: found %d pending target(s)", len(targets)) + + for _, t := range targets { + sem <- struct{}{} // acquire + go func(t pendingTarget) { + defer func() { <-sem }() + deliverMessage(t) + }(t) + } +} + +// startSender runs the sender loop: polls the database periodically and also +// listens for PostgreSQL notifications for immediate pickup of new messages. +func startSender() { + loadSenderEnvConfig() + log.Printf("INFO: sender: started (poll=%ds, retry=%.0fs, max_concurrent=%d)", + PollInterval, RetryInterval, MaxConcurrentSend) + + sem := make(chan struct{}, MaxConcurrentSend) + + // set up PostgreSQL LISTEN for immediate notification + notifyCh := make(chan struct{}, 1) + go func() { + listener := pq.NewListener("", 10*time.Second, time.Minute, func(ev pq.ListenerEventType, err error) { + if err != nil { + log.Printf("ERROR: sender: pg listener: %s", err) + } + }) + if err := listener.Listen("new_msg_to"); err != nil { + log.Printf("ERROR: sender: could not LISTEN on new_msg_to: %s", err) + return + } + defer listener.Close() + log.Println("INFO: sender: listening for new_msg_to notifications") + for { + select { + case n := <-listener.Notify: + if n != nil { + log.Printf("INFO: sender: notification received: %s", n.Extra) + select { + case notifyCh <- struct{}{}: + default: + } + } + case <-time.After(32 * time.Second): + // ping to keep connection alive + if err := listener.Ping(); err != nil { + log.Printf("ERROR: sender: pg listener ping: %s", err) + } + } + } + }() + + // initial poll on startup + processPendingMessages(sem) + + ticker := time.NewTicker(time.Duration(PollInterval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + processPendingMessages(sem) + case <-notifyCh: + // small delay to batch rapid inserts + time.Sleep(256 * time.Millisecond) + processPendingMessages(sem) + } + } +} diff --git a/cmd/fmsgd/store.go b/cmd/fmsgd/store.go new file mode 100644 index 0000000..b2689f2 --- /dev/null +++ b/cmd/fmsgd/store.go @@ -0,0 +1,592 @@ +package main + +import ( + "database/sql" + "fmt" + "log" + "strings" + + "github.com/levenlabs/golib/timeutil" + _ "github.com/lib/pq" +) + +func testDb() error { + db, err := sql.Open("postgres", "") + if err != nil { + return err + } + defer db.Close() + err = db.Ping() + if err != nil { + return err + } + + var dbName, user, host, port string + _ = db.QueryRow("SELECT current_database()").Scan(&dbName) + _ = db.QueryRow("SELECT current_user").Scan(&user) + _ = db.QueryRow("SELECT inet_server_addr()::text").Scan(&host) + _ = db.QueryRow("SELECT inet_server_port()::text").Scan(&port) + log.Printf("INFO: Database connected: %s@%s:%s/%s", user, host, port, dbName) + + // verify required tables exist + for _, table := range []string{"msg", "msg_to", "msg_add_to", "msg_attachment"} { + var exists bool + err = db.QueryRow(`SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + )`, table).Scan(&exists) + if err != nil { + return fmt.Errorf("checking table %s: %w", table, err) + } + if !exists { + return fmt.Errorf("required table %s does not exist", table) + } + } + return nil +} + +// lookupMsgIdByHash returns the msg id for a message with the given SHA256 hash, +// or 0 if no such message exists. +func lookupMsgIdByHash(hash []byte) (int64, error) { + db, err := sql.Open("postgres", "") + if err != nil { + return 0, err + } + defer db.Close() + + var id int64 + err = db.QueryRow("SELECT id FROM msg WHERE sha256 = $1", hash).Scan(&id) + if err == sql.ErrNoRows { + return 0, nil + } + return id, err +} + +// hasAddrReceivedMsgHash reports whether addr has already received a stored +// message identified by hash. +func hasAddrReceivedMsgHash(hash []byte, addr *FMsgAddress) (bool, error) { + if addr == nil || len(hash) == 0 { + return false, nil + } + + db, err := sql.Open("postgres", "") + if err != nil { + return false, err + } + defer db.Close() + + addrStr := strings.ToLower(addr.ToString()) + + var exists bool + err = db.QueryRow(` + SELECT EXISTS ( + SELECT 1 + FROM msg m + JOIN msg_to mt ON mt.msg_id = m.id + WHERE m.sha256 = $1 + AND lower(mt.addr) = $2 + AND mt.time_delivered IS NOT NULL + UNION ALL + SELECT 1 + FROM msg m + JOIN msg_add_to mat ON mat.msg_id = m.id + WHERE m.sha256 = $1 + AND lower(mat.addr) = $2 + AND mat.time_delivered IS NOT NULL + ) + `, hash, addrStr).Scan(&exists) + if err != nil { + return false, err + } + + return exists, nil +} + +type parentLinkStore interface { + lookupParentID(parentHash []byte) (int64, error) + setParentID(msgID int64, parentID int64) error + setPendingChildrenParentID(parentID int64, parentHash []byte) error +} + +type txParentLinkStore struct { + tx *sql.Tx +} + +func (s txParentLinkStore) lookupParentID(parentHash []byte) (int64, error) { + var id int64 + err := s.tx.QueryRow("SELECT id FROM msg WHERE sha256 = $1", parentHash).Scan(&id) + if err == sql.ErrNoRows { + return 0, nil + } + return id, err +} + +func (s txParentLinkStore) setParentID(msgID int64, parentID int64) error { + _, err := s.tx.Exec("UPDATE msg SET pid = $1 WHERE id = $2", parentID, msgID) + return err +} + +func (s txParentLinkStore) setPendingChildrenParentID(parentID int64, parentHash []byte) error { + _, err := s.tx.Exec("UPDATE msg SET pid = $1 WHERE psha256 = $2 AND pid IS NULL", parentID, parentHash) + return err +} + +func resolveStoredParent(store parentLinkStore, msgID int64, parentHash []byte, requireParent bool) error { + if len(parentHash) == 0 { + return nil + } + + parentID, err := store.lookupParentID(parentHash) + if err != nil { + return err + } + if parentID == 0 { + if requireParent { + return fmt.Errorf("parent message not found for psha256 %x", parentHash) + } + return nil + } + + return store.setParentID(msgID, parentID) +} + +func resolvePendingChildLinks(store parentLinkStore, parentID int64, parentHash []byte) error { + if len(parentHash) == 0 { + return nil + } + return store.setPendingChildrenParentID(parentID, parentHash) +} + +func resolveMsgParentLinks(tx *sql.Tx, msgID int64, msgHash []byte, parentHash []byte, requireParent bool) error { + store := txParentLinkStore{tx: tx} + if err := resolveStoredParent(store, msgID, parentHash, requireParent); err != nil { + return err + } + return resolvePendingChildLinks(store, msgID, msgHash) +} + +func requiresStoredParent(msg *FMsgHeader) bool { + return len(msg.Pid) > 0 && msg.Flags&FlagHasAddTo == 0 +} + +func wirePidForLoadedMessage(storedParentHash []byte, msgHash []byte, hasAddTo bool) []byte { + if hasAddTo { + return msgHash + } + return storedParentHash +} + +// getMsgByID loads a message and all its recipients from the database by msg ID. +// Returns the full FMsgHeader or nil if the message doesn't exist. +func getMsgByID(msgID int64) (*FMsgHeader, error) { + db, err := sql.Open("postgres", "") + if err != nil { + return nil, err + } + defer db.Close() + + tx, err := db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + h, err := loadMsg(tx, msgID) + if err != nil { + // If the message doesn't exist, loadMsg will return an error, + // but we want to distinguish "not found" from other errors + if err.Error() == "no rows in result set" || err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + + return h, nil +} + +func storeMsgDetail(msg *FMsgHeader) error { + + db, err := sql.Open("postgres", "") + if err != nil { + return err + } + defer db.Close() + + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + msgHash, err := msg.GetMessageHash() + if err != nil { + return err + } + + var addToFrom interface{} + if msg.AddToFrom != nil { + addToFrom = msg.AddToFrom.ToString() + } + + var msgID int64 + err = tx.QueryRow(`insert into msg (version + , no_reply + , is_important + , is_deflate + , time_sent + , from_addr + , add_to_from + , topic + , type + , sha256 + , psha256 + , size + , filepath) +values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) +returning id`, + msg.Version, + msg.Flags&FlagNoReply != 0, + msg.Flags&FlagImportant != 0, + msg.Flags&FlagDeflate != 0, + msg.Timestamp, + msg.From.ToString(), + addToFrom, + msg.Topic, + msg.Type, + msgHash, + msg.Pid, + int(msg.Size), + msg.Filepath).Scan(&msgID) + if err != nil { + return err + } + + stmt, err := tx.Prepare(`insert into msg_to (msg_id, addr, time_delivered) +values ($1, $2, $3)`) + if err != nil { + return err + } + defer stmt.Close() + + now := timeutil.TimestampNow().Float64() + for _, addr := range msg.To { + // recipients on our domain are already delivered; others are pending + var delivered interface{} + if addr.Domain == Domain { + delivered = now + } + if _, err := stmt.Exec(msgID, addr.ToString(), delivered); err != nil { + return err + } + } + + // insert add-to recipients into msg_add_to + if len(msg.AddTo) > 0 { + addToStmt, err := tx.Prepare(`insert into msg_add_to (msg_id, addr, time_delivered) +values ($1, $2, $3)`) + if err != nil { + return err + } + defer addToStmt.Close() + + for _, addr := range msg.AddTo { + var delivered interface{} + if addr.Domain == Domain { + delivered = now + } + if _, err := addToStmt.Exec(msgID, addr.ToString(), delivered); err != nil { + return err + } + } + } + + if len(msg.Attachments) > 0 { + attStmt, err := tx.Prepare(`insert into msg_attachment (msg_id, position, flags, type, filename, filesize, filepath) +values ($1, $2, $3, $4, $5, $6, $7)`) + if err != nil { + return err + } + defer attStmt.Close() + + for i := range msg.Attachments { + att := msg.Attachments[i] + if _, err := attStmt.Exec(msgID, i, int(att.Flags), att.Type, att.Filename, int(att.Size), att.Filepath); err != nil { + return err + } + } + } + + if err := resolveMsgParentLinks(tx, msgID, msgHash, msg.Pid, requiresStoredParent(msg)); err != nil { + return err + } + + return tx.Commit() + +} + +// storeMsgHeaderOnly stores just the message header for add-to notifications +// (spec code 11). Only the header is recorded so the header hash can be +// faithfully computed for subsequent messages referencing this one via pid. +func storeMsgHeaderOnly(msg *FMsgHeader) error { + db, err := sql.Open("postgres", "") + if err != nil { + return err + } + defer db.Close() + + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + msgHash, err := msg.GetMessageHash() + if err != nil { + return err + } + + var addToFrom interface{} + if msg.AddToFrom != nil { + addToFrom = msg.AddToFrom.ToString() + } + + var msgID int64 + err = tx.QueryRow(`insert into msg (version + , no_reply + , is_important + , is_deflate + , time_sent + , from_addr + , add_to_from + , topic + , type + , sha256 + , psha256 + , size + , filepath) +values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) +returning id`, + msg.Version, + msg.Flags&FlagNoReply != 0, + msg.Flags&FlagImportant != 0, + msg.Flags&FlagDeflate != 0, + msg.Timestamp, + msg.From.ToString(), + addToFrom, + msg.Topic, + msg.Type, + msgHash, + msg.Pid, + int(msg.Size), + "").Scan(&msgID) + if err != nil { + return err + } + + // insert to recipients (for record keeping) + toStmt, err := tx.Prepare(`insert into msg_to (msg_id, addr) values ($1, $2)`) + if err != nil { + return err + } + defer toStmt.Close() + for _, addr := range msg.To { + if _, err := toStmt.Exec(msgID, addr.ToString()); err != nil { + return err + } + } + + // insert add-to recipients + if len(msg.AddTo) > 0 { + addToStmt, err := tx.Prepare(`insert into msg_add_to (msg_id, addr) values ($1, $2)`) + if err != nil { + return err + } + defer addToStmt.Close() + for _, addr := range msg.AddTo { + if _, err := addToStmt.Exec(msgID, addr.ToString()); err != nil { + return err + } + } + } + + if len(msg.Attachments) > 0 { + attStmt, err := tx.Prepare(`insert into msg_attachment (msg_id, position, flags, type, filename, filesize, filepath) +values ($1, $2, $3, $4, $5, $6, $7)`) + if err != nil { + return err + } + defer attStmt.Close() + + for i := range msg.Attachments { + att := msg.Attachments[i] + if _, err := attStmt.Exec(msgID, i, int(att.Flags), att.Type, att.Filename, int(att.Size), att.Filepath); err != nil { + return err + } + } + } + + if err := resolveMsgParentLinks(tx, msgID, msgHash, msg.Pid, requiresStoredParent(msg)); err != nil { + return err + } + + return tx.Commit() +} + +// loadMsg loads a message and all its recipients from the database within the +// given transaction and returns a fully populated FMsgHeader. +func loadMsg(tx *sql.Tx, msgID int64) (*FMsgHeader, error) { + var version, size int + var noReply, isImportant, isDeflate bool + var pid, msgHash []byte + var fromAddr, topic, typ, filepath string + var addToFromAddr sql.NullString + var timeSent float64 + err := tx.QueryRow(` + SELECT version, no_reply, is_important, is_deflate, psha256, sha256, from_addr, add_to_from, topic, type, time_sent, size, filepath + FROM msg WHERE id = $1 + `, msgID).Scan(&version, &noReply, &isImportant, &isDeflate, &pid, &msgHash, &fromAddr, &addToFromAddr, &topic, &typ, &timeSent, &size, &filepath) + if err != nil { + return nil, fmt.Errorf("load msg %d: %w", msgID, err) + } + + recipRows, err := tx.Query(`SELECT addr FROM msg_to WHERE msg_id = $1 ORDER BY id`, msgID) + if err != nil { + return nil, fmt.Errorf("load recipients for msg %d: %w", msgID, err) + } + var allRecipientAddrs []string + for recipRows.Next() { + var a string + if err := recipRows.Scan(&a); err != nil { + recipRows.Close() + return nil, fmt.Errorf("scan recipient addr: %w", err) + } + allRecipientAddrs = append(allRecipientAddrs, a) + } + recipRows.Close() + if err := recipRows.Err(); err != nil { + return nil, fmt.Errorf("recipients query err for msg %d: %w", msgID, err) + } + + from, err := parseAddress([]byte(fromAddr)) + if err != nil { + return nil, fmt.Errorf("invalid from address %s: %w", fromAddr, err) + } + allTo := make([]FMsgAddress, 0, len(allRecipientAddrs)) + for _, a := range allRecipientAddrs { + addr, err := parseAddress([]byte(a)) + if err != nil { + return nil, fmt.Errorf("invalid to address %s: %w", a, err) + } + allTo = append(allTo, *addr) + } + + // load add-to recipients from msg_add_to + addToRows, err := tx.Query(`SELECT addr FROM msg_add_to WHERE msg_id = $1 ORDER BY id`, msgID) + if err != nil { + return nil, fmt.Errorf("load add-to recipients for msg %d: %w", msgID, err) + } + var allAddTo []FMsgAddress + for addToRows.Next() { + var a string + if err := addToRows.Scan(&a); err != nil { + addToRows.Close() + return nil, fmt.Errorf("scan add-to addr: %w", err) + } + addr, err := parseAddress([]byte(a)) + if err != nil { + addToRows.Close() + return nil, fmt.Errorf("invalid add-to address %s: %w", a, err) + } + allAddTo = append(allAddTo, *addr) + } + addToRows.Close() + if err := addToRows.Err(); err != nil { + return nil, fmt.Errorf("add-to recipients query err for msg %d: %w", msgID, err) + } + + attRows, err := tx.Query(` + SELECT flags, type, filename, filesize, filepath + FROM msg_attachment + WHERE msg_id = $1 + ORDER BY position, filename + `, msgID) + if err != nil { + return nil, fmt.Errorf("load attachments for msg %d: %w", msgID, err) + } + attachments := []FMsgAttachmentHeader{} + for attRows.Next() { + var flags, filesize int + var typ, filename, filepath string + if err := attRows.Scan(&flags, &typ, &filename, &filesize, &filepath); err != nil { + attRows.Close() + return nil, fmt.Errorf("scan attachment row: %w", err) + } + attachments = append(attachments, FMsgAttachmentHeader{ + Flags: uint8(flags), + Type: typ, + Filename: filename, + Size: uint32(filesize), + Filepath: filepath, + }) + } + attRows.Close() + if err := attRows.Err(); err != nil { + return nil, fmt.Errorf("attachments query err for msg %d: %w", msgID, err) + } + + // Compute flags bitfield from stored booleans and loaded data. + // has_pid and has_add_to are derived from actual data rather than stored, + // so add-to recipients added after the original message are included. + // + // When add-to recipients exist, the wire pid references the message being + // shared, not that message's parent. This keeps add-to on replies pointing + // at the reply payload rather than the root message. + pid = wirePidForLoadedMessage(pid, msgHash, len(allAddTo) > 0) + + var addToFrom *FMsgAddress + if addToFromAddr.Valid && addToFromAddr.String != "" { + addr, err := parseAddress([]byte(addToFromAddr.String)) + if err != nil { + return nil, fmt.Errorf("invalid add_to_from address %s: %w", addToFromAddr.String, err) + } + addToFrom = addr + } + if len(allAddTo) > 0 && addToFrom == nil { + // Backward-compatibility for older rows before add_to_from existed. + fallback := *from + addToFrom = &fallback + } + + var flags uint8 + if len(pid) > 0 { + flags |= FlagHasPid + } + if len(allAddTo) > 0 { + flags |= FlagHasAddTo + } + if noReply { + flags |= FlagNoReply + } + if isImportant { + flags |= FlagImportant + } + if isDeflate { + flags |= FlagDeflate + } + + return &FMsgHeader{ + Version: uint8(version), + Flags: flags, + Pid: pid, + From: *from, + To: allTo, + AddToFrom: addToFrom, + AddTo: allAddTo, + Timestamp: timeSent, + Topic: topic, + Type: typ, + Size: uint32(size), + Attachments: attachments, + Filepath: filepath, + }, nil +} diff --git a/cmd/fmsgd/store_test.go b/cmd/fmsgd/store_test.go new file mode 100644 index 0000000..6bfdf98 --- /dev/null +++ b/cmd/fmsgd/store_test.go @@ -0,0 +1,139 @@ +package main + +import ( + "bytes" + "errors" + "testing" +) + +type fakeParentLinkStore struct { + parentID int64 + lookupErr error + + lookupHash []byte + setMsgID int64 + setParentIDValue int64 + setCalled bool + pendingParentID int64 + pendingParentHash []byte + pendingCalled bool +} + +func (s *fakeParentLinkStore) lookupParentID(parentHash []byte) (int64, error) { + s.lookupHash = append([]byte(nil), parentHash...) + return s.parentID, s.lookupErr +} + +func (s *fakeParentLinkStore) setParentID(msgID int64, parentID int64) error { + s.setCalled = true + s.setMsgID = msgID + s.setParentIDValue = parentID + return nil +} + +func (s *fakeParentLinkStore) setPendingChildrenParentID(parentID int64, parentHash []byte) error { + s.pendingCalled = true + s.pendingParentID = parentID + s.pendingParentHash = append([]byte(nil), parentHash...) + return nil +} + +func TestResolveStoredParentRequiresExistingParent(t *testing.T) { + store := &fakeParentLinkStore{} + parentHash := []byte{1, 2, 3} + + err := resolveStoredParent(store, 10, parentHash, true) + if err == nil { + t.Fatal("resolveStoredParent returned nil error for required missing parent") + } + if !bytes.Equal(store.lookupHash, parentHash) { + t.Fatalf("lookup hash = %v, want %v", store.lookupHash, parentHash) + } + if store.setCalled { + t.Fatal("setParentID was called for missing parent") + } +} + +func TestResolveStoredParentAllowsOptionalMissingParent(t *testing.T) { + store := &fakeParentLinkStore{} + + if err := resolveStoredParent(store, 10, []byte{1, 2, 3}, false); err != nil { + t.Fatalf("resolveStoredParent returned error for optional missing parent: %v", err) + } + if store.setCalled { + t.Fatal("setParentID was called for optional missing parent") + } +} + +func TestResolveStoredParentSetsPidWhenParentExists(t *testing.T) { + store := &fakeParentLinkStore{parentID: 42} + + if err := resolveStoredParent(store, 10, []byte{1, 2, 3}, true); err != nil { + t.Fatalf("resolveStoredParent returned error: %v", err) + } + if !store.setCalled { + t.Fatal("setParentID was not called") + } + if store.setMsgID != 10 || store.setParentIDValue != 42 { + t.Fatalf("setParentID called with msgID=%d parentID=%d, want msgID=10 parentID=42", store.setMsgID, store.setParentIDValue) + } +} + +func TestResolveStoredParentPropagatesLookupError(t *testing.T) { + lookupErr := errors.New("lookup failed") + store := &fakeParentLinkStore{lookupErr: lookupErr} + + err := resolveStoredParent(store, 10, []byte{1, 2, 3}, true) + if !errors.Is(err, lookupErr) { + t.Fatalf("resolveStoredParent error = %v, want %v", err, lookupErr) + } + if store.setCalled { + t.Fatal("setParentID was called after lookup error") + } +} + +func TestResolvePendingChildLinksBackfillsByParentHash(t *testing.T) { + store := &fakeParentLinkStore{} + parentHash := []byte{4, 5, 6} + + if err := resolvePendingChildLinks(store, 42, parentHash); err != nil { + t.Fatalf("resolvePendingChildLinks returned error: %v", err) + } + if !store.pendingCalled { + t.Fatal("setPendingChildrenParentID was not called") + } + if store.pendingParentID != 42 || !bytes.Equal(store.pendingParentHash, parentHash) { + t.Fatalf("pending update got parentID=%d hash=%v, want parentID=42 hash=%v", store.pendingParentID, store.pendingParentHash, parentHash) + } +} + +func TestRequiresStoredParentUsesAddToFlag(t *testing.T) { + parentHash := []byte{1, 2, 3} + + if !requiresStoredParent(&FMsgHeader{Flags: FlagHasPid, Pid: parentHash}) { + t.Fatal("normal reply did not require stored parent") + } + if requiresStoredParent(&FMsgHeader{Flags: FlagHasPid | FlagHasAddTo, Pid: parentHash}) { + t.Fatal("add-to message required stored parent") + } +} + +func TestWirePidForLoadedMessageAddToReferencesSharedMessage(t *testing.T) { + parentHash := []byte{1, 2, 3} + msgHash := []byte{4, 5, 6} + + got := wirePidForLoadedMessage(parentHash, msgHash, true) + if !bytes.Equal(got, msgHash) { + t.Fatalf("add-to wire pid = %v, want message hash %v", got, msgHash) + } +} + +func TestWirePidForLoadedMessageReplyKeepsParentHash(t *testing.T) { + parentHash := []byte{1, 2, 3} + msgHash := []byte{4, 5, 6} + + got := wirePidForLoadedMessage(parentHash, msgHash, false) + if !bytes.Equal(got, parentHash) { + t.Fatalf("reply wire pid = %v, want parent hash %v", got, parentHash) + } +}