Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 78 additions & 108 deletions go/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,75 +12,103 @@ import (
"os"
"strings"
"testing"
"testing/synctest"

"time"

"connectrpc.com/connect"
"github.com/golang-jwt/jwt/v5"
"github.com/metal-stack/api/go/client"
apiv2 "github.com/metal-stack/api/go/metalstack/api/v2"
"github.com/metal-stack/api/go/metalstack/api/v2/apiv2connect"
infrav2 "github.com/metal-stack/api/go/metalstack/infra/v2"
"github.com/metal-stack/api/go/metalstack/infra/v2/infrav2connect"
"github.com/stretchr/testify/require"
)

func Test_Client(t *testing.T) {
var (
vs = &mockVersionService{}
ts = &mockTokenService{}
mux = http.NewServeMux()
log = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
)

mux.Handle(apiv2connect.NewVersionServiceHandler(vs))
mux.Handle(apiv2connect.NewTokenServiceHandler(ts))
server := httptest.NewTLSServer(mux)
server.EnableHTTP2 = true
defer func() {
server.Close()
}()

tokenString, err := generateToken(2 * time.Second)
require.NoError(t, err)

c, err := client.New(&client.DialConfig{
BaseURL: server.URL,
Token: tokenString,
Transport: server.Client().Transport,
TokenRenewal: &client.TokenRenewal{
PersistTokenFn: func(token string) error {
ts.token = token
t.Log("token persisted:", token)
return nil
synctest.Test(t, func(t *testing.T) {
tokenString, err := generateToken(2 * time.Second)
require.NoError(t, err)
var renewedToken string

c, err := client.New(&client.DialConfig{
BaseURL: "http://localhost",
Token: tokenString,

Interceptors: []connect.Interceptor{
client.NewTestInterceptor(t, []client.ClientCall{
{
WantRequest: &apiv2.VersionServiceGetRequest{},
WantResponse: func() connect.AnyResponse {
return connect.NewResponse(&apiv2.VersionServiceGetResponse{
Version: &apiv2.Version{Version: "1.0"},
})
},
},
{
WantRequest: &apiv2.VersionServiceGetRequest{},
WantResponse: func() connect.AnyResponse {
return connect.NewResponse(&apiv2.VersionServiceGetResponse{
Version: &apiv2.Version{Version: "1.0"},
})
},
},
{
WantRequest: &apiv2.TokenServiceRefreshRequest{},
WantResponse: func() connect.AnyResponse {
tokenString, err := generateToken(2 * time.Second)
require.NoError(t, err)

return connect.NewResponse(&apiv2.TokenServiceRefreshResponse{
Secret: tokenString,
})
},
},
{
WantRequest: &apiv2.VersionServiceGetRequest{},
WantResponse: func() connect.AnyResponse {
return connect.NewResponse(&apiv2.VersionServiceGetResponse{
Version: &apiv2.Version{Version: "1.0"},
})
},
},
}),
},
TokenRenewal: &client.TokenRenewal{
PersistTokenFn: func(token string) error {
renewedToken = token
return nil
},
},
},
Log: log,
Log: log,
})

require.NoError(t, err)
v, err := c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{})
require.NoError(t, err)
require.NotNil(t, v)
require.Equal(t, "1.0", v.Version.Version)
require.Empty(t, renewedToken)

time.Sleep(1 * time.Second)
v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{})
require.NoError(t, err)
require.NotNil(t, v)
require.Equal(t, "1.0", v.Version.Version)
require.Empty(t, renewedToken)

time.Sleep(3 * time.Second)
v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{})
require.NoError(t, err)
require.NotNil(t, v)
require.Equal(t, "1.0", v.Version.Version)
require.NotEmpty(t, renewedToken)
require.NotEqual(t, renewedToken, tokenString, "haven't changed")
})
require.NoError(t, err)
v, err := c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{})
require.NoError(t, err)
require.NotNil(t, v)
require.Equal(t, "1.0", v.Version.Version)
require.False(t, ts.wasCalled)
require.Equal(t, tokenString, vs.token)

time.Sleep(1 * time.Second)
v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{})
require.NoError(t, err)
require.NotNil(t, v)
require.Equal(t, "1.0", v.Version.Version)
require.False(t, ts.wasCalled)
require.Equal(t, tokenString, vs.token)

time.Sleep(1 * time.Second)
v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{})
require.NoError(t, err)
require.NotNil(t, v)
require.Equal(t, "1.0", v.Version.Version)

require.True(t, ts.wasCalled)
require.NotEqual(t, tokenString, ts.token, "token must have changed")
}

func generateToken(duration time.Duration) (string, error) {
Expand All @@ -102,64 +130,6 @@ func generateToken(duration time.Duration) (string, error) {
return tokenString, nil
}

type mockVersionService struct {
token string
}

func (m *mockVersionService) Get(ctx context.Context, req *apiv2.VersionServiceGetRequest) (*apiv2.VersionServiceGetResponse, error) {
callinfo, _ := connect.CallInfoForHandlerContext(ctx)
authHeader := callinfo.RequestHeader().Get("Authorization")

_, token, found := strings.Cut(authHeader, "Bearer ")

if !found {
return nil, fmt.Errorf("unable to extract token from header:%s", authHeader)
}

m.token = token
return &apiv2.VersionServiceGetResponse{Version: &apiv2.Version{Version: "1.0"}}, nil
}

type mockTokenService struct {
wasCalled bool
token string
}

// Create implements apiv2connect.TokenServiceHandler.
func (m *mockTokenService) Create(context.Context, *apiv2.TokenServiceCreateRequest) (*apiv2.TokenServiceCreateResponse, error) {
panic("unimplemented")
}

// Get implements apiv2connect.TokenServiceHandler.
func (m *mockTokenService) Get(context.Context, *apiv2.TokenServiceGetRequest) (*apiv2.TokenServiceGetResponse, error) {
panic("unimplemented")
}

// List implements apiv2connect.TokenServiceHandler.
func (m *mockTokenService) List(context.Context, *apiv2.TokenServiceListRequest) (*apiv2.TokenServiceListResponse, error) {
panic("unimplemented")
}

// Refresh implements apiv2connect.TokenServiceHandler.
func (m *mockTokenService) Refresh(ctx context.Context, _ *apiv2.TokenServiceRefreshRequest) (*apiv2.TokenServiceRefreshResponse, error) {
token, err := generateToken(2 * time.Second)
if err != nil {
return nil, err
}
m.wasCalled = true
return &apiv2.TokenServiceRefreshResponse{Token: &apiv2.Token{}, Secret: token}, nil
}

// Revoke implements apiv2connect.TokenServiceHandler.
func (m *mockTokenService) Revoke(context.Context, *apiv2.TokenServiceRevokeRequest) (*apiv2.TokenServiceRevokeResponse, error) {
panic("unimplemented")
}

// Update implements apiv2connect.TokenServiceHandler.
func (m *mockTokenService) Update(context.Context, *apiv2.TokenServiceUpdateRequest) (*apiv2.TokenServiceUpdateResponse, error) {
panic("unimplemented")
}

func Test_ClientInterceptors(t *testing.T) {
var (
bs = &mockBMCService{}
Expand Down
72 changes: 72 additions & 0 deletions go/client/test_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package client

import (
"context"
"reflect"
"testing"

"connectrpc.com/connect"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/testing/protocmp"
)

type TestClientInterceptor struct {
t *testing.T
calls []ClientCall
count int
}

type ClientCall struct {
WantRequest proto.Message
WantResponse func() connect.AnyResponse
WantError *connect.Error
}

func NewTestInterceptor(t *testing.T, calls []ClientCall) *TestClientInterceptor {
return &TestClientInterceptor{
t: t,
calls: calls,
}
}

func (t *TestClientInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
defer func() { t.count++ }()

if t.count >= len(t.calls) {
t.t.Errorf("received an unexpected client call of type %T: %v", ar.Any(), ar.Any())
t.t.FailNow()
}

call := t.calls[t.count]

if diff := cmp.Diff(call.WantRequest, ar.Any(), protocmp.Transform(), IgnoreUnexported(), cmpopts.IgnoreTypes(protoimpl.MessageState{})); diff != "" {
t.t.Errorf("request diff (+got -want):\n %s", diff)
t.t.FailNow()
}

if call.WantError != nil {
return nil, call.WantError
}

return call.WantResponse(), nil
}
}

func (t *TestClientInterceptor) WrapStreamingClient(connect.StreamingClientFunc) connect.StreamingClientFunc {
t.t.Errorf("streaming not supported")
return nil
}

func (t *TestClientInterceptor) WrapStreamingHandler(connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
t.t.Errorf("streaming not supported")
return nil
}

func IgnoreUnexported() cmp.Option {
// the exporter opt allows all unexported fields: https://github.com/google/go-cmp/pull/176
return cmp.Exporter(func(reflect.Type) bool { return true })
}
51 changes: 51 additions & 0 deletions go/client/test_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package client_test

import (
"log/slog"
"testing"

"connectrpc.com/connect"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
client "github.com/metal-stack/api/go/client"
apiv2 "github.com/metal-stack/api/go/metalstack/api/v2"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/testing/protocmp"
)

func TestInterceptor(t *testing.T) {
cl, err := client.New(&client.DialConfig{
BaseURL: "http://this-is-just-for-testing",
Interceptors: []connect.Interceptor{
client.NewTestInterceptor(t, []client.ClientCall{
{
WantRequest: &apiv2.IPServiceGetRequest{
Ip: "1.2.3.4",
},
WantResponse: func() connect.AnyResponse {
return connect.NewResponse(&apiv2.IPServiceGetResponse{
Ip: &apiv2.IP{Ip: "1.2.3.4"},
})
},
},
}),
},
UserAgent: "cli-test",
Log: slog.Default(),
})
require.NoError(t, err)

resp, err := cl.Apiv2().IP().Get(t.Context(), &apiv2.IPServiceGetRequest{
Ip: "1.2.3.4",
})
require.NoError(t, err)

if diff := cmp.Diff(&apiv2.IPServiceGetResponse{
Ip: &apiv2.IP{
Ip: "1.2.3.4",
},
}, resp, protocmp.Transform(), client.IgnoreUnexported(), cmpopts.IgnoreTypes(protoimpl.MessageState{})); diff != "" {
t.Errorf("diff = %s", diff)
}
}
Loading