From f293d2e4b6bc149fa59f099bb150eef95bafa6ad Mon Sep 17 00:00:00 2001 From: tmtrts Date: Mon, 26 Oct 2015 23:48:02 +0100 Subject: [PATCH] Use ReadUntill to protect against ddos. --- mta/mta.go | 29 ++++-- mta/mta_test.go | 16 +-- smtp/datareader_test.go | 86 ++++++++++++++++ smtp/parser.go | 36 +++---- smtp/parser_test.go | 24 ++--- smtp/protocol.go | 220 +++++++++++++++++++++++----------------- 6 files changed, 273 insertions(+), 138 deletions(-) create mode 100644 smtp/datareader_test.go diff --git a/mta/mta.go b/mta/mta.go index 522c29f..83b8168 100644 --- a/mta/mta.go +++ b/mta/mta.go @@ -205,15 +205,32 @@ func (s *Mta) HandleClient(proto smtp.Protocol) { }) var c *smtp.Cmd - var ok bool + var err error quit := false cmdC := make(chan bool) nextCmd := func() bool { go func() { - c, ok = proto.GetCmd() - cmdC <- true + for { + c, err = proto.GetCmd() + + if err != nil { + if err == smtp.ErrLtl { + proto.Send(smtp.Answer{ + Status: smtp.SyntaxError, + Message: "Line too long.", + }) + } else { + // Not a line too long error. What to do? + cmdC <- true + return + } + } else { + break + } + } + cmdC <- false }() select { @@ -225,8 +242,8 @@ func (s *Mta) HandleClient(proto smtp.Protocol) { }) return true } - case _ = <-cmdC: - return false + case q := <-cmdC: + return q } @@ -235,7 +252,7 @@ func (s *Mta) HandleClient(proto smtp.Protocol) { quit = nextCmd() - for ok == true && quit == false { + for quit == false { //log.Printf("Received cmd: %#v", *c) diff --git a/mta/mta_test.go b/mta/mta_test.go index 1d8949b..d78ef1c 100644 --- a/mta/mta_test.go +++ b/mta/mta_test.go @@ -1,7 +1,9 @@ package mta import ( + "bufio" "bytes" + "io" "testing" "github.com/gopistolet/gopistolet/smtp" @@ -47,18 +49,18 @@ func (p *testProtocol) Send(cmd smtp.Cmd) { } } -func (p *testProtocol) GetCmd() (*smtp.Cmd, bool) { +func (p *testProtocol) GetCmd() (*smtp.Cmd, error) { p.ctx.So(len(p.cmds), c.ShouldBeGreaterThan, 0) cmd := p.cmds[0] p.cmds = p.cmds[1:] if cmd == nil { - return nil, false + return nil, io.EOF } //c.Printf("SENDING: %#v\n", cmd) - return &cmd, true + return &cmd, nil } func (p *testProtocol) Close() { @@ -219,7 +221,7 @@ func TestMailAnswersCorrectSequence(t *testing.T) { To: getMailWithoutError("guy2@somewhere.test"), }, smtp.DataCmd{ - R: *smtp.NewDataReader(bytes.NewReader([]byte("Some test email\n.\n"))), + R: *smtp.NewDataReader(bufio.NewReader(bytes.NewReader([]byte("Some test email\n.\n")))), }, smtp.QuitCmd{}, }, @@ -450,7 +452,7 @@ func TestReset(t *testing.T) { To: getMailWithoutError("guy1@somewhere.test"), }, smtp.DataCmd{ - R: *smtp.NewDataReader(bytes.NewReader([]byte("Some email content\n.\n"))), + R: *smtp.NewDataReader(bufio.NewReader(bytes.NewReader([]byte("Some email content\n.\n")))), }, smtp.RcptCmd{ To: getMailWithoutError("someguy@somewhere.test"), @@ -517,7 +519,7 @@ func TestReset(t *testing.T) { To: getMailWithoutError("guy1@somewhere.test"), }, smtp.DataCmd{ - R: *smtp.NewDataReader(bytes.NewReader([]byte("Some email\n.\n"))), + R: *smtp.NewDataReader(bufio.NewReader(bytes.NewReader([]byte("Some email\n.\n")))), }, smtp.QuitCmd{}, }, @@ -592,7 +594,7 @@ func TestReset(t *testing.T) { To: getMailWithoutError("guy1@somewhere.test"), }, smtp.DataCmd{ - R: *smtp.NewDataReader(bytes.NewReader([]byte("Some email\n.\n"))), + R: *smtp.NewDataReader(bufio.NewReader(bytes.NewReader([]byte("Some email\n.\n")))), }, smtp.QuitCmd{}, }, diff --git a/smtp/datareader_test.go b/smtp/datareader_test.go new file mode 100644 index 0000000..3f95bd1 --- /dev/null +++ b/smtp/datareader_test.go @@ -0,0 +1,86 @@ +package smtp + +import ( + "bufio" + "bytes" + "io/ioutil" + "testing" +) + +func compare(t *testing.T, data []byte, expected []byte) { + br := bufio.NewReader(bytes.NewReader(data)) + + dataReader := NewDataReader(br) + output, err := ioutil.ReadAll(dataReader) + if bytes.Compare(output, expected) != 0 { + t.Errorf("Expected %v\ngot %v\n", expected, output) + } + if err != nil { + t.Errorf("Did not expect error: %v", err) + } + +} + +func expectError(t *testing.T, data []byte, expected error) { + br := bufio.NewReader(bytes.NewReader(data)) + dataReader := NewDataReader(br) + _, err := ioutil.ReadAll(dataReader) + if err != expected { + t.Errorf("Expected error: %v, got: %v", expected, err) + } + +} + +func TestDataReaderValid(t *testing.T) { + data := []byte("Some test mail\nblablabla\n.\n") + expected := []byte("Some test mail\nblablabla\n") + compare(t, data, expected) + + data = []byte("Some test mail\nblablabla\n.\nshould not read this") + expected = []byte("Some test mail\nblablabla\n") + compare(t, data, expected) + + data = []byte("Some test mail\n..blablabla\n.\n") + expected = []byte("Some test mail\n.blablabla\n") + compare(t, data, expected) + + data = []byte("Some test mail\n.blablabla\n.\n") + expected = []byte("Some test mail\nblablabla\n") + compare(t, data, expected) + + // first line is 1000 chars + data = []byte("aafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddd\n.\n") + expected = []byte("aafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddd\n") + compare(t, data, expected) + + // first line is 1001 chars but starts with a dot, so server should see it as 1000 + data = []byte(".aafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsdddddd\n.\n") + expected = []byte("aafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsdddddd\n") + compare(t, data, expected) + + // first line is 1000 chars, second 10, third 1000 + data = []byte("aafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddd\naj ge je a t\naafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddd\n.\n") + expected = []byte("aafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddd\naj ge je a t\naafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddd\n") + compare(t, data, expected) +} + +func TestDataReaderInvalid(t *testing.T) { + data := []byte("Some test mail\nblablabla\nno ending dot") + expectError(t, data, ErrIncomplete) + + data = []byte("Some test mail\r\nDot on invalid place\n.test") + expectError(t, data, ErrIncomplete) + + data = []byte("") + expectError(t, data, ErrIncomplete) +} + +func TestDataReaderTooLong(t *testing.T) { + // length === 1001 + data := []byte("aafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddd3\n") + expectError(t, data, ErrLtl) + + // first line is small, second is 1003 + data = []byte("Some text :)\naafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddddddddfsdaafsddddddd321\n") + expectError(t, data, ErrLtl) +} diff --git a/smtp/parser.go b/smtp/parser.go index c726fa3..733354f 100644 --- a/smtp/parser.go +++ b/smtp/parser.go @@ -1,6 +1,9 @@ package smtp -import "bufio" +import ( + "bufio" + "log" +) import "strings" import "errors" @@ -21,20 +24,9 @@ func (p *parser) ParseCommand(br *bufio.Reader) (command Cmd, err error) { servers (see Section 4). */ - line, err := br.ReadString('\n') - if err != nil { - return nil, err - } - - for line == "" { - line, err = br.ReadString('\n') - if err != nil { - return nil, err - } - } - var address *MailAddress - verb, args, err := parseLine(line) + verb, args, err := parseLine(br) + log.Printf("Verb: %s. args: %v", verb, args) if err != nil { return nil, err } @@ -156,7 +148,8 @@ func (p *parser) ParseCommand(br *bufio.Reader) (command Cmd, err error) { default: { - command = UnknownCmd{Cmd: verb, Line: strings.TrimSuffix(line, "\n")} + // TODO: CLEAN THIS UP + command = UnknownCmd{Cmd: verb, Line: strings.TrimSuffix(verb, "\n")} } } @@ -165,7 +158,7 @@ func (p *parser) ParseCommand(br *bufio.Reader) (command Cmd, err error) { } // parseLine returns the verb of the line and a list of all comma separated arguments -func parseLine(line string) (verb string, args []string, err error) { +func parseLine(br *bufio.Reader) (verb string, args []string, err error) { /* RFC 5321 @@ -175,9 +168,16 @@ func parseLine(line string) (verb string, args []string, err error) { and the is 512 octets. SMTP extensions may be used to increase this limit. */ - if len(line) > 512 { - return "", []string{}, errors.New("Line too long") + buffer, err := ReadUntill('\n', MAX_CMD_LINE, br) + if err != nil { + if err == ErrLtl { + SkipTillNewline(br) + return string(buffer), []string{}, err + } + + return string(buffer), []string{}, err } + line := string(buffer) // Strip \n and \r line = strings.TrimSuffix(line, "\n") diff --git a/smtp/parser_test.go b/smtp/parser_test.go index 16003e1..138428c 100644 --- a/smtp/parser_test.go +++ b/smtp/parser_test.go @@ -48,7 +48,6 @@ func TestParser(t *testing.T) { } for _, expectedCommand := range expectedCommands { - Print(expectedCommand) command, err := p.ParseCommand(br) So(err, ShouldEqual, nil) So(command, ShouldResemble, expectedCommand) @@ -59,8 +58,6 @@ func TestParser(t *testing.T) { Convey("Testing parser DATA cmd", t, func() { commands := "" commands += "DATA\r\n" - commands += "Some usefull data.\r\n" - commands += ".\r\n" commands += "quit\r\n" br := bufio.NewReader(strings.NewReader(commands)) @@ -69,13 +66,6 @@ func TestParser(t *testing.T) { command, err := p.ParseCommand(br) So(err, ShouldEqual, nil) So(command, ShouldHaveSameTypeAs, DataCmd{}) - dataCommand, ok := command.(DataCmd) - So(ok, ShouldEqual, true) - br2 := bufio.NewReader(dataCommand.R.r) - line, _ := br2.ReadString('\n') - So(line, ShouldEqual, "Some usefull data.\r\n") - line, _ = br2.ReadString('\n') - So(line, ShouldEqual, ".\r\n") command, err = p.ParseCommand(br) So(err, ShouldEqual, nil) @@ -140,16 +130,16 @@ func TestParser(t *testing.T) { args []string }{ { - line: "HELO", + line: "HELO\r\n", verb: "HELO", }, { - line: "HELO relay.example.org", + line: "HELO relay.example.org\r\n", verb: "HELO", args: []string{"relay.example.org"}, }, { - line: "MAIL FROM:", + line: "MAIL FROM:\r\n", verb: "MAIL", args: []string{"FROM:"}, }, @@ -166,7 +156,8 @@ func TestParser(t *testing.T) { } for _, test := range tests { - verb, args, err := parseLine(test.line) + br := bufio.NewReader(strings.NewReader(test.line)) + verb, args, err := parseLine(br) So(err, ShouldEqual, nil) So(verb, ShouldEqual, test.verb) So(args, ShouldResemble, test.args) @@ -181,13 +172,14 @@ func TestParser(t *testing.T) { addressString string }{ { - line: "RCPT TO:", + line: "RCPT TO:\r\n", addressString: "alice@example.com", }, } for _, test := range tests { - _, args, err := parseLine(test.line) + br := bufio.NewReader(strings.NewReader(test.line)) + _, args, err := parseLine(br) So(err, ShouldEqual, nil) addr, err := parseTO(args) diff --git a/smtp/protocol.go b/smtp/protocol.go index 6715a81..e168ab1 100644 --- a/smtp/protocol.go +++ b/smtp/protocol.go @@ -2,7 +2,6 @@ package smtp import ( "bufio" - "bytes" "errors" "fmt" "io" @@ -33,124 +32,164 @@ var ErrLtl = errors.New("Line too long") // ErrIncomplete Incomplete data error var ErrIncomplete = errors.New("Incomplete data") -type LimitedReader struct { - R io.Reader // underlying reader - N int // max bytes remaining - Delim byte -} +const ( + MAX_DATA_LINE = 1000 + MAX_CMD_LINE = 512 +) -func (l *LimitedReader) Read(p []byte) (int, error) { - if l.N <= 0 { - return 0, io.EOF - } +// ReadUntill reads untill delim is found or max bytes are read. +// If delim was found it returns nil as error. If delim wasn't found after max bytes, +// it returns ErrLtl. +func ReadUntill(delim byte, max int, r io.Reader) ([]byte, error) { + buffer := make([]byte, max) + + n := 0 + for n < max { + read, err := r.Read(buffer[n : n+1]) + if read == 0 || err != nil { + return buffer[0:n], err + } + + if read > 1 { + panic("Should read 1 byte at a time.") + } + + if buffer[n] == delim { + return buffer[0 : n+1], nil + } + + n++ - if len(p) > l.N { - p = p[0:l.N] } - bytesRead := 0 - buf := make([]byte, 1) - for l.N > 0 && bytesRead < len(p) { - n, err := l.R.Read(buf) + return buffer[0:n], ErrLtl +} - if n > 0 { - p[bytesRead] = buf[0] - l.N -= n - bytesRead += n - if buf[0] == l.Delim { - break +// SkipTillNewline removes all data untill a newline is found. +func SkipTillNewline(r io.Reader) error { + var err error + for { + _, err = ReadUntill('\n', 1000, r) + if err != nil { + if err == ErrLtl { + continue } + break } - if err != nil { - return bytesRead, err - } + break } - return bytesRead, nil + return err } -const ( - MAX_LINE = 1000 -) - // DataReader implements the reader that will read the data from a MAIL cmd type DataReader struct { - r io.Reader - buffer []byte + br *bufio.Reader + state int + bytesInLine int } -func NewDataReader(r io.Reader) *DataReader { +func NewDataReader(br *bufio.Reader) *DataReader { dr := &DataReader{ - r: r, - buffer: make([]byte, 0, MAX_LINE), + br: br, } return dr } -func (r *DataReader) Read(p []byte) (int, error) { - var n int = 0 - - if len(r.buffer) > 0 { - n = copy(p, r.buffer) - r.buffer = r.buffer[n:] - return n, nil - } - - limited := &LimitedReader{ - R: r.r, - N: MAX_LINE + 1, - Delim: '\n', - } - - br := bufio.NewReader(limited) +// Implementation from textproto.DotReader.Read +func (r *DataReader) Read(b []byte) (n int, err error) { + // Run data through a simple state machine to + // elide leading dots, rewrite trailing \r\n into \n, + // and detect ending .\r\n line. + const ( + stateBeginLine = iota // beginning of line; initial state; must be zero + stateDot // read . at beginning of line + stateDotCR // read .\r at beginning of line + stateCR // read \r (possibly at end of line) + stateData // reading data in middle of line + stateEOF // reached .\r\n end marker line + ) + + br := r.br + for n < len(b) && r.state != stateEOF { + var c byte + c, err = br.ReadByte() + if err != nil { + err = ErrIncomplete - line, err := br.ReadBytes('\n') - lineLen := len(line) - if lineLen > 0 && line[len(line)-1] != '\n' { - buf := make([]byte, 1) + } + r.bytesInLine++ + if r.bytesInLine > MAX_DATA_LINE { + err = ErrLtl + break + } + switch r.state { + case stateBeginLine: + if c == '.' { + r.state = stateDot + continue + } + if c == '\r' { + r.state = stateCR + continue + } + r.state = stateData - for n, err := r.r.Read(buf); ; { - if n > 0 { - if buf[0] == '\n' { - break - } + case stateDot: + if c == '\r' { + r.state = stateDotCR + continue } + if c == '\n' { + r.state = stateEOF + continue + } + r.state = stateData - if err != nil { + case stateDotCR: + if c == '\n' { + r.state = stateEOF + continue + } + // Not part of .\r\n. + // Consume leading dot and emit saved \r. + br.UnreadByte() + c = '\r' + r.state = stateData + + case stateCR: + if c == '\n' { + r.state = stateBeginLine + r.bytesInLine = 0 break } - - n, err = r.r.Read(buf) + // Not part of \r\n. Emit saved \r + br.UnreadByte() + c = '\r' + r.state = stateData + + case stateData: + if c == '\r' { + r.state = stateCR + continue + } + if c == '\n' { + r.state = stateBeginLine + r.bytesInLine = 0 + } } + b[n] = c + n++ } - fmt.Printf("Read %d bytes\n", lineLen) - - if bytes.Compare(line, []byte(".\r\n")) == 0 || - bytes.Compare(line, []byte(".\r")) == 0 || - bytes.Compare(line, []byte(".\n")) == 0 { - - return 0, io.EOF - } else if lineLen > 2 && line[0] == '.' { - line = line[1:] - lineLen-- - } - - if lineLen > MAX_LINE { - return 0, ErrLtl - } - - n = copy(p, line) - r.buffer = r.buffer[0 : lineLen-n] - copy(r.buffer, line[n:]) - if err == io.EOF { - return 0, ErrIncomplete + if err == nil && r.state == stateEOF { + err = io.EOF } - return n, nil + return } // Cmd All SMTP answers/commands should implement this interface. @@ -314,9 +353,8 @@ type Protocol interface { // Send a SMTP command. Send(Cmd) // Receive a command(will block while waiting for it). - // Returns false if there are no more commands left. Otherwise a command will be returned. - // We need the bool because if we just return nil, the nil will also implement the empty interface... - GetCmd() (*Cmd, bool) + // Returns an error if something wen't wrong. E.g line was too long. + GetCmd() (*Cmd, error) // Close the connection. Close() } @@ -343,14 +381,14 @@ func (p *MtaProtocol) Send(c Cmd) { fmt.Fprintf(p.c, "%s\r\n", c) } -func (p *MtaProtocol) GetCmd() (c *Cmd, ok bool) { +func (p *MtaProtocol) GetCmd() (c *Cmd, err error) { cmd, err := p.parser.ParseCommand(p.br) if err != nil { log.Printf("Could not parse command: %v", err) - return nil, false + return nil, err } - return &cmd, true + return &cmd, nil } func (p *MtaProtocol) Close() {