Skip to content

Commit

Permalink
chore: check grpc peer info for download service (#2385)
Browse files Browse the repository at this point in the history
Signed-off-by: Jim Ma <majinjing3@gmail.com>
  • Loading branch information
jim3ma committed May 25, 2023
1 parent c275421 commit e4deba0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
13 changes: 13 additions & 0 deletions client/daemon/rpcserver/rpcserver.go
Expand Up @@ -379,6 +379,19 @@ func (s *server) CheckHealth(context.Context, *emptypb.Empty) (*emptypb.Empty, e
func (s *server) Download(req *dfdaemonv1.DownRequest, stream dfdaemonv1.Daemon_DownloadServer) error {
s.Keep()
ctx := stream.Context()
pr, ok := grpcpeer.FromContext(ctx)

if !ok {
return status.Error(codes.FailedPrecondition, "invalid grpc peer info")
}

// currently, we only use daemon to download file via unix domain socket
if pr.Addr.Network() != "unix" {
err := fmt.Sprintf("invalid incoming source: %v", pr.Addr.String())
logger.Errorf(err)
return status.Error(codes.Unauthenticated, err)
}

if req.Recursive {
return s.recursiveDownload(ctx, req, stream)
}
Expand Down
42 changes: 28 additions & 14 deletions client/daemon/rpcserver/rpcserver_test.go
Expand Up @@ -18,15 +18,16 @@ package rpcserver

import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"path"
"sync"
"testing"

"github.com/distribution/distribution/v3/uuid"
"github.com/golang/mock/gomock"
"github.com/phayes/freeport"
testifyassert "github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
Expand Down Expand Up @@ -679,7 +680,13 @@ func TestServer_ServeDownload(t *testing.T) {
peerHost: &schedulerv1.PeerHost{},
peerTaskManager: mockPeerTaskManager,
}
client := setupPeerServerAndClient(t, s, assert, s.ServeDownload)

socketDir, err := ioutil.TempDir(os.TempDir(), "d7y-test-***")
assert.Nil(err, "make temp dir should be ok")
socketPath := path.Join(socketDir, "rpc.sock")
defer os.RemoveAll(socketDir)

client := setupPeerServerAndClient(t, socketPath, s, assert, s.ServeDownload)
request := &dfdaemonv1.DownRequest{
Uuid: uuid.Generate().String(),
Url: "http://localhost/test",
Expand Down Expand Up @@ -745,7 +752,13 @@ func TestServer_ServePeer(t *testing.T) {
peerHost: &schedulerv1.PeerHost{},
storageManager: mockStorageManger,
}
client := setupPeerServerAndClient(t, s, assert, s.ServePeer)

socketDir, err := ioutil.TempDir(os.TempDir(), "d7y-test-***")
assert.Nil(err, "make temp dir should be ok")
socketPath := path.Join(socketDir, "rpc.sock")
defer os.RemoveAll(socketDir)

client := setupPeerServerAndClient(t, socketPath, s, assert, s.ServePeer)
defer s.peerServer.GracefulStop()

var tests = []struct {
Expand Down Expand Up @@ -1077,7 +1090,12 @@ func TestServer_SyncPieceTasks(t *testing.T) {
peerTaskManager: mockTaskManager,
}

client := setupPeerServerAndClient(t, s, assert, s.ServePeer)
socketDir, err := ioutil.TempDir(os.TempDir(), "d7y-test-***")
assert.Nil(err, "make temp dir should be ok")
socketPath := path.Join(socketDir, "rpc.sock")
defer os.RemoveAll(socketDir)

client := setupPeerServerAndClient(t, socketPath, s, assert, s.ServePeer)
syncClient, err := client.SyncPieceTasks(
context.Background(),
&commonv1.PieceTaskRequest{
Expand Down Expand Up @@ -1141,28 +1159,24 @@ func TestServer_SyncPieceTasks(t *testing.T) {
}
}

func setupPeerServerAndClient(t *testing.T, srv *server, assert *testifyassert.Assertions, serveFunc func(listener net.Listener) error) dfdaemonclient.V1 {
func setupPeerServerAndClient(t *testing.T, socket string, srv *server, assert *testifyassert.Assertions, serveFunc func(listener net.Listener) error) dfdaemonclient.V1 {
if srv.healthServer == nil {
srv.healthServer = health.NewServer()
}
srv.downloadServer = dfdaemonserver.New(srv, srv.healthServer)
srv.peerServer = dfdaemonserver.New(srv, srv.healthServer)
port, err := freeport.GetFreePort()
if err != nil {
t.Fatal(err)
}

ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
assert.Nil(err, "get free port should be ok")
ln, err := net.Listen("unix", socket)
assert.Nil(err, "listen unix socket should be ok")
go func() {
if err := serveFunc(ln); err != nil {
t.Error(err)
}
}()

netAddr := &dfnet.NetAddr{
Type: dfnet.TCP,
Addr: fmt.Sprintf(":%d", port),
Type: dfnet.UNIX,
Addr: socket,
}
client, err := dfdaemonclient.GetInsecureV1(context.Background(), netAddr.String())
assert.Nil(err, "grpc dial should be ok")
Expand Down

0 comments on commit e4deba0

Please sign in to comment.