Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for custom root CAs #185

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
140 changes: 140 additions & 0 deletions pkg/common/webclient/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package webclient

import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
)

// DefaultJsonClient is the singleton instance of WebClient using http.DefaultClient
var DefaultJsonClient = NewJSONClient()

// GetJson fetches url using GET and unmarshals into the passed response using DefaultJsonClient
func GetJson(url string, response interface{}) error {
return DefaultJsonClient.Get(url, response)
}

// PostJson sends request as JSON and unmarshals the response JSON into the supplied struct using DefaultJsonClient
func PostJson(url string, request interface{}, response interface{}) error {
return DefaultJsonClient.Post(url, request, response)
}

// WebClient is a JSON wrapper around http.WebClient
piksel marked this conversation as resolved.
Show resolved Hide resolved
type client struct {
headers http.Header
indent string
HttpClient http.Client
parse ParserFunc
write WriterFunc
}

// SetTransport overrides the http.RoundTripper for the web client, mainly used for testing
func (c *client) SetTransport(transport http.RoundTripper) {
c.HttpClient.Transport = transport
}

// SetParser overrides the parser for the incoming response content
func (c *client) SetParser(parse ParserFunc) {
c.parse = parse
}

// SetWriter overrides the writer for the outgoing request content
func (c *client) SetWriter(write WriterFunc) {
c.write = write
}

// NewJSONClient returns a WebClient using the default http.Client and JSON serialization
func NewJSONClient() WebClient {
var c client
c = client{
headers: http.Header{
"Content-Type": []string{JsonContentType},
},
parse: json.Unmarshal,
write: func(v interface{}) ([]byte, error) {
return json.MarshalIndent(v, "", c.indent)
},
}
return &c
}

// Headers return the default headers for requests
func (c *client) Headers() http.Header {
return c.headers
}

// Get fetches url using GET and unmarshals into the passed response
func (c *client) Get(url string, response interface{}) error {
res, err := c.HttpClient.Get(url)
if err != nil {
return err
}

return c.parseResponse(res, response)
}

// Post sends a serialized representation of request and deserializes the result into response
func (c *client) Post(url string, request interface{}, response interface{}) error {
var err error
var body []byte

body, err = c.write(request)
if err != nil {
return fmt.Errorf("error creating payload: %v", err)
}

req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("error creating request: %v", err)
}

for key, val := range c.headers {
req.Header.Set(key, val[0])
}

var res *http.Response
res, err = c.HttpClient.Do(req)
if err != nil {
return fmt.Errorf("error sending payload: %v", err)
}

return c.parseResponse(res, response)
}

// ErrorResponse tries to deserialize any response body into the supplied struct, returning whether successful or not
func (c *client) ErrorResponse(err error, response interface{}) bool {
jerr, isWebError := err.(ClientError)
if !isWebError {
return false
}

return c.parse([]byte(jerr.Body), response) == nil
}

func (c *client) parseResponse(res *http.Response, response interface{}) error {
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)

if res.StatusCode >= 400 {
err = fmt.Errorf("got HTTP %v", res.Status)
}

if err == nil {
err = c.parse(body, response)
}

if err != nil {
if body == nil {
body = []byte{}
}
return ClientError{
StatusCode: res.StatusCode,
Body: string(body),
err: err,
}
}

return nil
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
package jsonclient_test
package webclient_test

import (
"errors"
"github.com/containrrr/shoutrrr/pkg/util/jsonclient"
"github.com/containrrr/shoutrrr/pkg/common/webclient"
"github.com/onsi/gomega/ghttp"
"net/http"
"testing"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

func TestJSONClient(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "JSONClient Suite")
}

var _ = Describe("JSONClient", func() {
var server *ghttp.Server

Expand All @@ -27,7 +21,7 @@ var _ = Describe("JSONClient", func() {
It("should return an error", func() {
server.AppendHandlers(ghttp.RespondWith(http.StatusOK, "invalid json"))
res := &mockResponse{}
err := jsonclient.Get(server.URL(), &res)
err := webclient.GetJson(server.URL(), &res)
Expect(server.ReceivedRequests()).Should(HaveLen(1))
Expect(err).To(MatchError("invalid character 'i' looking for beginning of value"))
Expect(res.Status).To(BeEmpty())
Expand All @@ -38,7 +32,7 @@ var _ = Describe("JSONClient", func() {
It("should return an error", func() {
server.AppendHandlers(ghttp.RespondWith(http.StatusOK, nil))
res := &mockResponse{}
err := jsonclient.Get(server.URL(), &res)
err := webclient.GetJson(server.URL(), &res)
Expect(server.ReceivedRequests()).Should(HaveLen(1))
Expect(err).To(MatchError("unexpected end of JSON input"))
Expect(res.Status).To(BeEmpty())
Expand All @@ -48,7 +42,7 @@ var _ = Describe("JSONClient", func() {
It("should deserialize GET response", func() {
server.AppendHandlers(ghttp.RespondWithJSONEncoded(http.StatusOK, mockResponse{Status: "OK"}))
res := &mockResponse{}
err := jsonclient.Get(server.URL(), &res)
err := webclient.GetJson(server.URL(), &res)
Expect(server.ReceivedRequests()).Should(HaveLen(1))
Expect(err).ToNot(HaveOccurred())
Expect(res.Status).To(Equal("OK"))
Expand All @@ -66,22 +60,22 @@ var _ = Describe("JSONClient", func() {
ghttp.RespondWithJSONEncoded(http.StatusOK, &mockResponse{Status: "That's Numberwang!"})),
)

err := jsonclient.Post(server.URL(), &req, &res)
err := webclient.PostJson(server.URL(), &req, &res)
Expect(server.ReceivedRequests()).Should(HaveLen(1))
Expect(err).ToNot(HaveOccurred())
Expect(res.Status).To(Equal("That's Numberwang!"))
})

It("should return error on error status responses", func() {
server.AppendHandlers(ghttp.RespondWith(404, "Not found!"))
err := jsonclient.Post(server.URL(), &mockRequest{}, &mockResponse{})
err := webclient.PostJson(server.URL(), &mockRequest{}, &mockResponse{})
Expect(server.ReceivedRequests()).Should(HaveLen(1))
Expect(err).To(MatchError("got HTTP 404 Not Found"))
})

It("should return error on invalid request", func() {
server.AppendHandlers(ghttp.VerifyRequest("POST", "/"))
err := jsonclient.Post(server.URL(), func() {}, &mockResponse{})
err := webclient.PostJson(server.URL(), func() {}, &mockResponse{})
Expect(server.ReceivedRequests()).Should(HaveLen(0))
Expect(err).To(MatchError("error creating payload: json: unsupported type: func()"))
})
Expand All @@ -93,10 +87,10 @@ var _ = Describe("JSONClient", func() {
ghttp.RespondWithJSONEncoded(http.StatusOK, res)),
)

err := jsonclient.Post(server.URL(), nil, &[]bool{})
err := webclient.PostJson(server.URL(), nil, &[]bool{})
Expect(server.ReceivedRequests()).Should(HaveLen(1))
Expect(err).To(MatchError("json: cannot unmarshal object into Go value of type []bool"))
Expect(jsonclient.ErrorBody(err)).To(MatchJSON(`{"Status":"cool skirt"}`))
Expect(webclient.ErrorBody(err)).To(MatchJSON(`{"Status":"cool skirt"}`))
})
})

Expand All @@ -106,24 +100,24 @@ var _ = Describe("JSONClient", func() {
})
})

var _ = Describe("Error", func() {
var _ = Describe("ClientError", func() {
When("no internal error has been set", func() {
It("should return a generic message with status code", func() {
errorWithNoError := jsonclient.Error{StatusCode: http.StatusEarlyHints}
errorWithNoError := webclient.ClientError{StatusCode: http.StatusEarlyHints}
Expect(errorWithNoError.String()).To(Equal("unknown error (HTTP 103)"))
})
})
Describe("ErrorBody", func() {
When("passed a non-json error", func() {
It("should return an empty string", func() {
Expect(jsonclient.ErrorBody(errors.New("unrelated error"))).To(BeEmpty())
Expect(webclient.ErrorBody(errors.New("unrelated error"))).To(BeEmpty())
})
})
When("passed a jsonclient.Error", func() {
When("passed a jsonclient.ClientError", func() {
It("should return the request body from that error", func() {
errorBody := `{"error": "bad user"}`
jsonError := jsonclient.Error{Body: errorBody}
Expect(jsonclient.ErrorBody(jsonError)).To(MatchJSON(errorBody))
jsonError := webclient.ClientError{Body: errorBody}
Expect(webclient.ErrorBody(jsonError)).To(MatchJSON(errorBody))
})
})
})
Expand Down
14 changes: 7 additions & 7 deletions pkg/util/jsonclient/error.go → pkg/common/webclient/error.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
package jsonclient
package webclient

import "fmt"

// Error contains additional http/JSON details
type Error struct {
// ClientError contains additional http/JSON details
type ClientError struct {
StatusCode int
Body string
err error
}

func (je Error) Error() string {
func (je ClientError) Error() string {
return je.String()
}

func (je Error) String() string {
func (je ClientError) String() string {
if je.err == nil {
return fmt.Sprintf("unknown error (HTTP %v)", je.StatusCode)
}
return je.err.Error()
}

// ErrorBody returns the request body from an Error
// ErrorBody returns the request body from a ClientError
func ErrorBody(e error) string {
if jsonError, ok := e.(Error); ok {
if jsonError, ok := e.(ClientError); ok {
return jsonError.Body
}
return ""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package jsonclient
package webclient

import "net/http"
import (
"net/http"
)

type Client interface {
// WebClient ...
type WebClient interface {
Get(url string, response interface{}) error
Post(url string, request interface{}, response interface{}) error
Headers() http.Header
ErrorResponse(err error, response interface{}) bool
SetTransport(http.RoundTripper)
SetParser(ParserFunc)
SetWriter(WriterFunc)
}
78 changes: 78 additions & 0 deletions pkg/common/webclient/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package webclient

import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"github.com/containrrr/shoutrrr/pkg/types"
"net/http"
)

// ParserFunc are functions that deserialize a struct from the passed bytes
type ParserFunc func(raw []byte, v interface{}) error

// WriterFunc are functions that serialize the passed struct into a byte stream
type WriterFunc func(v interface{}) ([]byte, error)

var _ types.TLSClient = &ClientService{}
var _ types.HTTPService = &ClientService{}

// JsonContentType is the default mime type for JSON
const JsonContentType = "application/json"

// ClientService is a Composable that adds a generic web request client to the service
type ClientService struct {
client *client
certPool *x509.CertPool
}

// HTTPClient returns the underlying http.WebClient used in the Service
func (s *ClientService) HTTPClient() *http.Client {
s.Initialize()
return &s.client.HttpClient
}

// WebClient returns the WebClient instance, initializing it if necessary
func (s *ClientService) WebClient() WebClient {
s.Initialize()
return s.client
}

// Initialize sets up the WebClient in the default state using JSON serialization and headers
func (s *ClientService) Initialize() {
if s.client != nil {
return
}

s.client = &client{
HttpClient: http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{},
},
},
headers: http.Header{
"Content-Type": []string{JsonContentType},
},
parse: json.Unmarshal,
write: func(v interface{}) ([]byte, error) {
return json.MarshalIndent(v, "", s.client.indent)
},
}
}

// AddTrustedRootCertificate adds the specified PEM certificate to the pool of trusted root CAs
func (s *ClientService) AddTrustedRootCertificate(caPEM []byte) bool {
s.Initialize()
if s.certPool == nil {
certPool, err := x509.SystemCertPool()
if err != nil {
certPool = x509.NewCertPool()
}
s.certPool = certPool
if tp, ok := s.client.HttpClient.Transport.(*http.Transport); ok {
tp.TLSClientConfig.RootCAs = s.certPool
}
}

return s.certPool.AppendCertsFromPEM(caPEM)
}