diff --git a/pkg/common/webclient/client.go b/pkg/common/webclient/client.go new file mode 100644 index 00000000..5369c8e0 --- /dev/null +++ b/pkg/common/webclient/client.go @@ -0,0 +1,107 @@ +package webclient + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" +) + +// client is a wrapper around http.Client with common notification service functionality +type client struct { + headers http.Header + indent string + httpClient http.Client + parse ParserFunc + write WriterFunc +} + +// SetParser overrides the parser for the incoming response content +func (c *client) SetParser(parse ParserFunc) { + c.parse = parse +} + +// SetWriter overrides the writer for the outgoing request content +func (c *client) SetWriter(write WriterFunc) { + c.write = write +} + +// Headers return the default headers for requests +func (c *client) Headers() http.Header { + return c.headers +} + +// HTTPClient returns the underlying http.WebClient used by the WebClient +func (c *client) HTTPClient() *http.Client { + return &c.httpClient +} + +// Get fetches url using GET and unmarshals into the passed response +func (c *client) Get(url string, response interface{}) error { + return c.request(http.MethodGet, url, response, nil) +} + +// Post sends a serialized representation of request and deserializes the result into response +func (c *client) Post(url string, request interface{}, response interface{}) error { + body, err := c.write(request) + if err != nil { + return fmt.Errorf("error creating payload: %v", err) + } + + return c.request(http.MethodPost, url, response, bytes.NewReader(body)) +} + +// ErrorResponse tries to deserialize any response body into the supplied struct, returning whether successful or not +func (c *client) ErrorResponse(err error, response interface{}) bool { + jerr, isWebError := err.(ClientError) + if !isWebError { + return false + } + + return c.parse([]byte(jerr.Body), response) == nil +} + +func (c *client) request(method, url string, response interface{}, payload io.Reader) error { + req, err := http.NewRequest(method, url, payload) + if err != nil { + return err + } + + for key, val := range c.headers { + req.Header.Set(key, val[0]) + } + + res, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("error sending payload: %v", err) + } + + return c.parseResponse(res, response) +} + +func (c *client) parseResponse(res *http.Response, response interface{}) error { + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + + if res.StatusCode >= 400 { + err = fmt.Errorf("got HTTP %v", res.Status) + } + + if err == nil { + err = c.parse(body, response) + } + + if err != nil { + if body == nil { + body = []byte{} + } + return ClientError{ + StatusCode: res.StatusCode, + Body: string(body), + err: err, + } + } + + return nil +} diff --git a/pkg/util/jsonclient/jsonclient_test.go b/pkg/common/webclient/client_test.go similarity index 57% rename from pkg/util/jsonclient/jsonclient_test.go rename to pkg/common/webclient/client_test.go index 23dbf8d9..2dc27f82 100644 --- a/pkg/util/jsonclient/jsonclient_test.go +++ b/pkg/common/webclient/client_test.go @@ -1,22 +1,16 @@ -package jsonclient_test +package webclient_test import ( "errors" - "github.com/containrrr/shoutrrr/pkg/util/jsonclient" + "github.com/containrrr/shoutrrr/pkg/common/webclient" "github.com/onsi/gomega/ghttp" "net/http" - "testing" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -func TestJSONClient(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "JSONClient Suite") -} - -var _ = Describe("JSONClient", func() { +var _ = Describe("WebClient", func() { var server *ghttp.Server BeforeEach(func() { @@ -27,7 +21,7 @@ var _ = Describe("JSONClient", func() { It("should return an error", func() { server.AppendHandlers(ghttp.RespondWith(http.StatusOK, "invalid json")) res := &mockResponse{} - err := jsonclient.Get(server.URL(), &res) + err := webclient.GetJSON(server.URL(), &res) Expect(server.ReceivedRequests()).Should(HaveLen(1)) Expect(err).To(MatchError("invalid character 'i' looking for beginning of value")) Expect(res.Status).To(BeEmpty()) @@ -38,7 +32,7 @@ var _ = Describe("JSONClient", func() { It("should return an error", func() { server.AppendHandlers(ghttp.RespondWith(http.StatusOK, nil)) res := &mockResponse{} - err := jsonclient.Get(server.URL(), &res) + err := webclient.GetJSON(server.URL(), &res) Expect(server.ReceivedRequests()).Should(HaveLen(1)) Expect(err).To(MatchError("unexpected end of JSON input")) Expect(res.Status).To(BeEmpty()) @@ -48,7 +42,47 @@ var _ = Describe("JSONClient", func() { It("should deserialize GET response", func() { server.AppendHandlers(ghttp.RespondWithJSONEncoded(http.StatusOK, mockResponse{Status: "OK"})) res := &mockResponse{} - err := jsonclient.Get(server.URL(), &res) + err := webclient.GetJSON(server.URL(), &res) + Expect(server.ReceivedRequests()).Should(HaveLen(1)) + Expect(err).ToNot(HaveOccurred()) + Expect(res.Status).To(Equal("OK")) + }) + + It("should update the parser and writer", func() { + client := webclient.NewJSONClient() + client.SetParser(func(raw []byte, v interface{}) error { + return errors.New(`mock parser`) + }) + server.AppendHandlers(ghttp.RespondWithJSONEncoded(http.StatusOK, mockResponse{Status: "OK"})) + err := client.Get(server.URL(), nil) + Expect(err).To(MatchError(`mock parser`)) + + client.SetWriter(func(v interface{}) ([]byte, error) { + return nil, errors.New(`mock writer`) + }) + err = client.Post(server.URL(), nil, nil) + Expect(err).To(MatchError(`error creating payload: mock writer`)) + }) + + It("should unwrap serialized error responses", func() { + client := webclient.NewJSONClient() + err := webclient.ClientError{Body: `{"Status": "BadStuff"}`} + res := &mockResponse{} + Expect(client.ErrorResponse(err, res)).To(BeTrue()) + Expect(res.Status).To(Equal(`BadStuff`)) + }) + + It("should send any additional headers that has been added", func() { + server.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyHeaderKV(`Authentication`, `you don't need to see my identification`), + ghttp.RespondWithJSONEncoded(http.StatusOK, mockResponse{Status: "OK"}), + ), + ) + client := webclient.NewJSONClient() + client.Headers().Set(`Authentication`, `you don't need to see my identification`) + res := &mockResponse{} + err := client.Get(server.URL(), &res) Expect(server.ReceivedRequests()).Should(HaveLen(1)) Expect(err).ToNot(HaveOccurred()) Expect(res.Status).To(Equal("OK")) @@ -66,7 +100,7 @@ var _ = Describe("JSONClient", func() { ghttp.RespondWithJSONEncoded(http.StatusOK, &mockResponse{Status: "That's Numberwang!"})), ) - err := jsonclient.Post(server.URL(), &req, &res) + err := webclient.PostJSON(server.URL(), &req, &res) Expect(server.ReceivedRequests()).Should(HaveLen(1)) Expect(err).ToNot(HaveOccurred()) Expect(res.Status).To(Equal("That's Numberwang!")) @@ -74,14 +108,14 @@ var _ = Describe("JSONClient", func() { It("should return error on error status responses", func() { server.AppendHandlers(ghttp.RespondWith(404, "Not found!")) - err := jsonclient.Post(server.URL(), &mockRequest{}, &mockResponse{}) + err := webclient.PostJSON(server.URL(), &mockRequest{}, &mockResponse{}) Expect(server.ReceivedRequests()).Should(HaveLen(1)) Expect(err).To(MatchError("got HTTP 404 Not Found")) }) It("should return error on invalid request", func() { server.AppendHandlers(ghttp.VerifyRequest("POST", "/")) - err := jsonclient.Post(server.URL(), func() {}, &mockResponse{}) + err := webclient.PostJSON(server.URL(), func() {}, &mockResponse{}) Expect(server.ReceivedRequests()).Should(HaveLen(0)) Expect(err).To(MatchError("error creating payload: json: unsupported type: func()")) }) @@ -93,10 +127,10 @@ var _ = Describe("JSONClient", func() { ghttp.RespondWithJSONEncoded(http.StatusOK, res)), ) - err := jsonclient.Post(server.URL(), nil, &[]bool{}) + err := webclient.PostJSON(server.URL(), nil, &[]bool{}) Expect(server.ReceivedRequests()).Should(HaveLen(1)) Expect(err).To(MatchError("json: cannot unmarshal object into Go value of type []bool")) - Expect(jsonclient.ErrorBody(err)).To(MatchJSON(`{"Status":"cool skirt"}`)) + Expect(webclient.ErrorBody(err)).To(MatchJSON(`{"Status":"cool skirt"}`)) }) }) @@ -106,24 +140,24 @@ var _ = Describe("JSONClient", func() { }) }) -var _ = Describe("Error", func() { +var _ = Describe("ClientError", func() { When("no internal error has been set", func() { It("should return a generic message with status code", func() { - errorWithNoError := jsonclient.Error{StatusCode: http.StatusEarlyHints} + errorWithNoError := webclient.ClientError{StatusCode: http.StatusEarlyHints} Expect(errorWithNoError.String()).To(Equal("unknown error (HTTP 103)")) }) }) Describe("ErrorBody", func() { When("passed a non-json error", func() { It("should return an empty string", func() { - Expect(jsonclient.ErrorBody(errors.New("unrelated error"))).To(BeEmpty()) + Expect(webclient.ErrorBody(errors.New("unrelated error"))).To(BeEmpty()) }) }) - When("passed a jsonclient.Error", func() { + When("passed a jsonclient.ClientError", func() { It("should return the request body from that error", func() { errorBody := `{"error": "bad user"}` - jsonError := jsonclient.Error{Body: errorBody} - Expect(jsonclient.ErrorBody(jsonError)).To(MatchJSON(errorBody)) + jsonError := webclient.ClientError{Body: errorBody} + Expect(webclient.ErrorBody(jsonError)).To(MatchJSON(errorBody)) }) }) }) diff --git a/pkg/util/jsonclient/error.go b/pkg/common/webclient/error.go similarity index 50% rename from pkg/util/jsonclient/error.go rename to pkg/common/webclient/error.go index 7317f8ae..56dc5581 100644 --- a/pkg/util/jsonclient/error.go +++ b/pkg/common/webclient/error.go @@ -1,28 +1,28 @@ -package jsonclient +package webclient import "fmt" -// Error contains additional http/JSON details -type Error struct { +// ClientError contains additional http/JSON details +type ClientError struct { StatusCode int Body string err error } -func (je Error) Error() string { +func (je ClientError) Error() string { return je.String() } -func (je Error) String() string { +func (je ClientError) String() string { if je.err == nil { return fmt.Sprintf("unknown error (HTTP %v)", je.StatusCode) } return je.err.Error() } -// ErrorBody returns the request body from an Error +// ErrorBody returns the request body from a ClientError func ErrorBody(e error) string { - if jsonError, ok := e.(Error); ok { + if jsonError, ok := e.(ClientError); ok { return jsonError.Body } return "" diff --git a/pkg/util/jsonclient/interface.go b/pkg/common/webclient/interface.go similarity index 54% rename from pkg/util/jsonclient/interface.go rename to pkg/common/webclient/interface.go index 7316026c..86a95810 100644 --- a/pkg/util/jsonclient/interface.go +++ b/pkg/common/webclient/interface.go @@ -1,10 +1,16 @@ -package jsonclient +package webclient -import "net/http" +import ( + "net/http" +) -type Client interface { +// WebClient ... +type WebClient interface { Get(url string, response interface{}) error Post(url string, request interface{}, response interface{}) error Headers() http.Header ErrorResponse(err error, response interface{}) bool + SetParser(ParserFunc) + SetWriter(WriterFunc) + HTTPClient() *http.Client } diff --git a/pkg/common/webclient/json.go b/pkg/common/webclient/json.go new file mode 100644 index 00000000..699405b9 --- /dev/null +++ b/pkg/common/webclient/json.go @@ -0,0 +1,37 @@ +package webclient + +import ( + "encoding/json" + "net/http" +) + +// JSONContentType is the default mime type for JSON +const JSONContentType = "application/json" + +// DefaultJSONClient is the singleton instance of WebClient using http.DefaultClient +var DefaultJSONClient = NewJSONClient() + +// GetJSON fetches url using GET and unmarshals into the passed response using DefaultJSONClient +func GetJSON(url string, response interface{}) error { + return DefaultJSONClient.Get(url, response) +} + +// PostJSON sends request as JSON and unmarshals the response JSON into the supplied struct using DefaultJSONClient +func PostJSON(url string, request interface{}, response interface{}) error { + return DefaultJSONClient.Post(url, request, response) +} + +// NewJSONClient returns a WebClient using the default http.Client and JSON serialization +func NewJSONClient() WebClient { + var c client + c = client{ + headers: http.Header{ + "Content-Type": []string{JSONContentType}, + }, + parse: json.Unmarshal, + write: func(v interface{}) ([]byte, error) { + return json.MarshalIndent(v, "", c.indent) + }, + } + return &c +} diff --git a/pkg/common/webclient/service.go b/pkg/common/webclient/service.go new file mode 100644 index 00000000..86dd3f19 --- /dev/null +++ b/pkg/common/webclient/service.go @@ -0,0 +1,75 @@ +package webclient + +import ( + "crypto/tls" + "crypto/x509" + "encoding/json" + "github.com/containrrr/shoutrrr/pkg/types" + "net/http" +) + +// ParserFunc are functions that deserialize a struct from the passed bytes +type ParserFunc func(raw []byte, v interface{}) error + +// WriterFunc are functions that serialize the passed struct into a byte stream +type WriterFunc func(v interface{}) ([]byte, error) + +var _ types.TLSClient = &ClientService{} +var _ types.HTTPService = &ClientService{} + +// ClientService is a Composable that adds a generic web request client to the service +type ClientService struct { + client *client + certPool *x509.CertPool +} + +// HTTPClient returns the underlying http.WebClient used in the Service +func (s *ClientService) HTTPClient() *http.Client { + s.Initialize() + return s.client.HTTPClient() +} + +// WebClient returns the WebClient instance, initializing it if necessary +func (s *ClientService) WebClient() WebClient { + s.Initialize() + return s.client +} + +// Initialize sets up the WebClient in the default state using JSON serialization and headers +func (s *ClientService) Initialize() { + if s.client != nil { + return + } + + s.client = &client{ + httpClient: http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{}, + }, + }, + headers: http.Header{ + "Content-Type": []string{JSONContentType}, + }, + parse: json.Unmarshal, + write: func(v interface{}) ([]byte, error) { + return json.MarshalIndent(v, "", s.client.indent) + }, + } +} + +// AddTrustedRootCertificate adds the specified PEM certificate to the pool of trusted root CAs +func (s *ClientService) AddTrustedRootCertificate(caPEM []byte) bool { + s.Initialize() + if s.certPool == nil { + certPool, err := x509.SystemCertPool() + if err != nil { + certPool = x509.NewCertPool() + } + s.certPool = certPool + if tp, ok := s.client.httpClient.Transport.(*http.Transport); ok { + tp.TLSClientConfig.RootCAs = s.certPool + } + } + + return s.certPool.AppendCertsFromPEM(caPEM) +} diff --git a/pkg/common/webclient/service_test.go b/pkg/common/webclient/service_test.go new file mode 100644 index 00000000..59c0f143 --- /dev/null +++ b/pkg/common/webclient/service_test.go @@ -0,0 +1,47 @@ +package webclient_test + +import ( + "github.com/containrrr/shoutrrr/pkg/common/webclient" + "net/http" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ClientService", func() { + + When("getting the web client from an empty service", func() { + It("should return an initialized web client", func() { + service := &webclient.ClientService{} + Expect(service.WebClient()).ToNot(BeNil()) + }) + }) + + When("getting the http client from an empty service", func() { + It("should return an initialized http client", func() { + service := &webclient.ClientService{} + Expect(service.HTTPClient()).ToNot(BeNil()) + }) + }) + + When("no certs have been added", func() { + It("should use nil as the certificate pool", func() { + service := &webclient.ClientService{} + tp := service.HTTPClient().Transport.(*http.Transport) + Expect(tp.TLSClientConfig.RootCAs).To(BeNil()) + }) + }) + + When("a custom cert have been added", func() { + It("should use a custom certificate pool", func() { + service := &webclient.ClientService{} + + // Adding an empty cert should fail + addedOk := service.AddTrustedRootCertificate([]byte{}) + Expect(addedOk).To(BeFalse()) + + tp := service.HTTPClient().Transport.(*http.Transport) + Expect(tp.TLSClientConfig.RootCAs).ToNot(BeNil()) + }) + }) +}) diff --git a/pkg/common/webclient/webclient_suite_test.go b/pkg/common/webclient/webclient_suite_test.go new file mode 100644 index 00000000..c8a859bf --- /dev/null +++ b/pkg/common/webclient/webclient_suite_test.go @@ -0,0 +1,13 @@ +package webclient_test + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "testing" +) + +func TestWebClient(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "WebClient Suite") +} diff --git a/pkg/services/discord/discord.go b/pkg/services/discord/discord.go index ae1bc6cc..e0e34b99 100644 --- a/pkg/services/discord/discord.go +++ b/pkg/services/discord/discord.go @@ -94,11 +94,9 @@ func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) e return err } - if err := service.config.SetURL(configURL); err != nil { - return err - } + err := service.config.SetURL(configURL) - return nil + return err } // CreateAPIURLFromConfig takes a discord config object and creates a post url diff --git a/pkg/services/gotify/gotify.go b/pkg/services/gotify/gotify.go index 1a2024e6..8682b4ac 100644 --- a/pkg/services/gotify/gotify.go +++ b/pkg/services/gotify/gotify.go @@ -1,15 +1,13 @@ package gotify import ( - "bytes" - "crypto/tls" - "encoding/json" "fmt" "net/http" "net/url" "strings" "time" + "github.com/containrrr/shoutrrr/pkg/common/webclient" "github.com/containrrr/shoutrrr/pkg/format" "github.com/containrrr/shoutrrr/pkg/services/standard" "github.com/containrrr/shoutrrr/pkg/types" @@ -18,9 +16,9 @@ import ( // Service providing Gotify as a notification service type Service struct { standard.Standard + webclient.ClientService config *Config pkr format.PropKeyResolver - Client *http.Client } // Initialize loads ServiceConfig from configURL and sets logger for this Service @@ -32,19 +30,19 @@ func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) e service.pkr = format.NewPropKeyResolver(service.config) err := service.config.SetURL(configURL) - service.Client = &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - // If DisableTLS is specified, we might still need to disable TLS verification - // since the default configuration of Gotify redirects HTTP to HTTPS - // Note that this cannot be overridden using params, only using the config URL - InsecureSkipVerify: service.config.DisableTLS, - }, - }, - // Set a reasonable timeout to prevent one bad transfer from block all subsequent ones - Timeout: 10 * time.Second, + client := service.HTTPClient() + if service.config.DisableTLS { + // If DisableTLS is specified, we might still need to disable TLS verification + // since the default configuration of Gotify redirects HTTP to HTTPS + // Note that this cannot be overridden using params, only using the config URL + if tp, ok := client.Transport.(*http.Transport); ok { + tp.TLSClientConfig.InsecureSkipVerify = true + } } + // Set a reasonable timeout to prevent one bad transfer from blocking all subsequent ones + client.Timeout = 10 * time.Second + return err } @@ -96,24 +94,22 @@ func (service *Service) Send(message string, params *types.Params) error { if err != nil { return err } - jsonBody, err := json.Marshal(JSON{ + + res := JSON{} + req := JSON{ Message: message, Title: config.Title, Priority: config.Priority, - }) - if err != nil { - return err } - jsonBuffer := bytes.NewBuffer(jsonBody) - resp, err := service.Client.Post(postURL, "application/json", jsonBuffer) + + err = service.WebClient().Post(postURL, &req, &res) if err != nil { + errorRes := errorResponse{} + if service.WebClient().ErrorResponse(err, errorRes) { + err = errorRes + } return fmt.Errorf("failed to send notification to Gotify: %s", err) } - defer resp.Body.Close() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("Gotify notification returned %d HTTP status code", resp.StatusCode) - } return nil } diff --git a/pkg/services/gotify/gotify_json.go b/pkg/services/gotify/gotify_json.go index 4a916361..45c40af1 100644 --- a/pkg/services/gotify/gotify_json.go +++ b/pkg/services/gotify/gotify_json.go @@ -6,3 +6,13 @@ type JSON struct { Title string `json:"title"` Priority int `json:"priority"` } + +type errorResponse struct { + HTTPError string `json:"error"` + HTTPErrorCode string `json:"errorCode"` + Description string `json:"errorDescription"` +} + +func (e errorResponse) Error() string { + return e.Description +} diff --git a/pkg/services/gotify/gotify_test.go b/pkg/services/gotify/gotify_test.go index 262c244c..fe9d6def 100644 --- a/pkg/services/gotify/gotify_test.go +++ b/pkg/services/gotify/gotify_test.go @@ -119,11 +119,11 @@ var _ = Describe("the Gotify plugin URL building and token validation functions" It("should not report an error if the server accepts the payload", func() { serviceURL, _ := url.Parse("gotify://my.gotify.tld/Aaa.bbb.ccc.ddd") err = service.Initialize(serviceURL, logger) - httpmock.ActivateNonDefault(service.Client) + httpmock.ActivateNonDefault(service.HTTPClient()) Expect(err).NotTo(HaveOccurred()) targetURL := "https://my.gotify.tld/message?token=Aaa.bbb.ccc.ddd" - httpmock.RegisterResponder("POST", targetURL, httpmock.NewStringResponder(200, "")) + httpmock.RegisterResponder("POST", targetURL, httpmock.NewStringResponder(200, `{"id":"1"}`)) err = service.Send("Message", nil) Expect(err).NotTo(HaveOccurred()) @@ -131,7 +131,7 @@ var _ = Describe("the Gotify plugin URL building and token validation functions" It("should not panic if an error occurs when sending the payload", func() { serviceURL, _ := url.Parse("gotify://my.gotify.tld/Aaa.bbb.ccc.ddd") err = service.Initialize(serviceURL, logger) - httpmock.ActivateNonDefault(service.Client) + httpmock.ActivateNonDefault(service.HTTPClient()) Expect(err).NotTo(HaveOccurred()) targetURL := "https://my.gotify.tld/message?token=Aaa.bbb.ccc.ddd" diff --git a/pkg/services/pushbullet/pushbullet.go b/pkg/services/pushbullet/pushbullet.go index 0f2062b2..84ef4d9d 100644 --- a/pkg/services/pushbullet/pushbullet.go +++ b/pkg/services/pushbullet/pushbullet.go @@ -2,10 +2,10 @@ package pushbullet import ( "fmt" + "github.com/containrrr/shoutrrr/pkg/common/webclient" "github.com/containrrr/shoutrrr/pkg/format" "github.com/containrrr/shoutrrr/pkg/services/standard" "github.com/containrrr/shoutrrr/pkg/types" - "github.com/containrrr/shoutrrr/pkg/util/jsonclient" "net/url" ) @@ -16,7 +16,7 @@ const ( // Service providing Pushbullet as a notification service type Service struct { standard.Standard - client jsonclient.Client + webclient.ClientService config *Config pkr format.PropKeyResolver } @@ -31,8 +31,7 @@ func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) e return err } - service.client = jsonclient.NewClient() - service.client.Headers().Set("Access-Token", service.config.Token) + service.WebClient().Headers().Set("Access-Token", service.config.Token) return nil } @@ -45,14 +44,14 @@ func (service *Service) Send(message string, params *types.Params) error { } for _, target := range config.Targets { - if err := doSend(&config, target, message, service.client); err != nil { + if err := doSend(&config, target, message, service.WebClient()); err != nil { return err } } return nil } -func doSend(config *Config, target string, message string, client jsonclient.Client) error { +func doSend(config *Config, target string, message string, client webclient.WebClient) error { push := NewNotePush(message, config.Title) push.SetTarget(target) diff --git a/pkg/services/pushbullet/pushbullet_test.go b/pkg/services/pushbullet/pushbullet_test.go index 03c68529..ef01f053 100644 --- a/pkg/services/pushbullet/pushbullet_test.go +++ b/pkg/services/pushbullet/pushbullet_test.go @@ -1,13 +1,14 @@ package pushbullet_test import ( - "errors" - . "github.com/containrrr/shoutrrr/pkg/services/pushbullet" + "github.com/containrrr/shoutrrr/pkg/services/pushbullet" "github.com/containrrr/shoutrrr/pkg/util" "github.com/jarcoal/httpmock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + + "errors" "net/url" "os" "testing" @@ -19,14 +20,14 @@ func TestPushbullet(t *testing.T) { } var ( - service *Service + service *pushbullet.Service envPushbulletURL *url.URL ) var _ = Describe("the pushbullet service", func() { BeforeSuite(func() { - service = &Service{} + service = &pushbullet.Service{} envPushbulletURL, _ = url.Parse(os.Getenv("SHOUTRRR_PUSHBULLET_URL")) }) @@ -49,7 +50,7 @@ var _ = Describe("the pushbullet service", func() { When("generating a config object", func() { It("should set token", func() { pushbulletURL, _ := url.Parse("pushbullet://tokentokentokentokentokentokentoke") - config := Config{} + config := pushbullet.Config{} err := config.SetURL(pushbulletURL) Expect(config.Token).To(Equal("tokentokentokentokentokentokentoke")) @@ -57,7 +58,7 @@ var _ = Describe("the pushbullet service", func() { }) It("should set the device from path", func() { pushbulletURL, _ := url.Parse("pushbullet://tokentokentokentokentokentokentoke/test") - config := Config{} + config := pushbullet.Config{} err := config.SetURL(pushbulletURL) Expect(err).NotTo(HaveOccurred()) @@ -66,7 +67,7 @@ var _ = Describe("the pushbullet service", func() { }) It("should set the channel from path", func() { pushbulletURL, _ := url.Parse("pushbullet://tokentokentokentokentokentokentoke/foo#bar") - config := Config{} + config := pushbullet.Config{} err := config.SetURL(pushbulletURL) Expect(err).NotTo(HaveOccurred()) @@ -79,7 +80,7 @@ var _ = Describe("the pushbullet service", func() { It("should be identical after de-/serialization", func() { testURL := "pushbullet://tokentokentokentokentokentokentoke/device?title=Great+News" - config := &Config{} + config := &pushbullet.Config{} err := config.SetURL(util.URLMust(testURL)) Expect(err).NotTo(HaveOccurred(), "verifying") @@ -92,21 +93,21 @@ var _ = Describe("the pushbullet service", func() { Describe("building the payload", func() { It("Email target should only populate one the correct field", func() { - push := PushRequest{} + push := pushbullet.PushRequest{} push.SetTarget("iam@email.com") Expect(push.Email).To(Equal("iam@email.com")) Expect(push.DeviceIden).To(BeEmpty()) Expect(push.ChannelTag).To(BeEmpty()) }) It("Device target should only populate one the correct field", func() { - push := PushRequest{} + push := pushbullet.PushRequest{} push.SetTarget("device") Expect(push.Email).To(BeEmpty()) Expect(push.DeviceIden).To(Equal("device")) Expect(push.ChannelTag).To(BeEmpty()) }) It("Channel target should only populate one the correct field", func() { - push := PushRequest{} + push := pushbullet.PushRequest{} push.SetTarget("#channel") Expect(push.Email).To(BeEmpty()) Expect(push.DeviceIden).To(BeEmpty()) @@ -118,7 +119,7 @@ var _ = Describe("the pushbullet service", func() { var err error targetURL := "https://api.pushbullet.com/v2/pushes" BeforeEach(func() { - httpmock.Activate() + httpmock.ActivateNonDefault(service.HTTPClient()) }) AfterEach(func() { httpmock.DeactivateAndReset() @@ -127,7 +128,7 @@ var _ = Describe("the pushbullet service", func() { err = initService("pushbullet://tokentokentokentokentokentokentoke/test") Expect(err).NotTo(HaveOccurred()) - responder, _ := httpmock.NewJsonResponder(200, &PushResponse{}) + responder, _ := httpmock.NewJsonResponder(200, &pushbullet.PushResponse{}) httpmock.RegisterResponder("POST", targetURL, responder) err = service.Send("Message", nil) diff --git a/pkg/services/services_test.go b/pkg/services/services_test.go index 2add3172..5f93a91e 100644 --- a/pkg/services/services_test.go +++ b/pkg/services/services_test.go @@ -2,7 +2,6 @@ package services_test import ( "github.com/containrrr/shoutrrr/pkg/router" - "github.com/containrrr/shoutrrr/pkg/services/gotify" "github.com/containrrr/shoutrrr/pkg/types" "github.com/jarcoal/httpmock" "log" @@ -41,6 +40,8 @@ var serviceURLs = map[string]string{ var serviceResponses = map[string]string{ "pushbullet": `{"created": 0}`, + "telegram": `{"ok": true}`, + "gotify": `{"id": 0}`, } var logger = log.New(GinkgoWriter, "Test", log.LstdFlags) @@ -78,6 +79,7 @@ var _ = Describe("services", func() { } httpmock.Activate() + // Always return an "OK" result, as the http request isn't what is under test respStatus := http.StatusOK if key == "discord" || key == "ifttt" { @@ -88,9 +90,8 @@ var _ = Describe("services", func() { service, err := serviceRouter.Locate(configURL) Expect(err).NotTo(HaveOccurred()) - if key == "gotify" { - gotifyService := service.(*gotify.Service) - httpmock.ActivateNonDefault(gotifyService.Client) + if httpService, isHTTPService := service.(types.HTTPService); isHTTPService { + httpmock.ActivateNonDefault(httpService.HTTPClient()) } err = service.Send("test", (*types.Params)(&map[string]string{ diff --git a/pkg/services/slack/slack.go b/pkg/services/slack/slack.go index 926cac3c..6c649744 100644 --- a/pkg/services/slack/slack.go +++ b/pkg/services/slack/slack.go @@ -1,15 +1,11 @@ package slack import ( - "bytes" - "encoding/json" "fmt" - "github.com/containrrr/shoutrrr/pkg/format" - "github.com/containrrr/shoutrrr/pkg/util/jsonclient" - "io/ioutil" - "net/http" "net/url" + "github.com/containrrr/shoutrrr/pkg/common/webclient" + "github.com/containrrr/shoutrrr/pkg/format" "github.com/containrrr/shoutrrr/pkg/services/standard" "github.com/containrrr/shoutrrr/pkg/types" ) @@ -17,6 +13,7 @@ import ( // Service sends notifications to a pre-configured channel or user type Service struct { standard.Standard + webclient.ClientService config *Config pkr format.PropKeyResolver } @@ -55,16 +52,25 @@ func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) e service.config = &Config{} service.pkr = format.NewPropKeyResolver(service.config) - return service.config.setURL(&service.pkr, configURL) + if err := service.config.setURL(&service.pkr, configURL); err != nil { + return err + } + + client := service.WebClient() + if service.config.Token.IsAPIToken() { + client.Headers().Set("Authorization", service.config.Token.Authorization()) + } else { + client.SetParser(parseWebhookResponse) + } + + return nil } func (service *Service) sendAPI(config *Config, payload interface{}) error { response := APIResponse{} - jsonClient := jsonclient.NewClient() - jsonClient.Headers().Set("Authorization", config.Token.Authorization()) - if err := jsonClient.Post(apiPostMessage, payload, &response); err != nil { + if err := service.WebClient().Post(apiPostMessage, payload, &response); err != nil { return err } @@ -83,22 +89,15 @@ func (service *Service) sendAPI(config *Config, payload interface{}) error { } func (service *Service) sendWebhook(config *Config, payload interface{}) error { - payloadBytes, err := json.Marshal(payload) - var res *http.Response - res, err = http.Post(config.Token.WebhookURL(), jsonclient.ContentType, bytes.NewBuffer(payloadBytes)) + var response *string + err := service.WebClient().Post(config.Token.WebhookURL(), payload, &response) if err != nil { return fmt.Errorf("failed to invoke webhook: %v", err) } - defer res.Body.Close() - resBytes, _ := ioutil.ReadAll(res.Body) - response := string(resBytes) - switch response { + switch *response { case "": - if res.StatusCode != http.StatusOK { - return fmt.Errorf("webhook status: %v", res.Status) - } // Treat status 200 as no error regardless of actual content fallthrough case "ok": diff --git a/pkg/services/slack/slack_config.go b/pkg/services/slack/slack_config.go index 76e81a22..069ffa8d 100644 --- a/pkg/services/slack/slack_config.go +++ b/pkg/services/slack/slack_config.go @@ -11,13 +11,13 @@ import ( // Config for the slack service type Config struct { standard.EnumlessConfig - BotName string `optional:"uses bot default" key:"botname,username" desc:"Bot name"` - Icon string `key:"icon,icon_emoji,icon_url" default:"" optional:"" desc:"Use emoji or URL as icon (based on presence of http(s):// prefix)"` - Token Token `desc:"API Bot token" url:"user,pass"` - Color string `key:"color" optional:"default border color" desc:"Message left-hand border color"` - Title string `key:"title" optional:"omitted" desc:"Prepended text above the message"` - Channel string `url:"host" desc:"Channel to send messages to in Cxxxxxxxxxx format"` - ThreadTS string `key:"thread_ts" optional:"" desc:"ts value of the parent message (to send message as reply in thread)"` + BotName string `optional:"uses bot default" key:"botname,username" desc:"Bot name"` + Icon string `key:"icon,icon_emoji,icon_url" default:"" optional:"" desc:"Use emoji or URL as icon (based on presence of http(s):// prefix)"` + Token Token `desc:"API- or Webhook token" url:"user,pass"` + Color string `key:"color" optional:"default border color" desc:"Message left-hand border color"` + Title string `key:"title" optional:"omitted" desc:"Prepended text above the message"` + Channel string `url:"host" desc:"Channel to send messages to in Cxxxxxxxxxx format (ignored for webhooks)"` + ThreadTS string `key:"thread_ts" optional:"" desc:"ts value of the parent message (to send message as reply in thread)"` } // GetURL returns a URL representation of it's current field values diff --git a/pkg/services/slack/slack_test.go b/pkg/services/slack/slack_test.go index aee8a158..2e1ec2ae 100644 --- a/pkg/services/slack/slack_test.go +++ b/pkg/services/slack/slack_test.go @@ -31,10 +31,12 @@ var ( var _ = Describe("the slack service", func() { BeforeSuite(func() { - service = &Service{} logger = log.New(GinkgoWriter, "Test", log.LstdFlags) envSlackURL, _ = url.Parse(os.Getenv("SHOUTRRR_SLACK_URL")) }) + BeforeEach(func() { + service = &Service{} + }) When("running integration tests", func() { It("should not error out", func() { @@ -169,7 +171,7 @@ var _ = Describe("the slack service", func() { When("sending via webhook URL", func() { var err error BeforeEach(func() { - httpmock.Activate() + httpmock.ActivateNonDefault(service.HTTPClient()) }) AfterEach(func() { httpmock.DeactivateAndReset() @@ -181,7 +183,7 @@ var _ = Describe("the slack service", func() { Expect(err).NotTo(HaveOccurred()) targetURL := "https://hooks.slack.com/services/AAAAAAAAA/BBBBBBBBB/123456789123456789123456" - httpmock.RegisterResponder("POST", targetURL, httpmock.NewStringResponder(200, "")) + httpmock.RegisterResponder("POST", targetURL, httpmock.NewStringResponder(200, ``)) err = service.Send("Message", nil) Expect(err).NotTo(HaveOccurred()) @@ -202,6 +204,7 @@ var _ = Describe("the slack service", func() { var err error BeforeEach(func() { httpmock.Activate() + httpmock.ActivateNonDefault(service.HTTPClient()) }) AfterEach(func() { httpmock.DeactivateAndReset() diff --git a/pkg/services/slack/slack_webhook.go b/pkg/services/slack/slack_webhook.go new file mode 100644 index 00000000..cc8a4eeb --- /dev/null +++ b/pkg/services/slack/slack_webhook.go @@ -0,0 +1,8 @@ +package slack + +func parseWebhookResponse(raw []byte, v interface{}) error { + var res = v.(**string) + s := string(raw) + *res = &s + return nil +} diff --git a/pkg/services/smtp/smtp.go b/pkg/services/smtp/smtp.go index 191ae6b9..0d6c94b9 100644 --- a/pkg/services/smtp/smtp.go +++ b/pkg/services/smtp/smtp.go @@ -2,8 +2,10 @@ package smtp import ( "crypto/tls" + "crypto/x509" "fmt" "github.com/containrrr/shoutrrr/pkg/format" + "github.com/containrrr/shoutrrr/pkg/util" "io" "math/rand" "net" @@ -22,6 +24,7 @@ type Service struct { config *Config multipartBoundary string propKeyResolver format.PropKeyResolver + certPool *x509.CertPool } const ( @@ -30,6 +33,8 @@ const ( contentMultipart = "multipart/alternative; boundary=%s" ) +var _ types.TLSClient = &Service{} + // Initialize loads ServiceConfig from configURL and sets logger for this Service func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) error { service.Logger.SetLogger(logger) @@ -64,7 +69,7 @@ func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) e // Send a notification message to e-mail recipients func (service *Service) Send(message string, params *types.Params) error { - client, err := getClientConnection(service.config) + client, err := getClientConnection(service.config, service.certPool) if err != nil { return fail(FailGetSMTPClient, err) } @@ -77,7 +82,20 @@ func (service *Service) Send(message string, params *types.Params) error { return service.doSend(client, message, &config) } -func getClientConnection(config *Config) (*smtp.Client, error) { +// AddTrustedRootCertificate adds the specified PEM certificate to the pool of trusted root CAs +func (service *Service) AddTrustedRootCertificate(caPEM []byte) bool { + if service.certPool == nil { + certPool, err := x509.SystemCertPool() + if err != nil { + service.Logf(`error getting system certs: %v`, err) + certPool = x509.NewCertPool() + } + service.certPool = certPool + } + return service.certPool.AppendCertsFromPEM(caPEM) +} + +func getClientConnection(config *Config, pool *x509.CertPool) (*smtp.Client, error) { var conn net.Conn var err error @@ -85,9 +103,7 @@ func getClientConnection(config *Config) (*smtp.Client, error) { addr := fmt.Sprintf("%s:%d", config.Host, config.Port) if useImplicitTLS(config.Encryption, config.Port) { - conn, err = tls.Dial("tcp", addr, &tls.Config{ - ServerName: config.Host, - }) + conn, err = tls.Dial("tcp", addr, getTLSConfig(config.Host, pool)) } else { conn, err = net.Dial("tcp", addr) } @@ -114,9 +130,7 @@ func (service *Service) doSend(client *smtp.Client, message string, config *Conf if supported, _ := client.Extension("StartTLS"); !supported { service.Logf("Warning: StartTLS enabled, but server did not report support for it. Connection is NOT encrypted") } else { - if err := client.StartTLS(&tls.Config{ - ServerName: config.Host, - }); err != nil { + if err := client.StartTLS(getTLSConfig(config.Host, service.certPool)); err != nil { return fail(FailEnableStartTLS, err) } } @@ -263,6 +277,15 @@ func (service *Service) writeMessagePart(wc io.WriteCloser, message string, temp return nil } +func getTLSConfig(hostname string, pool *x509.CertPool) *tls.Config { + tlsConfig := &tls.Config{ + ServerName: hostname, + RootCAs: pool, + } + util.ConfigureFallbackCertVerification(tlsConfig) + return tlsConfig +} + func writeMultipartHeader(wc io.WriteCloser, boundary string, contentType string) error { suffix := "\n" if len(contentType) < 1 { diff --git a/pkg/services/telegram/telegram.go b/pkg/services/telegram/telegram.go index 45f0364c..e3f0b47d 100644 --- a/pkg/services/telegram/telegram.go +++ b/pkg/services/telegram/telegram.go @@ -2,6 +2,7 @@ package telegram import ( "errors" + "github.com/containrrr/shoutrrr/pkg/common/webclient" "github.com/containrrr/shoutrrr/pkg/format" "net/url" @@ -17,6 +18,7 @@ const ( // Service sends notifications to a given telegram chat type Service struct { standard.Standard + webclient.ClientService config *Config pkr format.PropKeyResolver } @@ -42,6 +44,7 @@ func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) e Preview: true, Notification: true, } + service.ClientService.Initialize() service.pkr = format.NewPropKeyResolver(service.config) if err := service.config.setURL(&service.pkr, configURL); err != nil { return err @@ -51,8 +54,10 @@ func (service *Service) Initialize(configURL *url.URL, logger types.StdLogger) e } func (service *Service) sendMessageForChatIDs(message string, config *Config) error { + client := &Client{Token: config.Token, WebClient: service.WebClient()} for _, chat := range service.config.Chats { - if err := sendMessageToAPI(message, chat, config); err != nil { + payload := createSendMessagePayload(message, chat, config) + if _, err := client.SendMessage(&payload); err != nil { return err } } @@ -63,10 +68,3 @@ func (service *Service) sendMessageForChatIDs(message string, config *Config) er func (service *Service) GetConfig() *Config { return service.config } - -func sendMessageToAPI(message string, chat string, config *Config) error { - client := &Client{token: config.Token} - payload := createSendMessagePayload(message, chat, config) - _, err := client.SendMessage(&payload) - return err -} diff --git a/pkg/services/telegram/telegram_client.go b/pkg/services/telegram/telegram_client.go index 08624089..61cdc57a 100644 --- a/pkg/services/telegram/telegram_client.go +++ b/pkg/services/telegram/telegram_client.go @@ -1,27 +1,27 @@ package telegram import ( - "encoding/json" "fmt" - "github.com/containrrr/shoutrrr/pkg/util/jsonclient" + "github.com/containrrr/shoutrrr/pkg/common/webclient" ) // Client for Telegram API type Client struct { - token string + WebClient webclient.WebClient + Token string } func (c *Client) apiURL(endpoint string) string { - return fmt.Sprintf(apiFormat, c.token, endpoint) + return fmt.Sprintf(apiFormat, c.Token, endpoint) } // GetBotInfo returns the bot User info func (c *Client) GetBotInfo() (*User, error) { response := &userResponse{} - err := jsonclient.Get(c.apiURL("getMe"), response) + err := c.WebClient.Get(c.apiURL("getMe"), response) if !response.OK { - return nil, GetErrorResponse(jsonclient.ErrorBody(err)) + return nil, c.getErrorResponse(err) } return &response.Result, nil @@ -37,10 +37,10 @@ func (c *Client) GetUpdates(offset int, limit int, timeout int, allowedUpdates [ AllowedUpdates: allowedUpdates, } response := &updatesResponse{} - err := jsonclient.Post(c.apiURL("getUpdates"), request, response) + err := c.WebClient.Post(c.apiURL("getUpdates"), request, response) if !response.OK { - return nil, GetErrorResponse(jsonclient.ErrorBody(err)) + return nil, c.getErrorResponse(err) } return response.Result, nil @@ -50,20 +50,20 @@ func (c *Client) GetUpdates(offset int, limit int, timeout int, allowedUpdates [ func (c *Client) SendMessage(message *SendMessagePayload) (*Message, error) { response := &messageResponse{} - err := jsonclient.Post(c.apiURL("sendMessage"), message, response) + err := c.WebClient.Post(c.apiURL("sendMessage"), message, response) if !response.OK { - return nil, GetErrorResponse(jsonclient.ErrorBody(err)) + return nil, c.getErrorResponse(err) } return response.Result, nil } // GetErrorResponse retrieves the error message from a failed request -func GetErrorResponse(body string) error { - response := &errorResponse{} - if err := json.Unmarshal([]byte(body), response); err == nil { - return response +func (c *Client) getErrorResponse(err error) error { + errResponse := &ErrorResponse{} + if c.WebClient.ErrorResponse(err, errResponse) { + return errResponse } - return nil + return err } diff --git a/pkg/services/telegram/telegram_client_test.go b/pkg/services/telegram/telegram_client_test.go new file mode 100644 index 00000000..ae78cae3 --- /dev/null +++ b/pkg/services/telegram/telegram_client_test.go @@ -0,0 +1,45 @@ +package telegram_test + +import ( + "github.com/containrrr/shoutrrr/pkg/common/webclient" + "github.com/containrrr/shoutrrr/pkg/services/telegram" + "github.com/jarcoal/httpmock" + "net/http" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var client *telegram.Client + +var _ = Describe("the telegram client", func() { + BeforeEach(func() { + client = &telegram.Client{WebClient: webclient.NewJSONClient(), Token: `Test`} + httpmock.ActivateNonDefault(client.WebClient.HTTPClient()) + }) + AfterEach(func() { + httpmock.DeactivateAndReset() + }) + When("an error is returned from the API", func() { + It("should return the error description", func() { + + errRes := httpmock.NewJsonResponderOrPanic(http.StatusNotAcceptable, telegram.ErrorResponse{ + OK: false, + Description: "no.", + }) + httpmock.RegisterResponder("POST", `https://api.telegram.org/botTest/getUpdates`, errRes) + httpmock.RegisterResponder("GET", `https://api.telegram.org/botTest/getMe`, errRes) + httpmock.RegisterResponder("POST", `https://api.telegram.org/botTest/sendMessage`, errRes) + + _, err := client.GetUpdates(0, 1, 10, []string{}) + Expect(err).To(MatchError(`no.`)) + + _, err = client.GetBotInfo() + Expect(err).To(MatchError(`no.`)) + + _, err = client.SendMessage(&telegram.SendMessagePayload{}) + Expect(err).To(MatchError(`no.`)) + }) + }) + +}) diff --git a/pkg/services/telegram/telegram_generator.go b/pkg/services/telegram/telegram_generator.go index 2b3e945f..7087c0ac 100644 --- a/pkg/services/telegram/telegram_generator.go +++ b/pkg/services/telegram/telegram_generator.go @@ -1,6 +1,7 @@ package telegram import ( + "github.com/containrrr/shoutrrr/pkg/common/webclient" f "github.com/containrrr/shoutrrr/pkg/format" "github.com/containrrr/shoutrrr/pkg/types" "github.com/containrrr/shoutrrr/pkg/util/generator" @@ -41,7 +42,7 @@ func (g *Generator) Generate(_ types.Service, props map[string]string, _ []strin ud.Writeln("Fetching bot info...") // ud.Writeln("Session token: %v", g.sessionToken) - g.client = &Client{token: token} + g.client = &Client{Token: token, WebClient: webclient.NewJSONClient()} botInfo, err := g.client.GetBotInfo() if err != nil { return &Config{}, err diff --git a/pkg/services/telegram/telegram_json.go b/pkg/services/telegram/telegram_json.go index a44930ed..c46482bd 100644 --- a/pkg/services/telegram/telegram_json.go +++ b/pkg/services/telegram/telegram_json.go @@ -41,13 +41,14 @@ func createSendMessagePayload(message string, channel string, config *Config) Se return payload } -type errorResponse struct { +// ErrorResponse is the generic response from the API when an error occurred +type ErrorResponse struct { OK bool `json:"ok"` ErrorCode int `json:"error_code"` Description string `json:"description"` } -func (e *errorResponse) Error() string { +func (e *ErrorResponse) Error() string { return e.Description } diff --git a/pkg/services/telegram/telegram_test.go b/pkg/services/telegram/telegram_test.go index bff1fd8e..01177081 100644 --- a/pkg/services/telegram/telegram_test.go +++ b/pkg/services/telegram/telegram_test.go @@ -128,7 +128,7 @@ var _ = Describe("the telegram service", func() { Describe("sending the payload", func() { var err error BeforeEach(func() { - httpmock.Activate() + httpmock.ActivateNonDefault(telegram.HTTPClient()) }) AfterEach(func() { httpmock.DeactivateAndReset() @@ -138,7 +138,7 @@ var _ = Describe("the telegram service", func() { err = telegram.Initialize(serviceURL, logger) Expect(err).NotTo(HaveOccurred()) - setupResponder("sendMessage", telegram.GetConfig().Token, 200, "") + setupResponder("sendMessage", telegram.GetConfig().Token, 200, `{"ok": true}`) err = telegram.Send("Message", nil) Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/types/service.go b/pkg/types/service.go index e9d960d3..a0e0aafa 100644 --- a/pkg/types/service.go +++ b/pkg/types/service.go @@ -1,6 +1,7 @@ package types import ( + "net/http" "net/url" ) @@ -11,3 +12,8 @@ type Service interface { Initialize(serviceURL *url.URL, logger StdLogger) error SetLogger(logger StdLogger) } + +// HTTPService is the common interface for services that use a http.Client to send notifications +type HTTPService interface { + HTTPClient() *http.Client +} diff --git a/pkg/types/tls_client.go b/pkg/types/tls_client.go new file mode 100644 index 00000000..30aa004b --- /dev/null +++ b/pkg/types/tls_client.go @@ -0,0 +1,6 @@ +package types + +// TLSClient is the interface that needs to be implemented for custom TLS certificate support +type TLSClient interface { + AddTrustedRootCertificate([]byte) bool +} diff --git a/pkg/util/jsonclient/jsonclient.go b/pkg/util/jsonclient/jsonclient.go deleted file mode 100644 index 3e7fbe5c..00000000 --- a/pkg/util/jsonclient/jsonclient.go +++ /dev/null @@ -1,119 +0,0 @@ -package jsonclient - -import ( - "bytes" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" -) - -// ContentType is the default mime type for JSON -const ContentType = "application/json" - -// DefaultClient is the singleton instance of jsonclient using http.DefaultClient -var DefaultClient = NewClient() - -// Get fetches url using GET and unmarshals into the passed response using DefaultClient -func Get(url string, response interface{}) error { - return DefaultClient.Get(url, response) -} - -// Post sends request as JSON and unmarshals the response JSON into the supplied struct using DefaultClient -func Post(url string, request interface{}, response interface{}) error { - return DefaultClient.Post(url, request, response) -} - -// Client is a JSON wrapper around http.Client -type client struct { - httpClient *http.Client - headers http.Header - indent string -} - -func NewClient() Client { - return &client{ - httpClient: http.DefaultClient, - headers: http.Header{ - "Content-Type": []string{ContentType}, - }, - } -} - -// Headers return the default headers for requests -func (c *client) Headers() http.Header { - return c.headers -} - -// Get fetches url using GET and unmarshals into the passed response -func (c *client) Get(url string, response interface{}) error { - res, err := c.httpClient.Get(url) - if err != nil { - return err - } - - return parseResponse(res, response) -} - -// Post sends request as JSON and unmarshals the response JSON into the supplied struct -func (c *client) Post(url string, request interface{}, response interface{}) error { - var err error - var body []byte - - body, err = json.MarshalIndent(request, "", c.indent) - if err != nil { - return fmt.Errorf("error creating payload: %v", err) - } - - req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("error creating request: %v", err) - } - - for key, val := range c.headers { - req.Header.Set(key, val[0]) - } - - var res *http.Response - res, err = c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("error sending payload: %v", err) - } - - return parseResponse(res, response) -} - -func (c *client) ErrorResponse(err error, response interface{}) bool { - jerr, isJsonError := err.(Error) - if !isJsonError { - return false - } - - return json.Unmarshal([]byte(jerr.Body), response) == nil -} - -func parseResponse(res *http.Response, response interface{}) error { - defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) - - if res.StatusCode >= 400 { - err = fmt.Errorf("got HTTP %v", res.Status) - } - - if err == nil { - err = json.Unmarshal(body, response) - } - - if err != nil { - if body == nil { - body = []byte{} - } - return Error{ - StatusCode: res.StatusCode, - Body: string(body), - err: err, - } - } - - return nil -} diff --git a/pkg/util/util.go b/pkg/util/util.go index a47b1ed6..5ab4bb71 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -39,6 +39,7 @@ func TestLogger() *log.Logger { // DiscardLogger is a logger that discards any output written to it var DiscardLogger = log.New(ioutil.Discard, "", 0) +// URLMust parses the specified URL and panics (with offset) if it fails func URLMust(rawURL string) *url.URL { parsed, err := url.Parse(rawURL) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) diff --git a/pkg/util/util_tls.go b/pkg/util/util_tls.go new file mode 100644 index 00000000..0e8e3a7a --- /dev/null +++ b/pkg/util/util_tls.go @@ -0,0 +1,61 @@ +package util + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "runtime" +) + +// ConfigureFallbackCertVerification will set the VerifyPeerCertificate callback to a custom function that tries to +// validate the peer certificate using the system certificate store if verifying using the root CAs in the config fails +// +// Workaround for https://github.com/golang/go/issues/16736 +// Based on example https://pkg.go.dev/crypto/tls@go1.14#example-Config-VerifyPeerCertificate +func ConfigureFallbackCertVerification(conf *tls.Config) { + + if runtime.GOOS != `windows` { + // Only needed on windows + return + } + + if conf.RootCAs == nil { + // No custom certs have been added + return + } + + // Unless we enable this, the regular verification will still abort the handshake + conf.InsecureSkipVerify = true + + conf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(rawCerts)) + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + opts := x509.VerifyOptions{ + Roots: conf.RootCAs, + DNSName: conf.ServerName, + Intermediates: x509.NewCertPool(), + } + + targetCert := certs[0] + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + _, err := targetCert.Verify(opts) + if err != nil { + // Try again using no root store as CryptoAPI will be used to verify instead + opts.Roots = nil + _, err = targetCert.Verify(opts) + } + + return err + } +}