diff --git a/services/mailer/sender/message.go b/services/mailer/sender/message.go index 55f03e4f7ec63..dd790fbf94a69 100644 --- a/services/mailer/sender/message.go +++ b/services/mailer/sender/message.go @@ -40,6 +40,11 @@ func (m *Message) ToMessage() *gomail.Msg { if m.ReplyTo != "" { msg.SetGenHeader("Reply-To", m.ReplyTo) } + if setting.MailService.OverrideEnvelopeFrom { + if err := msg.EnvelopeFrom(setting.MailService.EnvelopeFrom); err != nil { + log.Error("Failed to set Envelope-From header: %v", err) + } + } for header := range m.Headers { msg.SetGenHeader(gomail.Header(header), m.Headers[header]...) } diff --git a/services/mailer/sender/message_test.go b/services/mailer/sender/message_test.go index ae153ebf05d0c..30423abfbc646 100644 --- a/services/mailer/sender/message_test.go +++ b/services/mailer/sender/message_test.go @@ -11,6 +11,7 @@ import ( "code.gitea.io/gitea/modules/setting" "github.com/stretchr/testify/assert" + gomail "github.com/wneessen/go-mail" ) func TestGenerateMessageID(t *testing.T) { @@ -99,6 +100,25 @@ func TestToMessage(t *testing.T) { }, header) } +func TestToMessageEnvelopeFromOverride(t *testing.T) { + oldConf := setting.MailService + defer func() { + setting.MailService = oldConf + }() + + setting.MailService = &setting.Mailer{ + From: "test@gitea.com", + OverrideEnvelopeFrom: true, + EnvelopeFrom: "bounce@gitea.com", + } + + msg := (&Message{FromAddress: "test@gitea.com", To: "user@example.com"}).ToMessage() + + envelope := msg.GetAddrHeaderString(gomail.HeaderEnvelopeFrom) + assert.Len(t, envelope, 1) + assert.Equal(t, "", envelope[0]) +} + func extractMailHeaderAndContent(t *testing.T, mail string) (map[string]string, string) { header := make(map[string]string) diff --git a/services/mailer/sender/smtp.go b/services/mailer/sender/smtp.go index 3207eee32fcec..b930c58ec4cbc 100644 --- a/services/mailer/sender/smtp.go +++ b/services/mailer/sender/smtp.go @@ -4,154 +4,222 @@ package sender import ( + "context" "crypto/tls" - "errors" "fmt" "io" - "net" "os" + "strconv" "strings" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + gomail "github.com/wneessen/go-mail" "github.com/wneessen/go-mail/smtp" ) +type gomailClient interface { + Close() error + DialAndSend(...*gomail.Msg) error + DialToSMTPClientWithContext(context.Context) (*smtp.Client, error) + CloseWithSMTPClient(*smtp.Client) error + SetSMTPAuth(gomail.SMTPAuthType) + SetSMTPAuthCustom(smtp.Auth) +} + +var ( + newGomailClient = func(host string, opts ...gomail.Option) (gomailClient, error) { return gomail.NewClient(host, opts...) } + probeSMTPServerFunc = probeSMTPServer +) + // SMTPSender Sender SMTP mail sender type SMTPSender struct{} var _ Sender = &SMTPSender{} // Send send email -func (s *SMTPSender) Send(from string, to []string, msg io.WriterTo) error { +func (s *SMTPSender) Send(_ string, _ []string, msg io.WriterTo) error { opts := setting.MailService - var network string - var address string - if opts.Protocol == "smtp+unix" { - network = "unix" - address = opts.SMTPAddr - } else { - network = "tcp" - address = net.JoinHostPort(opts.SMTPAddr, opts.SMTPPort) + mailMsg, ok := msg.(*gomail.Msg) + if !ok { + return fmt.Errorf("unexpected message type %T", msg) } - conn, err := net.Dial(network, address) - if err != nil { - return fmt.Errorf("failed to establish network connection to SMTP server: %w", err) + host := opts.SMTPAddr + protocol := opts.Protocol + if protocol == "" { + protocol = "smtp" } - defer conn.Close() - - var tlsconfig *tls.Config - if opts.Protocol == "smtps" || opts.Protocol == "smtp+starttls" { - tlsconfig = &tls.Config{ - InsecureSkipVerify: opts.ForceTrustServerCert, - ServerName: opts.SMTPAddr, - } - if opts.UseClientCert { - cert, err := tls.LoadX509KeyPair(opts.ClientCertFile, opts.ClientKeyFile) + var clientOpts []gomail.Option + if opts.EnableHelo { + helo := opts.HeloHostname + if helo == "" { + var err error + helo, err = os.Hostname() if err != nil { - return fmt.Errorf("could not load SMTP client certificate: %w", err) + return fmt.Errorf("could not retrieve system hostname: %w", err) } - tlsconfig.Certificates = []tls.Certificate{cert} } + clientOpts = append(clientOpts, gomail.WithHELO(helo)) } - if opts.Protocol == "smtps" { - conn = tls.Client(conn, tlsconfig) + authHost := opts.SMTPAddr + + switch protocol { + case "smtp+unix": + host = "unix://" + opts.SMTPAddr + clientOpts = append(clientOpts, gomail.WithTLSPolicy(gomail.NoTLS)) + case "smtps": + port, err := parseSMTPPort(opts.SMTPPort) + if err != nil { + return err + } + tlsConfig, err := buildTLSConfig(opts) + if err != nil { + return err + } + clientOpts = append(clientOpts, + gomail.WithPort(port), + gomail.WithTLSConfig(tlsConfig), + gomail.WithSSL(), + ) + case "smtp+starttls": + port, err := parseSMTPPort(opts.SMTPPort) + if err != nil { + return err + } + tlsConfig, err := buildTLSConfig(opts) + if err != nil { + return err + } + clientOpts = append(clientOpts, + gomail.WithPort(port), + gomail.WithTLSConfig(tlsConfig), + gomail.WithTLSPolicy(gomail.TLSOpportunistic), + ) + default: + port, err := parseSMTPPort(opts.SMTPPort) + if err != nil { + return err + } + clientOpts = append(clientOpts, + gomail.WithPort(port), + gomail.WithTLSPolicy(gomail.NoTLS), + ) } - host := "localhost" - if opts.Protocol == "smtp+unix" { - host = opts.SMTPAddr + if opts.User != "" { + clientOpts = append(clientOpts, + gomail.WithUsername(opts.User), + gomail.WithPassword(opts.Passwd), + ) } - client, err := smtp.NewClient(conn, host) + + client, err := newGomailClient(host, clientOpts...) if err != nil { - return fmt.Errorf("could not initiate SMTP session: %w", err) + return fmt.Errorf("could not create go-mail client: %w", err) } - - if opts.EnableHelo { - hostname := opts.HeloHostname - if len(hostname) == 0 { - hostname, err = os.Hostname() - if err != nil { - return fmt.Errorf("could not retrieve system hostname: %w", err) - } + defer func() { + if closeErr := client.Close(); closeErr != nil { + log.Error("Closing SMTP client failed: %v", closeErr) } + }() - if err = client.Hello(hostname); err != nil { - return fmt.Errorf("failed to issue HELO command: %w", err) + if opts.User != "" { + hasAuth, authOptions, hasStartTLS, probeErr := probeSMTPServerFunc(client) + if probeErr != nil { + return fmt.Errorf("failed to probe SMTP capabilities: %w", probeErr) } - } - - if opts.Protocol == "smtp+starttls" { - hasStartTLS, _ := client.Extension("STARTTLS") - if hasStartTLS { - if err = client.StartTLS(tlsconfig); err != nil { - return fmt.Errorf("failed to start TLS connection: %w", err) - } - } else { + if protocol == "smtp+starttls" && !hasStartTLS { log.Warn("StartTLS requested, but SMTP server does not support it; falling back to regular SMTP") } - } - - canAuth, options := client.Extension("AUTH") - if len(opts.User) > 0 { - if !canAuth { - return errors.New("SMTP server does not support AUTH, but credentials provided") + if !hasAuth { + return fmt.Errorf("SMTP server does not support AUTH, but credentials provided") } - var auth smtp.Auth - - if strings.Contains(options, "CRAM-MD5") { - auth = smtp.CRAMMD5Auth(opts.User, opts.Passwd) - } else if strings.Contains(options, "PLAIN") { - auth = smtp.PlainAuth("", opts.User, opts.Passwd, host, false) - } else if strings.Contains(options, "LOGIN") { - // Patch for AUTH LOGIN - auth = LoginAuth(opts.User, opts.Passwd) - } else if strings.Contains(options, "NTLM") { - auth = NtlmAuth(opts.User, opts.Passwd) + authOptions = strings.ToUpper(authOptions) + var selectedAuth smtp.Auth + switch { + case strings.Contains(authOptions, "CRAM-MD5"): + selectedAuth = smtp.CRAMMD5Auth(opts.User, opts.Passwd) + case strings.Contains(authOptions, "PLAIN"): + selectedAuth = smtp.PlainAuth("", opts.User, opts.Passwd, authHost, false) + case strings.Contains(authOptions, "LOGIN"): + selectedAuth = LoginAuth(opts.User, opts.Passwd) + case strings.Contains(authOptions, "NTLM"): + selectedAuth = NtlmAuth(opts.User, opts.Passwd) } - if auth != nil { - if err = client.Auth(auth); err != nil { - return fmt.Errorf("failed to authenticate SMTP: %w", err) - } + if selectedAuth != nil { + client.SetSMTPAuthCustom(selectedAuth) + } else if supportsAutoDiscover(authOptions) { + client.SetSMTPAuth(gomail.SMTPAuthAutoDiscover) } } - if opts.OverrideEnvelopeFrom { - if err = client.Mail(opts.EnvelopeFrom); err != nil { - return fmt.Errorf("failed to issue MAIL command: %w", err) - } - } else { - if err = client.Mail(fmt.Sprintf("<%s>", from)); err != nil { - return fmt.Errorf("failed to issue MAIL command: %w", err) - } + if err := client.DialAndSend(mailMsg); err != nil { + return fmt.Errorf("failed to send message via SMTP: %w", err) } - for _, rec := range to { - if err = client.Rcpt(rec); err != nil { - return fmt.Errorf("failed to issue RCPT command: %w", err) - } - } + return nil +} - w, err := client.Data() +func parseSMTPPort(port string) (int, error) { + if port == "" { + return 0, fmt.Errorf("SMTP port is not configured") + } + portNum, err := strconv.Atoi(port) if err != nil { - return fmt.Errorf("failed to issue DATA command: %w", err) - } else if _, err = msg.WriteTo(w); err != nil { - return fmt.Errorf("SMTP write failed: %w", err) - } else if err = w.Close(); err != nil { - return fmt.Errorf("SMTP close failed: %w", err) + return 0, fmt.Errorf("invalid SMTP port %q: %w", port, err) } + return portNum, nil +} - err = client.Quit() +func buildTLSConfig(opts *setting.Mailer) (*tls.Config, error) { + tlsConfig := &tls.Config{ + InsecureSkipVerify: opts.ForceTrustServerCert, + ServerName: opts.SMTPAddr, + } + if opts.UseClientCert { + cert, err := tls.LoadX509KeyPair(opts.ClientCertFile, opts.ClientKeyFile) + if err != nil { + return nil, fmt.Errorf("could not load SMTP client certificate: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + return tlsConfig, nil +} + +func probeSMTPServer(client gomailClient) (bool, string, bool, error) { + smtpClient, err := client.DialToSMTPClientWithContext(context.Background()) if err != nil { - log.Error("Quit client failed: %v", err) + return false, "", false, err } + defer func() { + if closeErr := client.CloseWithSMTPClient(smtpClient); closeErr != nil { + log.Debug("Closing SMTP probe client failed: %v", closeErr) + } + }() - return nil + hasStartTLS, _ := smtpClient.Extension("STARTTLS") + hasAuth, authOptions := smtpClient.Extension("AUTH") + return hasAuth, authOptions, hasStartTLS, nil +} + +func supportsAutoDiscover(options string) bool { + for _, mech := range []string{ + "SCRAM-SHA-256-PLUS", + "SCRAM-SHA-256", + "SCRAM-SHA-1-PLUS", + "SCRAM-SHA-1", + "XOAUTH2", + } { + if strings.Contains(options, mech) { + return true + } + } + return false } diff --git a/services/mailer/sender/smtp_test.go b/services/mailer/sender/smtp_test.go new file mode 100644 index 0000000000000..6aabe2047baaa --- /dev/null +++ b/services/mailer/sender/smtp_test.go @@ -0,0 +1,157 @@ +package sender + +import ( + "context" + "io" + "testing" + + "code.gitea.io/gitea/modules/setting" + + "github.com/stretchr/testify/assert" + gomail "github.com/wneessen/go-mail" + "github.com/wneessen/go-mail/smtp" +) + +type writerToFunc func(io.Writer) (int64, error) + +func (f writerToFunc) WriteTo(w io.Writer) (int64, error) { + return f(w) +} + +type fakeGomailClient struct { + dialAndSendCalled bool + dialAndSendErr error + closeCalled bool + probeCalled bool + setAuthCalled bool + authType gomail.SMTPAuthType + setCustomCalled bool + probeErr error +} + +func (f *fakeGomailClient) Close() error { + f.closeCalled = true + return nil +} + +func (f *fakeGomailClient) DialAndSend(_ ...*gomail.Msg) error { + f.dialAndSendCalled = true + return f.dialAndSendErr +} + +func (f *fakeGomailClient) DialToSMTPClientWithContext(context.Context) (*smtp.Client, error) { + f.probeCalled = true + return &smtp.Client{}, f.probeErr +} + +func (f *fakeGomailClient) CloseWithSMTPClient(*smtp.Client) error { + return nil +} + +func (f *fakeGomailClient) SetSMTPAuth(auth gomail.SMTPAuthType) { + f.setAuthCalled = true + f.authType = auth +} + +func (f *fakeGomailClient) SetSMTPAuthCustom(smtp.Auth) { + f.setCustomCalled = true +} + +func TestSMTPSenderRejectsNonGomailMessage(t *testing.T) { + oldConf := setting.MailService + t.Cleanup(func() { + setting.MailService = oldConf + }) + setting.MailService = &setting.Mailer{ + Protocol: "smtp", + SMTPAddr: "localhost", + SMTPPort: "25", + } + + err := new(SMTPSender).Send("", nil, writerToFunc(func(io.Writer) (int64, error) { + return 0, nil + })) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "unexpected message type") +} + +func TestSMTPSenderAuthUnsupported(t *testing.T) { + fakeClient := &fakeGomailClient{} + overrideClient := func(host string, opts ...gomail.Option) (gomailClient, error) { + return fakeClient, nil + } + oldClientFactory := newGomailClient + oldProbe := probeSMTPServerFunc + t.Cleanup(func() { + newGomailClient = oldClientFactory + probeSMTPServerFunc = oldProbe + }) + newGomailClient = overrideClient + probeSMTPServerFunc = func(gomailClient) (bool, string, bool, error) { + return false, "", false, nil + } + + oldConf := setting.MailService + t.Cleanup(func() { + setting.MailService = oldConf + }) + setting.MailService = &setting.Mailer{ + Protocol: "smtp", + SMTPAddr: "smtp.example.com", + SMTPPort: "25", + User: "user", + Passwd: "pass", + } + + msg := gomail.NewMsg() + msg.SetBodyString("text/plain", "body") + + err := new(SMTPSender).Send("", nil, msg) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "does not support AUTH") + assert.False(t, fakeClient.dialAndSendCalled) + assert.True(t, fakeClient.closeCalled) +} + +func TestSMTPSenderAuthAutoDiscover(t *testing.T) { + fakeClient := &fakeGomailClient{} + overrideClient := func(host string, opts ...gomail.Option) (gomailClient, error) { + return fakeClient, nil + } + oldClientFactory := newGomailClient + oldProbe := probeSMTPServerFunc + t.Cleanup(func() { + newGomailClient = oldClientFactory + probeSMTPServerFunc = oldProbe + }) + newGomailClient = overrideClient + probeSMTPServerFunc = func(gomailClient) (bool, string, bool, error) { + return true, "SCRAM-SHA-256 XOAUTH2", true, nil + } + + oldConf := setting.MailService + t.Cleanup(func() { + setting.MailService = oldConf + }) + setting.MailService = &setting.Mailer{ + Protocol: "smtp+starttls", + SMTPAddr: "smtp.example.com", + SMTPPort: "587", + User: "user", + Passwd: "pass", + } + + msg := gomail.NewMsg() + msg.SetBodyString("text/plain", "body") + + err := new(SMTPSender).Send("", nil, msg) + + assert.NoError(t, err) + assert.True(t, fakeClient.setAuthCalled) + assert.Equal(t, gomail.SMTPAuthAutoDiscover, fakeClient.authType) + assert.False(t, fakeClient.setCustomCalled) + assert.True(t, fakeClient.dialAndSendCalled) + assert.True(t, fakeClient.closeCalled) +}