Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for custom non-http transports #233

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions client_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,47 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo
return false, nil
}

// empty returns true if t is equivalent to an empty TLSConfiguration{} object.
func (t *TLSConfiguration) empty() bool {
if t.ServerCertificate.FromFile != "" {
return false
}

if len(t.ServerCertificate.FromBytes) != 0 {
return false
}

if t.ServerCertificate.FromDirectory != "" {
return false
}

if t.ClientCertificate.FromFile != "" {
return false
}

if len(t.ClientCertificate.FromBytes) != 0 {
return false
}

if t.ClientCertificateKey.FromFile != "" {
return false
}

if len(t.ClientCertificateKey.FromBytes) != 0 {
return false
}

if t.ServerName != "" {
return false
}

if t.InsecureSkipVerify {
return false
}

return true
}

// applyTo applies the user-defined TLS configuration to the given client's
// *tls.Config pointer; it is used to configure the internal http.Client
func (from *TLSConfiguration) applyTo(to *tls.Config) error {
Expand Down
46 changes: 45 additions & 1 deletion client_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)

func Test_walkconfigurationfields(t *testing.T) {
func Test_walkConfigurationFields(t *testing.T) {
var (
actual = []string{}
expected = []string{
Expand Down Expand Up @@ -44,3 +44,47 @@ func Test_walkconfigurationfields(t *testing.T) {

require.Subset(t, actual, expected)
}

func Test_TLSConfiguration_empty(t *testing.T) {
cases := map[string]struct {
tls TLSConfiguration
expectedEmpty bool
}{
"empty": {
tls: TLSConfiguration{},
expectedEmpty: true,
},
"with-server-name": {
tls: TLSConfiguration{
ServerName: "my-server",
},
expectedEmpty: false,
},
"with-server-certificate-from-file": {
tls: TLSConfiguration{
ServerCertificate: ServerCertificateEntry{
FromFile: "./cert.pem",
},
},
expectedEmpty: false,
},
"with-server-certificate-from-bytes": {
tls: TLSConfiguration{
ServerCertificate: ServerCertificateEntry{
FromBytes: []byte{1, 1, 2, 3, 5},
},
},
expectedEmpty: false,
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
if tc.expectedEmpty {
require.True(t, tc.tls.empty())
} else {
require.False(t, tc.tls.empty())
}
})
}
}
33 changes: 23 additions & 10 deletions generate/templates/client.handlebars
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,33 @@ func newClient(configuration ClientConfiguration) (*Client, error) {
}
c.parsedBaseAddress = address

transport, ok := c.client.Transport.(*http.Transport);
if !ok {
return nil, fmt.Errorf("the configured base client's transport (%T) is not of type *http.Transport", c.client.Transport)
}

// Adjust the dial context for unix domain socket addresses.
if strings.HasPrefix(configuration.Address, "unix://") {
transport.DialContext = func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", strings.TrimPrefix(configuration.Address, "unix://"))
// Adjust the dial context for unix domain socket addresses in the
// internal HTTP transport, if exists.
if httpTransport, ok := c.client.Transport.(*http.Transport); ok {
httpTransport.DialContext = func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", strings.TrimPrefix(configuration.Address, "unix://"))
}
} else {
return nil, fmt.Errorf(
"the configured base client's transport (%T) is not of type *http.Transport and cannot be used with the unix:// address",
c.client.Transport,
)
}
}

if err := configuration.TLS.applyTo(transport.TLSClientConfig); err != nil {
return nil, err
if !configuration.TLS.empty() {
// Apply TLS configuration to the internal HTTP transport, if exists.
if httpTransport, ok := c.client.Transport.(*http.Transport); ok {
if err := configuration.TLS.applyTo(httpTransport.TLSClientConfig); err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf(
"the configured base client's transport (%T) is not of type *http.Transport and cannot be used with TLS configuration",
c.client.Transport,
)
}
}

{{#with apiInfo}}
Expand Down