Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup: defer to close server in tests #110367

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
125 changes: 62 additions & 63 deletions pkg/client/tests/portfoward_test.go
Expand Up @@ -129,81 +129,80 @@ func TestForwardPorts(t *testing.T) {
}

for testName, test := range tests {
server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
t.Run(testName, func(t *testing.T) {
server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
defer server.Close()

transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
if err != nil {
t.Fatal(err)
}
url, _ := url.Parse(server.URL)
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)

stopChan := make(chan struct{}, 1)
readyChan := make(chan struct{})

pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr)
if err != nil {
t.Fatalf("%s: unexpected error calling New: %v", testName, err)
}
transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
if err != nil {
t.Fatal(err)
}
url, _ := url.Parse(server.URL)
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)

doneChan := make(chan error)
go func() {
doneChan <- pf.ForwardPorts()
}()
<-pf.Ready
stopChan := make(chan struct{}, 1)
readyChan := make(chan struct{})

forwardedPorts, err := pf.GetPorts()
if err != nil {
t.Fatal(err)
}
pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr)
if err != nil {
t.Fatalf("%s: unexpected error calling New: %v", testName, err)
}

remoteToLocalMap := map[int32]int32{}
for _, forwardedPort := range forwardedPorts {
remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local)
}
doneChan := make(chan error)
go func() {
doneChan <- pf.ForwardPorts()
}()
<-pf.Ready

for port, data := range test.clientSends {
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port]))
forwardedPorts, err := pf.GetPorts()
if err != nil {
t.Errorf("%s: error dialing %d: %s", testName, port, err)
server.Close()
continue
t.Fatal(err)
}
defer clientConn.Close()

n, err := clientConn.Write([]byte(data))
if err != nil && err != io.EOF {
t.Errorf("%s: Error sending data '%s': %s", testName, data, err)
server.Close()
continue
}
if n == 0 {
t.Errorf("%s: unexpected write of 0 bytes", testName)
server.Close()
continue
remoteToLocalMap := map[int32]int32{}
for _, forwardedPort := range forwardedPorts {
remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local)
}
b := make([]byte, 4)
_, err = clientConn.Read(b)
if err != nil && err != io.EOF {
t.Errorf("%s: Error reading data: %s", testName, err)
server.Close()
continue

clientSend := func(port int32, data string) error {
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port]))
if err != nil {
return fmt.Errorf("%s: error dialing %d: %s", testName, port, err)

}
defer clientConn.Close()

n, err := clientConn.Write([]byte(data))
if err != nil && err != io.EOF {
return fmt.Errorf("%s: Error sending data '%s': %s", testName, data, err)
}
if n == 0 {
return fmt.Errorf("%s: unexpected write of 0 bytes", testName)
}
b := make([]byte, 4)
_, err = clientConn.Read(b)
if err != nil && err != io.EOF {
return fmt.Errorf("%s: Error reading data: %s", testName, err)
}
if !bytes.Equal([]byte(test.serverSends[port]), b) {
return fmt.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
}
return nil
}
if !bytes.Equal([]byte(test.serverSends[port]), b) {
t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
server.Close()
continue
for port, data := range test.clientSends {
if err := clientSend(port, data); err != nil {
t.Error(err)
}
}
}
// tell r.ForwardPorts to stop
close(stopChan)
// tell r.ForwardPorts to stop
close(stopChan)

// wait for r.ForwardPorts to actually return
err = <-doneChan
if err != nil {
t.Errorf("%s: unexpected error: %s", testName, err)
}
server.Close()
// wait for r.ForwardPorts to actually return
err = <-doneChan
if err != nil {
t.Errorf("%s: unexpected error: %s", testName, err)
}
})
}

}
Expand Down
174 changes: 85 additions & 89 deletions pkg/client/tests/remotecommand_test.go
Expand Up @@ -195,108 +195,104 @@ func TestStream(t *testing.T) {
} else {
name = testCase.TestName + " (attach)"
}
var (
streamIn io.Reader
streamOut, streamErr io.Writer
)
localOut := &bytes.Buffer{}
localErr := &bytes.Buffer{}

requestReceived := make(chan struct{})
server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))

url, _ := url.ParseRequestURI(server.URL)
config := restclient.ClientContentConfig{
GroupVersion: schema.GroupVersion{Group: "x"},
Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}),
}
c, err := restclient.NewRESTClient(url, "", config, nil, nil)
if err != nil {
t.Fatalf("failed to create a client: %v", err)
}
req := c.Post().Resource("testing")

if exec {
req.Param("command", "ls")
req.Param("command", "/")
}

if len(testCase.Stdin) > 0 {
req.Param(api.ExecStdinParam, "1")
streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
}
t.Run(name, func(t *testing.T) {
var (
streamIn io.Reader
streamOut, streamErr io.Writer
)
localOut := &bytes.Buffer{}
localErr := &bytes.Buffer{}

requestReceived := make(chan struct{})
server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
defer server.Close()

url, _ := url.ParseRequestURI(server.URL)
config := restclient.ClientContentConfig{
GroupVersion: schema.GroupVersion{Group: "x"},
Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}),
}
c, err := restclient.NewRESTClient(url, "", config, nil, nil)
if err != nil {
t.Fatalf("failed to create a client: %v", err)
}
req := c.Post().Resource("testing")

if len(testCase.Stdout) > 0 {
req.Param(api.ExecStdoutParam, "1")
streamOut = localOut
}
if exec {
req.Param("command", "ls")
req.Param("command", "/")
}

if testCase.Tty {
req.Param(api.ExecTTYParam, "1")
} else if len(testCase.Stderr) > 0 {
req.Param(api.ExecStderrParam, "1")
streamErr = localErr
}
if len(testCase.Stdin) > 0 {
req.Param(api.ExecStdinParam, "1")
streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
}

conf := &restclient.Config{
Host: server.URL,
}
transport, upgradeTransport, err := spdy.RoundTripperFor(conf)
if err != nil {
t.Errorf("%s: unexpected error: %v", name, err)
continue
}
e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...)
if err != nil {
t.Errorf("%s: unexpected error: %v", name, err)
continue
}
err = e.Stream(remoteclient.StreamOptions{
Stdin: streamIn,
Stdout: streamOut,
Stderr: streamErr,
Tty: testCase.Tty,
})
hasErr := err != nil

if len(testCase.Error) > 0 {
if !hasErr {
t.Errorf("%s: expected an error", name)
} else {
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
}
if len(testCase.Stdout) > 0 {
req.Param(api.ExecStdoutParam, "1")
streamOut = localOut
}

server.Close()
continue
}
if testCase.Tty {
req.Param(api.ExecTTYParam, "1")
} else if len(testCase.Stderr) > 0 {
req.Param(api.ExecStderrParam, "1")
streamErr = localErr
}

if hasErr {
t.Errorf("%s: unexpected error: %v", name, err)
server.Close()
continue
}
conf := &restclient.Config{
Host: server.URL,
}
transport, upgradeTransport, err := spdy.RoundTripperFor(conf)
if err != nil {
t.Fatalf("%s: unexpected error: %v", name, err)
}
e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...)
if err != nil {
t.Fatalf("%s: unexpected error: %v", name, err)
}
err = e.Stream(remoteclient.StreamOptions{
Stdin: streamIn,
Stdout: streamOut,
Stderr: streamErr,
Tty: testCase.Tty,
})
hasErr := err != nil

if len(testCase.Error) > 0 {
if !hasErr {
t.Errorf("%s: expected an error", name)
} else {
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
}
}
return
}

if len(testCase.Stdout) > 0 {
if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
t.Errorf("%s: expected stdout data %q, got %q", name, e, a)
if hasErr {
t.Fatalf("%s: unexpected error: %v", name, err)
}
}

if testCase.Stderr != "" {
if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
t.Errorf("%s: expected stderr data %q, got %q", name, e, a)
if len(testCase.Stdout) > 0 {
if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
t.Fatalf("%s: expected stdout data %q, got %q", name, e, a)
}
}
}

select {
case <-requestReceived:
case <-time.After(time.Minute):
t.Errorf("%s: expected fakeServerInstance to receive request", name)
}
if testCase.Stderr != "" {
if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
t.Fatalf("%s: expected stderr data %q, got %q", name, e, a)
}
}

server.Close()
select {
case <-requestReceived:
case <-time.After(time.Minute):
t.Errorf("%s: expected fakeServerInstance to receive request", name)
}
})
}
}
}
Expand Down