From 78c507c7b38b75e42b9d88facf63d58e6d041e4d Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Mon, 25 Apr 2022 17:00:35 -0500 Subject: [PATCH] add unit test for automtls --- server_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/server_test.go b/server_test.go index 5fec4502..c8a28593 100644 --- a/server_test.go +++ b/server_test.go @@ -82,6 +82,75 @@ func TestServer_testMode(t *testing.T) { t.Logf("HELLO") } +func TestServer_testMode_AutoMTLS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + closeCh := make(chan struct{}) + go Serve(&ServeConfig{ + HandshakeConfig: testVersionedHandshake, + VersionedPlugins: map[int]PluginSet{ + 2: testGRPCPluginMap, + }, + GRPCServer: DefaultGRPCServer, + Logger: hclog.NewNullLogger(), + Test: &ServeTestConfig{ + Context: ctx, + ReattachConfigCh: nil, + CloseCh: closeCh, + }, + }) + + // Connect! + process := helperProcess("test-mtls") + c := NewClient(&ClientConfig{ + Cmd: process, + HandshakeConfig: testVersionedHandshake, + VersionedPlugins: map[int]PluginSet{ + 2: testGRPCPluginMap, + }, + AllowedProtocols: []Protocol{ProtocolGRPC}, + AutoMTLS: true, + }) + client, err := c.Client() + if err != nil { + t.Fatalf("err: %s", err) + } + + // Grab the impl + raw, err := client.Dispense("test") + if err != nil { + t.Fatalf("err should be nil, got %s", err) + } + + tester, ok := raw.(testInterface) + if !ok { + t.Fatalf("bad: %#v", raw) + } + + n := tester.Double(3) + if n != 6 { + t.Fatal("invalid response", n) + } + + // ensure we can make use of bidirectional communication with AutoMTLS + // enabled + err = tester.Bidirectional() + if err != nil { + t.Fatal("invalid response", err) + } + + // Pinging should work + if err := client.Ping(); err != nil { + t.Fatalf("should not err: %s", err) + } + + c.Kill() + // Canceling should cause an exit + cancel() + <-closeCh +} + func TestRmListener_impl(t *testing.T) { var _ net.Listener = new(rmListener) } @@ -145,7 +214,6 @@ func TestProtocolSelection_no_server(t *testing.T) { if protocol != ProtocolNetRPC { t.Fatalf("bad protocol %s", protocol) } - } func TestServer_testStdLogger(t *testing.T) {