Skip to content

Commit

Permalink
events: Add filters to keep track of local and other subscriptions (#…
Browse files Browse the repository at this point in the history
…24201)

This adds a very basic implementation of a list of namespace+eventType
combinations that each node is interested in by just running the
glob operations in for-loops. Some parallelization is possible, but
not enabled by default.

It only wires up keeping track of what the local event bus is interested
in for now (but doesn't use it yet to filter messages).

Also updates the cloudevents source URL to indicate the Vault node that generated the event.

Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
  • Loading branch information
Christopher Swenson and tomhjp committed Nov 30, 2023
1 parent 56f793d commit 9d39b6f
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 42 deletions.
3 changes: 3 additions & 0 deletions changelog/24201.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:change
events: Source URL is now `vault://{vault node}`
```
6 changes: 5 additions & 1 deletion vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,11 @@ func NewCore(conf *CoreConfig) (*Core, error) {
eventsLogger := conf.Logger.Named("events")
c.allLoggers = append(c.allLoggers, eventsLogger)
// start the event system
events, err := eventbus.NewEventBus(eventsLogger)
nodeID, err := c.LoadNodeID()
if err != nil {
return nil, err
}
events, err := eventbus.NewEventBus(nodeID, eventsLogger)
if err != nil {
return nil, err
}
Expand Down
71 changes: 40 additions & 31 deletions vault/eventbus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ const (
)

var (
ErrNotStarted = errors.New("event broker has not been started")
cloudEventsFormatterFilter *cloudevents.FormatterFilter
subscriptions atomic.Int64 // keeps track of event subscription count in all event buses
ErrNotStarted = errors.New("event broker has not been started")
subscriptions atomic.Int64 // keeps track of event subscription count in all event buses

// these metadata fields will have the plugin mount path prepended to them
metadataPrependPathFields = []string{
Expand All @@ -49,11 +48,13 @@ var (
// EventBus contains the main logic of running an event broker for Vault.
// Start() must be called before the EventBus will accept events for sending.
type EventBus struct {
logger hclog.Logger
broker *eventlogger.Broker
started atomic.Bool
formatterNodeID eventlogger.NodeID
timeout time.Duration
logger hclog.Logger
broker *eventlogger.Broker
started atomic.Bool
formatterNodeID eventlogger.NodeID
timeout time.Duration
filters *Filters
cloudEventsFormatterFilter *cloudevents.FormatterFilter
}

type pluginEventBus struct {
Expand All @@ -72,6 +73,7 @@ type asyncChanNode struct {
closeOnce sync.Once
cancelFunc context.CancelFunc
pipelineID eventlogger.PipelineID
removeFilter func()
removePipeline func(ctx context.Context, t eventlogger.EventType, id eventlogger.PipelineID) (bool, error)
}

Expand Down Expand Up @@ -162,21 +164,7 @@ func (bus *pluginEventBus) SendEvent(ctx context.Context, eventType logical.Even
return bus.bus.SendEventInternal(ctx, bus.namespace, bus.pluginInfo, eventType, data)
}

func init() {
// TODO: maybe this should relate to the Vault core somehow?
sourceUrl, err := url.Parse("https://vaultproject.io/")
if err != nil {
panic(err)
}
cloudEventsFormatterFilter = &cloudevents.FormatterFilter{
Source: sourceUrl,
Predicate: func(_ context.Context, e interface{}) (bool, error) {
return true, nil
},
}
}

func NewEventBus(logger hclog.Logger) (*EventBus, error) {
func NewEventBus(localNodeID string, logger hclog.Logger) (*EventBus, error) {
broker, err := eventlogger.NewBroker()
if err != nil {
return nil, err
Expand All @@ -192,11 +180,25 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) {
logger = hclog.Default().Named("events")
}

sourceUrl, err := url.Parse("vault://" + localNodeID)
if err != nil {
return nil, err
}

cloudEventsFormatterFilter := &cloudevents.FormatterFilter{
Source: sourceUrl,
Predicate: func(_ context.Context, e interface{}) (bool, error) {
return true, nil
},
}

return &EventBus{
logger: logger,
broker: broker,
formatterNodeID: formatterNodeID,
timeout: defaultTimeout,
logger: logger,
broker: broker,
formatterNodeID: formatterNodeID,
timeout: defaultTimeout,
cloudEventsFormatterFilter: cloudEventsFormatterFilter,
filters: NewFilters(localNodeID),
}, nil
}

Expand All @@ -215,7 +217,7 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP
return nil, nil, err
}

err = bus.broker.RegisterNode(bus.formatterNodeID, cloudEventsFormatterFilter)
err = bus.broker.RegisterNode(bus.formatterNodeID, bus.cloudEventsFormatterFilter)
if err != nil {
return nil, nil, err
}
Expand All @@ -240,7 +242,12 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP
}

ctx, cancel := context.WithCancel(ctx)
asyncNode := newAsyncNode(ctx, bus.logger, bus.broker)

bus.filters.addPattern(bus.filters.self, namespacePathPatterns, pattern)

asyncNode := newAsyncNode(ctx, bus.logger, bus.broker, func() {
bus.filters.removePattern(bus.filters.self, namespacePathPatterns, pattern)
})
err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode)
if err != nil {
defer cancel()
Expand Down Expand Up @@ -301,7 +308,7 @@ func newFilterNode(namespacePatterns []string, pattern string, bexprFilter strin
}
}

// Filter for correct event type, including wildcards.
// NodeFilter for correct event type, including wildcards.
if !glob.Glob(pattern, eventRecv.EventType) {
return false, nil
}
Expand All @@ -315,11 +322,12 @@ func newFilterNode(namespacePatterns []string, pattern string, bexprFilter strin
}, nil
}

func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger.Broker) *asyncChanNode {
func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger.Broker, removeFilter func()) *asyncChanNode {
return &asyncChanNode{
ctx: ctx,
ch: make(chan *eventlogger.Event),
logger: logger,
removeFilter: removeFilter,
removePipeline: broker.RemovePipelineAndNodes,
}
}
Expand All @@ -328,6 +336,7 @@ func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger.
func (node *asyncChanNode) Close(ctx context.Context) {
node.closeOnce.Do(func() {
defer node.cancelFunc()
node.removeFilter()
removed, err := node.removePipeline(ctx, eventTypeAll, node.pipelineID)

switch {
Expand Down
29 changes: 19 additions & 10 deletions vault/eventbus/bus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

// TestBusBasics tests that basic event sending and subscribing function.
func TestBusBasics(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func TestBusBasics(t *testing.T) {
// TestBusIgnoresSendContext tests that the context is ignored when sending to an event,
// so that we do not give up too quickly.
func TestBusIgnoresSendContext(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestBusIgnoresSendContext(t *testing.T) {
// TestSubscribeNonRootNamespace verifies that events for non-root namespaces
// aren't filtered out by the bus.
func TestSubscribeNonRootNamespace(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -162,7 +162,7 @@ func TestSubscribeNonRootNamespace(t *testing.T) {

// TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus.
func TestNamespaceFiltering(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestNamespaceFiltering(t *testing.T) {

// TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers.
func TestBus2Subscriptions(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -293,7 +293,7 @@ func TestBusSubscriptionsCancel(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) {
subscriptions.Store(0)
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -396,7 +396,7 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) {
// TestBusWildcardSubscriptions tests that a single subscription can receive
// multiple event types using * for glob patterns.
func TestBusWildcardSubscriptions(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -471,7 +471,7 @@ func TestBusWildcardSubscriptions(t *testing.T) {
// TestDataPathIsPrependedWithMount tests that "data_path", if present in the
// metadata, is prepended with the plugin's mount.
func TestDataPathIsPrependedWithMount(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -591,7 +591,7 @@ func TestDataPathIsPrependedWithMount(t *testing.T) {

// TestBexpr tests go-bexpr filters are evaluated on an event.
func TestBexpr(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -671,7 +671,7 @@ func TestBexpr(t *testing.T) {
// TestPipelineCleanedUp ensures pipelines are properly cleaned up after
// subscriptions are closed.
func TestPipelineCleanedUp(t *testing.T) {
bus, err := NewEventBus(nil)
bus, err := NewEventBus("", nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -683,6 +683,10 @@ func TestPipelineCleanedUp(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// check that the filters are set
if !bus.filters.anyMatch(namespace.RootNamespace, eventType) {
t.Fatal()
}
if !bus.broker.IsAnyPipelineRegistered(eventTypeAll) {
cancel()
t.Fatal()
Expand All @@ -693,4 +697,9 @@ func TestPipelineCleanedUp(t *testing.T) {
if bus.broker.IsAnyPipelineRegistered(eventTypeAll) {
t.Fatal()
}

// and that the filters are cleaned up
if bus.filters.anyMatch(namespace.RootNamespace, eventType) {
t.Fatal()
}
}
120 changes: 120 additions & 0 deletions vault/eventbus/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package eventbus

import (
"slices"
"sort"
"sync"

"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
"github.com/ryanuber/go-glob"
)

// Filters keeps track of all the event patterns that each node is interested in.
type Filters struct {
lock sync.RWMutex
self nodeID
filters map[nodeID]*NodeFilter
}

// nodeID is used to syntactically indicate that the string is a node's name identifier.
type nodeID string

// pattern is used to represent one or more combinations of patterns
type pattern struct {
eventTypePattern string
namespacePatterns []string
}

// NodeFilter keeps track of all patterns that a particular node is interested in.
type NodeFilter struct {
patterns []pattern
}

func (nf *NodeFilter) match(ns *namespace.Namespace, eventType logical.EventType) bool {
if nf == nil {
return false
}
for _, p := range nf.patterns {
if glob.Glob(p.eventTypePattern, string(eventType)) {
for _, nsp := range p.namespacePatterns {
if glob.Glob(nsp, ns.Path) {
return true
}
}
}
}
return false
}

// NewFilters creates an empty set of filters to keep track of each node's pattern interests.
func NewFilters(self string) *Filters {
return &Filters{
self: nodeID(self),
filters: map[nodeID]*NodeFilter{},
}
}

// addPattern adds a pattern to a node's list.
func (f *Filters) addPattern(node nodeID, namespacePatterns []string, eventTypePattern string) {
f.lock.Lock()
defer f.lock.Unlock()
if _, ok := f.filters[node]; !ok {
f.filters[node] = &NodeFilter{}
}
nsPatterns := slices.Clone(namespacePatterns)
sort.Strings(nsPatterns)
f.filters[node].patterns = append(f.filters[node].patterns, pattern{eventTypePattern: eventTypePattern, namespacePatterns: nsPatterns})
}

func (f *Filters) addNsPattern(node nodeID, ns *namespace.Namespace, eventTypePattern string) {
f.addPattern(node, []string{ns.Path}, eventTypePattern)
}

// removePattern removes a pattern from a node's list.
func (f *Filters) removePattern(node nodeID, namespacePatterns []string, eventTypePattern string) {
nsPatterns := slices.Clone(namespacePatterns)
sort.Strings(nsPatterns)
check := pattern{eventTypePattern: eventTypePattern, namespacePatterns: nsPatterns}
f.lock.Lock()
defer f.lock.Unlock()
filters, ok := f.filters[node]
if !ok {
return
}
filters.patterns = slices.DeleteFunc(filters.patterns, func(m pattern) bool {
return m.eventTypePattern == check.eventTypePattern &&
slices.Equal(m.namespacePatterns, check.namespacePatterns)
})
}

func (f *Filters) removeNsPattern(node nodeID, ns *namespace.Namespace, eventTypePattern string) {
f.removePattern(node, []string{ns.Path}, eventTypePattern)
}

// anyMatch returns true if any node's pattern list matches the arguments.
func (f *Filters) anyMatch(ns *namespace.Namespace, eventType logical.EventType) bool {
f.lock.RLock()
defer f.lock.RUnlock()
for _, nf := range f.filters {
if nf.match(ns, eventType) {
return true
}
}
return false
}

// nodeMatch returns true if the given node's pattern list matches the arguments.
func (f *Filters) nodeMatch(node nodeID, ns *namespace.Namespace, eventType logical.EventType) bool {
f.lock.RLock()
defer f.lock.RUnlock()
return f.filters[node].match(ns, eventType)
}

// localMatch returns true if the local node's pattern list matches the arguments.
func (f *Filters) localMatch(ns *namespace.Namespace, eventType logical.EventType) bool {
return f.nodeMatch(f.self, ns, eventType)
}
Loading

0 comments on commit 9d39b6f

Please sign in to comment.