From 79ae8cea0271b0adfe4c9fe5179922bc9a4c3fbb Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy <84944216+helsaawy@users.noreply.github.com> Date: Mon, 22 Aug 2022 11:09:02 -0400 Subject: [PATCH] Added HV Socket tests (#240) * Added HV Socket tests Added tests for core Hyper-V socket functionality, including testing CloseRead and CloseWrite, as well as checking addresses are appropriate and timeouts work. Added fuzzing test to check for edge case read/write issues. Signed-off-by: Hamza El-Saawy * pr: asserts, naming, fatal in test Signed-off-by: Hamza El-Saawy Signed-off-by: Hamza El-Saawy --- .gitignore | 3 + go.mod | 2 + go.sum | 2 + hvsock_test.go | 738 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 745 insertions(+) create mode 100644 hvsock_test.go diff --git a/.gitignore b/.gitignore index 9d428772..815e2066 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ *.exe +# testing +testdata + # go workspaces go.work go.work.sum diff --git a/go.mod b/go.mod index 00fe52c3..715f4b85 100644 --- a/go.mod +++ b/go.mod @@ -6,3 +6,5 @@ require ( github.com/sirupsen/logrus v1.7.0 golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 ) + +require golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect diff --git a/go.sum b/go.sum index 0425d53b..7ec6d7ce 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16CMAGuqwO2lX1mTyyRRc= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/hvsock_test.go b/hvsock_test.go new file mode 100644 index 00000000..0c84a2c6 --- /dev/null +++ b/hvsock_test.go @@ -0,0 +1,738 @@ +//go:build windows + +package winio + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "strings" + "testing" + "time" + + "golang.org/x/sys/windows" + + "github.com/Microsoft/go-winio/internal/socket" + "github.com/Microsoft/go-winio/pkg/guid" +) + +const testStr = "test" + +func randHvsockAddr() *HvsockAddr { + p := rand.Uint32() //nolint:gosec // used for testing + return &HvsockAddr{ + VMID: HvsockGUIDLoopback(), + ServiceID: VsockServiceID(p), + } +} + +func serverListen(u testUtil) (l *HvsockListener, a *HvsockAddr) { + var err error + for i := 0; i < 3; i++ { + a = randHvsockAddr() + l, err = ListenHvsock(a) + if errors.Is(err, windows.WSAEADDRINUSE) { + u.T.Logf("address collision %v", a) + continue + } + break + } + u.Must(err, "could not listen") + u.T.Cleanup(func() { + if l != nil { + u.Must(l.Close(), "Hyper-V socket listener close") + } + }) + + return l, a +} + +func clientServer(u testUtil) (cl, sv *HvsockConn, _ *HvsockAddr) { + l, addr := serverListen(u) + ch := u.Go(func() error { + conn, err := l.Accept() + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } + sv = conn.(*HvsockConn) + if err := l.Close(); err != nil { + return err + } + l = nil + return nil + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cl, err := Dial(ctx, addr) + u.Must(err, "could not dial") + u.T.Cleanup(func() { + if cl != nil { + u.Must(cl.Close(), "client close") + } + }) + + u.WaitErr(ch, time.Second) + u.T.Cleanup(func() { + if sv != nil { + u.Must(sv.Close(), "server close") + } + }) + return cl, sv, addr +} + +func TestHvSockConstants(t *testing.T) { + tests := []struct { + name string + want string + give guid.GUID + }{ + {"wildcard", "00000000-0000-0000-0000-000000000000", HvsockGUIDWildcard()}, + {"broadcast", "ffffffff-ffff-ffff-ffff-ffffffffffff", HvsockGUIDBroadcast()}, + {"loopback", "e0e16197-dd56-4a10-9195-5ee7a155a838", HvsockGUIDLoopback()}, + {"children", "90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd", HvsockGUIDChildren()}, + {"parent", "a42e7cda-d03f-480c-9cc2-a4de20abb878", HvsockGUIDParent()}, + {"silohost", "36bd0c5c-7276-4223-88ba-7d03b654c568", HvsockGUIDSiloHost()}, + {"vsock template", "00000000-facb-11e6-bd58-64006a7986d3", hvsockVsockServiceTemplate()}, + } + for _, tt := range tests { + if tt.give.String() != tt.want { + t.Errorf("%s give: %v; want: %s", tt.name, tt.give, tt.want) + } + } +} + +func TestHvSockListenerAddresses(t *testing.T) { + u := newUtil(t) + l, addr := serverListen(u) + + la := (l.Addr()).(*HvsockAddr) + u.Assert(*la == *addr, fmt.Sprintf("give: %v; want: %v", la, addr)) + + ra := rawHvsockAddr{} + sa := HvsockAddr{} + u.Must(socket.GetSockName(windows.Handle(l.sock.handle), &ra)) + sa.fromRaw(&ra) + u.Assert(sa == *addr, fmt.Sprintf("listener local addr give: %v; want: %v", sa, addr)) +} + +func TestHvSockAddresses(t *testing.T) { + u := newUtil(t) + cl, sv, addr := clientServer(u) + + sra := (sv.RemoteAddr()).(*HvsockAddr) + sla := (sv.LocalAddr()).(*HvsockAddr) + cra := (cl.RemoteAddr()).(*HvsockAddr) + cla := (cl.LocalAddr()).(*HvsockAddr) + + t.Run("Info", func(t *testing.T) { + tests := []struct { + name string + give *HvsockAddr + want HvsockAddr + }{ + {"client local", cla, HvsockAddr{HvsockGUIDChildren(), sra.ServiceID}}, + {"client remote", cra, *addr}, + {"server local", sla, HvsockAddr{HvsockGUIDChildren(), addr.ServiceID}}, + {"server remote", sra, HvsockAddr{HvsockGUIDLoopback(), cla.ServiceID}}, + } + for _, tt := range tests { + if *tt.give != tt.want { + t.Errorf("%s address give: %v; want: %v", tt.name, tt.give, tt.want) + } + } + }) + + t.Run("OSinfo", func(t *testing.T) { + u := newUtil(t) + ra := rawHvsockAddr{} + sa := HvsockAddr{} + + localTests := []struct { + name string + giveSock *win32File + wantAddr HvsockAddr + }{ + {"client", cl.sock, HvsockAddr{HvsockGUIDChildren(), cla.ServiceID}}, + // The server sockets local address seems arbitrary, so skip this test + // see comment in `(*HvsockListener) Accept()` for more info + // {"server", sv.sock, _sla}, + } + for _, tt := range localTests { + u.Must(socket.GetSockName(windows.Handle(tt.giveSock.handle), &ra)) + sa.fromRaw(&ra) + if sa != tt.wantAddr { + t.Errorf("%s local addr give: %v; want: %v", tt.name, sa, tt.wantAddr) + } + } + + remoteTests := []struct { + name string + giveConn *HvsockConn + }{ + {"client", cl}, + {"server", sv}, + } + for _, tt := range remoteTests { + u.Must(socket.GetPeerName(windows.Handle(tt.giveConn.sock.handle), &ra)) + sa.fromRaw(&ra) + if sa != tt.giveConn.remote { + t.Errorf("%s remote addr give: %v; want: %v", tt.name, sa, tt.giveConn.remote) + } + } + }) +} + +func TestHvSockReadWrite(t *testing.T) { + u := newUtil(t) + l, addr := serverListen(u) + tests := []struct { + req, rsp string + }{ + {"hello ", "world!"}, + {"ping", "pong"}, + } + + // a sync.WaitGroup doesnt offer a channel to use in a select with a timeout + // could use an errgroup.Group, but for now dual channels work fine + svCh := u.Go(func() error { + c, err := l.Accept() + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } + defer c.Close() + + b := make([]byte, 64) + for _, tt := range tests { + n, err := c.Read(b) + if err != nil { + return fmt.Errorf("server rx: %w", err) + } + + r := string(b[:n]) + if r != tt.req { + return fmt.Errorf("server rx error: got %q; wanted %q", r, tt.req) + } + if _, err = c.Write([]byte(tt.rsp)); err != nil { + return fmt.Errorf("server tx error, could not send %q: %w", tt.rsp, err) + } + } + n, err := c.Read(b) + if n != 0 { + return errors.New("server did not get EOF") + } + if !errors.Is(err, io.EOF) { + return fmt.Errorf("server did not get EOF: %w", err) + } + return nil + }) + + clCh := u.Go(func() error { + cl, err := Dial(context.Background(), addr) + if err != nil { + return fmt.Errorf("client dial: %w", err) + } + defer cl.Close() + + b := make([]byte, 64) + for _, tt := range tests { + _, err := cl.Write([]byte(tt.req)) + if err != nil { + return fmt.Errorf("client tx error, could not send %q: %w", tt.req, err) + } + + n, err := cl.Read(b) + if err != nil { + return fmt.Errorf("client tx: %w", err) + } + + r := string(b[:n]) + if r != tt.rsp { + return fmt.Errorf("client rx error: got %q; wanted %q", b[:n], tt.rsp) + } + } + return cl.CloseWrite() + }) + + u.WaitErr(svCh, 15*time.Second, "server") + u.WaitErr(clCh, 15*time.Second, "client") +} + +func TestHvSockReadTooSmall(t *testing.T) { + u := newUtil(t) + s := "this is a really long string that hopefully takes up more than 16 bytes ..." + l, addr := serverListen(u) + + svCh := u.Go(func() error { + c, err := l.Accept() + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } + defer c.Close() + + b := make([]byte, 16) + ss := "" + for { + n, err := c.Read(b) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("server rx: %w", err) + } + ss += string(b[:n]) + } + + if ss != s { + return fmt.Errorf("got %q, wanted: %q", ss, s) + } + return nil + }) + + clCh := u.Go(func() error { + cl, err := Dial(context.Background(), addr) + if err != nil { + return fmt.Errorf("client dial: %w", err) + } + defer cl.Close() + + if _, err = cl.Write([]byte(s)); err != nil { + return fmt.Errorf("client tx error, could not send: %w", err) + } + return nil + }) + + u.WaitErr(svCh, 15*time.Second, "server") + u.WaitErr(clCh, 15*time.Second, "client") +} + +func TestHvSockCloseReadWriteListener(t *testing.T) { + u := newUtil(t) + l, addr := serverListen(u) + + ch := make(chan struct{}) + svCh := u.Go(func() error { + defer close(ch) + c, err := l.Accept() + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } + defer c.Close() + + hv := c.(*HvsockConn) + // + // test CloseWrite() + // + n, err := c.Write([]byte(testStr)) + if err != nil { + return fmt.Errorf("server tx: %w", err) + } + if n != len(testStr) { + return fmt.Errorf("server wrote %d bytes, wanted %d", n, len(testStr)) + } + + if err := hv.CloseWrite(); err != nil { + return fmt.Errorf("server close write: %w", err) + } + + if _, err = c.Write([]byte(testStr)); !errors.Is(err, windows.WSAESHUTDOWN) { + return fmt.Errorf("server did not shutdown writes: %w", err) + } + // safe to call multiple times + if err := hv.CloseWrite(); err != nil { + return fmt.Errorf("server second close write: %w", err) + } + + // + // test CloseRead() + // + b := make([]byte, 256) + n, err = c.Read(b) + if err != nil { + return fmt.Errorf("server read: %w", err) + } + if n != len(testStr) { + return fmt.Errorf("server read %d bytes, wanted %d", n, len(testStr)) + } + if string(b[:n]) != testStr { + return fmt.Errorf("server got %q; wanted %q", b[:n], testStr) + } + if err := hv.CloseRead(); err != nil { + return fmt.Errorf("server close read: %w", err) + } + + ch <- struct{}{} + + // signal the client to send more info + // if it was sent before, the read would succeed if the data was buffered prior + _, err = c.Read(b) + if !errors.Is(err, windows.WSAESHUTDOWN) { + return fmt.Errorf("server did not shutdown reads: %w", err) + } + // safe to call multiple times + if err := hv.CloseRead(); err != nil { + return fmt.Errorf("server second close read: %w", err) + } + + c.Close() + if err := hv.CloseWrite(); !errors.Is(err, socket.ErrSocketClosed) { + return fmt.Errorf("server close write: %w", err) + } + if err := hv.CloseRead(); !errors.Is(err, socket.ErrSocketClosed) { + return fmt.Errorf("server close read: %w", err) + } + return nil + }) + + cl, err := Dial(context.Background(), addr) + u.Must(err, "could not dial") + defer cl.Close() + + b := make([]byte, 256) + n, err := cl.Read(b) + u.Must(err, "client read") + u.Assert(n == len(testStr), fmt.Sprintf("client read %d bytes, wanted %d", n, len(testStr))) + u.Assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr)) + + n, err = cl.Read(b) + u.Assert(n == 0, "client did not get EOF") + u.Is(err, io.EOF, "client did not get EOF") + + n, err = cl.Write([]byte(testStr)) + u.Must(err, "client write") + u.Assert(n == len(testStr), fmt.Sprintf("client wrote %d bytes, wanted %d", n, len(testStr))) + + u.Wait(ch, time.Second) + + // this should succeed + _, err = cl.Write([]byte("test2")) + u.Must(err, "client write") + u.WaitErr(svCh, time.Second, "server") +} + +func TestHvSockCloseReadWriteDial(t *testing.T) { + u := newUtil(t) + l, addr := serverListen(u) + + ch := make(chan struct{}) + clCh := u.Go(func() error { + defer close(ch) + c, err := l.Accept() + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } + defer c.Close() + + b := make([]byte, 256) + n, err := c.Read(b) + if err != nil { + return fmt.Errorf("server read: %w", err) + } + if string(b[:n]) != testStr { + return fmt.Errorf("server got %q; wanted %q", b[:n], testStr) + } + + n, err = c.Read(b) + if n != 0 { + return fmt.Errorf("server did not get EOF") + } + if !errors.Is(err, io.EOF) { + return errors.New("server did not get EOF") + } + + _, err = c.Write([]byte(testStr)) + if err != nil { + return fmt.Errorf("server tx: %w", err) + } + + ch <- struct{}{} + + _, err = c.Write([]byte(testStr)) + if err != nil { + return fmt.Errorf("server tx: %w", err) + } + return c.Close() + }) + + cl, err := Dial(context.Background(), addr) + u.Must(err, "could not dial") + defer cl.Close() + + // + // test CloseWrite() + // + _, err = cl.Write([]byte(testStr)) + u.Must(err, "client write") + u.Must(cl.CloseWrite(), "client close write") + + _, err = cl.Write([]byte(testStr)) + u.Is(err, windows.WSAESHUTDOWN, "client did not shutdown writes") + + // safe to call multiple times + u.Must(cl.CloseWrite(), "client second close write") + + // + // test CloseRead() + // + b := make([]byte, 256) + n, err := cl.Read(b) + u.Must(err, "client read") + u.Assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr)) + u.Must(cl.CloseRead(), "client close read") + + u.Wait(ch, time.Millisecond) + + // signal the client to send more info + // if it was sent before, the read would succeed if the data was buffered prior + _, err = cl.Read(b) + u.Is(err, windows.WSAESHUTDOWN, "client did not shutdown reads") + + // safe to call multiple times + u.Must(cl.CloseRead(), "client second close write") + + l.Close() + cl.Close() + + wantErr := socket.ErrSocketClosed + u.Is(cl.CloseWrite(), wantErr, "client close write") + u.Is(cl.CloseRead(), wantErr, "client close read") + u.WaitErr(clCh, time.Second, "client") +} + +func TestHvSockDialNoTimeout(t *testing.T) { + u := newUtil(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := u.Go(func() error { + addr := randHvsockAddr() + cl, err := Dial(ctx, addr) + if err == nil { + cl.Close() + } + if !errors.Is(err, windows.WSAECONNREFUSED) { + return err + } + return nil + }) + + // connections usually take about ~500µs + u.WaitErr(ch, 2*time.Millisecond, "dial did not time out") +} + +func TestHvSockDialDeadline(t *testing.T) { + u := newUtil(t) + d := &HvsockDialer{} + d.Deadline = time.Now().Add(50 * time.Microsecond) + d.Retries = 1 + // we need the wait time to be long enough for the deadline goroutine to run first and signal + // timeout + d.RetryWait = 100 * time.Millisecond + addr := randHvsockAddr() + cl, err := d.Dial(context.Background(), addr) + if err == nil { + cl.Close() + t.Fatalf("dial should not have finished") + } + u.Is(err, context.DeadlineExceeded, "dial did not exceed deadline") +} + +func TestHvSockDialContext(t *testing.T) { + u := newUtil(t) + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(50*time.Microsecond, cancel) + + d := &HvsockDialer{} + d.Retries = 1 + d.RetryWait = 100 * time.Millisecond + addr := randHvsockAddr() + cl, err := d.Dial(ctx, addr) + if err == nil { + cl.Close() + t.Fatalf("dial should not have finished") + } + u.Is(err, context.Canceled, "dial was not canceled") +} + +func TestHvSockAcceptClose(t *testing.T) { + u := newUtil(t) + l, _ := serverListen(u) + go func() { + time.Sleep(50 * time.Millisecond) + l.Close() + }() + + c, err := l.Accept() + if err == nil { + c.Close() + t.Fatal("listener should not have accepted anything") + } + u.Is(err, ErrFileClosed) +} + +func FuzzHvSockRxTx(f *testing.F) { + for _, b := range [][]byte{ + []byte("hello?"), + []byte("This is a really long string that should be a good example of the really long " + + "payloads that may be sent over hvsockets when really long inputs are being used, tautologically. " + + "That means that we will have to test with really long input sequences, which means that " + + "we need to include really long byte sequences or strings in our testing so that we know that " + + "the sockets can deal with really long inputs. Look at this key mashing: " + + "sdflhsdfgkjdhskljjsad;kljfasd;lfkjsadl ;fasdjfopiwej09q34iur092\"i4o[piwajfliasdkf-012ior]-" + + "01oi3;'lSD