Skip to content

Commit

Permalink
add: improve server auth
Browse files Browse the repository at this point in the history
  • Loading branch information
coufalja committed Jan 22, 2024
1 parent 8ecac00 commit 84afc1a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 27 deletions.
13 changes: 11 additions & 2 deletions cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync"
"time"

grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/jamf/regatta/cert"
Expand Down Expand Up @@ -41,8 +42,14 @@ func createAPIServer() (*regattaserver.RegattaServer, error) {
addr, secure, net := resolveUrl(viper.GetString("api.address"))
opts := []grpc.ServerOption{
grpc.KeepaliveParams(keepalive.ServerParameters{MaxConnectionAge: 60 * time.Second}),
grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
grpc_auth.StreamServerInterceptor(defaultAuthFunc),
grpc_prometheus.StreamServerInterceptor,
)),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc_prometheus.UnaryServerInterceptor,
grpc_auth.UnaryServerInterceptor(defaultAuthFunc),
)),
}
if secure {
c, err := cert.New(viper.GetString("api.cert-filename"), viper.GetString("api.key-filename"))
Expand Down Expand Up @@ -104,6 +111,8 @@ func authFunc(token string) func(ctx context.Context) (context.Context, error) {
}
}

var defaultAuthFunc = authFunc("")

type tokenCredentials string

func (t tokenCredentials) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) {
Expand Down
17 changes: 9 additions & 8 deletions regattaserver/maintenance.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ type ResetServer struct {
}

func (m *ResetServer) Reset(ctx context.Context, req *regattapb.ResetRequest) (*regattapb.ResetResponse, error) {
ctx, err := m.AuthFunc(ctx)
if err != nil {
return nil, err
}
reset := func(name string) error {
t, err := m.Tables.GetTable(name)
if err != nil {
Expand Down Expand Up @@ -64,6 +60,10 @@ func (m *ResetServer) Reset(ctx context.Context, req *regattapb.ResetRequest) (*
return &regattapb.ResetResponse{}, nil
}

func (m *ResetServer) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) {
return m.AuthFunc(ctx)
}

// BackupServer implements some Maintenance service methods from proto/regatta.proto.
type BackupServer struct {
regattapb.UnimplementedMaintenanceServer
Expand All @@ -72,10 +72,7 @@ type BackupServer struct {
}

func (m *BackupServer) Backup(req *regattapb.BackupRequest, srv regattapb.Maintenance_BackupServer) error {
ctx, err := m.AuthFunc(srv.Context())
if err != nil {
return err
}
ctx := srv.Context()
table, err := m.Tables.GetTable(string(req.Table))
if err != nil {
return err
Expand Down Expand Up @@ -149,6 +146,10 @@ func (m *BackupServer) Restore(srv regattapb.Maintenance_RestoreServer) error {
return srv.SendAndClose(&regattapb.RestoreResponse{})
}

func (m *BackupServer) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) {
return m.AuthFunc(ctx)
}

type backupReader struct {
stream regattapb.Maintenance_RestoreServer
}
Expand Down
16 changes: 4 additions & 12 deletions regattaserver/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ type TablesServer struct {
}

func (t *TablesServer) Create(ctx context.Context, req *regattapb.CreateTableRequest) (*regattapb.CreateTableResponse, error) {
_, err := t.AuthFunc(ctx)
if err != nil {
return nil, err
}
if len(req.Name) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "name must be set")
}
Expand All @@ -40,10 +36,6 @@ func (t *TablesServer) Create(ctx context.Context, req *regattapb.CreateTableReq
}

func (t *TablesServer) Delete(ctx context.Context, req *regattapb.DeleteTableRequest) (*regattapb.DeleteTableResponse, error) {
_, err := t.AuthFunc(ctx)
if err != nil {
return nil, err
}
if len(req.Name) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "name must be set")
}
Expand All @@ -57,10 +49,6 @@ func (t *TablesServer) Delete(ctx context.Context, req *regattapb.DeleteTableReq
}

func (t *TablesServer) List(ctx context.Context, _ *regattapb.ListTablesRequest) (*regattapb.ListTablesResponse, error) {
_, err := t.AuthFunc(ctx)
if err != nil {
return nil, err
}
ts, err := t.Tables.GetTables()
if err != nil {
if serrors.IsSafeToRetry(err) {
Expand All @@ -81,6 +69,10 @@ func (t *TablesServer) List(ctx context.Context, _ *regattapb.ListTablesRequest)
return resp, nil
}

func (t *TablesServer) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) {
return t.AuthFunc(ctx)
}

type ReadonlyTablesServer struct {
TablesServer
}
Expand Down
5 changes: 0 additions & 5 deletions regattaserver/tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,6 @@ func TestTablesServer_List(t *testing.T) {
},
}},
},
{
name: "deny all",
fields: fields{AuthFunc: denyAll},
wantErr: require.Error,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit 84afc1a

Please sign in to comment.