diff --git a/.github/workflows/go1.25.yml b/.github/workflows/go1.25.yml index 2d9436a..6ddf095 100644 --- a/.github/workflows/go1.25.yml +++ b/.github/workflows/go1.25.yml @@ -10,9 +10,6 @@ jobs: build: runs-on: ubuntu-latest - defaults: - run: - working-directory: ./src steps: - uses: actions/checkout@v3 @@ -22,7 +19,7 @@ jobs: go-version: '1.25' - name: Build - run: go build -v . + run: go build -v ./... - name: Test - run: go test -v . + run: go test -v ./... diff --git a/.gitignore b/.gitignore index 0077797..3c19fc4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ .claude/ *.exe .env -fmsgd \ No newline at end of file +/cmd/fmsgd/fmsgd \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index 300059f..4f4029e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,6 +19,8 @@ All code MUST conform to the specification. When in doubt, re-read SPEC.md and f ## Build & Test - Language: Go -- Module path: `src/` -- Build: `cd src && go build ./...` -- Test: `cd src && go test ./...` +- Module path: repo root (`go.mod` at `fmsgd/go.mod`) +- Binary source: `cmd/fmsgd/` +- Shared protocol package: `pkg/fmsg/` +- Build: `go build ./...` (from repo root) +- Test: `go test ./...` (from repo root) diff --git a/README.md b/README.md index 5c9019c..d56a3db 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,7 @@ Implementation of [fmsg](https://github.com/markmnl/fmsg) host written in Go! Us Tested with Go 1.25 on Linux and Windows, AMD64 and ARM 1. Clone this repository -2. Navigate to src/ -2. Run `go build .` +2. Run `go build ./cmd/fmsgd/` ## Environment diff --git a/src/deflate.go b/cmd/fmsgd/deflate.go similarity index 100% rename from src/deflate.go rename to cmd/fmsgd/deflate.go diff --git a/src/deflate_test.go b/cmd/fmsgd/deflate_test.go similarity index 100% rename from src/deflate_test.go rename to cmd/fmsgd/deflate_test.go diff --git a/cmd/fmsgd/defs.go b/cmd/fmsgd/defs.go new file mode 100644 index 0000000..cf18055 --- /dev/null +++ b/cmd/fmsgd/defs.go @@ -0,0 +1,18 @@ +package main + +import "github.com/markmnl/fmsgd/pkg/fmsg" + +// Type aliases so all internal code continues to use unqualified names. +type FMsgAddress = fmsg.Address +type FMsgAttachmentHeader = fmsg.AttachmentHeader +type FMsgHeader = fmsg.Header + +// Flag constants forwarded from the fmsg package. +const ( + FlagHasPid = fmsg.FlagHasPid + FlagHasAddTo = fmsg.FlagHasAddTo + FlagCommonType = fmsg.FlagCommonType + FlagImportant = fmsg.FlagImportant + FlagNoReply = fmsg.FlagNoReply + FlagDeflate = fmsg.FlagDeflate +) diff --git a/src/defs_test.go b/cmd/fmsgd/defs_test.go similarity index 99% rename from src/defs_test.go rename to cmd/fmsgd/defs_test.go index 591b555..52a1a3a 100644 --- a/src/defs_test.go +++ b/cmd/fmsgd/defs_test.go @@ -10,6 +10,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/markmnl/fmsgd/pkg/fmsg" ) func TestAddressToString(t *testing.T) { @@ -841,13 +843,13 @@ func TestHashPayloadRejectsExpandedSizeMismatch(t *testing.T) { // Correct expanded size should succeed var dst bytes.Buffer - if err := hashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))); err != nil { + if err := fmsg.HashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))); err != nil { t.Fatalf("hashPayload with correct expanded size: %v", err) } // Wrong expanded size should fail dst.Reset() - err := hashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))+1) + err := fmsg.HashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))+1) if err == nil { t.Fatal("hashPayload with wrong expanded size: expected error, got nil") } diff --git a/src/dns.go b/cmd/fmsgd/dns.go similarity index 100% rename from src/dns.go rename to cmd/fmsgd/dns.go diff --git a/src/host.go b/cmd/fmsgd/host.go similarity index 90% rename from src/host.go rename to cmd/fmsgd/host.go index 5bc6e6e..194fed9 100644 --- a/src/host.go +++ b/cmd/fmsgd/host.go @@ -1,1735 +1,1688 @@ -package main - -import ( - "bufio" - "bytes" - "compress/zlib" - "crypto/tls" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "io" - "log" - "math" - "mime" - "net" - "net/url" - "os" - "path/filepath" - "strings" - "time" - "unicode" - "unicode/utf8" - - env "github.com/caitlinelfring/go-env-default" - "github.com/joho/godotenv" - "github.com/levenlabs/golib/timeutil" -) - -const ( - InboxDirName = "in" - OutboxDirName = "out" - - // Flag bit assignments per SPEC.md: - // bit 0 = has pid, bit 1 = has add to, bit 2 = common type, bit 3 = important, - // bit 4 = no reply, bit 5 = deflate, bits 6-7 = reserved. - FlagHasPid uint8 = 1 - FlagHasAddTo uint8 = 1 << 1 - FlagCommonType uint8 = 1 << 2 - FlagImportant uint8 = 1 << 3 - FlagNoReply uint8 = 1 << 4 - FlagDeflate uint8 = 1 << 5 - - RejectCodeInvalid uint8 = 1 - RejectCodeUnsupportedVersion uint8 = 2 - RejectCodeUndisclosed uint8 = 3 - RejectCodeTooBig uint8 = 4 - RejectCodeInsufficentResources uint8 = 5 - RejectCodeParentNotFound uint8 = 6 - RejectCodePastTime uint8 = 7 - RejectCodeFutureTime uint8 = 8 - RejectCodeTimeTravel uint8 = 9 - RejectCodeDuplicate uint8 = 10 - AcceptCodeAddTo uint8 = 11 - AcceptCodeContinue uint8 = 64 - AcceptCodeSkipData uint8 = 65 - - RejectCodeUserUnknown uint8 = 100 - RejectCodeUserFull uint8 = 101 - RejectCodeUserNotAccepting uint8 = 102 - RejectCodeUserDuplicate uint8 = 103 - RejectCodeUserUndisclosed uint8 = 105 - - RejectCodeAccept uint8 = 200 - - messageReservedBitsMask uint8 = 0b11000000 - attachmentReservedBitsMask uint8 = 0b11111100 -) - -// responseCodeName returns the human-friendly name for a response code. -func responseCodeName(code uint8) string { - switch code { - case RejectCodeInvalid: - return "invalid" - case RejectCodeUnsupportedVersion: - return "unsupported version" - case RejectCodeUndisclosed: - return "undisclosed" - case RejectCodeTooBig: - return "too big" - case RejectCodeInsufficentResources: - return "insufficient resources" - case RejectCodeParentNotFound: - return "parent not found" - case RejectCodePastTime: - return "past time" - case RejectCodeFutureTime: - return "future time" - case RejectCodeTimeTravel: - return "time travel" - case RejectCodeDuplicate: - return "duplicate" - case AcceptCodeAddTo: - return "accept add to" - case RejectCodeUserUnknown: - return "user unknown" - case RejectCodeUserFull: - return "user full" - case RejectCodeUserNotAccepting: - return "user not accepting" - case RejectCodeUserDuplicate: - return "user duplicate" - case RejectCodeUserUndisclosed: - return "user undisclosed" - case RejectCodeAccept: - return "accept" - default: - return fmt.Sprintf("unknown(%d)", code) - } -} - -var ErrProtocolViolation = errors.New("protocol violation") - -// commonMediaTypes maps common type IDs to their MIME strings per SPEC.md §4. -// IDs 1–64; unmapped IDs must be rejected with code 1 (invalid). -var commonMediaTypes = map[uint8]string{ - 1: "application/epub+zip", 2: "application/gzip", 3: "application/json", 4: "application/msword", - 5: "application/octet-stream", 6: "application/pdf", 7: "application/rtf", 8: "application/vnd.amazon.ebook", - 9: "application/vnd.ms-excel", 10: "application/vnd.ms-powerpoint", - 11: "application/vnd.oasis.opendocument.presentation", 12: "application/vnd.oasis.opendocument.spreadsheet", - 13: "application/vnd.oasis.opendocument.text", - 14: "application/vnd.openxmlformats-officedocument.presentationml.presentation", - 15: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - 16: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - 17: "application/x-tar", 18: "application/xhtml+xml", 19: "application/xml", 20: "application/zip", - 21: "audio/aac", 22: "audio/midi", 23: "audio/mpeg", 24: "audio/ogg", 25: "audio/opus", 26: "audio/vnd.wave", 27: "audio/webm", - 28: "font/otf", 29: "font/ttf", 30: "font/woff", 31: "font/woff2", - 32: "image/apng", 33: "image/avif", 34: "image/bmp", 35: "image/gif", 36: "image/heic", 37: "image/jpeg", 38: "image/png", - 39: "image/svg+xml", 40: "image/tiff", 41: "image/webp", - 42: "model/3mf", 43: "model/gltf-binary", 44: "model/obj", 45: "model/step", 46: "model/stl", 47: "model/vnd.usdz+zip", - 48: "text/calendar", 49: "text/css", 50: "text/csv", 51: "text/html", 52: "text/javascript", 53: "text/markdown", - 54: "text/plain;charset=US-ASCII", 55: "text/plain;charset=UTF-16", 56: "text/plain;charset=UTF-8", 57: "text/vcard", - 58: "video/H264", 59: "video/H265", 60: "video/H266", 61: "video/ogg", 62: "video/VP8", 63: "video/VP9", 64: "video/webm", -} - -// getCommonMediaType returns the MIME type string for a common type ID, or -// empty string + false if the ID is not mapped (should be rejected per spec). -func getCommonMediaType(id uint8) (string, bool) { - s, ok := commonMediaTypes[id] - return s, ok -} - -// getCommonMediaTypeID returns the common type ID for a MIME string. -func getCommonMediaTypeID(mediaType string) (uint8, bool) { - for id, mime := range commonMediaTypes { - if mime == mediaType { - return id, true - } - } - return 0, false -} - -var Port = 4930 - -// The only reason RemotePort would ever be different from Port is when running two fmsg hosts on the same machine so the same port is unavaliable. -var RemotePort = 4930 -var PastTimeDelta float64 = 7 * 24 * 60 * 60 -var FutureTimeDelta float64 = 300 -var MinDownloadRate float64 = 5000 -var MinUploadRate float64 = 5000 -var ReadBufferSize = 1600 -var MaxMessageSize = uint32(1024 * 10) -var MaxExpandedSize = uint32(1024 * 10) -var SkipAuthorisedIPs = false -var TLSInsecureSkipVerify = false -var DataDir = "got on startup" -var Domain = "got on startup" -var IDURI = "got on startup" -var AtRune, _ = utf8.DecodeRuneInString("@") -var MinNetIODeadline = 6 * time.Second - -var serverTLSConfig *tls.Config - -func buildServerTLSConfig() *tls.Config { - certFile := os.Getenv("FMSG_TLS_CERT") - keyFile := os.Getenv("FMSG_TLS_KEY") - if certFile == "" || keyFile == "" { - log.Fatalf("ERROR: FMSG_TLS_CERT and FMSG_TLS_KEY must be set") - } - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - log.Fatalf("ERROR: loading TLS certificate: %s", err) - } - return &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - }, - NextProtos: []string{"fmsg/1"}, - } -} - -func buildClientTLSConfig(serverName string) *tls.Config { - return &tls.Config{ - ServerName: serverName, - MinVersion: tls.VersionTLS12, - InsecureSkipVerify: TLSInsecureSkipVerify, - NextProtos: []string{"fmsg/1"}, - } -} - -// loadEnvConfig reads env vars (after godotenv.Load so .env is picked up). -func loadEnvConfig() { - Port = env.GetIntDefault("FMSG_PORT", 4930) - RemotePort = env.GetIntDefault("FMSG_REMOTE_PORT", 4930) - PastTimeDelta = env.GetFloatDefault("FMSG_MAX_PAST_TIME_DELTA", 7*24*60*60) - FutureTimeDelta = env.GetFloatDefault("FMSG_MAX_FUTURE_TIME_DELTA", 300) - MinDownloadRate = env.GetFloatDefault("FMSG_MIN_DOWNLOAD_RATE", 5000) - MinUploadRate = env.GetFloatDefault("FMSG_MIN_UPLOAD_RATE", 5000) - ReadBufferSize = env.GetIntDefault("FMSG_READ_BUFFER_SIZE", 1600) - MaxMessageSize = uint32(env.GetIntDefault("FMSG_MAX_MSG_SIZE", 1024*10)) - MaxExpandedSize = uint32(env.GetIntDefault("FMSG_MAX_EXPANDED_SIZE", int(MaxMessageSize))) - SkipAuthorisedIPs = os.Getenv("FMSG_SKIP_AUTHORISED_IPS") == "true" - TLSInsecureSkipVerify = os.Getenv("FMSG_TLS_INSECURE_SKIP_VERIFY") == "true" -} - -// Updates DataDir from environment, panics if not a valid directory. -func setDataDir() { - value, hasValue := os.LookupEnv("FMSG_DATA_DIR") - if !hasValue { - log.Panic("ERROR: FMSG_DATA_DIR not set") - } - _, err := os.Stat(value) - if err != nil { - log.Panicf("ERROR: FMSG_DATA_DIR, %s: %s", value, err) - } - DataDir = value -} - -// Updates Domain from environment, panics if not a valid domain. -func setDomain() { - domain, hasValue := os.LookupEnv("FMSG_DOMAIN") - if !hasValue { - log.Panicln("ERROR: FMSG_DOMAIN not set") - } - _, err := net.LookupHost("fmsg." + domain) - if err != nil { - log.Panicf("ERROR: FMSG_DOMAIN, %s: %s\n", domain, err) - } - Domain = domain - - // verify our external IP is in the fmsg authorised IP set - checkDomainIP(domain) -} - -// Updates IDURL from environment, panics if not valid. -func setIDURL() { - rawUrl, hasValue := os.LookupEnv("FMSG_ID_URL") - if !hasValue { - log.Panicln("ERROR: FMSG_ID_URL not set") - } - url, err := url.Parse(rawUrl) - if err != nil { - log.Panicf("ERROR: FMSG_ID_URL not valid, %s: %s", rawUrl, err) - } - _, err = net.LookupHost(url.Hostname()) - if err != nil { - log.Panicf("ERROR: FMSG_ID_URL lookup failed, %s: %s", url, err) - } - // TODO ping URL to verify its up and responding in a timely manner - IDURI = rawUrl - log.Printf("INFO: ID URL: %s", IDURI) -} - -func calcNetIODuration(sizeInBytes int, bytesPerSecond float64) time.Duration { - rate := float64(sizeInBytes) / bytesPerSecond - d := time.Duration(rate * float64(time.Second)) - if d < MinNetIODeadline { - return MinNetIODeadline - } - return d -} - -func isValidUser(s string) bool { - if !utf8.ValidString(s) || len(s) == 0 || len(s) > 64 { - return false - } - - isSpecial := func(r rune) bool { - return r == '-' || r == '_' || r == '.' - } - - runes := []rune(s) - if isSpecial(runes[0]) || isSpecial(runes[len(runes)-1]) { - return false - } - - lastWasSpecial := false - for _, c := range runes { - if unicode.IsLetter(c) || unicode.IsNumber(c) { - lastWasSpecial = false - continue - } - if !isSpecial(c) { - return false - } - if lastWasSpecial { - return false - } - lastWasSpecial = true - } - return true -} - -func isASCIIBytes(b []byte) bool { - for _, c := range b { - if c > 127 { - return false - } - } - return true -} - -func isValidAttachmentFilename(name string) bool { - if !utf8.ValidString(name) || len(name) == 0 || len(name) >= 256 { - return false - } - - isSpecial := func(r rune) bool { - return r == '-' || r == '_' || r == ' ' || r == '.' - } - - runes := []rune(name) - if isSpecial(runes[0]) || isSpecial(runes[len(runes)-1]) { - return false - } - - lastWasSpecial := false - for _, r := range runes { - if unicode.IsLetter(r) || unicode.IsNumber(r) { - lastWasSpecial = false - continue - } - if !isSpecial(r) { - return false - } - if lastWasSpecial { - return false - } - lastWasSpecial = true - } - - return true -} - -func isMessageRetrievable(msg *FMsgHeader) bool { - if msg == nil { - return false - } - if msg.Filepath != "" { - st, err := os.Stat(msg.Filepath) - if err == nil && !st.IsDir() { - return true - } - } - if len(msg.Pid) == 0 { - return false - } - parentID, err := lookupMsgIdByHash(msg.Pid) - if err != nil || parentID == 0 { - return false - } - parentMsg, err := getMsgByID(parentID) - if err != nil { - return false - } - if parentMsg == nil { - return false - } - return isMessageRetrievable(parentMsg) -} - -func isParentParticipant(parent *FMsgHeader, addr *FMsgAddress) bool { - if parent == nil || addr == nil { - return false - } - target := strings.ToLower(addr.ToString()) - if strings.ToLower(parent.From.ToString()) == target { - return true - } - for i := range parent.To { - if strings.ToLower(parent.To[i].ToString()) == target { - return true - } - } - if parent.AddToFrom != nil && strings.ToLower(parent.AddToFrom.ToString()) == target { - return true - } - for i := range parent.AddTo { - if strings.ToLower(parent.AddTo[i].ToString()) == target { - return true - } - } - return false -} - -func isValidDomain(s string) bool { - if len(s) == 0 || len(s) > 253 { - return false - } - if s == "localhost" { - return true - } - labels := strings.Split(s, ".") - if len(labels) < 2 { - return false - } - for _, label := range labels { - if len(label) == 0 || len(label) > 63 { - return false - } - if label[0] == '-' || label[len(label)-1] == '-' { - return false - } - for _, c := range label { - if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-') { - return false - } - } - } - return true -} - -func parseAddress(b []byte) (*FMsgAddress, error) { - if len(b) < 4 { - return nil, fmt.Errorf("invalid address: too short (%d bytes)", len(b)) - } - var addr = &FMsgAddress{} - addrStr := string(b) - firstAt := strings.IndexRune(addrStr, AtRune) - if firstAt == -1 || firstAt != 0 { - return addr, fmt.Errorf("invalid address, must start with @ %s", addr) - } - lastAt := strings.LastIndex(addrStr, "@") - if lastAt == firstAt { - return addr, fmt.Errorf("invalid address, must have second @ %s", addr) - } - addr.User = addrStr[1:lastAt] - if !isValidUser(addr.User) { - return addr, fmt.Errorf("invalid user in address: %s", addr.User) - } - addr.Domain = addrStr[lastAt+1:] - if !isValidDomain(addr.Domain) { - return addr, fmt.Errorf("invalid domain in address: %s", addr.Domain) - } - return addr, nil -} - -// Reads byte slice prefixed with uint8 size from reader supplied -func ReadUInt8Slice(r io.Reader) ([]byte, error) { - var size byte - err := binary.Read(r, binary.LittleEndian, &size) - if err != nil { - return nil, err - } - return io.ReadAll(io.LimitReader(r, int64(size))) -} - -func readAddress(r io.Reader) (*FMsgAddress, error) { - slice, err := ReadUInt8Slice(r) - if err != nil { - return nil, err - } - return parseAddress(slice) -} - -func handleChallenge(c net.Conn, r *bufio.Reader) error { - hashSlice, err := io.ReadAll(io.LimitReader(r, 32)) - if err != nil { - return err - } - hash := *(*[32]byte)(hashSlice) - log.Printf("INFO: CHALLENGE <-- %s", hex.EncodeToString(hashSlice)) - - // Verify the challenger's IP is the Host-B IP registered for this message - // (§10.5 step 2). An unrecognised hash OR a mismatched IP both → TERMINATE. - remoteIP, _, _ := net.SplitHostPort(c.RemoteAddr().String()) - header, exists := lookupOutgoing(hash, remoteIP) - if !exists { - return fmt.Errorf("challenge for unknown message: %s, from: %s", hex.EncodeToString(hashSlice), c.RemoteAddr().String()) - } - msgHash, err := header.GetMessageHash() - if err != nil { - return err - } - if _, err := c.Write(msgHash); err != nil { - return err - } - return nil -} - -func rejectAccept(c net.Conn, codes []byte) error { - _, err := c.Write(codes) - return err -} - -func sendCode(c net.Conn, code uint8) error { - return rejectAccept(c, []byte{code}) -} - -func validateMessageFlags(c net.Conn, flags uint8) error { - if flags&messageReservedBitsMask != 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("reserved message flag bits set: %#08b", flags) - } - return nil -} - -func validateAttachmentFlags(c net.Conn, flags uint8) error { - if flags&attachmentReservedBitsMask != 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("reserved attachment flag bits set: %#08b", flags) - } - return nil -} - -func hasDomainRecipient(addrs []FMsgAddress, domain string) bool { - for _, addr := range addrs { - if strings.EqualFold(addr.Domain, domain) { - return true - } - } - return false -} - -func determineSenderDomain(h *FMsgHeader) string { - if len(h.AddTo) > 0 && h.AddToFrom != nil { - return h.AddToFrom.Domain - } - return h.From.Domain -} - -func verifySenderIP(c net.Conn, senderDomain string) error { - if SkipAuthorisedIPs { - return nil - } - - remoteHost, _, err := net.SplitHostPort(c.RemoteAddr().String()) - if err != nil { - log.Printf("WARN: failed to parse remote address for DNS check: %s", err) - return fmt.Errorf("DNS verification failed") - } - - remoteIP := net.ParseIP(remoteHost) - if remoteIP == nil { - log.Printf("WARN: failed to parse remote IP: %s", remoteHost) - return fmt.Errorf("DNS verification failed") - } - - authorisedIPs, err := lookupAuthorisedIPs(senderDomain) - if err != nil { - log.Printf("WARN: DNS lookup failed for fmsg.%s: %s", senderDomain, err) - return fmt.Errorf("DNS verification failed") - } - - for _, ip := range authorisedIPs { - if remoteIP.Equal(ip) { - return nil - } - } - - log.Printf("WARN: remote IP %s not in authorised IPs for fmsg.%s", remoteIP.String(), senderDomain) - return fmt.Errorf("DNS verification failed") -} - -func handleAddToPath(c net.Conn, h *FMsgHeader) (*FMsgHeader, error) { - if len(h.AddTo) == 0 { - return h, nil - } - - addToHasOurDomain := hasDomainRecipient(h.AddTo, Domain) - - parentID, err := lookupMsgIdByHash(h.Pid) - if err != nil { - return h, err - } - - if parentID == 0 { - h.InitialResponseCode = AcceptCodeContinue - return h, nil - } - - parentMsg, err := getMsgByID(parentID) - if err != nil { - return h, err - } - if parentMsg == nil || !isMessageRetrievable(parentMsg) { - h.InitialResponseCode = AcceptCodeContinue - return h, nil - } - - if parentMsg.Timestamp-FutureTimeDelta > h.Timestamp { - if err := sendCode(c, RejectCodeTimeTravel); err != nil { - return h, err - } - return h, fmt.Errorf("add-to: time travel detected (parent time %f, current %f)", parentMsg.Timestamp, h.Timestamp) - } - - if addToHasOurDomain { - h.InitialResponseCode = AcceptCodeSkipData - return h, nil - } - - h.Filepath = parentMsg.Filepath - for i := range h.Attachments { - if i < len(parentMsg.Attachments) { - h.Attachments[i].Filepath = parentMsg.Attachments[i].Filepath - } - } - h.InitialResponseCode = AcceptCodeAddTo - return h, nil -} - -func validatePidReplyPath(c net.Conn, h *FMsgHeader) error { - if len(h.AddTo) != 0 || h.Flags&FlagHasPid == 0 { - return nil - } - - parentID, err := lookupMsgIdByHash(h.Pid) - if err != nil { - return err - } - if parentID == 0 { - if err := sendCode(c, RejectCodeParentNotFound); err != nil { - return err - } - return fmt.Errorf("pid reply: parent not found for pid %s", hex.EncodeToString(h.Pid)) - } - - parentMsg, err := getMsgByID(parentID) - if err != nil { - return err - } - if parentMsg == nil { - if err := sendCode(c, RejectCodeParentNotFound); err != nil { - return err - } - return fmt.Errorf("pid reply: parent message not found by ID %d", parentID) - } - if !isMessageRetrievable(parentMsg) { - if err := sendCode(c, RejectCodeParentNotFound); err != nil { - return err - } - return fmt.Errorf("pid reply: parent is not retrievable for msg %d", parentID) - } - - if parentMsg.Timestamp-FutureTimeDelta > h.Timestamp { - if err := sendCode(c, RejectCodeTimeTravel); err != nil { - return err - } - return fmt.Errorf("pid reply: time travel detected (parent time %f, current %f)", parentMsg.Timestamp, h.Timestamp) - } - if !isParentParticipant(parentMsg, &h.From) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("pid reply: sender %s was not a participant of parent", h.From.ToString()) - } - - return nil -} - -func readVersionOrChallenge(c net.Conn, r *bufio.Reader, h *FMsgHeader) (bool, error) { - v, err := r.ReadByte() - if err != nil { - return false, err - } - if v >= 129 { - challengeVersion := 256 - int(v) - if challengeVersion == 1 { - return true, handleChallenge(c, r) - } - if err := sendCode(c, RejectCodeUnsupportedVersion); err != nil { - log.Printf("WARN: failed to send unsupported version response: %s", err) - } - return false, fmt.Errorf("unsupported challenge version: %d", challengeVersion) - } - if v != 1 { - if err := sendCode(c, RejectCodeUnsupportedVersion); err != nil { - log.Printf("WARN: failed to send unsupported version response: %s", err) - } - return false, fmt.Errorf("unsupported message version: %d", v) - } - h.Version = v - return false, nil -} - -func readToRecipients(c net.Conn, r *bufio.Reader, h *FMsgHeader) (map[string]bool, error) { - num, err := r.ReadByte() - if err != nil { - return nil, err - } - if num == 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return nil, err - } - return nil, fmt.Errorf("to count must be >= 1") - } - seen := make(map[string]bool) - for num > 0 { - addr, err := readAddress(r) - if err != nil { - return nil, err - } - key := strings.ToLower(addr.ToString()) - if seen[key] { - return nil, fmt.Errorf("duplicate recipient address: %s", addr.ToString()) - } - seen[key] = true - h.To = append(h.To, *addr) - num-- - } - return seen, nil -} - -func readAddToRecipients(c net.Conn, r *bufio.Reader, h *FMsgHeader, seen map[string]bool) error { - if h.Flags&FlagHasAddTo == 0 { - return nil - } - if h.Flags&FlagHasPid == 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("add to exists but pid does not") - } - - addToFrom, err := readAddress(r) - if err != nil { - if err2 := sendCode(c, RejectCodeInvalid); err2 != nil { - return err2 - } - return fmt.Errorf("reading add-to-from address: %w", err) - } - - addToFromKey := strings.ToLower(addToFrom.ToString()) - fromKey := strings.ToLower(h.From.ToString()) - inFromOrTo := fromKey == addToFromKey - if !inFromOrTo { - for _, toAddr := range h.To { - if strings.ToLower(toAddr.ToString()) == addToFromKey { - inFromOrTo = true - break - } - } - } - if !inFromOrTo { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("add-to-from (%s) not in from or to", addToFrom.ToString()) - } - h.AddToFrom = addToFrom - - addToCount, err := r.ReadByte() - if err != nil { - return err - } - if addToCount == 0 { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("add to flag set but count is 0") - } - - addToSeen := make(map[string]bool) - for addToCount > 0 { - addr, err := readAddress(r) - if err != nil { - return err - } - key := strings.ToLower(addr.ToString()) - if addToSeen[key] { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("duplicate recipient address in add to: %s", addr.ToString()) - } - addToSeen[key] = true - if seen[key] { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("add-to address already in to: %s", addr.ToString()) - } - h.AddTo = append(h.AddTo, *addr) - addToCount-- - } - - return nil -} - -func readAndValidateTimestamp(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { - if err := binary.Read(r, binary.LittleEndian, &h.Timestamp); err != nil { - return err - } - now := timeutil.TimestampNow().Float64() - delta := now - h.Timestamp - if PastTimeDelta > 0 && delta > PastTimeDelta { - if err := sendCode(c, RejectCodePastTime); err != nil { - return err - } - return fmt.Errorf("message timestamp: %f too far in past, delta: %fs", h.Timestamp, delta) - } - if FutureTimeDelta > 0 && delta < 0 && math.Abs(delta) > FutureTimeDelta { - if err := sendCode(c, RejectCodeFutureTime); err != nil { - return err - } - return fmt.Errorf("message timestamp: %f too far in future, delta: %fs", h.Timestamp, delta) - } - return nil -} - -func readMessageType(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { - if h.Flags&FlagCommonType != 0 { - typeID, err := r.ReadByte() - if err != nil { - return err - } - mtype, ok := getCommonMediaType(typeID) - if !ok { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("unmapped common type ID: %d", typeID) - } - h.TypeID = typeID - h.Type = mtype - return nil - } - - mime, err := ReadUInt8Slice(r) - if err != nil { - return err - } - if !isASCIIBytes(mime) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("message media type must be US-ASCII") - } - h.Type = string(mime) - return nil -} - -func readAttachmentType(c net.Conn, r *bufio.Reader, flags uint8) (string, uint8, error) { - if flags&(1<<0) != 0 { - typeID, err := r.ReadByte() - if err != nil { - return "", 0, err - } - mtype, ok := getCommonMediaType(typeID) - if !ok { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return "", 0, err - } - return "", 0, fmt.Errorf("unmapped attachment common type ID: %d", typeID) - } - return mtype, typeID, nil - } - - typeBytes, err := ReadUInt8Slice(r) - if err != nil { - return "", 0, err - } - if !isASCIIBytes(typeBytes) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return "", 0, err - } - return "", 0, fmt.Errorf("attachment media type must be US-ASCII") - } - return string(typeBytes), 0, nil -} - -func readAttachmentHeaders(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { - var attachCount uint8 - if err := binary.Read(r, binary.LittleEndian, &attachCount); err != nil { - return err - } - - totalSize := h.Size - // When message is compressed, expanded size comes from the header field. - // When uncompressed, the wire size IS the expanded size. - var totalExpandedSize uint32 - if h.Flags&FlagDeflate != 0 { - totalExpandedSize = h.ExpandedSize - } else { - totalExpandedSize = h.Size - } - filenameSeen := make(map[string]bool) - for i := uint8(0); i < attachCount; i++ { - attFlags, err := r.ReadByte() - if err != nil { - return err - } - if err := validateAttachmentFlags(c, attFlags); err != nil { - return err - } - - attType, attTypeID, err := readAttachmentType(c, r, attFlags) - if err != nil { - return err - } - - filenameBytes, err := ReadUInt8Slice(r) - if err != nil { - return err - } - filename := string(filenameBytes) - if !isValidAttachmentFilename(filename) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("invalid attachment filename: %s", filename) - } - filenameKey := strings.ToLower(filename) - if filenameSeen[filenameKey] { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return err - } - return fmt.Errorf("duplicate attachment filename: %s", filename) - } - filenameSeen[filenameKey] = true - - var attSize uint32 - if err := binary.Read(r, binary.LittleEndian, &attSize); err != nil { - return err - } - - // read attachment expanded size — present iff attachment zlib-deflate flag set (§5) - var attExpandedSize uint32 - if attFlags&(1<<1) != 0 { - if err := binary.Read(r, binary.LittleEndian, &attExpandedSize); err != nil { - return err - } - totalExpandedSize += attExpandedSize - } else { - // uncompressed: expanded size equals wire size - totalExpandedSize += attSize - } - - h.Attachments = append(h.Attachments, FMsgAttachmentHeader{ - Flags: attFlags, - TypeID: attTypeID, - Type: attType, - Filename: filename, - Size: attSize, - ExpandedSize: attExpandedSize, - }) - totalSize += attSize - } - - if totalSize > MaxMessageSize { - if err := sendCode(c, RejectCodeTooBig); err != nil { - return err - } - return fmt.Errorf("total message size %d exceeds max %d", totalSize, MaxMessageSize) - } - - if totalExpandedSize > MaxExpandedSize { - if err := sendCode(c, RejectCodeTooBig); err != nil { - return err - } - return fmt.Errorf("total expanded size %d exceeds MAX_EXPANDED_SIZE %d", totalExpandedSize, MaxExpandedSize) - } - - return nil -} - -func readHeader(c net.Conn) (*FMsgHeader, *bufio.Reader, error) { - r := bufio.NewReaderSize(c, ReadBufferSize) - var h = &FMsgHeader{InitialResponseCode: AcceptCodeContinue} - - d := calcNetIODuration(66000, MinDownloadRate) // max possible header size - c.SetReadDeadline(time.Now().Add(d)) - - handled, err := readVersionOrChallenge(c, r, h) - if err != nil { - if handled { - return nil, r, err - } - return h, r, err - } - if handled { - return nil, r, nil - } - - // read flags - flags, err := r.ReadByte() - if err != nil { - return h, r, err - } - h.Flags = flags - if err := validateMessageFlags(c, flags); err != nil { - return h, r, err - } - - // read pid if any - if flags&FlagHasPid == 1 { - pid, err := io.ReadAll(io.LimitReader(r, 32)) - if err != nil { - return h, r, err - } - h.Pid = make([]byte, 32) - copy(h.Pid, pid) - } - - // read from address - from, err := readAddress(r) - if err != nil { - return h, r, err - } - - h.From = *from - - seen, err := readToRecipients(c, r, h) - if err != nil { - return h, r, err - } - - if err := readAddToRecipients(c, r, h, seen); err != nil { - return h, r, err - } - - if err := readAndValidateTimestamp(c, r, h); err != nil { - return h, r, err - } - - // read topic — only present when pid is NOT set (first message in a thread) - if flags&FlagHasPid == 0 { - topic, err := ReadUInt8Slice(r) - if err != nil { - return h, r, err - } - h.Topic = string(topic) - } - - if err := readMessageType(c, r, h); err != nil { - return h, r, err - } - - // read message size - if err := binary.Read(r, binary.LittleEndian, &h.Size); err != nil { - return h, r, err - } - // read expanded size — present iff zlib-deflate flag is set (§2 field 12) - if h.Flags&FlagDeflate != 0 { - if err := binary.Read(r, binary.LittleEndian, &h.ExpandedSize); err != nil { - return h, r, err - } - if h.ExpandedSize > MaxExpandedSize { - if err := sendCode(c, RejectCodeTooBig); err != nil { - return h, r, err - } - return h, r, fmt.Errorf("expanded size %d exceeds MAX_EXPANDED_SIZE %d", h.ExpandedSize, MaxExpandedSize) - } - } - // Size check is deferred until attachment headers are parsed (see below) - - if err := readAttachmentHeaders(c, r, h); err != nil { - return h, r, err - } - - log.Printf("INFO: <-- MSG\n%s", h) - - if !hasDomainRecipient(h.To, Domain) && !hasDomainRecipient(h.AddTo, Domain) { - if err := sendCode(c, RejectCodeInvalid); err != nil { - return h, r, err - } - return h, r, fmt.Errorf("no recipients for domain %s", Domain) - } - - if err := verifySenderIP(c, determineSenderDomain(h)); err != nil { - return nil, r, err - } - - h, err = handleAddToPath(c, h) - if err != nil { - return h, r, err - } - if h == nil { - return nil, r, nil - } - - if err := validatePidReplyPath(c, h); err != nil { - return h, r, err - } - - return h, r, nil -} - -// Sends CHALLENGE request to sender, receiving and storing the challenge hash. -// DNS verification of the remote IP is performed during header exchange (readHeader). -// TODO [Spec step 2]: The spec defines challenge modes (NEVER, ALWAYS, -// HAS_NOT_PARTICIPATED, DIFFERENT_DOMAIN) as implementation choices. -// Currently defaults to ALWAYS. Implement configurable challenge mode. -func challenge(conn net.Conn, h *FMsgHeader, senderDomain string) error { - - // Connection 2 MUST target the same IP as Connection 1 (spec 2.1). - remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String()) - if err != nil { - return fmt.Errorf("failed to parse remote address for challenge: %w", err) - } - conn2, err := tls.Dial("tcp", net.JoinHostPort(remoteHost, fmt.Sprintf("%d", RemotePort)), buildClientTLSConfig("fmsg."+senderDomain)) - if err != nil { - return err - } - version := uint8(255) - if err := binary.Write(conn2, binary.LittleEndian, version); err != nil { - return err - } - hash := h.GetHeaderHash() - log.Printf("INFO: --> CHALLENGE\t%s\n", hex.EncodeToString(hash)) - if _, err := conn2.Write(hash); err != nil { - return err - } - - // read challenge response - resp, err := io.ReadAll(io.LimitReader(conn2, 32)) - if err != nil { - return err - } - if len(resp) != 32 { - return fmt.Errorf("challenge response size %d, expected 32", len(resp)) - } - copy(h.ChallengeHash[:], resp) - h.ChallengeCompleted = true - log.Printf("INFO: <-- CHALLENGE RESP\t%s\n", hex.EncodeToString(resp)) - - // gracefully close 2nd connection - if err := conn2.Close(); err != nil { - return err - } - - return nil -} - -func validateMsgRecvForAddr(h *FMsgHeader, addr *FMsgAddress, msgHash []byte) (code uint8, err error) { - duplicate, err := hasAddrReceivedMsgHash(msgHash, addr) - if err != nil { - return RejectCodeUserUndisclosed, err - } - if duplicate { - return RejectCodeUserDuplicate, nil - } - - detail, err := getAddressDetail(addr) - if err != nil { - return RejectCodeUserUndisclosed, err - } - if detail == nil { - return RejectCodeUserUnknown, nil - } - - // check user accepting new - if !detail.AcceptingNew { - return RejectCodeUserNotAccepting, nil - } - - // check user limits - if detail.LimitRecvCountPer1d > -1 && detail.RecvCountPer1d+1 > detail.LimitRecvCountPer1d { - log.Printf("WARN: Message rejected: RecvCountPer1d would exceed LimitRecvCountPer1d %d", detail.LimitRecvCountPer1d) - return RejectCodeUserFull, nil - } - if detail.LimitRecvSizePer1d > -1 && detail.RecvSizePer1d+int64(h.Size) > detail.LimitRecvSizePer1d { - log.Printf("WARN: Message rejected: RecvSizePer1d would exceed LimitRecvSizePer1d %d", detail.LimitRecvSizePer1d) - return RejectCodeUserFull, nil - } - if detail.LimitRecvSizeTotal > -1 && detail.RecvSizeTotal+int64(h.Size) > detail.LimitRecvSizeTotal { - log.Printf("WARN: Message rejected: RecvSizeTotal would exceed LimitRecvSizeTotal %d", detail.LimitRecvSizeTotal) - return RejectCodeUserFull, nil - } - - return RejectCodeAccept, nil -} - -// uniqueFilepath generates a unique file path in the given directory, -// appending a counter suffix if the base name already exists. -func uniqueFilepath(dir string, timestamp uint32, ext string) string { - base := fmt.Sprintf("%d", timestamp) - fp := filepath.Join(dir, base+ext) - if _, err := os.Stat(fp); os.IsNotExist(err) { - return fp - } - for i := 1; ; i++ { - fp = filepath.Join(dir, fmt.Sprintf("%s_%d%s", base, i, ext)) - if _, err := os.Stat(fp); os.IsNotExist(err) { - return fp - } - } -} - -func localRecipients(h *FMsgHeader) []FMsgAddress { - addrs := make([]FMsgAddress, 0, len(h.To)+len(h.AddTo)) - for _, addr := range h.To { - if strings.EqualFold(addr.Domain, Domain) { - addrs = append(addrs, addr) - } - } - for _, addr := range h.AddTo { - if strings.EqualFold(addr.Domain, Domain) { - addrs = append(addrs, addr) - } - } - return addrs -} - -func allLocalRecipientsHaveMessageHash(msgHash []byte, addrs []FMsgAddress) (bool, error) { - if len(addrs) == 0 { - return false, nil - } - for i := range addrs { - duplicate, err := hasAddrReceivedMsgHash(msgHash, &addrs[i]) - if err != nil { - return false, err - } - if !duplicate { - return false, nil - } - } - return true, nil -} - -func markAllCodes(codes []byte, code uint8) { - for i := range codes { - codes[i] = code - } -} - -func prepareMessageData(r io.Reader, h *FMsgHeader, skipData bool) ([]string, error) { - if skipData { - parentID, err := lookupMsgIdByHash(h.Pid) - if err != nil { - return nil, err - } - if parentID == 0 { - return nil, fmt.Errorf("%w code 65 requires stored parent for pid %s", ErrProtocolViolation, hex.EncodeToString(h.Pid)) - } - parentMsg, err := getMsgByID(parentID) - if err != nil { - return nil, err - } - if parentMsg == nil || parentMsg.Filepath == "" { - return nil, fmt.Errorf("%w code 65 parent data unavailable for msg %d", ErrProtocolViolation, parentID) - } - h.Filepath = parentMsg.Filepath - return nil, nil - } - - createdPaths := make([]string, 0, 1+len(h.Attachments)) - - fd, err := os.CreateTemp("", "fmsg-download-*") - if err != nil { - return nil, err - } - - if _, err := io.CopyN(fd, r, int64(h.Size)); err != nil { - fd.Close() - _ = os.Remove(fd.Name()) - return nil, err - } - if err := fd.Close(); err != nil { - _ = os.Remove(fd.Name()) - return nil, err - } - - h.Filepath = fd.Name() - createdPaths = append(createdPaths, fd.Name()) - - for i := range h.Attachments { - afd, err := os.CreateTemp("", "fmsg-attachment-*") - if err != nil { - for _, path := range createdPaths { - _ = os.Remove(path) - } - return nil, err - } - - if _, err := io.CopyN(afd, r, int64(h.Attachments[i].Size)); err != nil { - afd.Close() - _ = os.Remove(afd.Name()) - for _, path := range createdPaths { - _ = os.Remove(path) - } - return nil, err - } - if err := afd.Close(); err != nil { - _ = os.Remove(afd.Name()) - for _, path := range createdPaths { - _ = os.Remove(path) - } - return nil, err - } - h.Attachments[i].Filepath = afd.Name() - createdPaths = append(createdPaths, afd.Name()) - } - - return createdPaths, nil -} - -func cleanupFiles(paths []string) { - for _, path := range paths { - if path == "" { - continue - } - _ = os.Remove(path) - } -} - -func copyMessagePayload(src *os.File, dstPath string, compressed bool, wireSize uint32) error { - if _, err := src.Seek(0, io.SeekStart); err != nil { - return err - } - - fd2, err := os.Create(dstPath) - if err != nil { - return err - } - - var copyErr error - if compressed { - lr := io.LimitReader(src, int64(wireSize)) - zr, err := zlib.NewReader(lr) - if err != nil { - fd2.Close() - _ = os.Remove(dstPath) - return err - } - _, copyErr = io.Copy(fd2, zr) - _ = zr.Close() - } else { - _, copyErr = io.CopyN(fd2, src, int64(wireSize)) - } - if err := fd2.Close(); err != nil { - return err - } - - if copyErr != nil { - _ = os.Remove(dstPath) - return copyErr - } - return nil -} - -func uniqueAttachmentPath(dir string, timestamp uint32, idx int, filename string) string { - ext := filepath.Ext(filename) - base := fmt.Sprintf("%d_att_%d", timestamp, idx) - p := filepath.Join(dir, base+ext) - if _, err := os.Stat(p); os.IsNotExist(err) { - return p - } - for n := 1; ; n++ { - p = filepath.Join(dir, fmt.Sprintf("%s_%d%s", base, n, ext)) - if _, err := os.Stat(p); os.IsNotExist(err) { - return p - } - } -} - -func persistAttachmentPayloads(h *FMsgHeader, dirpath string) error { - for i := range h.Attachments { - a := &h.Attachments[i] - src, err := os.Open(a.Filepath) - if err != nil { - return err - } - dstPath := uniqueAttachmentPath(dirpath, uint32(h.Timestamp), i, a.Filename) - compressed := a.Flags&(1<<1) != 0 - err = copyMessagePayload(src, dstPath, compressed, a.Size) - src.Close() - if err != nil { - return err - } - a.Filepath = dstPath - } - return nil -} - -func storeAcceptedMessage(h *FMsgHeader, codes []byte, acceptedTo []FMsgAddress, acceptedAddTo []FMsgAddress, primaryFilepath string) bool { - if len(acceptedTo) == 0 && len(acceptedAddTo) == 0 { - return false - } - - origTo := h.To - origAddTo := h.AddTo - h.To = acceptedTo - h.AddTo = acceptedAddTo - h.Filepath = primaryFilepath - if err := storeMsgDetail(h); err != nil { - log.Printf("ERROR: storing message: %s", err) - h.To = origTo - h.AddTo = origAddTo - for i := range codes { - if codes[i] == RejectCodeAccept { - codes[i] = RejectCodeUndisclosed - } - } - return false - } - - h.To = origTo - h.AddTo = origAddTo - allAccepted := append(acceptedTo, acceptedAddTo...) - for i := range allAccepted { - if err := postMsgStatRecv(&allAccepted[i], h.Timestamp, int(h.Size)); err != nil { - log.Printf("WARN: Failed to post msg recv stat: %s", err) - } - } - return true -} - -func downloadMessage(c net.Conn, r io.Reader, h *FMsgHeader, skipData bool) error { - addrs := localRecipients(h) - if len(addrs) == 0 { - return fmt.Errorf("%w our domain: %s, not in recipient list", ErrProtocolViolation, Domain) - } - codes := make([]byte, len(addrs)) - - createdPaths, err := prepareMessageData(r, h, skipData) - if err != nil { - return err - } - cleanupOnReturn := !skipData - defer func() { - if cleanupOnReturn { - cleanupFiles(createdPaths) - } - }() - - // verify hash matches challenge response when challenge was completed - msgHash, err := h.GetMessageHash() - if err != nil { - return err - } - if h.ChallengeCompleted && !bytes.Equal(h.ChallengeHash[:], msgHash) { - challengeHashStr := hex.EncodeToString(h.ChallengeHash[:]) - actualHashStr := hex.EncodeToString(msgHash) - return fmt.Errorf("%w actual hash: %s mismatch challenge response: %s", ErrProtocolViolation, actualHashStr, challengeHashStr) - } - - // pid/add-to validation is handled during header exchange in readHeader(). - - // determine file extension from MIME type - exts, _ := mime.ExtensionsByType(h.Type) - var ext string - if exts == nil { - ext = ".unknown" - } else { - ext = exts[0] - } - - src, err := os.Open(h.Filepath) - if err != nil { - return err - } - defer src.Close() - - // validate each recipient and copy message for accepted ones - // Build a set of add-to addresses for later classification - addToSet := make(map[string]bool) - for _, addr := range h.AddTo { - addToSet[strings.ToLower(addr.ToString())] = true - } - acceptedTo := []FMsgAddress{} - acceptedAddTo := []FMsgAddress{} - var primaryFilepath string - for i, addr := range addrs { - code, err := validateMsgRecvForAddr(h, &addr, msgHash) - if err != nil { - return err - } - if code != RejectCodeAccept { - log.Printf("WARN: Rejected message to: %s: %s (%d)", addr.ToString(), responseCodeName(code), code) - codes[i] = code - continue - } - - // copy to recipient's directory - dirpath := filepath.Join(DataDir, addr.Domain, addr.User, InboxDirName) - if err := os.MkdirAll(dirpath, 0750); err != nil { - return err - } - - fp := uniqueFilepath(dirpath, uint32(h.Timestamp), ext) - if err := copyMessagePayload(src, fp, h.Flags&FlagDeflate != 0, h.Size); err != nil { - log.Printf("ERROR: copying downloaded message from: %s, to: %s", h.Filepath, fp) - codes[i] = RejectCodeUndisclosed - continue - } - - codes[i] = RejectCodeAccept - if addToSet[strings.ToLower(addr.ToString())] { - acceptedAddTo = append(acceptedAddTo, addr) - } else { - acceptedTo = append(acceptedTo, addr) - } - if primaryFilepath == "" { - primaryFilepath = fp - if err := persistAttachmentPayloads(h, filepath.Dir(primaryFilepath)); err != nil { - log.Printf("ERROR: copying attachment payloads for message storage: %s", err) - codes[i] = RejectCodeUndisclosed - primaryFilepath = "" - acceptedTo = acceptedTo[:0] - acceptedAddTo = acceptedAddTo[:0] - continue - } - } - } - - stored := storeAcceptedMessage(h, codes, acceptedTo, acceptedAddTo, primaryFilepath) - if stored { - cleanupOnReturn = false - } - - return rejectAccept(c, codes) -} - -// resolvePostChallengeCode determines the initial response code to send after -// the optional challenge (§10.4). Code 11 (accept add-to) is returned as-is -// since it has no local recipients to duplicate-check. For the skip-data (65) -// and continue (64) paths, a completed challenge with all-local-duplicate -// produces code 10 (duplicate) instead. -func resolvePostChallengeCode(initialCode uint8, challengeCompleted bool, allLocalDup bool) uint8 { - if initialCode == AcceptCodeAddTo { - return AcceptCodeAddTo - } - if challengeCompleted && allLocalDup { - return RejectCodeDuplicate - } - if initialCode == AcceptCodeSkipData { - return AcceptCodeSkipData - } - return AcceptCodeContinue -} - -func abortConn(c net.Conn) { - if tcp, ok := c.(*net.TCPConn); ok { - tcp.SetLinger(0) - } - _ = c.Close() -} - -type responseTrackingConn struct { - net.Conn - wroteResponse bool -} - -func (c *responseTrackingConn) Write(b []byte) (int, error) { - n, err := c.Conn.Write(b) - if n > 0 { - c.wroteResponse = true - } - return n, err -} - -func handleConn(c net.Conn) { - defer func() { - if r := recover(); r != nil { - log.Printf("ERROR: Recovered in handleConn: %v", r) - } - }() - - log.Printf("INFO: Connection from: %s\n", c.RemoteAddr().String()) - tc := &responseTrackingConn{Conn: c} - - // read header - header, r, err := readHeader(tc) - if err != nil { - log.Printf("WARN: reading header from, %s: %s", c.RemoteAddr().String(), err) - if tc.wroteResponse { - _ = c.Close() - return - } - abortConn(c) - return - } - - // if no header AND no error this was a challenge thats been handeled - if header == nil { - c.Close() - return - } - - if err := challenge(tc, header, determineSenderDomain(header)); err != nil { - log.Printf("ERROR: Challenge failed to, %s: %s", c.RemoteAddr().String(), err) - abortConn(c) - return - } - - // §10.4: Determine initial response code after optional challenge. - // Code 11 (add-to, no local recipients) does not need a dup check. - // Codes 65 and 64 both require a dup check when challenge was completed. - allLocalDup := false - if header.ChallengeCompleted && header.InitialResponseCode != AcceptCodeAddTo { - addrs := localRecipients(header) - var err error - allLocalDup, err = allLocalRecipientsHaveMessageHash(header.ChallengeHash[:], addrs) - if err != nil { - log.Printf("ERROR: duplicate check failed for %s: %s", c.RemoteAddr().String(), err) - if err := sendCode(c, RejectCodeUndisclosed); err != nil { - abortConn(c) - return - } - _ = c.Close() - return - } - } - - code := resolvePostChallengeCode(header.InitialResponseCode, header.ChallengeCompleted, allLocalDup) - skipData := false - - switch code { - case AcceptCodeAddTo: - // No local add-to recipients; store header and respond code 11, close. - if err := storeMsgHeaderOnly(header); err != nil { - log.Printf("ERROR: storing add-to header: %s", err) - if err := sendCode(c, RejectCodeUndisclosed); err != nil { - abortConn(c) - return - } - _ = c.Close() - return - } - if err := sendCode(c, AcceptCodeAddTo); err != nil { - log.Printf("ERROR: failed sending code 11 to %s: %s", c.RemoteAddr().String(), err) - abortConn(c) - return - } - log.Printf("INFO: additional recipients received (code 11) for pid %s", hex.EncodeToString(header.Pid)) - c.Close() - return - case RejectCodeDuplicate: - if err := sendCode(c, RejectCodeDuplicate); err != nil { - log.Printf("ERROR: failed sending code 10 to %s: %s", c.RemoteAddr().String(), err) - } - c.Close() - return - case AcceptCodeSkipData: - if err := sendCode(c, AcceptCodeSkipData); err != nil { - log.Printf("ERROR: failed sending code 65 to %s: %s", c.RemoteAddr().String(), err) - abortConn(c) - return - } - skipData = true - log.Printf("INFO: sent code 65 (skip data) to %s", c.RemoteAddr().String()) - default: - if err := sendCode(c, AcceptCodeContinue); err != nil { - log.Printf("ERROR: failed sending code 64 to %s: %s", c.RemoteAddr().String(), err) - abortConn(c) - return - } - log.Printf("INFO: sent code 64 (continue) to %s", c.RemoteAddr().String()) - } - - // store message - deadlineBytes := int(header.Size) - if skipData { - deadlineBytes = 1 - } - c.SetReadDeadline(time.Now().Add(calcNetIODuration(deadlineBytes, MinDownloadRate))) - if err := downloadMessage(c, r, header, skipData); err != nil { - // if error was a protocal violation, abort; otherise let sender know there was an internal error - log.Printf("ERROR: Download failed from, %s: %s", c.RemoteAddr().String(), err) - if errors.Is(err, ErrProtocolViolation) { - abortConn(c) - return - } else { - _ = sendCode(c, RejectCodeUndisclosed) - } - } - - // gracefully close 1st connection - c.Close() -} - -func main() { - - initOutgoing() - - // load environment variables from .env file if present - if err := godotenv.Load(); err != nil { - log.Printf("INFO: Could not load .env file: %v", err) - } - - // read env config (must be after godotenv.Load) - loadEnvConfig() - - // determine listen address from args - listenAddress := "127.0.0.1" - for _, arg := range os.Args[1:] { - listenAddress = arg - } - - // initalize database - err := testDb() - if err != nil { - log.Fatalf("ERROR: connecting to database: %s\n", err) - } - - // set DataDir, Domain and IDURL from env - setDataDir() - setDomain() - setIDURL() - - // load TLS configuration (must be after loadEnvConfig for FMSG_TLS_INSECURE_SKIP_VERIFY) - serverTLSConfig = buildServerTLSConfig() - - // start sender in background (small delay so listener is ready first) - go func() { - time.Sleep(1 * time.Second) - startSender() - }() - - // start listening - addr := fmt.Sprintf("%s:%d", listenAddress, Port) - ln, err := tls.Listen("tcp", addr, serverTLSConfig) - if err != nil { - log.Fatal(err) - } - log.Printf("INFO: Ready to receive on %s\n", addr) - for { - conn, err := ln.Accept() - if err != nil { - log.Printf("ERROR: Accept connection from %s returned: %s\n", ln.Addr().String(), err) - } else { - go handleConn(conn) - } - } - -} +package main + +import ( + "bufio" + "bytes" + "compress/zlib" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "math" + "mime" + "net" + "net/url" + "os" + "path/filepath" + "strings" + "time" + "unicode" + "unicode/utf8" + + env "github.com/caitlinelfring/go-env-default" + "github.com/joho/godotenv" + "github.com/levenlabs/golib/timeutil" + + "github.com/markmnl/fmsgd/pkg/fmsg" +) + +const ( + InboxDirName = "in" + OutboxDirName = "out" + + RejectCodeInvalid uint8 = 1 + RejectCodeUnsupportedVersion uint8 = 2 + RejectCodeUndisclosed uint8 = 3 + RejectCodeTooBig uint8 = 4 + RejectCodeInsufficentResources uint8 = 5 + RejectCodeParentNotFound uint8 = 6 + RejectCodePastTime uint8 = 7 + RejectCodeFutureTime uint8 = 8 + RejectCodeTimeTravel uint8 = 9 + RejectCodeDuplicate uint8 = 10 + AcceptCodeAddTo uint8 = 11 + AcceptCodeContinue uint8 = 64 + AcceptCodeSkipData uint8 = 65 + + RejectCodeUserUnknown uint8 = 100 + RejectCodeUserFull uint8 = 101 + RejectCodeUserNotAccepting uint8 = 102 + RejectCodeUserDuplicate uint8 = 103 + RejectCodeUserUndisclosed uint8 = 105 + + RejectCodeAccept uint8 = 200 + + messageReservedBitsMask uint8 = 0b11000000 + attachmentReservedBitsMask uint8 = 0b11111100 +) + +// responseCodeName returns the human-friendly name for a response code. +func responseCodeName(code uint8) string { + switch code { + case RejectCodeInvalid: + return "invalid" + case RejectCodeUnsupportedVersion: + return "unsupported version" + case RejectCodeUndisclosed: + return "undisclosed" + case RejectCodeTooBig: + return "too big" + case RejectCodeInsufficentResources: + return "insufficient resources" + case RejectCodeParentNotFound: + return "parent not found" + case RejectCodePastTime: + return "past time" + case RejectCodeFutureTime: + return "future time" + case RejectCodeTimeTravel: + return "time travel" + case RejectCodeDuplicate: + return "duplicate" + case AcceptCodeAddTo: + return "accept add to" + case RejectCodeUserUnknown: + return "user unknown" + case RejectCodeUserFull: + return "user full" + case RejectCodeUserNotAccepting: + return "user not accepting" + case RejectCodeUserDuplicate: + return "user duplicate" + case RejectCodeUserUndisclosed: + return "user undisclosed" + case RejectCodeAccept: + return "accept" + default: + return fmt.Sprintf("unknown(%d)", code) + } +} + +var ErrProtocolViolation = errors.New("protocol violation") + +var Port = 4930 + +// The only reason RemotePort would ever be different from Port is when running two fmsg hosts on the same machine so the same port is unavaliable. +var RemotePort = 4930 +var PastTimeDelta float64 = 7 * 24 * 60 * 60 +var FutureTimeDelta float64 = 300 +var MinDownloadRate float64 = 5000 +var MinUploadRate float64 = 5000 +var ReadBufferSize = 1600 +var MaxMessageSize = uint32(1024 * 10) +var MaxExpandedSize = uint32(1024 * 10) +var SkipAuthorisedIPs = false +var TLSInsecureSkipVerify = false +var DataDir = "got on startup" +var Domain = "got on startup" +var IDURI = "got on startup" +var AtRune, _ = utf8.DecodeRuneInString("@") +var MinNetIODeadline = 6 * time.Second + +var serverTLSConfig *tls.Config + +func buildServerTLSConfig() *tls.Config { + certFile := os.Getenv("FMSG_TLS_CERT") + keyFile := os.Getenv("FMSG_TLS_KEY") + if certFile == "" || keyFile == "" { + log.Fatalf("ERROR: FMSG_TLS_CERT and FMSG_TLS_KEY must be set") + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalf("ERROR: loading TLS certificate: %s", err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + }, + NextProtos: []string{"fmsg/1"}, + } +} + +func buildClientTLSConfig(serverName string) *tls.Config { + return &tls.Config{ + ServerName: serverName, + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: TLSInsecureSkipVerify, + NextProtos: []string{"fmsg/1"}, + } +} + +// loadEnvConfig reads env vars (after godotenv.Load so .env is picked up). +func loadEnvConfig() { + Port = env.GetIntDefault("FMSG_PORT", 4930) + RemotePort = env.GetIntDefault("FMSG_REMOTE_PORT", 4930) + PastTimeDelta = env.GetFloatDefault("FMSG_MAX_PAST_TIME_DELTA", 7*24*60*60) + FutureTimeDelta = env.GetFloatDefault("FMSG_MAX_FUTURE_TIME_DELTA", 300) + MinDownloadRate = env.GetFloatDefault("FMSG_MIN_DOWNLOAD_RATE", 5000) + MinUploadRate = env.GetFloatDefault("FMSG_MIN_UPLOAD_RATE", 5000) + ReadBufferSize = env.GetIntDefault("FMSG_READ_BUFFER_SIZE", 1600) + MaxMessageSize = uint32(env.GetIntDefault("FMSG_MAX_MSG_SIZE", 1024*10)) + MaxExpandedSize = uint32(env.GetIntDefault("FMSG_MAX_EXPANDED_SIZE", int(MaxMessageSize))) + SkipAuthorisedIPs = os.Getenv("FMSG_SKIP_AUTHORISED_IPS") == "true" + TLSInsecureSkipVerify = os.Getenv("FMSG_TLS_INSECURE_SKIP_VERIFY") == "true" +} + +// Updates DataDir from environment, panics if not a valid directory. +func setDataDir() { + value, hasValue := os.LookupEnv("FMSG_DATA_DIR") + if !hasValue { + log.Panic("ERROR: FMSG_DATA_DIR not set") + } + _, err := os.Stat(value) + if err != nil { + log.Panicf("ERROR: FMSG_DATA_DIR, %s: %s", value, err) + } + DataDir = value +} + +// Updates Domain from environment, panics if not a valid domain. +func setDomain() { + domain, hasValue := os.LookupEnv("FMSG_DOMAIN") + if !hasValue { + log.Panicln("ERROR: FMSG_DOMAIN not set") + } + _, err := net.LookupHost("fmsg." + domain) + if err != nil { + log.Panicf("ERROR: FMSG_DOMAIN, %s: %s\n", domain, err) + } + Domain = domain + + // verify our external IP is in the fmsg authorised IP set + checkDomainIP(domain) +} + +// Updates IDURL from environment, panics if not valid. +func setIDURL() { + rawUrl, hasValue := os.LookupEnv("FMSG_ID_URL") + if !hasValue { + log.Panicln("ERROR: FMSG_ID_URL not set") + } + url, err := url.Parse(rawUrl) + if err != nil { + log.Panicf("ERROR: FMSG_ID_URL not valid, %s: %s", rawUrl, err) + } + _, err = net.LookupHost(url.Hostname()) + if err != nil { + log.Panicf("ERROR: FMSG_ID_URL lookup failed, %s: %s", url, err) + } + // TODO ping URL to verify its up and responding in a timely manner + IDURI = rawUrl + log.Printf("INFO: ID URL: %s", IDURI) +} + +func calcNetIODuration(sizeInBytes int, bytesPerSecond float64) time.Duration { + rate := float64(sizeInBytes) / bytesPerSecond + d := time.Duration(rate * float64(time.Second)) + if d < MinNetIODeadline { + return MinNetIODeadline + } + return d +} + +func isValidUser(s string) bool { + if !utf8.ValidString(s) || len(s) == 0 || len(s) > 64 { + return false + } + + isSpecial := func(r rune) bool { + return r == '-' || r == '_' || r == '.' + } + + runes := []rune(s) + if isSpecial(runes[0]) || isSpecial(runes[len(runes)-1]) { + return false + } + + lastWasSpecial := false + for _, c := range runes { + if unicode.IsLetter(c) || unicode.IsNumber(c) { + lastWasSpecial = false + continue + } + if !isSpecial(c) { + return false + } + if lastWasSpecial { + return false + } + lastWasSpecial = true + } + return true +} + +func isASCIIBytes(b []byte) bool { + for _, c := range b { + if c > 127 { + return false + } + } + return true +} + +func isValidAttachmentFilename(name string) bool { + if !utf8.ValidString(name) || len(name) == 0 || len(name) >= 256 { + return false + } + + isSpecial := func(r rune) bool { + return r == '-' || r == '_' || r == ' ' || r == '.' + } + + runes := []rune(name) + if isSpecial(runes[0]) || isSpecial(runes[len(runes)-1]) { + return false + } + + lastWasSpecial := false + for _, r := range runes { + if unicode.IsLetter(r) || unicode.IsNumber(r) { + lastWasSpecial = false + continue + } + if !isSpecial(r) { + return false + } + if lastWasSpecial { + return false + } + lastWasSpecial = true + } + + return true +} + +func isMessageRetrievable(msg *FMsgHeader) bool { + if msg == nil { + return false + } + if msg.Filepath != "" { + st, err := os.Stat(msg.Filepath) + if err == nil && !st.IsDir() { + return true + } + } + if len(msg.Pid) == 0 { + return false + } + parentID, err := lookupMsgIdByHash(msg.Pid) + if err != nil || parentID == 0 { + return false + } + parentMsg, err := getMsgByID(parentID) + if err != nil { + return false + } + if parentMsg == nil { + return false + } + return isMessageRetrievable(parentMsg) +} + +func isParentParticipant(parent *FMsgHeader, addr *FMsgAddress) bool { + if parent == nil || addr == nil { + return false + } + target := strings.ToLower(addr.ToString()) + if strings.ToLower(parent.From.ToString()) == target { + return true + } + for i := range parent.To { + if strings.ToLower(parent.To[i].ToString()) == target { + return true + } + } + if parent.AddToFrom != nil && strings.ToLower(parent.AddToFrom.ToString()) == target { + return true + } + for i := range parent.AddTo { + if strings.ToLower(parent.AddTo[i].ToString()) == target { + return true + } + } + return false +} + +func isValidDomain(s string) bool { + if len(s) == 0 || len(s) > 253 { + return false + } + if s == "localhost" { + return true + } + labels := strings.Split(s, ".") + if len(labels) < 2 { + return false + } + for _, label := range labels { + if len(label) == 0 || len(label) > 63 { + return false + } + if label[0] == '-' || label[len(label)-1] == '-' { + return false + } + for _, c := range label { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-') { + return false + } + } + } + return true +} + +func parseAddress(b []byte) (*FMsgAddress, error) { + if len(b) < 4 { + return nil, fmt.Errorf("invalid address: too short (%d bytes)", len(b)) + } + var addr = &FMsgAddress{} + addrStr := string(b) + firstAt := strings.IndexRune(addrStr, AtRune) + if firstAt == -1 || firstAt != 0 { + return addr, fmt.Errorf("invalid address, must start with @ %s", addr) + } + lastAt := strings.LastIndex(addrStr, "@") + if lastAt == firstAt { + return addr, fmt.Errorf("invalid address, must have second @ %s", addr) + } + addr.User = addrStr[1:lastAt] + if !isValidUser(addr.User) { + return addr, fmt.Errorf("invalid user in address: %s", addr.User) + } + addr.Domain = addrStr[lastAt+1:] + if !isValidDomain(addr.Domain) { + return addr, fmt.Errorf("invalid domain in address: %s", addr.Domain) + } + return addr, nil +} + +// Reads byte slice prefixed with uint8 size from reader supplied +func ReadUInt8Slice(r io.Reader) ([]byte, error) { + var size byte + err := binary.Read(r, binary.LittleEndian, &size) + if err != nil { + return nil, err + } + return io.ReadAll(io.LimitReader(r, int64(size))) +} + +func readAddress(r io.Reader) (*FMsgAddress, error) { + slice, err := ReadUInt8Slice(r) + if err != nil { + return nil, err + } + return parseAddress(slice) +} + +func handleChallenge(c net.Conn, r *bufio.Reader) error { + hashSlice, err := io.ReadAll(io.LimitReader(r, 32)) + if err != nil { + return err + } + hash := *(*[32]byte)(hashSlice) + log.Printf("INFO: CHALLENGE <-- %s", hex.EncodeToString(hashSlice)) + + // Verify the challenger's IP is the Host-B IP registered for this message + // (§10.5 step 2). An unrecognised hash OR a mismatched IP both → TERMINATE. + remoteIP, _, _ := net.SplitHostPort(c.RemoteAddr().String()) + header, exists := lookupOutgoing(hash, remoteIP) + if !exists { + return fmt.Errorf("challenge for unknown message: %s, from: %s", hex.EncodeToString(hashSlice), c.RemoteAddr().String()) + } + msgHash, err := header.GetMessageHash() + if err != nil { + return err + } + if _, err := c.Write(msgHash); err != nil { + return err + } + return nil +} + +func rejectAccept(c net.Conn, codes []byte) error { + _, err := c.Write(codes) + return err +} + +func sendCode(c net.Conn, code uint8) error { + return rejectAccept(c, []byte{code}) +} + +func validateMessageFlags(c net.Conn, flags uint8) error { + if flags&messageReservedBitsMask != 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("reserved message flag bits set: %#08b", flags) + } + return nil +} + +func validateAttachmentFlags(c net.Conn, flags uint8) error { + if flags&attachmentReservedBitsMask != 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("reserved attachment flag bits set: %#08b", flags) + } + return nil +} + +func hasDomainRecipient(addrs []FMsgAddress, domain string) bool { + for _, addr := range addrs { + if strings.EqualFold(addr.Domain, domain) { + return true + } + } + return false +} + +func determineSenderDomain(h *FMsgHeader) string { + if len(h.AddTo) > 0 && h.AddToFrom != nil { + return h.AddToFrom.Domain + } + return h.From.Domain +} + +func verifySenderIP(c net.Conn, senderDomain string) error { + if SkipAuthorisedIPs { + return nil + } + + remoteHost, _, err := net.SplitHostPort(c.RemoteAddr().String()) + if err != nil { + log.Printf("WARN: failed to parse remote address for DNS check: %s", err) + return fmt.Errorf("DNS verification failed") + } + + remoteIP := net.ParseIP(remoteHost) + if remoteIP == nil { + log.Printf("WARN: failed to parse remote IP: %s", remoteHost) + return fmt.Errorf("DNS verification failed") + } + + authorisedIPs, err := lookupAuthorisedIPs(senderDomain) + if err != nil { + log.Printf("WARN: DNS lookup failed for fmsg.%s: %s", senderDomain, err) + return fmt.Errorf("DNS verification failed") + } + + for _, ip := range authorisedIPs { + if remoteIP.Equal(ip) { + return nil + } + } + + log.Printf("WARN: remote IP %s not in authorised IPs for fmsg.%s", remoteIP.String(), senderDomain) + return fmt.Errorf("DNS verification failed") +} + +func handleAddToPath(c net.Conn, h *FMsgHeader) (*FMsgHeader, error) { + if len(h.AddTo) == 0 { + return h, nil + } + + addToHasOurDomain := hasDomainRecipient(h.AddTo, Domain) + + parentID, err := lookupMsgIdByHash(h.Pid) + if err != nil { + return h, err + } + + if parentID == 0 { + h.InitialResponseCode = AcceptCodeContinue + return h, nil + } + + parentMsg, err := getMsgByID(parentID) + if err != nil { + return h, err + } + if parentMsg == nil || !isMessageRetrievable(parentMsg) { + h.InitialResponseCode = AcceptCodeContinue + return h, nil + } + + if parentMsg.Timestamp-FutureTimeDelta > h.Timestamp { + if err := sendCode(c, RejectCodeTimeTravel); err != nil { + return h, err + } + return h, fmt.Errorf("add-to: time travel detected (parent time %f, current %f)", parentMsg.Timestamp, h.Timestamp) + } + + if addToHasOurDomain { + h.InitialResponseCode = AcceptCodeSkipData + return h, nil + } + + h.Filepath = parentMsg.Filepath + for i := range h.Attachments { + if i < len(parentMsg.Attachments) { + h.Attachments[i].Filepath = parentMsg.Attachments[i].Filepath + } + } + h.InitialResponseCode = AcceptCodeAddTo + return h, nil +} + +func validatePidReplyPath(c net.Conn, h *FMsgHeader) error { + if len(h.AddTo) != 0 || h.Flags&FlagHasPid == 0 { + return nil + } + + parentID, err := lookupMsgIdByHash(h.Pid) + if err != nil { + return err + } + if parentID == 0 { + if err := sendCode(c, RejectCodeParentNotFound); err != nil { + return err + } + return fmt.Errorf("pid reply: parent not found for pid %s", hex.EncodeToString(h.Pid)) + } + + parentMsg, err := getMsgByID(parentID) + if err != nil { + return err + } + if parentMsg == nil { + if err := sendCode(c, RejectCodeParentNotFound); err != nil { + return err + } + return fmt.Errorf("pid reply: parent message not found by ID %d", parentID) + } + if !isMessageRetrievable(parentMsg) { + if err := sendCode(c, RejectCodeParentNotFound); err != nil { + return err + } + return fmt.Errorf("pid reply: parent is not retrievable for msg %d", parentID) + } + + if parentMsg.Timestamp-FutureTimeDelta > h.Timestamp { + if err := sendCode(c, RejectCodeTimeTravel); err != nil { + return err + } + return fmt.Errorf("pid reply: time travel detected (parent time %f, current %f)", parentMsg.Timestamp, h.Timestamp) + } + if !isParentParticipant(parentMsg, &h.From) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("pid reply: sender %s was not a participant of parent", h.From.ToString()) + } + + return nil +} + +func readVersionOrChallenge(c net.Conn, r *bufio.Reader, h *FMsgHeader) (bool, error) { + v, err := r.ReadByte() + if err != nil { + return false, err + } + if v >= 129 { + challengeVersion := 256 - int(v) + if challengeVersion == 1 { + return true, handleChallenge(c, r) + } + if err := sendCode(c, RejectCodeUnsupportedVersion); err != nil { + log.Printf("WARN: failed to send unsupported version response: %s", err) + } + return false, fmt.Errorf("unsupported challenge version: %d", challengeVersion) + } + if v != 1 { + if err := sendCode(c, RejectCodeUnsupportedVersion); err != nil { + log.Printf("WARN: failed to send unsupported version response: %s", err) + } + return false, fmt.Errorf("unsupported message version: %d", v) + } + h.Version = v + return false, nil +} + +func readToRecipients(c net.Conn, r *bufio.Reader, h *FMsgHeader) (map[string]bool, error) { + num, err := r.ReadByte() + if err != nil { + return nil, err + } + if num == 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return nil, err + } + return nil, fmt.Errorf("to count must be >= 1") + } + seen := make(map[string]bool) + for num > 0 { + addr, err := readAddress(r) + if err != nil { + return nil, err + } + key := strings.ToLower(addr.ToString()) + if seen[key] { + return nil, fmt.Errorf("duplicate recipient address: %s", addr.ToString()) + } + seen[key] = true + h.To = append(h.To, *addr) + num-- + } + return seen, nil +} + +func readAddToRecipients(c net.Conn, r *bufio.Reader, h *FMsgHeader, seen map[string]bool) error { + if h.Flags&FlagHasAddTo == 0 { + return nil + } + if h.Flags&FlagHasPid == 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("add to exists but pid does not") + } + + addToFrom, err := readAddress(r) + if err != nil { + if err2 := sendCode(c, RejectCodeInvalid); err2 != nil { + return err2 + } + return fmt.Errorf("reading add-to-from address: %w", err) + } + + addToFromKey := strings.ToLower(addToFrom.ToString()) + fromKey := strings.ToLower(h.From.ToString()) + inFromOrTo := fromKey == addToFromKey + if !inFromOrTo { + for _, toAddr := range h.To { + if strings.ToLower(toAddr.ToString()) == addToFromKey { + inFromOrTo = true + break + } + } + } + if !inFromOrTo { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("add-to-from (%s) not in from or to", addToFrom.ToString()) + } + h.AddToFrom = addToFrom + + addToCount, err := r.ReadByte() + if err != nil { + return err + } + if addToCount == 0 { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("add to flag set but count is 0") + } + + addToSeen := make(map[string]bool) + for addToCount > 0 { + addr, err := readAddress(r) + if err != nil { + return err + } + key := strings.ToLower(addr.ToString()) + if addToSeen[key] { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("duplicate recipient address in add to: %s", addr.ToString()) + } + addToSeen[key] = true + if seen[key] { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("add-to address already in to: %s", addr.ToString()) + } + h.AddTo = append(h.AddTo, *addr) + addToCount-- + } + + return nil +} + +func readAndValidateTimestamp(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { + if err := binary.Read(r, binary.LittleEndian, &h.Timestamp); err != nil { + return err + } + now := timeutil.TimestampNow().Float64() + delta := now - h.Timestamp + if PastTimeDelta > 0 && delta > PastTimeDelta { + if err := sendCode(c, RejectCodePastTime); err != nil { + return err + } + return fmt.Errorf("message timestamp: %f too far in past, delta: %fs", h.Timestamp, delta) + } + if FutureTimeDelta > 0 && delta < 0 && math.Abs(delta) > FutureTimeDelta { + if err := sendCode(c, RejectCodeFutureTime); err != nil { + return err + } + return fmt.Errorf("message timestamp: %f too far in future, delta: %fs", h.Timestamp, delta) + } + return nil +} + +func readMessageType(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { + if h.Flags&FlagCommonType != 0 { + typeID, err := r.ReadByte() + if err != nil { + return err + } + mtype, ok := fmsg.GetCommonMediaType(typeID) + if !ok { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("unmapped common type ID: %d", typeID) + } + h.TypeID = typeID + h.Type = mtype + return nil + } + + mime, err := ReadUInt8Slice(r) + if err != nil { + return err + } + if !isASCIIBytes(mime) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("message media type must be US-ASCII") + } + h.Type = string(mime) + return nil +} + +func readAttachmentType(c net.Conn, r *bufio.Reader, flags uint8) (string, uint8, error) { + if flags&(1<<0) != 0 { + typeID, err := r.ReadByte() + if err != nil { + return "", 0, err + } + mtype, ok := fmsg.GetCommonMediaType(typeID) + if !ok { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return "", 0, err + } + return "", 0, fmt.Errorf("unmapped attachment common type ID: %d", typeID) + } + return mtype, typeID, nil + } + + typeBytes, err := ReadUInt8Slice(r) + if err != nil { + return "", 0, err + } + if !isASCIIBytes(typeBytes) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return "", 0, err + } + return "", 0, fmt.Errorf("attachment media type must be US-ASCII") + } + return string(typeBytes), 0, nil +} + +func readAttachmentHeaders(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { + var attachCount uint8 + if err := binary.Read(r, binary.LittleEndian, &attachCount); err != nil { + return err + } + + totalSize := h.Size + // When message is compressed, expanded size comes from the header field. + // When uncompressed, the wire size IS the expanded size. + var totalExpandedSize uint32 + if h.Flags&FlagDeflate != 0 { + totalExpandedSize = h.ExpandedSize + } else { + totalExpandedSize = h.Size + } + filenameSeen := make(map[string]bool) + for i := uint8(0); i < attachCount; i++ { + attFlags, err := r.ReadByte() + if err != nil { + return err + } + if err := validateAttachmentFlags(c, attFlags); err != nil { + return err + } + + attType, attTypeID, err := readAttachmentType(c, r, attFlags) + if err != nil { + return err + } + + filenameBytes, err := ReadUInt8Slice(r) + if err != nil { + return err + } + filename := string(filenameBytes) + if !isValidAttachmentFilename(filename) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("invalid attachment filename: %s", filename) + } + filenameKey := strings.ToLower(filename) + if filenameSeen[filenameKey] { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return err + } + return fmt.Errorf("duplicate attachment filename: %s", filename) + } + filenameSeen[filenameKey] = true + + var attSize uint32 + if err := binary.Read(r, binary.LittleEndian, &attSize); err != nil { + return err + } + + // read attachment expanded size — present iff attachment zlib-deflate flag set (§5) + var attExpandedSize uint32 + if attFlags&(1<<1) != 0 { + if err := binary.Read(r, binary.LittleEndian, &attExpandedSize); err != nil { + return err + } + totalExpandedSize += attExpandedSize + } else { + // uncompressed: expanded size equals wire size + totalExpandedSize += attSize + } + + h.Attachments = append(h.Attachments, FMsgAttachmentHeader{ + Flags: attFlags, + TypeID: attTypeID, + Type: attType, + Filename: filename, + Size: attSize, + ExpandedSize: attExpandedSize, + }) + totalSize += attSize + } + + if totalSize > MaxMessageSize { + if err := sendCode(c, RejectCodeTooBig); err != nil { + return err + } + return fmt.Errorf("total message size %d exceeds max %d", totalSize, MaxMessageSize) + } + + if totalExpandedSize > MaxExpandedSize { + if err := sendCode(c, RejectCodeTooBig); err != nil { + return err + } + return fmt.Errorf("total expanded size %d exceeds MAX_EXPANDED_SIZE %d", totalExpandedSize, MaxExpandedSize) + } + + return nil +} + +func readHeader(c net.Conn) (*FMsgHeader, *bufio.Reader, error) { + r := bufio.NewReaderSize(c, ReadBufferSize) + var h = &FMsgHeader{InitialResponseCode: AcceptCodeContinue} + + d := calcNetIODuration(66000, MinDownloadRate) // max possible header size + c.SetReadDeadline(time.Now().Add(d)) + + handled, err := readVersionOrChallenge(c, r, h) + if err != nil { + if handled { + return nil, r, err + } + return h, r, err + } + if handled { + return nil, r, nil + } + + // read flags + flags, err := r.ReadByte() + if err != nil { + return h, r, err + } + h.Flags = flags + if err := validateMessageFlags(c, flags); err != nil { + return h, r, err + } + + // read pid if any + if flags&FlagHasPid == 1 { + pid, err := io.ReadAll(io.LimitReader(r, 32)) + if err != nil { + return h, r, err + } + h.Pid = make([]byte, 32) + copy(h.Pid, pid) + } + + // read from address + from, err := readAddress(r) + if err != nil { + return h, r, err + } + + h.From = *from + + seen, err := readToRecipients(c, r, h) + if err != nil { + return h, r, err + } + + if err := readAddToRecipients(c, r, h, seen); err != nil { + return h, r, err + } + + if err := readAndValidateTimestamp(c, r, h); err != nil { + return h, r, err + } + + // read topic — only present when pid is NOT set (first message in a thread) + if flags&FlagHasPid == 0 { + topic, err := ReadUInt8Slice(r) + if err != nil { + return h, r, err + } + h.Topic = string(topic) + } + + if err := readMessageType(c, r, h); err != nil { + return h, r, err + } + + // read message size + if err := binary.Read(r, binary.LittleEndian, &h.Size); err != nil { + return h, r, err + } + // read expanded size — present iff zlib-deflate flag is set (§2 field 12) + if h.Flags&FlagDeflate != 0 { + if err := binary.Read(r, binary.LittleEndian, &h.ExpandedSize); err != nil { + return h, r, err + } + if h.ExpandedSize > MaxExpandedSize { + if err := sendCode(c, RejectCodeTooBig); err != nil { + return h, r, err + } + return h, r, fmt.Errorf("expanded size %d exceeds MAX_EXPANDED_SIZE %d", h.ExpandedSize, MaxExpandedSize) + } + } + // Size check is deferred until attachment headers are parsed (see below) + + if err := readAttachmentHeaders(c, r, h); err != nil { + return h, r, err + } + + log.Printf("INFO: <-- MSG\n%s", h) + + if !hasDomainRecipient(h.To, Domain) && !hasDomainRecipient(h.AddTo, Domain) { + if err := sendCode(c, RejectCodeInvalid); err != nil { + return h, r, err + } + return h, r, fmt.Errorf("no recipients for domain %s", Domain) + } + + if err := verifySenderIP(c, determineSenderDomain(h)); err != nil { + return nil, r, err + } + + h, err = handleAddToPath(c, h) + if err != nil { + return h, r, err + } + if h == nil { + return nil, r, nil + } + + if err := validatePidReplyPath(c, h); err != nil { + return h, r, err + } + + return h, r, nil +} + +// Sends CHALLENGE request to sender, receiving and storing the challenge hash. +// DNS verification of the remote IP is performed during header exchange (readHeader). +// TODO [Spec step 2]: The spec defines challenge modes (NEVER, ALWAYS, +// HAS_NOT_PARTICIPATED, DIFFERENT_DOMAIN) as implementation choices. +// Currently defaults to ALWAYS. Implement configurable challenge mode. +func challenge(conn net.Conn, h *FMsgHeader, senderDomain string) error { + + // Connection 2 MUST target the same IP as Connection 1 (spec 2.1). + remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + return fmt.Errorf("failed to parse remote address for challenge: %w", err) + } + conn2, err := tls.Dial("tcp", net.JoinHostPort(remoteHost, fmt.Sprintf("%d", RemotePort)), buildClientTLSConfig("fmsg."+senderDomain)) + if err != nil { + return err + } + version := uint8(255) + if err := binary.Write(conn2, binary.LittleEndian, version); err != nil { + return err + } + hash := h.GetHeaderHash() + log.Printf("INFO: --> CHALLENGE\t%s\n", hex.EncodeToString(hash)) + if _, err := conn2.Write(hash); err != nil { + return err + } + + // read challenge response + resp, err := io.ReadAll(io.LimitReader(conn2, 32)) + if err != nil { + return err + } + if len(resp) != 32 { + return fmt.Errorf("challenge response size %d, expected 32", len(resp)) + } + copy(h.ChallengeHash[:], resp) + h.ChallengeCompleted = true + log.Printf("INFO: <-- CHALLENGE RESP\t%s\n", hex.EncodeToString(resp)) + + // gracefully close 2nd connection + if err := conn2.Close(); err != nil { + return err + } + + return nil +} + +func validateMsgRecvForAddr(h *FMsgHeader, addr *FMsgAddress, msgHash []byte) (code uint8, err error) { + duplicate, err := hasAddrReceivedMsgHash(msgHash, addr) + if err != nil { + return RejectCodeUserUndisclosed, err + } + if duplicate { + return RejectCodeUserDuplicate, nil + } + + detail, err := getAddressDetail(addr) + if err != nil { + return RejectCodeUserUndisclosed, err + } + if detail == nil { + return RejectCodeUserUnknown, nil + } + + // check user accepting new + if !detail.AcceptingNew { + return RejectCodeUserNotAccepting, nil + } + + // check user limits + if detail.LimitRecvCountPer1d > -1 && detail.RecvCountPer1d+1 > detail.LimitRecvCountPer1d { + log.Printf("WARN: Message rejected: RecvCountPer1d would exceed LimitRecvCountPer1d %d", detail.LimitRecvCountPer1d) + return RejectCodeUserFull, nil + } + if detail.LimitRecvSizePer1d > -1 && detail.RecvSizePer1d+int64(h.Size) > detail.LimitRecvSizePer1d { + log.Printf("WARN: Message rejected: RecvSizePer1d would exceed LimitRecvSizePer1d %d", detail.LimitRecvSizePer1d) + return RejectCodeUserFull, nil + } + if detail.LimitRecvSizeTotal > -1 && detail.RecvSizeTotal+int64(h.Size) > detail.LimitRecvSizeTotal { + log.Printf("WARN: Message rejected: RecvSizeTotal would exceed LimitRecvSizeTotal %d", detail.LimitRecvSizeTotal) + return RejectCodeUserFull, nil + } + + return RejectCodeAccept, nil +} + +// uniqueFilepath generates a unique file path in the given directory, +// appending a counter suffix if the base name already exists. +func uniqueFilepath(dir string, timestamp uint32, ext string) string { + base := fmt.Sprintf("%d", timestamp) + fp := filepath.Join(dir, base+ext) + if _, err := os.Stat(fp); os.IsNotExist(err) { + return fp + } + for i := 1; ; i++ { + fp = filepath.Join(dir, fmt.Sprintf("%s_%d%s", base, i, ext)) + if _, err := os.Stat(fp); os.IsNotExist(err) { + return fp + } + } +} + +func localRecipients(h *FMsgHeader) []FMsgAddress { + addrs := make([]FMsgAddress, 0, len(h.To)+len(h.AddTo)) + for _, addr := range h.To { + if strings.EqualFold(addr.Domain, Domain) { + addrs = append(addrs, addr) + } + } + for _, addr := range h.AddTo { + if strings.EqualFold(addr.Domain, Domain) { + addrs = append(addrs, addr) + } + } + return addrs +} + +func allLocalRecipientsHaveMessageHash(msgHash []byte, addrs []FMsgAddress) (bool, error) { + if len(addrs) == 0 { + return false, nil + } + for i := range addrs { + duplicate, err := hasAddrReceivedMsgHash(msgHash, &addrs[i]) + if err != nil { + return false, err + } + if !duplicate { + return false, nil + } + } + return true, nil +} + +func markAllCodes(codes []byte, code uint8) { + for i := range codes { + codes[i] = code + } +} + +func prepareMessageData(r io.Reader, h *FMsgHeader, skipData bool) ([]string, error) { + if skipData { + parentID, err := lookupMsgIdByHash(h.Pid) + if err != nil { + return nil, err + } + if parentID == 0 { + return nil, fmt.Errorf("%w code 65 requires stored parent for pid %s", ErrProtocolViolation, hex.EncodeToString(h.Pid)) + } + parentMsg, err := getMsgByID(parentID) + if err != nil { + return nil, err + } + if parentMsg == nil || parentMsg.Filepath == "" { + return nil, fmt.Errorf("%w code 65 parent data unavailable for msg %d", ErrProtocolViolation, parentID) + } + h.Filepath = parentMsg.Filepath + return nil, nil + } + + createdPaths := make([]string, 0, 1+len(h.Attachments)) + + fd, err := os.CreateTemp("", "fmsg-download-*") + if err != nil { + return nil, err + } + + if _, err := io.CopyN(fd, r, int64(h.Size)); err != nil { + fd.Close() + _ = os.Remove(fd.Name()) + return nil, err + } + if err := fd.Close(); err != nil { + _ = os.Remove(fd.Name()) + return nil, err + } + + h.Filepath = fd.Name() + createdPaths = append(createdPaths, fd.Name()) + + for i := range h.Attachments { + afd, err := os.CreateTemp("", "fmsg-attachment-*") + if err != nil { + for _, path := range createdPaths { + _ = os.Remove(path) + } + return nil, err + } + + if _, err := io.CopyN(afd, r, int64(h.Attachments[i].Size)); err != nil { + afd.Close() + _ = os.Remove(afd.Name()) + for _, path := range createdPaths { + _ = os.Remove(path) + } + return nil, err + } + if err := afd.Close(); err != nil { + _ = os.Remove(afd.Name()) + for _, path := range createdPaths { + _ = os.Remove(path) + } + return nil, err + } + h.Attachments[i].Filepath = afd.Name() + createdPaths = append(createdPaths, afd.Name()) + } + + return createdPaths, nil +} + +func cleanupFiles(paths []string) { + for _, path := range paths { + if path == "" { + continue + } + _ = os.Remove(path) + } +} + +func copyMessagePayload(src *os.File, dstPath string, compressed bool, wireSize uint32) error { + if _, err := src.Seek(0, io.SeekStart); err != nil { + return err + } + + fd2, err := os.Create(dstPath) + if err != nil { + return err + } + + var copyErr error + if compressed { + lr := io.LimitReader(src, int64(wireSize)) + zr, err := zlib.NewReader(lr) + if err != nil { + fd2.Close() + _ = os.Remove(dstPath) + return err + } + _, copyErr = io.Copy(fd2, zr) + _ = zr.Close() + } else { + _, copyErr = io.CopyN(fd2, src, int64(wireSize)) + } + if err := fd2.Close(); err != nil { + return err + } + + if copyErr != nil { + _ = os.Remove(dstPath) + return copyErr + } + return nil +} + +func uniqueAttachmentPath(dir string, timestamp uint32, idx int, filename string) string { + ext := filepath.Ext(filename) + base := fmt.Sprintf("%d_att_%d", timestamp, idx) + p := filepath.Join(dir, base+ext) + if _, err := os.Stat(p); os.IsNotExist(err) { + return p + } + for n := 1; ; n++ { + p = filepath.Join(dir, fmt.Sprintf("%s_%d%s", base, n, ext)) + if _, err := os.Stat(p); os.IsNotExist(err) { + return p + } + } +} + +func persistAttachmentPayloads(h *FMsgHeader, dirpath string) error { + for i := range h.Attachments { + a := &h.Attachments[i] + src, err := os.Open(a.Filepath) + if err != nil { + return err + } + dstPath := uniqueAttachmentPath(dirpath, uint32(h.Timestamp), i, a.Filename) + compressed := a.Flags&(1<<1) != 0 + err = copyMessagePayload(src, dstPath, compressed, a.Size) + src.Close() + if err != nil { + return err + } + a.Filepath = dstPath + } + return nil +} + +func storeAcceptedMessage(h *FMsgHeader, codes []byte, acceptedTo []FMsgAddress, acceptedAddTo []FMsgAddress, primaryFilepath string) bool { + if len(acceptedTo) == 0 && len(acceptedAddTo) == 0 { + return false + } + + origTo := h.To + origAddTo := h.AddTo + h.To = acceptedTo + h.AddTo = acceptedAddTo + h.Filepath = primaryFilepath + if err := storeMsgDetail(h); err != nil { + log.Printf("ERROR: storing message: %s", err) + h.To = origTo + h.AddTo = origAddTo + for i := range codes { + if codes[i] == RejectCodeAccept { + codes[i] = RejectCodeUndisclosed + } + } + return false + } + + h.To = origTo + h.AddTo = origAddTo + allAccepted := append(acceptedTo, acceptedAddTo...) + for i := range allAccepted { + if err := postMsgStatRecv(&allAccepted[i], h.Timestamp, int(h.Size)); err != nil { + log.Printf("WARN: Failed to post msg recv stat: %s", err) + } + } + return true +} + +func downloadMessage(c net.Conn, r io.Reader, h *FMsgHeader, skipData bool) error { + addrs := localRecipients(h) + if len(addrs) == 0 { + return fmt.Errorf("%w our domain: %s, not in recipient list", ErrProtocolViolation, Domain) + } + codes := make([]byte, len(addrs)) + + createdPaths, err := prepareMessageData(r, h, skipData) + if err != nil { + return err + } + cleanupOnReturn := !skipData + defer func() { + if cleanupOnReturn { + cleanupFiles(createdPaths) + } + }() + + // verify hash matches challenge response when challenge was completed + msgHash, err := h.GetMessageHash() + if err != nil { + return err + } + if h.ChallengeCompleted && !bytes.Equal(h.ChallengeHash[:], msgHash) { + challengeHashStr := hex.EncodeToString(h.ChallengeHash[:]) + actualHashStr := hex.EncodeToString(msgHash) + return fmt.Errorf("%w actual hash: %s mismatch challenge response: %s", ErrProtocolViolation, actualHashStr, challengeHashStr) + } + + // pid/add-to validation is handled during header exchange in readHeader(). + + // determine file extension from MIME type + exts, _ := mime.ExtensionsByType(h.Type) + var ext string + if exts == nil { + ext = ".unknown" + } else { + ext = exts[0] + } + + src, err := os.Open(h.Filepath) + if err != nil { + return err + } + defer src.Close() + + // validate each recipient and copy message for accepted ones + // Build a set of add-to addresses for later classification + addToSet := make(map[string]bool) + for _, addr := range h.AddTo { + addToSet[strings.ToLower(addr.ToString())] = true + } + acceptedTo := []FMsgAddress{} + acceptedAddTo := []FMsgAddress{} + var primaryFilepath string + for i, addr := range addrs { + code, err := validateMsgRecvForAddr(h, &addr, msgHash) + if err != nil { + return err + } + if code != RejectCodeAccept { + log.Printf("WARN: Rejected message to: %s: %s (%d)", addr.ToString(), responseCodeName(code), code) + codes[i] = code + continue + } + + // copy to recipient's directory + dirpath := filepath.Join(DataDir, addr.Domain, addr.User, InboxDirName) + if err := os.MkdirAll(dirpath, 0750); err != nil { + return err + } + + fp := uniqueFilepath(dirpath, uint32(h.Timestamp), ext) + if err := copyMessagePayload(src, fp, h.Flags&FlagDeflate != 0, h.Size); err != nil { + log.Printf("ERROR: copying downloaded message from: %s, to: %s", h.Filepath, fp) + codes[i] = RejectCodeUndisclosed + continue + } + + codes[i] = RejectCodeAccept + if addToSet[strings.ToLower(addr.ToString())] { + acceptedAddTo = append(acceptedAddTo, addr) + } else { + acceptedTo = append(acceptedTo, addr) + } + if primaryFilepath == "" { + primaryFilepath = fp + if err := persistAttachmentPayloads(h, filepath.Dir(primaryFilepath)); err != nil { + log.Printf("ERROR: copying attachment payloads for message storage: %s", err) + codes[i] = RejectCodeUndisclosed + primaryFilepath = "" + acceptedTo = acceptedTo[:0] + acceptedAddTo = acceptedAddTo[:0] + continue + } + } + } + + stored := storeAcceptedMessage(h, codes, acceptedTo, acceptedAddTo, primaryFilepath) + if stored { + cleanupOnReturn = false + } + + return rejectAccept(c, codes) +} + +// resolvePostChallengeCode determines the initial response code to send after +// the optional challenge (§10.4). Code 11 (accept add-to) is returned as-is +// since it has no local recipients to duplicate-check. For the skip-data (65) +// and continue (64) paths, a completed challenge with all-local-duplicate +// produces code 10 (duplicate) instead. +func resolvePostChallengeCode(initialCode uint8, challengeCompleted bool, allLocalDup bool) uint8 { + if initialCode == AcceptCodeAddTo { + return AcceptCodeAddTo + } + if challengeCompleted && allLocalDup { + return RejectCodeDuplicate + } + if initialCode == AcceptCodeSkipData { + return AcceptCodeSkipData + } + return AcceptCodeContinue +} + +func abortConn(c net.Conn) { + if tcp, ok := c.(*net.TCPConn); ok { + tcp.SetLinger(0) + } + _ = c.Close() +} + +type responseTrackingConn struct { + net.Conn + wroteResponse bool +} + +func (c *responseTrackingConn) Write(b []byte) (int, error) { + n, err := c.Conn.Write(b) + if n > 0 { + c.wroteResponse = true + } + return n, err +} + +func handleConn(c net.Conn) { + defer func() { + if r := recover(); r != nil { + log.Printf("ERROR: Recovered in handleConn: %v", r) + } + }() + + log.Printf("INFO: Connection from: %s\n", c.RemoteAddr().String()) + tc := &responseTrackingConn{Conn: c} + + // read header + header, r, err := readHeader(tc) + if err != nil { + log.Printf("WARN: reading header from, %s: %s", c.RemoteAddr().String(), err) + if tc.wroteResponse { + _ = c.Close() + return + } + abortConn(c) + return + } + + // if no header AND no error this was a challenge thats been handeled + if header == nil { + c.Close() + return + } + + if err := challenge(tc, header, determineSenderDomain(header)); err != nil { + log.Printf("ERROR: Challenge failed to, %s: %s", c.RemoteAddr().String(), err) + abortConn(c) + return + } + + // §10.4: Determine initial response code after optional challenge. + // Code 11 (add-to, no local recipients) does not need a dup check. + // Codes 65 and 64 both require a dup check when challenge was completed. + allLocalDup := false + if header.ChallengeCompleted && header.InitialResponseCode != AcceptCodeAddTo { + addrs := localRecipients(header) + var err error + allLocalDup, err = allLocalRecipientsHaveMessageHash(header.ChallengeHash[:], addrs) + if err != nil { + log.Printf("ERROR: duplicate check failed for %s: %s", c.RemoteAddr().String(), err) + if err := sendCode(c, RejectCodeUndisclosed); err != nil { + abortConn(c) + return + } + _ = c.Close() + return + } + } + + code := resolvePostChallengeCode(header.InitialResponseCode, header.ChallengeCompleted, allLocalDup) + skipData := false + + switch code { + case AcceptCodeAddTo: + // No local add-to recipients; store header and respond code 11, close. + if err := storeMsgHeaderOnly(header); err != nil { + log.Printf("ERROR: storing add-to header: %s", err) + if err := sendCode(c, RejectCodeUndisclosed); err != nil { + abortConn(c) + return + } + _ = c.Close() + return + } + if err := sendCode(c, AcceptCodeAddTo); err != nil { + log.Printf("ERROR: failed sending code 11 to %s: %s", c.RemoteAddr().String(), err) + abortConn(c) + return + } + log.Printf("INFO: additional recipients received (code 11) for pid %s", hex.EncodeToString(header.Pid)) + c.Close() + return + case RejectCodeDuplicate: + if err := sendCode(c, RejectCodeDuplicate); err != nil { + log.Printf("ERROR: failed sending code 10 to %s: %s", c.RemoteAddr().String(), err) + } + c.Close() + return + case AcceptCodeSkipData: + if err := sendCode(c, AcceptCodeSkipData); err != nil { + log.Printf("ERROR: failed sending code 65 to %s: %s", c.RemoteAddr().String(), err) + abortConn(c) + return + } + skipData = true + log.Printf("INFO: sent code 65 (skip data) to %s", c.RemoteAddr().String()) + default: + if err := sendCode(c, AcceptCodeContinue); err != nil { + log.Printf("ERROR: failed sending code 64 to %s: %s", c.RemoteAddr().String(), err) + abortConn(c) + return + } + log.Printf("INFO: sent code 64 (continue) to %s", c.RemoteAddr().String()) + } + + // store message + deadlineBytes := int(header.Size) + if skipData { + deadlineBytes = 1 + } + c.SetReadDeadline(time.Now().Add(calcNetIODuration(deadlineBytes, MinDownloadRate))) + if err := downloadMessage(c, r, header, skipData); err != nil { + // if error was a protocal violation, abort; otherise let sender know there was an internal error + log.Printf("ERROR: Download failed from, %s: %s", c.RemoteAddr().String(), err) + if errors.Is(err, ErrProtocolViolation) { + abortConn(c) + return + } else { + _ = sendCode(c, RejectCodeUndisclosed) + } + } + + // gracefully close 1st connection + c.Close() +} + +func main() { + + initOutgoing() + + // load environment variables from .env file if present + if err := godotenv.Load(); err != nil { + log.Printf("INFO: Could not load .env file: %v", err) + } + + // read env config (must be after godotenv.Load) + loadEnvConfig() + + // determine listen address from args + listenAddress := "127.0.0.1" + for _, arg := range os.Args[1:] { + listenAddress = arg + } + + // initalize database + err := testDb() + if err != nil { + log.Fatalf("ERROR: connecting to database: %s\n", err) + } + + // set DataDir, Domain and IDURL from env + setDataDir() + setDomain() + setIDURL() + + // load TLS configuration (must be after loadEnvConfig for FMSG_TLS_INSECURE_SKIP_VERIFY) + serverTLSConfig = buildServerTLSConfig() + + // start sender in background (small delay so listener is ready first) + go func() { + time.Sleep(1 * time.Second) + startSender() + }() + + // start listening + addr := fmt.Sprintf("%s:%d", listenAddress, Port) + ln, err := tls.Listen("tcp", addr, serverTLSConfig) + if err != nil { + log.Fatal(err) + } + log.Printf("INFO: Ready to receive on %s\n", addr) + for { + conn, err := ln.Accept() + if err != nil { + log.Printf("ERROR: Accept connection from %s returned: %s\n", ln.Addr().String(), err) + } else { + go handleConn(conn) + } + } + +} diff --git a/src/host_test.go b/cmd/fmsgd/host_test.go similarity index 100% rename from src/host_test.go rename to cmd/fmsgd/host_test.go diff --git a/src/id.go b/cmd/fmsgd/id.go similarity index 96% rename from src/id.go rename to cmd/fmsgd/id.go index 32b0bab..49c22e2 100644 --- a/src/id.go +++ b/cmd/fmsgd/id.go @@ -1,92 +1,92 @@ -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "net/url" -) - -type AddressDetail struct { - Address string `json:"address"` - DisplayName string `json:"displayName"` - AcceptingNew bool `json:"acceptingNew"` - LimitRecvSizeTotal int64 `json:"limitRecvSizeTotal"` - LimitRecvSizePerMsg int64 `json:"limitRecvSizePerMsg"` - LimitRecvSizePer1d int64 `json:"limitRecvSizePer1d"` - LimitRecvCountPer1d int64 `json:"limitRecvCountPer1d"` - LimitSendSizeTotal int64 `json:"limitSendSizeTotal"` - LimitSendSizePerMsg int64 `json:"limitSendSizePerMsg"` - LimitSendSizePer1d int64 `json:"limitSendSizePer1d"` - LimitSendCountPer1d int64 `json:"limitSendCountPer1d"` - RecvSizeTotal int64 `json:"recvSizeTotal"` - RecvSizePer1d int64 `json:"recvSizePer1d"` - RecvCountPer1d int64 `json:"recvCountPer1d"` - SendSizeTotal int64 `json:"sendSizeTotal"` - SendSizePer1d int64 `json:"sendSizePer1d"` - SendCountPer1d int64 `json:"sendCountPer1d"` - Tags []string `json:"tags"` -} - -// Returns pointer to an AddressDetail populated by querying fmsg Id standard at FMSG_ID_URL for -// address supplied. If the address is not found returns nil, nil. -func getAddressDetail(addr *FMsgAddress) (*AddressDetail, error) { - uri := IDURI + "/fmsgid/" + url.PathEscape(addr.ToString()) - resp, err := http.Get(uri) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - return nil, nil - } - - var detail AddressDetail - err = json.NewDecoder(resp.Body).Decode(&detail) - if err != nil { - return nil, err - } - - return &detail, nil -} - -func postMsgStatSend(addr *FMsgAddress, timestamp float64, size int) error { - return postMsgStat(addr, timestamp, size, true) -} - -func postMsgStatRecv(addr *FMsgAddress, timestamp float64, size int) error { - return postMsgStat(addr, timestamp, size, false) -} - -func postMsgStat(addr *FMsgAddress, timestamp float64, size int, isSending bool) error { - var part string - if isSending { - part = "send" - } else { - part = "recv" - } - uri := fmt.Sprintf("%s/fmsgid/%s", IDURI, part) - - payload := map[string]interface{}{ - "address": addr.ToString(), - "ts": timestamp, - "size": size} - jsonPayload, err := json.Marshal(payload) - if err != nil { - return err - } - - resp, err := http.Post(uri, "application/json", bytes.NewBuffer(jsonPayload)) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("POST %s returned %d", uri, resp.StatusCode) - } - - return nil -} +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +type AddressDetail struct { + Address string `json:"address"` + DisplayName string `json:"displayName"` + AcceptingNew bool `json:"acceptingNew"` + LimitRecvSizeTotal int64 `json:"limitRecvSizeTotal"` + LimitRecvSizePerMsg int64 `json:"limitRecvSizePerMsg"` + LimitRecvSizePer1d int64 `json:"limitRecvSizePer1d"` + LimitRecvCountPer1d int64 `json:"limitRecvCountPer1d"` + LimitSendSizeTotal int64 `json:"limitSendSizeTotal"` + LimitSendSizePerMsg int64 `json:"limitSendSizePerMsg"` + LimitSendSizePer1d int64 `json:"limitSendSizePer1d"` + LimitSendCountPer1d int64 `json:"limitSendCountPer1d"` + RecvSizeTotal int64 `json:"recvSizeTotal"` + RecvSizePer1d int64 `json:"recvSizePer1d"` + RecvCountPer1d int64 `json:"recvCountPer1d"` + SendSizeTotal int64 `json:"sendSizeTotal"` + SendSizePer1d int64 `json:"sendSizePer1d"` + SendCountPer1d int64 `json:"sendCountPer1d"` + Tags []string `json:"tags"` +} + +// Returns pointer to an AddressDetail populated by querying fmsg Id standard at FMSG_ID_URL for +// address supplied. If the address is not found returns nil, nil. +func getAddressDetail(addr *FMsgAddress) (*AddressDetail, error) { + uri := IDURI + "/fmsgid/" + url.PathEscape(addr.ToString()) + resp, err := http.Get(uri) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, nil + } + + var detail AddressDetail + err = json.NewDecoder(resp.Body).Decode(&detail) + if err != nil { + return nil, err + } + + return &detail, nil +} + +func postMsgStatSend(addr *FMsgAddress, timestamp float64, size int) error { + return postMsgStat(addr, timestamp, size, true) +} + +func postMsgStatRecv(addr *FMsgAddress, timestamp float64, size int) error { + return postMsgStat(addr, timestamp, size, false) +} + +func postMsgStat(addr *FMsgAddress, timestamp float64, size int, isSending bool) error { + var part string + if isSending { + part = "send" + } else { + part = "recv" + } + uri := fmt.Sprintf("%s/fmsgid/%s", IDURI, part) + + payload := map[string]interface{}{ + "address": addr.ToString(), + "ts": timestamp, + "size": size} + jsonPayload, err := json.Marshal(payload) + if err != nil { + return err + } + + resp, err := http.Post(uri, "application/json", bytes.NewBuffer(jsonPayload)) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("POST %s returned %d", uri, resp.StatusCode) + } + + return nil +} diff --git a/src/outgoing.go b/cmd/fmsgd/outgoing.go similarity index 100% rename from src/outgoing.go rename to cmd/fmsgd/outgoing.go diff --git a/src/outgoing_test.go b/cmd/fmsgd/outgoing_test.go similarity index 100% rename from src/outgoing_test.go rename to cmd/fmsgd/outgoing_test.go diff --git a/src/sender.go b/cmd/fmsgd/sender.go similarity index 100% rename from src/sender.go rename to cmd/fmsgd/sender.go diff --git a/src/store.go b/cmd/fmsgd/store.go similarity index 100% rename from src/store.go rename to cmd/fmsgd/store.go diff --git a/src/store_test.go b/cmd/fmsgd/store_test.go similarity index 100% rename from src/store_test.go rename to cmd/fmsgd/store_test.go diff --git a/src/go.mod b/go.mod similarity index 100% rename from src/go.mod rename to go.mod diff --git a/src/go.sum b/go.sum similarity index 100% rename from src/go.sum rename to go.sum diff --git a/pkg/fmsg/README.md b/pkg/fmsg/README.md new file mode 100644 index 0000000..3294d47 --- /dev/null +++ b/pkg/fmsg/README.md @@ -0,0 +1,57 @@ +# fmsg + +Go package implementing the [fmsg protocol](../../SPEC.md) message types, wire-format encoding, and SHA-256 hashing. + +``` +import "github.com/markmnl/fmsgd/pkg/fmsg" +``` + +## Types + +| Type | Description | +|---|---| +| `fmsg.Header` | All fields of an fmsg message header | +| `fmsg.Address` | fmsg address (`@user@domain`) | +| `fmsg.AttachmentHeader` | Wire-level metadata for a single attachment | + +## Flag constants + +```go +fmsg.FlagHasPid // bit 0: message is a reply (pid field present) +fmsg.FlagHasAddTo // bit 1: add-to addresses present +fmsg.FlagCommonType // bit 2: type encoded as common media type ID +fmsg.FlagImportant // bit 3: sender marks as important +fmsg.FlagNoReply // bit 4: sender will discard replies +fmsg.FlagDeflate // bit 5: body is zlib-deflate compressed +``` + +## Usage + +### Encode a header to wire format + +```go +h := &fmsg.Header{ + Version: 1, + From: fmsg.Address{User: "alice", Domain: "example.com"}, + To: []fmsg.Address{{User: "bob", Domain: "example.com"}}, + Timestamp: float64(time.Now().UnixMicro()) / 1e6, + Topic: "hello", + Type: "text/plain", + Size: uint32(len(body)), +} +wire := h.Encode() +``` + +### Hash a message (for verification or storage) + +```go +// Set h.Filepath (and each attachment's Filepath) before calling. +hash, err := h.GetMessageHash() +``` + +### Look up a common media type + +```go +mtype, ok := fmsg.GetCommonMediaType(id) // ID → "text/plain" +id, ok := fmsg.GetCommonMediaTypeID(mt) // "text/plain" → ID +``` diff --git a/pkg/fmsg/fmsg.go b/pkg/fmsg/fmsg.go new file mode 100644 index 0000000..69c1ebc --- /dev/null +++ b/pkg/fmsg/fmsg.go @@ -0,0 +1,330 @@ +// Package fmsg defines the fmsg message types and implements wire-format +// encoding and hashing as specified in SPEC.md. +// +// To compute a message hash from database fields, populate a [Header] +// (including Filepath and per-attachment Filepath values), then call +// [Header.GetMessageHash]. +package fmsg + +import ( + "bytes" + "compress/zlib" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "os" + "strings" +) + +// Flag bit assignments per SPEC.md. +// Bits 6–7 are reserved and must be zero on the wire. +const ( + FlagHasPid uint8 = 1 // bit 0: pid field present; message is a reply + FlagHasAddTo uint8 = 1 << 1 // bit 1: add-to addresses present + FlagCommonType uint8 = 1 << 2 // bit 2: type encoded as common type ID, not string + FlagImportant uint8 = 1 << 3 // bit 3: sender marks message as important + FlagNoReply uint8 = 1 << 4 // bit 4: sender will discard replies + FlagDeflate uint8 = 1 << 5 // bit 5: message body is zlib-deflate compressed +) + +// Address is an fmsg address of the form @user@domain. +type Address struct { + User string + Domain string +} + +// ToString returns the address in @user@domain form. +func (addr *Address) ToString() string { + return fmt.Sprintf("@%s@%s", addr.User, addr.Domain) +} + +// AttachmentHeader holds the wire-level metadata for a single attachment. +type AttachmentHeader struct { + Flags uint8 + TypeID uint8 + Type string + Filename string + Size uint32 // wire (possibly compressed) size in bytes + ExpandedSize uint32 // decompressed size; present on wire iff bit 1 of Flags is set + + Filepath string // path to attachment data on disk +} + +// Header holds all fields of an fmsg message header. +// +// Fields ChallengeHash, ChallengeCompleted, and InitialResponseCode are +// fmsgd server-runtime state; they are not part of the wire format and can +// be ignored by other consumers of this package. +type Header struct { + Version uint8 + Flags uint8 + Pid []byte + From Address + To []Address + AddToFrom *Address // present when FlagHasAddTo is set + AddTo []Address + Timestamp float64 + TypeID uint8 + Topic string + Type string + + Size uint32 // wire (possibly compressed) size of the message body + ExpandedSize uint32 // decompressed size; present on wire iff FlagDeflate set + Attachments []AttachmentHeader + + HeaderHash []byte + ChallengeHash [32]byte // fmsgd server field: challenge response hash + ChallengeCompleted bool // fmsgd server field: true if challenge was completed + InitialResponseCode uint8 // fmsgd server field: protocol response code (11/64/65) + Filepath string // path to message body data on disk + messageHash []byte +} + +// Encode serialises the message header to wire format as defined in SPEC.md. +// The returned bytes cover all fields up to and including the attachment +// headers (fields 1–12 per spec). This method panics on internal buffer errors +// rather than returning an error. +func (h *Header) Encode() []byte { + var b bytes.Buffer + b.WriteByte(h.Version) + b.WriteByte(h.Flags) + if h.Flags&FlagHasPid == 1 { + b.Write(h.Pid[:]) + } + str := h.From.ToString() + b.WriteByte(byte(len(str))) + b.WriteString(str) + b.WriteByte(byte(len(h.To))) + for _, addr := range h.To { + str = addr.ToString() + b.WriteByte(byte(len(str))) + b.WriteString(str) + } + if h.Flags&FlagHasAddTo != 0 { + // add-to-from address (field 6) + addToFrom := h.AddToFrom + if addToFrom == nil { + addToFrom = &h.From + } + str := addToFrom.ToString() + b.WriteByte(byte(len(str))) + b.WriteString(str) + // add-to addresses (field 7) + b.WriteByte(byte(len(h.AddTo))) + for _, addr := range h.AddTo { + str = addr.ToString() + b.WriteByte(byte(len(str))) + b.WriteString(str) + } + } + if err := binary.Write(&b, binary.LittleEndian, h.Timestamp); err != nil { + panic(err) + } + // topic is only present when pid is NOT set + if h.Flags&FlagHasPid == 0 { + b.WriteByte(byte(len(h.Topic))) + b.WriteString(h.Topic) + } + if h.Flags&FlagCommonType != 0 { + typeID := h.TypeID + if typeID == 0 { + if id, ok := getCommonMediaTypeID(h.Type); ok { + typeID = id + } + } + b.WriteByte(typeID) + } else { + b.WriteByte(byte(len(h.Type))) + b.WriteString(h.Type) + } + // size (uint32 LE) + if err := binary.Write(&b, binary.LittleEndian, h.Size); err != nil { + panic(err) + } + // expanded size (uint32 LE) — present iff zlib-deflate flag set + if h.Flags&FlagDeflate != 0 { + if err := binary.Write(&b, binary.LittleEndian, h.ExpandedSize); err != nil { + panic(err) + } + } + // attachment headers + b.WriteByte(byte(len(h.Attachments))) + for _, att := range h.Attachments { + b.WriteByte(att.Flags) + if att.Flags&1 != 0 { + typeID := att.TypeID + if typeID == 0 { + if id, ok := getCommonMediaTypeID(att.Type); ok { + typeID = id + } + } + b.WriteByte(typeID) + } else { + b.WriteByte(byte(len(att.Type))) + b.WriteString(att.Type) + } + b.WriteByte(byte(len(att.Filename))) + b.WriteString(att.Filename) + if err := binary.Write(&b, binary.LittleEndian, att.Size); err != nil { + panic(err) + } + // attachment expanded size — present iff attachment zlib-deflate flag set + if att.Flags&(1<<1) != 0 { + if err := binary.Write(&b, binary.LittleEndian, att.ExpandedSize); err != nil { + panic(err) + } + } + } + return b.Bytes() +} + +// String returns a human-readable summary of the header fields. +func (h *Header) String() string { + var b strings.Builder + fmt.Fprintf(&b, "v%d flags=%d", h.Version, h.Flags) + if len(h.Pid) > 0 { + fmt.Fprintf(&b, " pid=%s", hex.EncodeToString(h.Pid)) + } + fmt.Fprintf(&b, "\nfrom:\t%s", h.From.ToString()) + for i, addr := range h.To { + if i == 0 { + fmt.Fprintf(&b, "\nto:\t%s", addr.ToString()) + } else { + fmt.Fprintf(&b, "\n\t%s", addr.ToString()) + } + } + for i, addr := range h.AddTo { + if i == 0 { + fmt.Fprintf(&b, "\nadd to:\t%s", addr.ToString()) + } else { + fmt.Fprintf(&b, "\n\t%s", addr.ToString()) + } + } + fmt.Fprintf(&b, "\ntopic:\t%s", h.Topic) + fmt.Fprintf(&b, "\ntype:\t%s", h.Type) + fmt.Fprintf(&b, "\nsize:\t%d", h.Size) + return b.String() +} + +// GetHeaderHash returns the SHA-256 hash of the encoded header (fields 1–12). +// The result is cached after the first call. +func (h *Header) GetHeaderHash() []byte { + if h.HeaderHash == nil { + b := sha256.Sum256(h.Encode()) + h.HeaderHash = b[:] + } + return h.HeaderHash +} + +// GetMessageHash returns the SHA-256 hash of the full message: +// encoded header + decompressed message body + decompressed attachment data. +// The result is cached after the first call. +func (h *Header) GetMessageHash() ([]byte, error) { + if h.messageHash == nil { + hash := sha256.New() + + headerBytes := h.Encode() + if _, err := io.Copy(hash, bytes.NewBuffer(headerBytes)); err != nil { + return nil, err + } + + if err := hashPayload(hash, h.Filepath, int64(h.Size), h.Flags&FlagDeflate != 0, h.ExpandedSize); err != nil { + return nil, err + } + + // include attachment data (sequential byte sequences following + // the message body, bounded by attachment header sizes) + for _, att := range h.Attachments { + compressed := att.Flags&(1<<1) != 0 + if err := hashPayload(hash, att.Filepath, int64(att.Size), compressed, att.ExpandedSize); err != nil { + return nil, fmt.Errorf("hash attachment %s: %w", att.Filename, err) + } + } + + h.messageHash = hash.Sum(nil) + } + return h.messageHash, nil +} + +// HashPayload reads wireSize bytes from the file at filepath, decompressing +// via zlib if deflated, and writes the (decompressed) bytes to dst. +// The message hash is always computed over decompressed data. +func HashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool, expandedSize uint32) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + defer f.Close() + + if deflated { + lr := io.LimitReader(f, wireSize) + zr, err := zlib.NewReader(lr) + if err != nil { + return err + } + written, err := io.Copy(dst, zr) + _ = zr.Close() + if err != nil { + return err + } + if uint32(written) != expandedSize { + return fmt.Errorf("decompressed size %d does not match declared expanded size %d", written, expandedSize) + } + return nil + } + + _, err = io.CopyN(dst, f, wireSize) + return err +} + +// commonMediaTypes maps common type IDs (1–64) to MIME strings per SPEC.md §4. +// Unmapped IDs must be rejected with response code 1 (invalid). +var commonMediaTypes = map[uint8]string{ + 1: "application/epub+zip", 2: "application/gzip", 3: "application/json", 4: "application/msword", + 5: "application/octet-stream", 6: "application/pdf", 7: "application/rtf", 8: "application/vnd.amazon.ebook", + 9: "application/vnd.ms-excel", 10: "application/vnd.ms-powerpoint", + 11: "application/vnd.oasis.opendocument.presentation", 12: "application/vnd.oasis.opendocument.spreadsheet", + 13: "application/vnd.oasis.opendocument.text", + 14: "application/vnd.openxmlformats-officedocument.presentationml.presentation", + 15: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + 16: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + 17: "application/x-tar", 18: "application/xhtml+xml", 19: "application/xml", 20: "application/zip", + 21: "audio/aac", 22: "audio/midi", 23: "audio/mpeg", 24: "audio/ogg", 25: "audio/opus", 26: "audio/vnd.wave", 27: "audio/webm", + 28: "font/otf", 29: "font/ttf", 30: "font/woff", 31: "font/woff2", + 32: "image/apng", 33: "image/avif", 34: "image/bmp", 35: "image/gif", 36: "image/heic", 37: "image/jpeg", 38: "image/png", + 39: "image/svg+xml", 40: "image/tiff", 41: "image/webp", + 42: "model/3mf", 43: "model/gltf-binary", 44: "model/obj", 45: "model/step", 46: "model/stl", 47: "model/vnd.usdz+zip", + 48: "text/calendar", 49: "text/css", 50: "text/csv", 51: "text/html", 52: "text/javascript", 53: "text/markdown", + 54: "text/plain;charset=US-ASCII", 55: "text/plain;charset=UTF-16", 56: "text/plain;charset=UTF-8", 57: "text/vcard", + 58: "video/H264", 59: "video/H265", 60: "video/H266", 61: "video/ogg", 62: "video/VP8", 63: "video/VP9", 64: "video/webm", +} + +// GetCommonMediaType returns the MIME type string for a common type ID, or +// ("", false) if the ID is not in the mapping (should be rejected per spec). +func GetCommonMediaType(id uint8) (string, bool) { + s, ok := commonMediaTypes[id] + return s, ok +} + +// GetCommonMediaTypeID returns the common type ID for a MIME string, or +// (0, false) if the string has no assigned ID. +func GetCommonMediaTypeID(mediaType string) (uint8, bool) { + for id, mime := range commonMediaTypes { + if mime == mediaType { + return id, true + } + } + return 0, false +} + +// getCommonMediaTypeID is the unexported alias used internally by Encode. +func getCommonMediaTypeID(mediaType string) (uint8, bool) { + return GetCommonMediaTypeID(mediaType) +} + +// hashPayload is the unexported alias used internally by GetMessageHash. +func hashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool, expandedSize uint32) error { + return HashPayload(dst, filepath, wireSize, deflated, expandedSize) +} diff --git a/pkg/fmsg/fmsg_test.go b/pkg/fmsg/fmsg_test.go new file mode 100644 index 0000000..d263392 --- /dev/null +++ b/pkg/fmsg/fmsg_test.go @@ -0,0 +1,547 @@ +package fmsg_test + +import ( + "bytes" + "compress/zlib" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "os" + "testing" + + "github.com/markmnl/fmsgd/pkg/fmsg" +) + +// ── Flag constants ──────────────────────────────────────────────────────────── + +func TestFlagConstants(t *testing.T) { + tests := []struct { + name string + got uint8 + want uint8 + }{ + {"FlagHasPid", fmsg.FlagHasPid, 0x01}, + {"FlagHasAddTo", fmsg.FlagHasAddTo, 0x02}, + {"FlagCommonType", fmsg.FlagCommonType, 0x04}, + {"FlagImportant", fmsg.FlagImportant, 0x08}, + {"FlagNoReply", fmsg.FlagNoReply, 0x10}, + {"FlagDeflate", fmsg.FlagDeflate, 0x20}, + } + for _, tt := range tests { + if tt.got != tt.want { + t.Errorf("%s = 0x%02X, want 0x%02X", tt.name, tt.got, tt.want) + } + } +} + +// ── Address ─────────────────────────────────────────────────────────────────── + +func TestAddressToString(t *testing.T) { + tests := []struct { + addr fmsg.Address + want string + }{ + {fmsg.Address{User: "alice", Domain: "example.com"}, "@alice@example.com"}, + {fmsg.Address{User: "bob", Domain: "b.io"}, "@bob@b.io"}, + {fmsg.Address{User: "x", Domain: "y.z"}, "@x@y.z"}, + } + for _, tt := range tests { + if got := tt.addr.ToString(); got != tt.want { + t.Errorf("ToString() = %q, want %q", got, tt.want) + } + } +} + +// ── Encode ──────────────────────────────────────────────────────────────────── + +// buildExpectedWire manually constructs the expected wire bytes for a minimal +// header: version=1, flags=0, from=@alice@example.com, to=[@bob@example.com], +// timestamp=0, topic="hi", type="text/plain", size=12, no attachments. +// This is the ground-truth used by TestHeaderEncode. +func buildExpectedWire() []byte { + var b bytes.Buffer + b.WriteByte(1) // version + b.WriteByte(0) // flags + from := "@alice@example.com" + b.WriteByte(byte(len(from))) + b.WriteString(from) + b.WriteByte(1) // to count + to := "@bob@example.com" + b.WriteByte(byte(len(to))) + b.WriteString(to) + _ = binary.Write(&b, binary.LittleEndian, float64(0)) // timestamp + b.WriteByte(byte(len("hi"))) + b.WriteString("hi") + b.WriteByte(byte(len("text/plain"))) + b.WriteString("text/plain") + _ = binary.Write(&b, binary.LittleEndian, uint32(12)) // size + b.WriteByte(0) // attachment count + return b.Bytes() +} + +func TestHeaderEncode(t *testing.T) { + h := &fmsg.Header{ + Version: 1, + Flags: 0, + From: fmsg.Address{User: "alice", Domain: "example.com"}, + To: []fmsg.Address{{User: "bob", Domain: "example.com"}}, + Timestamp: 0, + Topic: "hi", + Type: "text/plain", + Size: 12, + } + got := h.Encode() + want := buildExpectedWire() + if !bytes.Equal(got, want) { + t.Errorf("Encode():\n got %x\n want %x", got, want) + } +} + +func TestHeaderEncodeHasPid(t *testing.T) { + // When FlagHasPid is set: 32 pid bytes written at offset 2; topic absent. + pid := make([]byte, 32) + for i := range pid { + pid[i] = byte(i) + } + h := &fmsg.Header{ + Version: 1, + Flags: fmsg.FlagHasPid, + Pid: pid, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Timestamp: 0, + Topic: "should-not-appear", + Type: "text/plain", + Size: 0, + } + wire := h.Encode() + if !bytes.Equal(wire[2:34], pid) { + t.Error("Encode() with FlagHasPid: pid bytes not at wire[2:34]") + } + if bytes.Contains(wire, []byte("should-not-appear")) { + t.Error("Encode() with FlagHasPid: topic must be absent from wire") + } +} + +func TestHeaderEncodeDeflate(t *testing.T) { + // When FlagDeflate is set, ExpandedSize uint32 follows Size. + h := &fmsg.Header{ + Version: 1, + Flags: fmsg.FlagDeflate, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Timestamp: 0, + Topic: "t", + Type: "application/octet-stream", + Size: 100, + ExpandedSize: 9999, + } + wire := h.Encode() + // Locate size bytes: offset = 1+1 + 1+7 + 1 + 1+7 + 8 + 1+1 + 1+24 = 55 + // version+flags = 2 + // fromlen+"@a@b.io" = 1+7 = 8; total 10 + // tocount = 1; total 11 + // tolen+"@c@d.io" = 1+7 = 8; total 19 + // timestamp = 8; total 27 + // topiclen+"t" = 1+1 = 2; total 29 + // typelen+"application/octet-stream" = 1+24 = 25; total 54 + // size uint32 = 4 bytes at [54:58] + // expanded size uint32 = 4 bytes at [58:62] + var size, expanded uint32 + _ = binary.Read(bytes.NewReader(wire[54:58]), binary.LittleEndian, &size) + _ = binary.Read(bytes.NewReader(wire[58:62]), binary.LittleEndian, &expanded) + if size != 100 { + t.Errorf("Encode() DeflateFlag: Size = %d, want 100", size) + } + if expanded != 9999 { + t.Errorf("Encode() DeflateFlag: ExpandedSize = %d, want 9999", expanded) + } +} + +func TestHeaderEncodeCommonType(t *testing.T) { + // When FlagCommonType is set, TypeID byte is written instead of type string. + // "text/csv" is ID 50. + h := &fmsg.Header{ + Version: 1, + Flags: fmsg.FlagCommonType, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Timestamp: 0, + Topic: "t", + TypeID: 50, + Type: "text/csv", + Size: 0, + } + wire := h.Encode() + if bytes.Contains(wire, []byte("text/csv")) { + t.Error("Encode() with FlagCommonType: type string must not appear in wire") + } + // TypeID byte is at offset: 2+8+1+8+8+2 = 29 + // (version+flags) + (fromlen+from) + tocount + (tolen+to) + timestamp + (topiclen+topic) + if wire[29] != 50 { + t.Errorf("Encode() with FlagCommonType: type byte at [29] = %d, want 50", wire[29]) + } +} + +func TestHeaderEncodeAttachment(t *testing.T) { + h := &fmsg.Header{ + Version: 1, + Flags: 0, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Timestamp: 0, + Topic: "t", + Type: "text/plain", + Size: 5, + Attachments: []fmsg.AttachmentHeader{ + { + Flags: 0, + Type: "image/png", + Filename: "pic.png", + Size: 1024, + }, + }, + } + wire := h.Encode() + if !bytes.Contains(wire, []byte("pic.png")) { + t.Error("Encode(): attachment filename not in wire") + } + if !bytes.Contains(wire, []byte("image/png")) { + t.Error("Encode(): attachment type not in wire") + } +} + +// ── GetHeaderHash ───────────────────────────────────────────────────────────── + +func TestGetHeaderHash(t *testing.T) { + h := &fmsg.Header{ + Version: 1, + Flags: 0, + From: fmsg.Address{User: "alice", Domain: "example.com"}, + To: []fmsg.Address{{User: "bob", Domain: "example.com"}}, + Timestamp: 0, + Topic: "hi", + Type: "text/plain", + Size: 12, + } + got := h.GetHeaderHash() + want := sha256.Sum256(buildExpectedWire()) + if !bytes.Equal(got, want[:]) { + t.Errorf("GetHeaderHash() = %x, want %x", got, want) + } +} + +func TestGetHeaderHashCached(t *testing.T) { + h := &fmsg.Header{ + Version: 1, Flags: 0, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Type: "text/plain", Size: 0, + } + h1 := h.GetHeaderHash() + h2 := h.GetHeaderHash() + if &h1[0] != &h2[0] { + t.Error("GetHeaderHash() should return the same slice on repeated calls (cached)") + } +} + +// ── GetMessageHash ──────────────────────────────────────────────────────────── + +// TestGetMessageHashSmall verifies a complete small-message hash against an +// independently computed expected value: sha256(encoded_header || body). +func TestGetMessageHashSmall(t *testing.T) { + const body = "Hello, fmsg!" + + f, err := os.CreateTemp(t.TempDir(), "fmsg-body-*") + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteString(body); err != nil { + t.Fatal(err) + } + f.Close() + + h := &fmsg.Header{ + Version: 1, + Flags: 0, + From: fmsg.Address{User: "alice", Domain: "example.com"}, + To: []fmsg.Address{{User: "bob", Domain: "example.com"}}, + Timestamp: 0, + Topic: "hi", + Type: "text/plain", + Size: uint32(len(body)), + Filepath: f.Name(), + } + + got, err := h.GetMessageHash() + if err != nil { + t.Fatalf("GetMessageHash() error: %v", err) + } + + // Expected: sha256(header_wire || body_bytes) computed independently. + wantHash := sha256.New() + wantHash.Write(buildExpectedWireWithSize(uint32(len(body)))) + wantHash.Write([]byte(body)) + want := wantHash.Sum(nil) + + if !bytes.Equal(got, want) { + t.Errorf("GetMessageHash():\n got %s\n want %s", hex.EncodeToString(got), hex.EncodeToString(want)) + } + + // Golden hash guards against regressions in Encode() or the hash algorithm. + const wantGolden = "eaefdd1cf1868078ff38ba882cbd31a2297af3db33cec888bd3e088bafdbcc3b" + if hex.EncodeToString(got) != wantGolden { + t.Errorf("GetMessageHash() golden:\n got %s\n want %s", hex.EncodeToString(got), wantGolden) + } +} + +// buildExpectedWireWithSize is like buildExpectedWire but uses a variable body size. +func buildExpectedWireWithSize(size uint32) []byte { + var b bytes.Buffer + b.WriteByte(1) // version + b.WriteByte(0) // flags + from := "@alice@example.com" + b.WriteByte(byte(len(from))) + b.WriteString(from) + b.WriteByte(1) + to := "@bob@example.com" + b.WriteByte(byte(len(to))) + b.WriteString(to) + _ = binary.Write(&b, binary.LittleEndian, float64(0)) + b.WriteByte(byte(len("hi"))) + b.WriteString("hi") + b.WriteByte(byte(len("text/plain"))) + b.WriteString("text/plain") + _ = binary.Write(&b, binary.LittleEndian, size) + b.WriteByte(0) + return b.Bytes() +} + +func TestGetMessageHashCached(t *testing.T) { + const body = "cached" + f, err := os.CreateTemp(t.TempDir(), "fmsg-body-*") + if err != nil { + t.Fatal(err) + } + f.WriteString(body) + f.Close() + + h := &fmsg.Header{ + Version: 1, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Type: "text/plain", + Size: uint32(len(body)), + Filepath: f.Name(), + } + h1, err := h.GetMessageHash() + if err != nil { + t.Fatal(err) + } + h2, err := h.GetMessageHash() + if err != nil { + t.Fatal(err) + } + if &h1[0] != &h2[0] { + t.Error("GetMessageHash() should return the same slice on repeated calls (cached)") + } +} + +func TestGetMessageHashWithAttachment(t *testing.T) { + const body = "body data" + const attData = "attachment data" + + bodyFile, err := os.CreateTemp(t.TempDir(), "fmsg-body-*") + if err != nil { + t.Fatal(err) + } + bodyFile.WriteString(body) + bodyFile.Close() + + attFile, err := os.CreateTemp(t.TempDir(), "fmsg-att-*") + if err != nil { + t.Fatal(err) + } + attFile.WriteString(attData) + attFile.Close() + + h := &fmsg.Header{ + Version: 1, + From: fmsg.Address{User: "a", Domain: "b.io"}, + To: []fmsg.Address{{User: "c", Domain: "d.io"}}, + Type: "text/plain", + Size: uint32(len(body)), + Filepath: bodyFile.Name(), + Attachments: []fmsg.AttachmentHeader{ + { + Flags: 0, + Type: "image/png", + Filename: "pic.png", + Size: uint32(len(attData)), + Filepath: attFile.Name(), + }, + }, + } + got, err := h.GetMessageHash() + if err != nil { + t.Fatalf("GetMessageHash() with attachment error: %v", err) + } + + // Expected: sha256(header_wire || body || attachment_data) + hw := sha256.New() + hw.Write(h.Encode()) + hw.Write([]byte(body)) + hw.Write([]byte(attData)) + want := hw.Sum(nil) + + if !bytes.Equal(got, want) { + t.Errorf("GetMessageHash() with attachment:\n got %x\n want %x", got, want) + } +} + +// ── HashPayload ─────────────────────────────────────────────────────────────── + +func TestHashPayloadPlain(t *testing.T) { + content := []byte("payload bytes") + f, err := os.CreateTemp(t.TempDir(), "fmsg-payload-*") + if err != nil { + t.Fatal(err) + } + f.Write(content) + f.Close() + + var dst bytes.Buffer + if err := fmsg.HashPayload(&dst, f.Name(), int64(len(content)), false, 0); err != nil { + t.Fatalf("HashPayload() error: %v", err) + } + if !bytes.Equal(dst.Bytes(), content) { + t.Errorf("HashPayload() wrote %x, want %x", dst.Bytes(), content) + } +} + +func TestHashPayloadDeflated(t *testing.T) { + plain := []byte("hello compressed world") + + // Write zlib-compressed content to a temp file. + f, err := os.CreateTemp(t.TempDir(), "fmsg-deflated-*") + if err != nil { + t.Fatal(err) + } + zw := zlib.NewWriter(f) + zw.Write(plain) + zw.Close() + wireSize, _ := f.Seek(0, 1) // current offset = compressed size + f.Close() + + var dst bytes.Buffer + if err := fmsg.HashPayload(&dst, f.Name(), wireSize, true, uint32(len(plain))); err != nil { + t.Fatalf("HashPayload() deflated error: %v", err) + } + if !bytes.Equal(dst.Bytes(), plain) { + t.Errorf("HashPayload() deflated wrote %q, want %q", dst.Bytes(), plain) + } +} + +func TestHashPayloadDeflatedSizeMismatch(t *testing.T) { + plain := []byte("data") + f, err := os.CreateTemp(t.TempDir(), "fmsg-deflated-*") + if err != nil { + t.Fatal(err) + } + zw := zlib.NewWriter(f) + zw.Write(plain) + zw.Close() + wireSize, _ := f.Seek(0, 1) + f.Close() + + var dst bytes.Buffer + err = fmsg.HashPayload(&dst, f.Name(), wireSize, true, uint32(len(plain))+99) + if err == nil { + t.Error("HashPayload() should error when expanded size does not match") + } +} + +// ── Common media types ──────────────────────────────────────────────────────── + +func TestGetCommonMediaType(t *testing.T) { + tests := []struct { + id uint8 + want string + }{ + {3, "application/json"}, + {6, "application/pdf"}, + {38, "image/png"}, + {50, "text/csv"}, + {56, "text/plain;charset=UTF-8"}, + {64, "video/webm"}, + } + for _, tt := range tests { + got, ok := fmsg.GetCommonMediaType(tt.id) + if !ok { + t.Errorf("GetCommonMediaType(%d): not found", tt.id) + } + if got != tt.want { + t.Errorf("GetCommonMediaType(%d) = %q, want %q", tt.id, got, tt.want) + } + } +} + +func TestGetCommonMediaTypeUnknown(t *testing.T) { + _, ok := fmsg.GetCommonMediaType(0) + if ok { + t.Error("GetCommonMediaType(0) should return false") + } + _, ok = fmsg.GetCommonMediaType(65) + if ok { + t.Error("GetCommonMediaType(65) should return false") + } +} + +func TestGetCommonMediaTypeID(t *testing.T) { + tests := []struct { + mime string + want uint8 + }{ + {"application/json", 3}, + {"application/pdf", 6}, + {"image/png", 38}, + {"text/csv", 50}, + {"text/plain;charset=UTF-8", 56}, + {"video/webm", 64}, + } + for _, tt := range tests { + got, ok := fmsg.GetCommonMediaTypeID(tt.mime) + if !ok { + t.Errorf("GetCommonMediaTypeID(%q): not found", tt.mime) + } + if got != tt.want { + t.Errorf("GetCommonMediaTypeID(%q) = %d, want %d", tt.mime, got, tt.want) + } + } +} + +func TestGetCommonMediaTypeIDUnknown(t *testing.T) { + _, ok := fmsg.GetCommonMediaTypeID("application/unknown") + if ok { + t.Error("GetCommonMediaTypeID(unknown) should return false") + } +} + +func TestCommonMediaTypeRoundTrip(t *testing.T) { + // Every ID in the valid range 1–64 must round-trip: ID → string → ID. + for id := uint8(1); id <= 64; id++ { + mime, ok := fmsg.GetCommonMediaType(id) + if !ok { + t.Errorf("GetCommonMediaType(%d): not found", id) + continue + } + got, ok := fmsg.GetCommonMediaTypeID(mime) + if !ok { + t.Errorf("GetCommonMediaTypeID(%q): not found (from ID %d)", mime, id) + continue + } + if got != id { + t.Errorf("round-trip ID %d → %q → %d", id, mime, got) + } + } +} diff --git a/src/.env.example b/src/.env.example deleted file mode 100644 index dbedf7e..0000000 --- a/src/.env.example +++ /dev/null @@ -1,23 +0,0 @@ -# fmsgd Environment Variables - -# Required -FMSG_DATA_DIR=/var/lib/fmsgd/ -FMSG_DOMAIN=example.com -FMSG_ID_URL=http://127.0.0.1:8080 - - -FMSG_MAX_MSG_SIZE=10240 -FMSG_MAX_EXPANDED_SIZE=10240 -FMSG_MAX_PAST_TIME_DELTA=604800 -FMSG_MAX_FUTURE_TIME_DELTA=300 -FMSG_MIN_DOWNLOAD_RATE=5000 -FMSG_MIN_UPLOAD_RATE=5000 -FMSG_READ_BUFFER_SIZE=1600 - -# PostgreSQL connection variables (see https://www.postgresql.org/docs/current/libpq-envars.html) -PGHOST=127.0.0.1 -PGPORT=5432 -PGUSER= -PGPASSWORD= -PGDATABASE=fmsgd -PGSSLMODE=disable diff --git a/src/defs.go b/src/defs.go deleted file mode 100644 index a0ee4fa..0000000 --- a/src/defs.go +++ /dev/null @@ -1,248 +0,0 @@ -package main - -import ( - "bytes" - "compress/zlib" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "fmt" - "io" - "os" - "strings" -) - -type FMsgAddress struct { - User string - Domain string -} - -type FMsgAttachmentHeader struct { - Flags uint8 - TypeID uint8 - Type string - Filename string - Size uint32 - ExpandedSize uint32 - - Filepath string -} - -type FMsgHeader struct { - Version uint8 - Flags uint8 - Pid []byte - From FMsgAddress - To []FMsgAddress - AddToFrom *FMsgAddress // Present when has-add-to flag is set - AddTo []FMsgAddress - Timestamp float64 - TypeID uint8 - Topic string - Type string - - // Size in bytes of entire message - Size uint32 - ExpandedSize uint32 // Decompressed size; present on wire iff FlagDeflate set - Attachments []FMsgAttachmentHeader - - HeaderHash []byte - ChallengeHash [32]byte - ChallengeCompleted bool // True if challenge was initiated and completed - InitialResponseCode uint8 // Protocol response chosen after header validation (11/64/65) - Filepath string - messageHash []byte -} - -// Returns a string representation of an address in the form @user@example.com -func (addr *FMsgAddress) ToString() string { - return fmt.Sprintf("@%s@%s", addr.User, addr.Domain) -} - -// Encode the message header to wire format as a []byte. This includes all -// fields up to and including the attachment headers per spec. This function -// will panic on error instead of returning one. -func (h *FMsgHeader) Encode() []byte { - var b bytes.Buffer - b.WriteByte(h.Version) - b.WriteByte(h.Flags) - if h.Flags&FlagHasPid == 1 { - b.Write(h.Pid[:]) - } - str := h.From.ToString() - b.WriteByte(byte(len(str))) - b.WriteString(str) - b.WriteByte(byte(len(h.To))) - for _, addr := range h.To { - str = addr.ToString() - b.WriteByte(byte(len(str))) - b.WriteString(str) - } - if h.Flags&FlagHasAddTo != 0 { - // add-to-from address (field 6) - addToFrom := h.AddToFrom - if addToFrom == nil { - addToFrom = &h.From - } - str := addToFrom.ToString() - b.WriteByte(byte(len(str))) - b.WriteString(str) - // add-to addresses (field 7) - b.WriteByte(byte(len(h.AddTo))) - for _, addr := range h.AddTo { - str = addr.ToString() - b.WriteByte(byte(len(str))) - b.WriteString(str) - } - } - if err := binary.Write(&b, binary.LittleEndian, h.Timestamp); err != nil { - panic(err) - } - // topic is only present when pid is NOT set - if h.Flags&FlagHasPid == 0 { - b.WriteByte(byte(len(h.Topic))) - b.WriteString(h.Topic) - } - if h.Flags&FlagCommonType != 0 { - typeID := h.TypeID - if typeID == 0 { - if id, ok := getCommonMediaTypeID(h.Type); ok { - typeID = id - } - } - b.WriteByte(typeID) - } else { - b.WriteByte(byte(len(h.Type))) - b.WriteString(h.Type) - } - // size (uint32 LE) - if err := binary.Write(&b, binary.LittleEndian, h.Size); err != nil { - panic(err) - } - // expanded size (uint32 LE) — present iff zlib-deflate flag set - if h.Flags&FlagDeflate != 0 { - if err := binary.Write(&b, binary.LittleEndian, h.ExpandedSize); err != nil { - panic(err) - } - } - // attachment headers - b.WriteByte(byte(len(h.Attachments))) - for _, att := range h.Attachments { - b.WriteByte(att.Flags) - if att.Flags&1 != 0 { - typeID := att.TypeID - if typeID == 0 { - if id, ok := getCommonMediaTypeID(att.Type); ok { - typeID = id - } - } - b.WriteByte(typeID) - } else { - b.WriteByte(byte(len(att.Type))) - b.WriteString(att.Type) - } - b.WriteByte(byte(len(att.Filename))) - b.WriteString(att.Filename) - if err := binary.Write(&b, binary.LittleEndian, att.Size); err != nil { - panic(err) - } - // attachment expanded size — present iff attachment zlib-deflate flag set - if att.Flags&(1<<1) != 0 { - if err := binary.Write(&b, binary.LittleEndian, att.ExpandedSize); err != nil { - panic(err) - } - } - } - return b.Bytes() -} - -// String returns a human-readable summary of the header fields. -func (h *FMsgHeader) String() string { - var b strings.Builder - fmt.Fprintf(&b, "v%d flags=%d", h.Version, h.Flags) - if len(h.Pid) > 0 { - fmt.Fprintf(&b, " pid=%s", hex.EncodeToString(h.Pid)) - } - fmt.Fprintf(&b, "\nfrom:\t%s", h.From.ToString()) - for i, addr := range h.To { - if i == 0 { - fmt.Fprintf(&b, "\nto:\t%s", addr.ToString()) - } else { - fmt.Fprintf(&b, "\n\t%s", addr.ToString()) - } - } - for i, addr := range h.AddTo { - if i == 0 { - fmt.Fprintf(&b, "\nadd to:\t%s", addr.ToString()) - } else { - fmt.Fprintf(&b, "\n\t%s", addr.ToString()) - } - } - fmt.Fprintf(&b, "\ntopic:\t%s", h.Topic) - fmt.Fprintf(&b, "\ntype:\t%s", h.Type) - fmt.Fprintf(&b, "\nsize:\t%d", h.Size) - return b.String() -} - -func (h *FMsgHeader) GetHeaderHash() []byte { - if h.HeaderHash == nil { - b := sha256.Sum256(h.Encode()) - h.HeaderHash = b[:] - } - return h.HeaderHash -} - -func (h *FMsgHeader) GetMessageHash() ([]byte, error) { - if h.messageHash == nil { - hash := sha256.New() - - headerBytes := h.Encode() - if _, err := io.Copy(hash, bytes.NewBuffer(headerBytes)); err != nil { - return nil, err - } - - if err := hashPayload(hash, h.Filepath, int64(h.Size), h.Flags&FlagDeflate != 0, h.ExpandedSize); err != nil { - return nil, err - } - - // include attachment data (sequential byte sequences following - // the message body, bounded by attachment header sizes) - for _, att := range h.Attachments { - compressed := att.Flags&(1<<1) != 0 - if err := hashPayload(hash, att.Filepath, int64(att.Size), compressed, att.ExpandedSize); err != nil { - return nil, fmt.Errorf("hash attachment %s: %w", att.Filename, err) - } - } - - h.messageHash = hash.Sum(nil) - } - return h.messageHash, nil -} - -func hashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool, expandedSize uint32) error { - f, err := os.Open(filepath) - if err != nil { - return err - } - defer f.Close() - - if deflated { - lr := io.LimitReader(f, wireSize) - zr, err := zlib.NewReader(lr) - if err != nil { - return err - } - written, err := io.Copy(dst, zr) - _ = zr.Close() - if err != nil { - return err - } - if uint32(written) != expandedSize { - return fmt.Errorf("decompressed size %d does not match declared expanded size %d", written, expandedSize) - } - return nil - } - - _, err = io.CopyN(dst, f, wireSize) - return err -}