From e4deba077dae63f1fae31d4991f9e837f92f1c25 Mon Sep 17 00:00:00 2001 From: Jim Ma Date: Thu, 25 May 2023 17:28:07 +0800 Subject: [PATCH] chore: check grpc peer info for download service (#2385) Signed-off-by: Jim Ma --- client/daemon/rpcserver/rpcserver.go | 13 +++++++ client/daemon/rpcserver/rpcserver_test.go | 42 +++++++++++++++-------- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/client/daemon/rpcserver/rpcserver.go b/client/daemon/rpcserver/rpcserver.go index 48e3a20a75e..40d0cee9016 100644 --- a/client/daemon/rpcserver/rpcserver.go +++ b/client/daemon/rpcserver/rpcserver.go @@ -379,6 +379,19 @@ func (s *server) CheckHealth(context.Context, *emptypb.Empty) (*emptypb.Empty, e func (s *server) Download(req *dfdaemonv1.DownRequest, stream dfdaemonv1.Daemon_DownloadServer) error { s.Keep() ctx := stream.Context() + pr, ok := grpcpeer.FromContext(ctx) + + if !ok { + return status.Error(codes.FailedPrecondition, "invalid grpc peer info") + } + + // currently, we only use daemon to download file via unix domain socket + if pr.Addr.Network() != "unix" { + err := fmt.Sprintf("invalid incoming source: %v", pr.Addr.String()) + logger.Errorf(err) + return status.Error(codes.Unauthenticated, err) + } + if req.Recursive { return s.recursiveDownload(ctx, req, stream) } diff --git a/client/daemon/rpcserver/rpcserver_test.go b/client/daemon/rpcserver/rpcserver_test.go index a0df6472708..3120aeb7e9d 100644 --- a/client/daemon/rpcserver/rpcserver_test.go +++ b/client/daemon/rpcserver/rpcserver_test.go @@ -18,15 +18,16 @@ package rpcserver import ( "context" - "fmt" "io" + "io/ioutil" "net" + "os" + "path" "sync" "testing" "github.com/distribution/distribution/v3/uuid" "github.com/golang/mock/gomock" - "github.com/phayes/freeport" testifyassert "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/health" @@ -679,7 +680,13 @@ func TestServer_ServeDownload(t *testing.T) { peerHost: &schedulerv1.PeerHost{}, peerTaskManager: mockPeerTaskManager, } - client := setupPeerServerAndClient(t, s, assert, s.ServeDownload) + + socketDir, err := ioutil.TempDir(os.TempDir(), "d7y-test-***") + assert.Nil(err, "make temp dir should be ok") + socketPath := path.Join(socketDir, "rpc.sock") + defer os.RemoveAll(socketDir) + + client := setupPeerServerAndClient(t, socketPath, s, assert, s.ServeDownload) request := &dfdaemonv1.DownRequest{ Uuid: uuid.Generate().String(), Url: "http://localhost/test", @@ -745,7 +752,13 @@ func TestServer_ServePeer(t *testing.T) { peerHost: &schedulerv1.PeerHost{}, storageManager: mockStorageManger, } - client := setupPeerServerAndClient(t, s, assert, s.ServePeer) + + socketDir, err := ioutil.TempDir(os.TempDir(), "d7y-test-***") + assert.Nil(err, "make temp dir should be ok") + socketPath := path.Join(socketDir, "rpc.sock") + defer os.RemoveAll(socketDir) + + client := setupPeerServerAndClient(t, socketPath, s, assert, s.ServePeer) defer s.peerServer.GracefulStop() var tests = []struct { @@ -1077,7 +1090,12 @@ func TestServer_SyncPieceTasks(t *testing.T) { peerTaskManager: mockTaskManager, } - client := setupPeerServerAndClient(t, s, assert, s.ServePeer) + socketDir, err := ioutil.TempDir(os.TempDir(), "d7y-test-***") + assert.Nil(err, "make temp dir should be ok") + socketPath := path.Join(socketDir, "rpc.sock") + defer os.RemoveAll(socketDir) + + client := setupPeerServerAndClient(t, socketPath, s, assert, s.ServePeer) syncClient, err := client.SyncPieceTasks( context.Background(), &commonv1.PieceTaskRequest{ @@ -1141,19 +1159,15 @@ func TestServer_SyncPieceTasks(t *testing.T) { } } -func setupPeerServerAndClient(t *testing.T, srv *server, assert *testifyassert.Assertions, serveFunc func(listener net.Listener) error) dfdaemonclient.V1 { +func setupPeerServerAndClient(t *testing.T, socket string, srv *server, assert *testifyassert.Assertions, serveFunc func(listener net.Listener) error) dfdaemonclient.V1 { if srv.healthServer == nil { srv.healthServer = health.NewServer() } srv.downloadServer = dfdaemonserver.New(srv, srv.healthServer) srv.peerServer = dfdaemonserver.New(srv, srv.healthServer) - port, err := freeport.GetFreePort() - if err != nil { - t.Fatal(err) - } - ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) - assert.Nil(err, "get free port should be ok") + ln, err := net.Listen("unix", socket) + assert.Nil(err, "listen unix socket should be ok") go func() { if err := serveFunc(ln); err != nil { t.Error(err) @@ -1161,8 +1175,8 @@ func setupPeerServerAndClient(t *testing.T, srv *server, assert *testifyassert.A }() netAddr := &dfnet.NetAddr{ - Type: dfnet.TCP, - Addr: fmt.Sprintf(":%d", port), + Type: dfnet.UNIX, + Addr: socket, } client, err := dfdaemonclient.GetInsecureV1(context.Background(), netAddr.String()) assert.Nil(err, "grpc dial should be ok")