diff --git a/domainfronted_test.go b/domainfronted_test.go index 75aa4bd..aba8881 100644 --- a/domainfronted_test.go +++ b/domainfronted_test.go @@ -2,6 +2,7 @@ package domainfronted import ( "io/ioutil" + "net" "net/http" "strconv" "strings" @@ -43,38 +44,63 @@ func TestHttpClientWithBadEnproxyConn(t *testing.T) { assert.Error(t, err, "HttpClient using a non-existent host should have failed") } -func TestRoundTrip(t *testing.T) { +func TestBadPKFile(t *testing.T) { server := &Server{ Addr: "localhost:0", - AllowNonGlobalDestinations: true, CertContext: &CertContext{ - PKFile: "testpk.pem", + PKFile: "", ServerCertFile: "testcert.pem", }, } - l, err := server.Listen() + _, err := server.Listen() + assert.Error(t, err, "Listen should have failed") +} + +func TestBadCertificateFile(t *testing.T) { + server := &Server{ + Addr: "localhost:0", + CertContext: &CertContext{ + PKFile: "testpk.pem", + ServerCertFile: "", + }, + } + _, err := server.Listen() + assert.Error(t, err, "Listen should have failed") +} + +func TestNonGlobalAddress(t *testing.T) { + l := startServer(t, false) + client := clientFor(t, l) + defer client.Close() + + gotConn := false + var gotConnMutex sync.Mutex + tl, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Unable to listen: %s", err) } go func() { - err = server.Serve(l) - if err != nil { - t.Fatalf("Unable to serve: %s", err) - } + tl.Accept() + gotConnMutex.Lock() + gotConn = true + gotConnMutex.Unlock() }() - addrParts := strings.Split(l.Addr().String(), ":") - host := addrParts[0] - port, err := strconv.Atoi(addrParts[1]) - if err != nil { - t.Fatalf("Unable to parse port: %s", err) - } + conn, err := client.Dial("tcp", l.Addr().String()) + defer conn.Close() - client := NewClient(&ClientConfig{ - Host: host, - Port: port, - InsecureSkipVerify: true, - }) + data := []byte("Some Meaningless Data") + conn.Write(data) + // Give enproxy time to flush + time.Sleep(500 * time.Millisecond) + _, err = conn.Write(data) + assert.Error(t, err, "Sending data after previous attempt to write to local address should have failed") + assert.False(t, gotConn, "Sending data to local address should never have resulted in connection") +} + +func TestRoundTrip(t *testing.T) { + l := startServer(t, true) + client := clientFor(t, l) defer client.Close() proxy.Test(t, client) @@ -161,3 +187,40 @@ func TestIntegration(t *testing.T) { assert.NotEqual(t, time.Duration(0), actualConnectTime, "Should have received a connectTime") assert.NotEqual(t, time.Duration(0), actualHandshakeTime, "Should have received a handshakeTime") } + +func startServer(t *testing.T, allowNonGlobal bool) net.Listener { + server := &Server{ + Addr: "localhost:0", + AllowNonGlobalDestinations: allowNonGlobal, + CertContext: &CertContext{ + PKFile: "testpk.pem", + ServerCertFile: "testcert.pem", + }, + } + l, err := server.Listen() + if err != nil { + t.Fatalf("Unable to listen: %s", err) + } + go func() { + err = server.Serve(l) + if err != nil { + t.Fatalf("Unable to serve: %s", err) + } + }() + return l +} + +func clientFor(t *testing.T, l net.Listener) *Client { + addrParts := strings.Split(l.Addr().String(), ":") + host := addrParts[0] + port, err := strconv.Atoi(addrParts[1]) + if err != nil { + t.Fatalf("Unable to parse port: %s", err) + } + + return NewClient(&ClientConfig{ + Host: host, + Port: port, + InsecureSkipVerify: true, + }) +}