-
Notifications
You must be signed in to change notification settings - Fork 0
/
worker_interceptor.go
70 lines (60 loc) · 2.33 KB
/
worker_interceptor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
package auth
import (
"context"
"fmt"
"net/http"
rbacv1 "github.com/llm-operator/rbac-manager/api/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
// WorkerConfig is the configuration for a WorkerInterceptor.
type WorkerConfig struct {
RBACServerAddr string
}
// NewWorkerInterceptor creates a new WorkerInterceptor.
func NewWorkerInterceptor(ctx context.Context, c WorkerConfig) (*WorkerInterceptor, error) {
conn, err := grpc.DialContext(ctx, c.RBACServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
return &WorkerInterceptor{client: rbacv1.NewRbacInternalServiceClient(conn)}, nil
}
// WorkerInterceptor is an authentication interceptor for requests from worker clusters.
type WorkerInterceptor struct {
client rbacv1.RbacInternalServiceClient
}
// Unary returns a unary server interceptor.
func (a *WorkerInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
token, err := ExtractTokenFromContext(ctx)
if err != nil {
return nil, err
}
aresp, err := a.client.AuthorizeWorker(ctx, &rbacv1.AuthorizeWorkerRequest{Token: token})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to authorize: %v", err)
}
if !aresp.Authorized {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
ctx = AppendClusterInfoToContext(ctx, newClusterInfoFromAuthorizeResponse(aresp))
return handler(ctx, req)
}
}
// InterceptHTTPRequest intercepts an HTTP request and returns an HTTP status code.
func (a *WorkerInterceptor) InterceptHTTPRequest(req *http.Request) (int, ClusterInfo, error) {
token, found := extractTokenFromHeader(req.Header)
if !found {
return http.StatusUnauthorized, ClusterInfo{}, fmt.Errorf("missing authorization")
}
aresp, err := a.client.AuthorizeWorker(req.Context(), &rbacv1.AuthorizeWorkerRequest{Token: token})
if err != nil {
return http.StatusInternalServerError, ClusterInfo{}, fmt.Errorf("failed to authorize: %v", err)
}
if !aresp.Authorized {
return http.StatusUnauthorized, ClusterInfo{}, fmt.Errorf("permission denied")
}
return http.StatusOK, newClusterInfoFromAuthorizeResponse(aresp), nil
}