-
Notifications
You must be signed in to change notification settings - Fork 323
/
oauth.go
104 lines (87 loc) · 2.57 KB
/
oauth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package graph
import (
"context"
"fmt"
"github.com/99designs/gqlgen/graphql"
"github.com/highlight-run/highlight/backend/model"
"github.com/highlight-run/highlight/backend/store"
"github.com/highlight-run/highlight/backend/util"
e "github.com/pkg/errors"
"github.com/samber/lo"
"go.opentelemetry.io/otel/trace"
"sync"
"time"
)
type OAuthValidator interface {
graphql.HandlerExtension
graphql.FieldInterceptor
}
type Tracer struct {
store *store.Store
clientRequests map[string]uint64
clientRequestsMutex sync.RWMutex
}
func NewGraphqlOAuthValidator(store *store.Store) OAuthValidator {
tracer := Tracer{
store: store,
clientRequests: make(map[string]uint64),
}
go tracer.flushClientRequests()
return &tracer
}
func (t *Tracer) flushClientRequests() {
ctx := context.Background()
for {
t.clientRequestsMutex.RLock()
for clientID, requests := range t.clientRequests {
span, _ := util.StartSpanFromContext(ctx, "private.oauth.count", util.WithSpanKind(trace.SpanKindServer), util.Tag("client_id", clientID), util.Tag("requests", requests))
span.Finish()
}
t.clientRequestsMutex.RUnlock()
t.clientRequestsMutex.Lock()
t.clientRequests = make(map[string]uint64)
t.clientRequestsMutex.Unlock()
time.Sleep(5 * time.Second)
}
}
func (t *Tracer) ExtensionName() string {
return "HighlightOAuthValidator"
}
func (t *Tracer) Validate(graphql.ExecutableSchema) error {
return nil
}
func (t *Tracer) InterceptField(ctx context.Context, next graphql.Resolver) (interface{}, error) {
clientID, _ := ctx.Value(model.ContextKeys.OAuthClientID).(string)
if clientID == "" {
return next(ctx)
}
span, ctx := util.StartSpanFromContext(ctx, "private.oauth.field", util.Tag("client_id", clientID))
defer span.Finish()
t.clientRequestsMutex.Lock()
if _, ok := t.clientRequests[clientID]; !ok {
t.clientRequests[clientID] = 0
}
t.clientRequests[clientID]++
t.clientRequestsMutex.Unlock()
if !graphql.HasOperationContext(ctx) {
return next(ctx)
}
client, err := t.store.GetOAuth(ctx, clientID)
if err != nil {
return nil, AuthorizationError
}
fc := graphql.GetFieldContext(ctx)
if fc == nil || (fc.Object != "Query" && fc.Object != "Mutation") {
return next(ctx)
}
fieldName := fc.Field.Name
span.SetAttribute("field_name", fieldName)
// TODO(vkorolik) rate limit based on opConfig
_, found := lo.Find(client.Operations, func(item *model.OAuthOperation) bool {
return item.AuthorizedGraphQLOperation == fieldName
})
if !found {
return nil, e.New(fmt.Sprintf("403 - AuthorizationError: %s", fieldName))
}
return next(ctx)
}