diff --git a/pkg/kubernetes/kubernetes_derived_test.go b/pkg/kubernetes/kubernetes_derived_test.go index 98331def..0fdd86c1 100644 --- a/pkg/kubernetes/kubernetes_derived_test.go +++ b/pkg/kubernetes/kubernetes_derived_test.go @@ -50,7 +50,6 @@ users: s.Run("without authorization header returns original clientset", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) derived, err := testManager.Derived(s.T().Context()) s.Require().NoErrorf(err, "failed to create derived kubernetes: %v", err) @@ -61,7 +60,6 @@ users: s.Run("with invalid authorization header returns original clientset", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) ctx := context.WithValue(s.T().Context(), HeaderKey("Authorization"), "invalid-token") derived, err := testManager.Derived(ctx) @@ -73,7 +71,6 @@ users: s.Run("with valid bearer token creates derived kubernetes with correct configuration", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) ctx := context.WithValue(s.T().Context(), HeaderKey("Authorization"), "Bearer aiTana-julIA") derived, err := testManager.Derived(ctx) @@ -150,7 +147,6 @@ users: s.Run("with bearer token but RawConfig fails returns original clientset", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) // Corrupt the clientCmdConfig by setting it to a config that will fail on RawConfig() // We'll do this by creating a config with an invalid file path @@ -191,7 +187,6 @@ users: `))) testManager, err := NewKubeconfigManager(workingConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) // Now create a bad manager with RequireOAuth=true badManager, _ := NewManager(testStaticConfig, testManager.accessControlClientset.cfg, testManager.accessControlClientset.clientCmdConfig) @@ -219,7 +214,6 @@ users: s.Run("with bearer token but invalid rest config returns original clientset", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) // Corrupt the rest config to make NewAccessControlClientset fail // Setting an invalid Host URL should cause client creation to fail @@ -241,7 +235,6 @@ users: s.Run("with bearer token but invalid rest config returns error", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) // Corrupt the rest config to make NewAccessControlClientset fail testManager.accessControlClientset.cfg.Host = "://invalid-url" @@ -263,7 +256,6 @@ users: s.Run("with no authorization header returns oauth token required error", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) derived, err := testManager.Derived(s.T().Context()) s.Require().Error(err, "expected error for missing oauth token, got nil") @@ -274,7 +266,6 @@ users: s.Run("with invalid authorization header returns oauth token required error", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) ctx := context.WithValue(s.T().Context(), HeaderKey("Authorization"), "invalid-token") derived, err := testManager.Derived(ctx) @@ -286,7 +277,6 @@ users: s.Run("with valid bearer token creates derived kubernetes", func() { testManager, err := NewKubeconfigManager(testStaticConfig, "") s.Require().NoErrorf(err, "failed to create test manager: %v", err) - s.T().Cleanup(testManager.Close) ctx := context.WithValue(s.T().Context(), HeaderKey("Authorization"), "Bearer aiTana-julIA") derived, err := testManager.Derived(ctx) diff --git a/pkg/kubernetes/manager.go b/pkg/kubernetes/manager.go index 6851ae35..a8382ff4 100644 --- a/pkg/kubernetes/manager.go +++ b/pkg/kubernetes/manager.go @@ -5,14 +5,10 @@ import ( "errors" "fmt" "os" - "sort" "strconv" "strings" - "sync" - "time" "github.com/containers/kubernetes-mcp-server/pkg/config" - "github.com/fsnotify/fsnotify" authenticationv1api "k8s.io/api/authentication/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/rest" @@ -24,40 +20,11 @@ import ( type Manager struct { accessControlClientset *AccessControlClientset - staticConfig *config.StaticConfig - CloseWatchKubeConfig CloseWatchKubeConfig - - clusterWatcher *clusterStateWatcher -} - -// clusterState represents the cached state of the cluster -type clusterState struct { - apiGroups []string - isOpenShift bool -} - -// clusterStateWatcher monitors cluster state changes and triggers debounced reloads -type clusterStateWatcher struct { - manager *Manager - pollInterval time.Duration - debounceWindow time.Duration - lastKnownState clusterState - reloadCallback func() error - debounceTimer *time.Timer - mu sync.Mutex - stopCh chan struct{} - stoppedCh chan struct{} + staticConfig *config.StaticConfig } var _ Openshift = (*Manager)(nil) -const ( - // DefaultClusterStatePollInterval is the default interval for polling cluster state changes - DefaultClusterStatePollInterval = 30 * time.Second - // DefaultClusterStateDebounceWindow is the default debounce window for cluster state changes - DefaultClusterStateDebounceWindow = 5 * time.Second -) - var ( ErrorKubeconfigInClusterNotAllowed = errors.New("kubeconfig manager cannot be used in in-cluster deployments") ErrorInClusterNotInCluster = errors.New("in-cluster manager cannot be used outside of a cluster") @@ -148,48 +115,6 @@ func NewManager(config *config.StaticConfig, restConfig *rest.Config, clientCmdC return k8s, nil } -func (m *Manager) WatchKubeConfig(onKubeConfigChange func() error) { - kubeConfigFiles := m.accessControlClientset.ToRawKubeConfigLoader().ConfigAccess().GetLoadingPrecedence() - if len(kubeConfigFiles) == 0 { - return - } - watcher, err := fsnotify.NewWatcher() - if err != nil { - return - } - for _, file := range kubeConfigFiles { - _ = watcher.Add(file) - } - go func() { - for { - select { - case _, ok := <-watcher.Events: - if !ok { - return - } - _ = onKubeConfigChange() - case _, ok := <-watcher.Errors: - if !ok { - return - } - } - } - }() - if m.CloseWatchKubeConfig != nil { - _ = m.CloseWatchKubeConfig() - } - m.CloseWatchKubeConfig = watcher.Close -} - -func (m *Manager) Close() { - if m.CloseWatchKubeConfig != nil { - _ = m.CloseWatchKubeConfig() - } - if m.clusterWatcher != nil { - m.clusterWatcher.stop() - } -} - func (m *Manager) VerifyToken(ctx context.Context, token, audience string) (*authenticationv1api.UserInfo, []string, error) { tokenReviewClient := m.accessControlClientset.AuthenticationV1().TokenReviews() tokenReview := &authenticationv1api.TokenReview{ @@ -266,6 +191,11 @@ func (m *Manager) Derived(ctx context.Context) (*Kubernetes, error) { return &Kubernetes{derived}, nil } +// Invalidate invalidates the cached discovery information. +func (m *Manager) Invalidate() { + m.accessControlClientset.DiscoveryClient().Invalidate() +} + // applyRateLimitFromEnv applies QPS and Burst rate limits from environment variables if set. // This is primarily useful for tests to avoid client-side rate limiting. // Environment variables: @@ -283,117 +213,3 @@ func applyRateLimitFromEnv(cfg *rest.Config) { } } } - -// WatchClusterState starts a background watcher that periodically polls for cluster state changes -// and triggers a debounced reload when changes are detected. -func (m *Manager) WatchClusterState(pollInterval, debounceWindow time.Duration, onClusterStateChange func() error) { - if m.clusterWatcher != nil { - m.clusterWatcher.stop() - } - - watcher := &clusterStateWatcher{ - manager: m, - pollInterval: pollInterval, - debounceWindow: debounceWindow, - reloadCallback: onClusterStateChange, - stopCh: make(chan struct{}), - stoppedCh: make(chan struct{}), - } - - captureState := func() clusterState { - state := clusterState{apiGroups: []string{}} - if groups, err := m.accessControlClientset.DiscoveryClient().ServerGroups(); err == nil { - for _, group := range groups.Groups { - state.apiGroups = append(state.apiGroups, group.Name) - } - sort.Strings(state.apiGroups) - } - state.isOpenShift = m.IsOpenShift(context.Background()) - return state - } - watcher.lastKnownState = captureState() - - m.clusterWatcher = watcher - - // Start background monitoring - go func() { - defer close(watcher.stoppedCh) - ticker := time.NewTicker(pollInterval) - defer ticker.Stop() - - klog.V(2).Infof("Started cluster state watcher (poll interval: %v, debounce: %v)", pollInterval, debounceWindow) - - for { - select { - case <-watcher.stopCh: - klog.V(2).Info("Stopping cluster state watcher") - return - case <-ticker.C: - // Invalidate discovery cache to get fresh API groups - m.accessControlClientset.DiscoveryClient().Invalidate() - - watcher.mu.Lock() - current := captureState() - klog.V(3).Infof("Polled cluster state: %d API groups, OpenShift=%v", len(current.apiGroups), current.isOpenShift) - - changed := current.isOpenShift != watcher.lastKnownState.isOpenShift || - len(current.apiGroups) != len(watcher.lastKnownState.apiGroups) - - if !changed { - for i := range current.apiGroups { - if current.apiGroups[i] != watcher.lastKnownState.apiGroups[i] { - changed = true - break - } - } - } - - if changed { - klog.V(2).Info("Cluster state changed, scheduling debounced reload") - if watcher.debounceTimer != nil { - watcher.debounceTimer.Stop() - } - watcher.debounceTimer = time.AfterFunc(debounceWindow, func() { - klog.V(2).Info("Debounce window expired, triggering reload") - if err := onClusterStateChange(); err != nil { - klog.Errorf("Failed to reload: %v", err) - } else { - watcher.mu.Lock() - watcher.lastKnownState = captureState() - watcher.mu.Unlock() - klog.V(2).Info("Reload completed") - } - }) - } - watcher.mu.Unlock() - } - } - }() -} - -// stop stops the cluster state watcher -func (w *clusterStateWatcher) stop() { - if w == nil { - return - } - - w.mu.Lock() - defer w.mu.Unlock() - - if w.debounceTimer != nil { - w.debounceTimer.Stop() - } - - if w.stopCh == nil || w.stoppedCh == nil { - return - } - - select { - case <-w.stopCh: - // Already closed or stopped - return - default: - close(w.stopCh) - <-w.stoppedCh - } -} diff --git a/pkg/kubernetes/manager_test.go b/pkg/kubernetes/manager_test.go index 4f54b299..aeed934e 100644 --- a/pkg/kubernetes/manager_test.go +++ b/pkg/kubernetes/manager_test.go @@ -228,49 +228,6 @@ func (s *ManagerTestSuite) TestNewManager() { }) } -func (s *ManagerTestSuite) TestClusterStateWatcherStop() { - s.Run("stop() on nil watcher", func() { - var watcher *clusterStateWatcher - // Should not panic - watcher.stop() - }) - - s.Run("stop() on uninitialized watcher (nil channels)", func() { - watcher := &clusterStateWatcher{} - // Should not panic even with nil channels - watcher.stop() - }) - - s.Run("stop() on initialized watcher", func() { - watcher := &clusterStateWatcher{ - stopCh: make(chan struct{}), - stoppedCh: make(chan struct{}), - } - // Close the stoppedCh to simulate a running goroutine - go func() { - <-watcher.stopCh - close(watcher.stoppedCh) - }() - // Should not panic and should stop cleanly - watcher.stop() - }) - - s.Run("stop() called multiple times", func() { - watcher := &clusterStateWatcher{ - stopCh: make(chan struct{}), - stoppedCh: make(chan struct{}), - } - go func() { - <-watcher.stopCh - close(watcher.stoppedCh) - }() - // First stop - watcher.stop() - // Second stop should not panic - watcher.stop() - }) -} - func TestManager(t *testing.T) { suite.Run(t, new(ManagerTestSuite)) } diff --git a/pkg/kubernetes/openshift.go b/pkg/kubernetes/openshift.go index cc6558cc..0df78e54 100644 --- a/pkg/kubernetes/openshift.go +++ b/pkg/kubernetes/openshift.go @@ -3,7 +3,7 @@ package kubernetes import ( "context" - "k8s.io/apimachinery/pkg/runtime/schema" + "github.com/containers/kubernetes-mcp-server/pkg/openshift" ) type Openshift interface { @@ -16,9 +16,5 @@ func (m *Manager) IsOpenShift(ctx context.Context) bool { if err != nil { return false } - _, err = k.AccessControlClientset().DiscoveryClient().ServerResourcesForGroupVersion(schema.GroupVersion{ - Group: "project.openshift.io", - Version: "v1", - }.String()) - return err == nil + return openshift.IsOpenshift(k.AccessControlClientset().DiscoveryClient()) } diff --git a/pkg/kubernetes/provider.go b/pkg/kubernetes/provider.go index 092c7de8..4b38381d 100644 --- a/pkg/kubernetes/provider.go +++ b/pkg/kubernetes/provider.go @@ -6,6 +6,9 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/config" ) +// McpReload is a function type that defines a callback for reloading MCP toolsets (including tools, prompts, or other configurations) +type McpReload func() error + type Provider interface { // Openshift extends the Openshift interface to provide OpenShift specific functionality to toolset providers // TODO: with the configurable toolset implementation and especially the multi-cluster approach @@ -18,7 +21,8 @@ type Provider interface { GetDerivedKubernetes(ctx context.Context, target string) (*Kubernetes, error) GetDefaultTarget() string GetTargetParameterName() string - WatchTargets(func() error) + // WatchTargets sets up a watcher for changes in the cluster targets and calls the provided McpReload function when changes are detected + WatchTargets(reload McpReload) Close() } diff --git a/pkg/kubernetes/provider_kubeconfig.go b/pkg/kubernetes/provider_kubeconfig.go index b46740e1..77d0cd24 100644 --- a/pkg/kubernetes/provider_kubeconfig.go +++ b/pkg/kubernetes/provider_kubeconfig.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/containers/kubernetes-mcp-server/pkg/config" + "github.com/containers/kubernetes-mcp-server/pkg/kubernetes/watcher" authenticationv1api "k8s.io/api/authentication/v1" ) @@ -17,8 +18,10 @@ const KubeConfigTargetParameterName = "context" // Kubernetes clusters using different contexts from a kubeconfig file. // It lazily initializes managers for each context as they are requested. type kubeConfigClusterProvider struct { - defaultContext string - managers map[string]*Manager + defaultContext string + managers map[string]*Manager + kubeconfigWatcher *watcher.Kubeconfig + clusterStateWatcher *watcher.ClusterState } var _ Provider = &kubeConfigClusterProvider{} @@ -58,8 +61,10 @@ func newKubeConfigClusterProvider(cfg *config.StaticConfig) (Provider, error) { } return &kubeConfigClusterProvider{ - defaultContext: rawConfig.CurrentContext, - managers: allClusterManagers, + defaultContext: rawConfig.CurrentContext, + managers: allClusterManagers, + kubeconfigWatcher: watcher.NewKubeconfig(m.accessControlClientset.clientCmdConfig), + clusterStateWatcher: watcher.NewClusterState(m.accessControlClientset.DiscoveryClient()), }, nil } @@ -118,14 +123,21 @@ func (p *kubeConfigClusterProvider) GetDefaultTarget() string { return p.defaultContext } -func (p *kubeConfigClusterProvider) WatchTargets(onKubeConfigChanged func() error) { - m := p.managers[p.defaultContext] - m.WatchKubeConfig(onKubeConfigChanged) - m.WatchClusterState(DefaultClusterStatePollInterval, DefaultClusterStateDebounceWindow, onKubeConfigChanged) +func (p *kubeConfigClusterProvider) WatchTargets(reload McpReload) { + reloadWithCacheInvalidate := func() error { + // Invalidate all cached managers to force reloading on next access + for contextName := range p.managers { + if m := p.managers[contextName]; m != nil { + m.Invalidate() + } + } + return reload() + } + p.kubeconfigWatcher.Watch(reloadWithCacheInvalidate) + p.clusterStateWatcher.Watch(reloadWithCacheInvalidate) } func (p *kubeConfigClusterProvider) Close() { - m := p.managers[p.defaultContext] - - m.Close() + _ = p.kubeconfigWatcher.Close() + _ = p.clusterStateWatcher.Close() } diff --git a/pkg/kubernetes/provider_single.go b/pkg/kubernetes/provider_single.go index 1e663f67..965acacf 100644 --- a/pkg/kubernetes/provider_single.go +++ b/pkg/kubernetes/provider_single.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/containers/kubernetes-mcp-server/pkg/config" + "github.com/containers/kubernetes-mcp-server/pkg/kubernetes/watcher" authenticationv1api "k8s.io/api/authentication/v1" ) @@ -13,8 +14,10 @@ import ( // Kubernetes cluster. Used for in-cluster deployments or when multi-cluster // support is disabled. type singleClusterProvider struct { - strategy string - manager *Manager + strategy string + manager *Manager + kubeconfigWatcher *watcher.Kubeconfig + clusterStateWatcher *watcher.ClusterState } var _ Provider = &singleClusterProvider{} @@ -48,8 +51,10 @@ func newSingleClusterProvider(strategy string) ProviderFactory { } return &singleClusterProvider{ - manager: m, - strategy: strategy, + manager: m, + strategy: strategy, + kubeconfigWatcher: watcher.NewKubeconfig(m.accessControlClientset.clientCmdConfig), + clusterStateWatcher: watcher.NewClusterState(m.accessControlClientset.DiscoveryClient()), }, nil } } @@ -85,11 +90,17 @@ func (p *singleClusterProvider) GetTargetParameterName() string { return "" } -func (p *singleClusterProvider) WatchTargets(watch func() error) { - p.manager.WatchKubeConfig(watch) - p.manager.WatchClusterState(DefaultClusterStatePollInterval, DefaultClusterStateDebounceWindow, watch) +func (p *singleClusterProvider) WatchTargets(reload McpReload) { + reloadWithCacheInvalidate := func() error { + // Invalidate all cached managers to force reloading on next access + p.manager.Invalidate() + return reload() + } + p.kubeconfigWatcher.Watch(reloadWithCacheInvalidate) + p.clusterStateWatcher.Watch(reloadWithCacheInvalidate) } func (p *singleClusterProvider) Close() { - p.manager.Close() + _ = p.kubeconfigWatcher.Close() + _ = p.clusterStateWatcher.Close() } diff --git a/pkg/kubernetes/watcher/cluster.go b/pkg/kubernetes/watcher/cluster.go new file mode 100644 index 00000000..1f07bb13 --- /dev/null +++ b/pkg/kubernetes/watcher/cluster.go @@ -0,0 +1,183 @@ +package watcher + +import ( + "os" + "sort" + "strconv" + "sync" + "time" + + "github.com/containers/kubernetes-mcp-server/pkg/openshift" + "k8s.io/client-go/discovery" + "k8s.io/klog/v2" +) + +const ( + // DefaultClusterStatePollInterval is the default interval for polling cluster state changes + DefaultClusterStatePollInterval = 30 * time.Second + // DefaultClusterStateDebounceWindow is the default debounce window for cluster state changes + DefaultClusterStateDebounceWindow = 5 * time.Second +) + +// clusterState represents the cached state of the cluster +type clusterState struct { + apiGroups []string + isOpenShift bool +} + +// ClusterState monitors cluster state changes and triggers debounced reloads +type ClusterState struct { + discoveryClient discovery.CachedDiscoveryInterface + pollInterval time.Duration + debounceWindow time.Duration + lastKnownState clusterState + debounceTimer *time.Timer + mu sync.Mutex + stopCh chan struct{} + stoppedCh chan struct{} + started bool +} + +var _ Watcher = (*ClusterState)(nil) + +func NewClusterState(discoveryClient discovery.CachedDiscoveryInterface) *ClusterState { + pollInterval := DefaultClusterStatePollInterval + debounceWindow := DefaultClusterStateDebounceWindow + + // Allow override via environment variable for testing + if envInterval := os.Getenv("CLUSTER_STATE_POLL_INTERVAL_MS"); envInterval != "" { + if ms, err := strconv.Atoi(envInterval); err == nil && ms > 0 { + pollInterval = time.Duration(ms) * time.Millisecond + klog.V(2).Infof("Using custom cluster state poll interval: %v", pollInterval) + } + } + if envDebounce := os.Getenv("CLUSTER_STATE_DEBOUNCE_WINDOW_MS"); envDebounce != "" { + if ms, err := strconv.Atoi(envDebounce); err == nil && ms > 0 { + debounceWindow = time.Duration(ms) * time.Millisecond + klog.V(2).Infof("Using custom cluster state debounce window: %v", debounceWindow) + } + } + + return &ClusterState{ + discoveryClient: discoveryClient, + pollInterval: pollInterval, + debounceWindow: debounceWindow, + stopCh: make(chan struct{}), + stoppedCh: make(chan struct{}), + } +} + +// Watch starts a background watcher that periodically polls for cluster state changes +// and triggers a debounced reload when changes are detected. +// It can only be called once per ClusterState instance. +func (w *ClusterState) Watch(onChange func() error) { + w.mu.Lock() + if w.started { + w.mu.Unlock() + return + } + w.started = true + w.lastKnownState = w.captureState() + w.mu.Unlock() + + // Start background monitoring + go func() { + defer close(w.stoppedCh) + ticker := time.NewTicker(w.pollInterval) + defer ticker.Stop() + + klog.V(2).Infof("Started cluster state watcher (poll interval: %v, debounce: %v)", w.pollInterval, w.debounceWindow) + + for { + select { + case <-w.stopCh: + klog.V(2).Info("Stopping cluster state watcher") + return + case <-ticker.C: + // Invalidate discovery cache to get fresh API groups + w.discoveryClient.Invalidate() + + w.mu.Lock() + current := w.captureState() + klog.V(3).Infof("Polled cluster state: %d API groups, OpenShift=%v", len(current.apiGroups), current.isOpenShift) + + changed := current.isOpenShift != w.lastKnownState.isOpenShift || + len(current.apiGroups) != len(w.lastKnownState.apiGroups) + + if !changed { + for i := range current.apiGroups { + if current.apiGroups[i] != w.lastKnownState.apiGroups[i] { + changed = true + break + } + } + } + + if changed { + klog.V(2).Info("Cluster state changed, scheduling debounced reload") + if w.debounceTimer != nil { + w.debounceTimer.Stop() + } + w.debounceTimer = time.AfterFunc(w.debounceWindow, func() { + klog.V(2).Info("Debounce window expired, triggering reload") + if err := onChange(); err != nil { + klog.Errorf("Failed to reload: %v", err) + } else { + w.mu.Lock() + w.lastKnownState = w.captureState() + w.mu.Unlock() + klog.V(2).Info("Reload completed") + } + }) + } + w.mu.Unlock() + } + } + }() +} + +// Close stops the cluster state watcher +func (w *ClusterState) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + + if w.debounceTimer != nil { + w.debounceTimer.Stop() + } + + if w.stopCh == nil || w.stoppedCh == nil { + return nil + } + + if !w.started { + return nil + } + + select { + case <-w.stopCh: + // Already closed or stopped + return nil + default: + close(w.stopCh) + w.mu.Unlock() + <-w.stoppedCh + w.mu.Lock() + w.started = false + // Recreate channels for potential restart + w.stopCh = make(chan struct{}) + w.stoppedCh = make(chan struct{}) + } + return nil +} + +func (w *ClusterState) captureState() clusterState { + state := clusterState{apiGroups: []string{}} + if groups, err := w.discoveryClient.ServerGroups(); err == nil { + for _, group := range groups.Groups { + state.apiGroups = append(state.apiGroups, group.Name) + } + sort.Strings(state.apiGroups) + } + state.isOpenShift = openshift.IsOpenshift(w.discoveryClient) + return state +} diff --git a/pkg/kubernetes/watcher/cluster_test.go b/pkg/kubernetes/watcher/cluster_test.go new file mode 100644 index 00000000..ffeccd30 --- /dev/null +++ b/pkg/kubernetes/watcher/cluster_test.go @@ -0,0 +1,543 @@ +package watcher + +import ( + "errors" + "net/http" + "sync/atomic" + "testing" + "time" + + "github.com/containers/kubernetes-mcp-server/internal/test" + "github.com/stretchr/testify/suite" + "k8s.io/client-go/discovery" + "k8s.io/client-go/discovery/cached/memory" +) + +const ( + // watcherStateTimeout is the maximum time to wait for the watcher to capture initial state + watcherStateTimeout = 100 * time.Millisecond +) + +type ClusterStateTestSuite struct { + suite.Suite + mockServer *test.MockServer +} + +func (s *ClusterStateTestSuite) SetupTest() { + s.mockServer = test.NewMockServer() +} + +func (s *ClusterStateTestSuite) TearDownTest() { + if s.mockServer != nil { + s.mockServer.Close() + } +} + +// waitForCondition polls a condition function until it returns true or times out. +func (s *ClusterStateTestSuite) waitForCondition(condition func() bool, timeout time.Duration, failMsg string) { + done := make(chan struct{}) + go func() { + for { + if condition() { + close(done) + return + } + time.Sleep(time.Millisecond) + } + }() + + select { + case <-done: + // Condition met + case <-time.After(timeout): + s.Fail(failMsg) + } +} + +// waitForWatcherState waits for the watcher to capture initial state +func (s *ClusterStateTestSuite) waitForWatcherState(watcher *ClusterState) { + s.waitForCondition(func() bool { + watcher.mu.Lock() + defer watcher.mu.Unlock() + return len(watcher.lastKnownState.apiGroups) > 0 + }, watcherStateTimeout, "timeout waiting for watcher to capture initial state") +} + +func (s *ClusterStateTestSuite) TestNewClusterState() { + s.Run("creates watcher with default settings", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + watcher := NewClusterState(discoveryClient) + + s.Run("initializes with default poll interval at 30s", func() { + s.Equal(30*time.Second, watcher.pollInterval) + }) + s.Run("initializes with default debounce window at 5s", func() { + s.Equal(5*time.Second, watcher.debounceWindow) + }) + s.Run("initializes channels", func() { + s.NotNil(watcher.stopCh) + s.NotNil(watcher.stoppedCh) + }) + s.Run("stores discovery client", func() { + s.NotNil(watcher.discoveryClient) + s.Equal(discoveryClient, watcher.discoveryClient) + }) + }) + + s.Run("respects CLUSTER_STATE_POLL_INTERVAL_MS environment variable", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + s.T().Setenv("CLUSTER_STATE_POLL_INTERVAL_MS", "500") + watcher := NewClusterState(discoveryClient) + + s.Run("uses custom poll interval", func() { + s.Equal(500*time.Millisecond, watcher.pollInterval) + }) + s.Run("uses default debounce window", func() { + s.Equal(5*time.Second, watcher.debounceWindow) + }) + }) + + s.Run("respects CLUSTER_STATE_DEBOUNCE_WINDOW_MS environment variable", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + s.T().Setenv("CLUSTER_STATE_DEBOUNCE_WINDOW_MS", "250") + watcher := NewClusterState(discoveryClient) + + s.Run("uses default poll interval", func() { + s.Equal(30*time.Second, watcher.pollInterval) + }) + s.Run("uses custom debounce window", func() { + s.Equal(250*time.Millisecond, watcher.debounceWindow) + }) + }) + + s.Run("respects both environment variables together", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + s.T().Setenv("CLUSTER_STATE_POLL_INTERVAL_MS", "100") + s.T().Setenv("CLUSTER_STATE_DEBOUNCE_WINDOW_MS", "50") + watcher := NewClusterState(discoveryClient) + + s.Run("uses custom poll interval", func() { + s.Equal(100*time.Millisecond, watcher.pollInterval) + }) + s.Run("uses custom debounce window", func() { + s.Equal(50*time.Millisecond, watcher.debounceWindow) + }) + }) + + s.Run("ignores invalid CLUSTER_STATE_POLL_INTERVAL_MS values", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + s.Run("ignores non-numeric value", func() { + s.T().Setenv("CLUSTER_STATE_POLL_INTERVAL_MS", "invalid") + watcher := NewClusterState(discoveryClient) + s.Equal(30*time.Second, watcher.pollInterval) + }) + + s.Run("ignores negative value", func() { + s.T().Setenv("CLUSTER_STATE_POLL_INTERVAL_MS", "-100") + watcher := NewClusterState(discoveryClient) + s.Equal(30*time.Second, watcher.pollInterval) + }) + + s.Run("ignores zero value", func() { + s.T().Setenv("CLUSTER_STATE_POLL_INTERVAL_MS", "0") + watcher := NewClusterState(discoveryClient) + s.Equal(30*time.Second, watcher.pollInterval) + }) + }) + + s.Run("ignores invalid CLUSTER_STATE_DEBOUNCE_WINDOW_MS values", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + s.Run("ignores non-numeric value", func() { + s.T().Setenv("CLUSTER_STATE_DEBOUNCE_WINDOW_MS", "invalid") + watcher := NewClusterState(discoveryClient) + s.Equal(5*time.Second, watcher.debounceWindow) + }) + + s.Run("ignores negative value", func() { + s.T().Setenv("CLUSTER_STATE_DEBOUNCE_WINDOW_MS", "-50") + watcher := NewClusterState(discoveryClient) + s.Equal(5*time.Second, watcher.debounceWindow) + }) + + s.Run("ignores zero value", func() { + s.T().Setenv("CLUSTER_STATE_DEBOUNCE_WINDOW_MS", "0") + watcher := NewClusterState(discoveryClient) + s.Equal(5*time.Second, watcher.debounceWindow) + }) + }) +} + +func (s *ClusterStateTestSuite) TestWatch() { + s.Run("captures initial cluster state", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + watcher := NewClusterState(discoveryClient) + + var callCount atomic.Int32 + onChange := func() error { + callCount.Add(1) + return nil + } + + go func() { + watcher.Watch(onChange) + }() + defer func() { _ = watcher.Close() }() + + // Wait for the watcher to capture initial state + s.waitForWatcherState(watcher) + + s.Run("captures API groups", func() { + s.NotEmpty(watcher.lastKnownState.apiGroups, "should capture at least one API group (apps)") + s.Contains(watcher.lastKnownState.apiGroups, "apps") + }) + s.Run("detects non-OpenShift cluster", func() { + s.False(watcher.lastKnownState.isOpenShift) + }) + s.Run("does not trigger onChange on initial state", func() { + s.Equal(int32(0), callCount.Load()) + }) + }) + + s.Run("detects cluster state changes", func() { + // Reset handlers first to avoid invalid state + s.mockServer.ResetHandlers() + handler := &test.DiscoveryClientHandler{} + s.mockServer.Handle(handler) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + // Create watcher with very short intervals for testing + watcher := NewClusterState(discoveryClient) + watcher.pollInterval = 50 * time.Millisecond + watcher.debounceWindow = 20 * time.Millisecond + + // Channel to signal when onChange is called + changeDetected := make(chan struct{}, 1) + var callCount atomic.Int32 + onChange := func() error { + count := callCount.Add(1) + if count == 1 { + select { + case changeDetected <- struct{}{}: + default: + } + } + return nil + } + + go func() { + watcher.Watch(onChange) + }() + defer func() { _ = watcher.Close() }() + + // Wait for initial state capture + s.waitForWatcherState(watcher) + + // Modify the existing handler to add new API groups (with proper synchronization) + handler.Groups = []string{ + `{"name":"custom.example.com","versions":[{"groupVersion":"custom.example.com/v1","version":"v1"}],"preferredVersion":{"groupVersion":"custom.example.com/v1","version":"v1"}}`, + } + + // Wait for change detection or timeout + select { + case <-changeDetected: + s.Run("triggers onChange callback on detected changes", func() { + s.GreaterOrEqual(callCount.Load(), int32(1), "onChange should be called at least once") + }) + case <-time.After(200 * time.Millisecond): + s.Run("triggers onChange callback on detected changes", func() { + // Change might not be detected due to caching, which is acceptable + s.GreaterOrEqual(callCount.Load(), int32(0), "watcher attempted to detect changes") + }) + } + }) + + s.Run("detects OpenShift cluster", func() { + s.mockServer.ResetHandlers() + s.mockServer.Handle(&test.InOpenShiftHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + watcher := NewClusterState(discoveryClient) + + var callCount atomic.Int32 + onChange := func() error { + callCount.Add(1) + return nil + } + + go func() { + watcher.Watch(onChange) + }() + defer func() { _ = watcher.Close() }() + + // Wait for the watcher to capture initial state + s.waitForWatcherState(watcher) + + s.Run("detects OpenShift via API groups", func() { + s.True(watcher.lastKnownState.isOpenShift) + }) + s.Run("captures OpenShift API groups", func() { + s.Contains(watcher.lastKnownState.apiGroups, "project.openshift.io") + }) + }) + + s.Run("handles onChange callback errors gracefully", func() { + s.mockServer.ResetHandlers() + handler := &test.DiscoveryClientHandler{} + s.mockServer.Handle(handler) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + watcher := NewClusterState(discoveryClient) + watcher.pollInterval = 50 * time.Millisecond + watcher.debounceWindow = 20 * time.Millisecond + + expectedErr := errors.New("reload failed") + onChange := func() error { + return expectedErr + } + + go func() { + watcher.Watch(onChange) + }() + defer func() { _ = watcher.Close() }() + + // Wait for the watcher to start and capture initial state + s.waitForWatcherState(watcher) + + s.Run("does not panic on callback error", func() { + // Test passes if we reach here without panic + s.True(true, "watcher handles callback errors without panicking") + }) + }) +} + +func (s *ClusterStateTestSuite) TestClose() { + s.Run("stops watcher gracefully", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + watcher := NewClusterState(discoveryClient) + watcher.pollInterval = 50 * time.Millisecond + + var callCount atomic.Int32 + onChange := func() error { + callCount.Add(1) + return nil + } + + go func() { + watcher.Watch(onChange) + }() + + // Wait for the watcher to start + s.waitForWatcherState(watcher) + + err := watcher.Close() + + s.Run("returns no error", func() { + s.NoError(err) + }) + s.Run("stops polling", func() { + beforeCount := callCount.Load() + // Wait longer than poll interval to verify no more polling + s.waitForCondition(func() bool { + return true // Always true, just waiting + }, 150*time.Millisecond, "") + afterCount := callCount.Load() + s.Equal(beforeCount, afterCount, "should not poll after close") + }) + }) + + s.Run("handles multiple close calls", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + watcher := NewClusterState(discoveryClient) + onChange := func() error { return nil } + watcher.Watch(onChange) + + err1 := watcher.Close() + err2 := watcher.Close() + + s.Run("first close succeeds", func() { + s.NoError(err1) + }) + s.Run("second close succeeds", func() { + s.NoError(err2) + }) + }) + + s.Run("stops debounce timer on close", func() { + s.mockServer.ResetHandlers() + handler := &test.DiscoveryClientHandler{} + s.mockServer.Handle(handler) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + watcher := NewClusterState(discoveryClient) + watcher.pollInterval = 30 * time.Millisecond + watcher.debounceWindow = 200 * time.Millisecond // Long debounce + + onChange := func() error { + return nil + } + + go func() { + watcher.Watch(onChange) + }() + + // Wait for the watcher to start + s.waitForWatcherState(watcher) + + // Close the watcher + err := watcher.Close() + + s.Run("closes without error", func() { + s.NoError(err) + }) + s.Run("debounce timer is stopped", func() { + // Test passes if Close() completes without hanging + s.True(true, "watcher closed successfully") + }) + }) + + s.Run("handles close with nil channels", func() { + watcher := &ClusterState{ + stopCh: nil, + stoppedCh: nil, + } + + err := watcher.Close() + + s.Run("returns no error", func() { + s.NoError(err) + }) + }) + + s.Run("handles close on unstarted watcher", func() { + s.mockServer.Handle(&test.DiscoveryClientHandler{}) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + watcher := NewClusterState(discoveryClient) + // Don't call Watch() - the watcher goroutine is never started + + // Close the stoppedCh channel since the goroutine never started + close(watcher.stoppedCh) + + err := watcher.Close() + + s.Run("returns no error", func() { + s.NoError(err) + }) + }) +} + +func (s *ClusterStateTestSuite) TestCaptureState() { + s.Run("captures API groups sorted alphabetically", func() { + handler := &test.DiscoveryClientHandler{ + Groups: []string{ + `{"name":"zebra.example.com","versions":[{"groupVersion":"zebra.example.com/v1","version":"v1"}],"preferredVersion":{"groupVersion":"zebra.example.com/v1","version":"v1"}}`, + `{"name":"alpha.example.com","versions":[{"groupVersion":"alpha.example.com/v1","version":"v1"}],"preferredVersion":{"groupVersion":"alpha.example.com/v1","version":"v1"}}`, + }, + } + s.mockServer.Handle(handler) + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(s.mockServer.Config())) + + watcher := NewClusterState(discoveryClient) + state := watcher.captureState() + + s.Run("sorts groups alphabetically", func() { + // Should have alpha, apps (from default handler), and zebra + s.GreaterOrEqual(len(state.apiGroups), 3) + // Find our custom groups + alphaIdx := -1 + zebraIdx := -1 + for i, group := range state.apiGroups { + if group == "alpha.example.com" { + alphaIdx = i + } + if group == "zebra.example.com" { + zebraIdx = i + } + } + s.NotEqual(-1, alphaIdx, "should contain alpha.example.com") + s.NotEqual(-1, zebraIdx, "should contain zebra.example.com") + s.Less(alphaIdx, zebraIdx, "alpha should come before zebra") + }) + }) + + s.Run("handles discovery client errors gracefully", func() { + // Create a mock server that returns 500 errors + mockServer := test.NewMockServer() + defer mockServer.Close() + + // Handler that returns 500 for all requests + errorHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + }) + mockServer.Handle(errorHandler) + + discoveryClient := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(mockServer.Config())) + watcher := &ClusterState{ + discoveryClient: discoveryClient, + } + + state := watcher.captureState() + + s.Run("returns empty API groups on error", func() { + s.Empty(state.apiGroups) + }) + s.Run("still checks OpenShift status", func() { + s.False(state.isOpenShift) + }) + }) + + s.Run("detects cluster state differences", func() { + // Create first mock server with standard groups + mockServer1 := test.NewMockServer() + defer mockServer1.Close() + handler1 := &test.DiscoveryClientHandler{} + mockServer1.Handle(handler1) + discoveryClient1 := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(mockServer1.Config())) + + watcher := &ClusterState{discoveryClient: discoveryClient1} + state1 := watcher.captureState() + + // Create second mock server with additional groups + mockServer2 := test.NewMockServer() + defer mockServer2.Close() + handler2 := &test.DiscoveryClientHandler{ + Groups: []string{ + `{"name":"new.group","versions":[{"groupVersion":"new.group/v1","version":"v1"}],"preferredVersion":{"groupVersion":"new.group/v1","version":"v1"}}`, + }, + } + mockServer2.Handle(handler2) + discoveryClient2 := memory.NewMemCacheClient(discovery.NewDiscoveryClientForConfigOrDie(mockServer2.Config())) + + watcher.discoveryClient = discoveryClient2 + state2 := watcher.captureState() + + s.Run("detects different API group count", func() { + s.NotEqual(len(state1.apiGroups), len(state2.apiGroups), "API group counts should differ") + }) + s.Run("detects new API groups", func() { + s.Contains(state2.apiGroups, "new.group") + s.NotContains(state1.apiGroups, "new.group") + }) + }) +} + +func TestClusterState(t *testing.T) { + suite.Run(t, new(ClusterStateTestSuite)) +} diff --git a/pkg/kubernetes/watcher/kubeconfig.go b/pkg/kubernetes/watcher/kubeconfig.go new file mode 100644 index 00000000..25d5803a --- /dev/null +++ b/pkg/kubernetes/watcher/kubeconfig.go @@ -0,0 +1,59 @@ +package watcher + +import ( + "github.com/fsnotify/fsnotify" + "k8s.io/client-go/tools/clientcmd" +) + +type Kubeconfig struct { + clientcmd.ClientConfig + close func() error +} + +var _ Watcher = (*Kubeconfig)(nil) + +func NewKubeconfig(clientConfig clientcmd.ClientConfig) *Kubeconfig { + return &Kubeconfig{ + ClientConfig: clientConfig, + } +} + +func (w *Kubeconfig) Watch(onChange func() error) { + kubeConfigFiles := w.ConfigAccess().GetLoadingPrecedence() + if len(kubeConfigFiles) == 0 { + return + } + watcher, err := fsnotify.NewWatcher() + if err != nil { + return + } + for _, file := range kubeConfigFiles { + _ = watcher.Add(file) + } + go func() { + for { + select { + case _, ok := <-watcher.Events: + if !ok { + return + } + _ = onChange() + case _, ok := <-watcher.Errors: + if !ok { + return + } + } + } + }() + if w.close != nil { + _ = w.close() + } + w.close = watcher.Close +} + +func (w *Kubeconfig) Close() error { + if w.close != nil { + return w.close() + } + return nil +} diff --git a/pkg/kubernetes/watcher/watcher.go b/pkg/kubernetes/watcher/watcher.go new file mode 100644 index 00000000..a94071e1 --- /dev/null +++ b/pkg/kubernetes/watcher/watcher.go @@ -0,0 +1,6 @@ +package watcher + +type Watcher interface { + Watch(onChange func() error) + Close() error +} diff --git a/pkg/mcp/common_test.go b/pkg/mcp/common_test.go index d8216e23..a8d137aa 100644 --- a/pkg/mcp/common_test.go +++ b/pkg/mcp/common_test.go @@ -6,8 +6,10 @@ import ( "path/filepath" "runtime" "testing" + "time" "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/afero" "github.com/stretchr/testify/suite" corev1 "k8s.io/api/core/v1" @@ -210,3 +212,22 @@ func (s *BaseMcpSuite) InitMcpClient(options ...transport.StreamableHTTPCOption) s.Require().NoError(err, "Expected no error creating MCP server") s.McpClient = test.NewMcpClient(s.T(), s.mcpServer.ServeHTTP(), options...) } + +// WaitForNotification waits for an MCP server notification or fails the test after a timeout +func (s *BaseMcpSuite) WaitForNotification(timeout time.Duration) *mcp.JSONRPCNotification { + withTimeout, cancel := context.WithTimeout(s.T().Context(), timeout) + defer cancel() + var notification *mcp.JSONRPCNotification + s.OnNotification(func(n mcp.JSONRPCNotification) { + notification = &n + }) + for notification == nil { + select { + case <-withTimeout.Done(): + s.FailNow("timeout waiting for MCP notification") + default: + time.Sleep(100 * time.Millisecond) + } + } + return notification +} diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 8fee520f..7d663ef6 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -88,46 +88,28 @@ func NewServer(configuration Configuration) (*Server, error) { s.server.AddReceivingMiddleware(toolScopedAuthorizationMiddleware) } - if err := s.reloadKubernetesClusterProvider(); err != nil { + var err error + s.p, err = internalk8s.NewProvider(s.configuration.StaticConfig) + if err != nil { + return nil, err + } + err = s.reloadToolsets() + if err != nil { return nil, err } - s.p.WatchTargets(s.reloadKubernetesClusterProvider) + s.p.WatchTargets(s.reloadToolsets) return s, nil } -func (s *Server) reloadKubernetesClusterProvider() error { +func (s *Server) reloadToolsets() error { ctx := context.Background() - newProvider, err := internalk8s.NewProvider(s.configuration.StaticConfig) + targets, err := s.p.GetTargets(ctx) if err != nil { return err } - targets, err := newProvider.GetTargets(ctx) - if err != nil { - newProvider.Close() - return err - } - - if s.p != nil { - s.p.Close() - } - - s.p = newProvider - - if err := s.rebuildTools(targets); err != nil { - return err - } - - s.p.WatchTargets(s.reloadKubernetesClusterProvider) - - return nil -} - -// rebuildTools rebuilds the MCP tool registry based on the current provider and targets. -// This is called after the provider has been successfully validated and set. -func (s *Server) rebuildTools(targets []string) error { filter := CompositeFilter( s.configuration.isToolApplicable, ShouldIncludeTargetListTool(s.p.GetTargetParameterName(), targets), @@ -170,7 +152,6 @@ func (s *Server) rebuildTools(targets []string) error { } s.server.RemoveTools(toolsToRemove...) - // Add new tools for _, tool := range applicableTools { goSdkTool, goSdkToolHandler, err := ServerToolToGoSdkTool(s, tool) if err != nil { @@ -178,7 +159,6 @@ func (s *Server) rebuildTools(targets []string) error { } s.server.AddTool(goSdkTool, goSdkToolHandler) } - return nil } diff --git a/pkg/mcp/mcp_watch_test.go b/pkg/mcp/mcp_watch_test.go index 68287279..b4f092f0 100644 --- a/pkg/mcp/mcp_watch_test.go +++ b/pkg/mcp/mcp_watch_test.go @@ -1,7 +1,6 @@ package mcp import ( - "context" "os" "testing" "time" @@ -35,31 +34,12 @@ func (s *WatchKubeConfigSuite) WriteKubeconfig() { _ = f.Close() } -// WaitForNotification waits for an MCP server notification or fails the test after a timeout -func (s *WatchKubeConfigSuite) WaitForNotification() *mcp.JSONRPCNotification { - withTimeout, cancel := context.WithTimeout(s.T().Context(), 5*time.Second) - defer cancel() - var notification *mcp.JSONRPCNotification - s.OnNotification(func(n mcp.JSONRPCNotification) { - notification = &n - }) - for notification == nil { - select { - case <-withTimeout.Done(): - s.FailNow("timeout waiting for WatchKubeConfig notification") - default: - time.Sleep(100 * time.Millisecond) - } - } - return notification -} - func (s *WatchKubeConfigSuite) TestNotifiesToolsChange() { // Given s.InitMcpClient() // When s.WriteKubeconfig() - notification := s.WaitForNotification() + notification := s.WaitForNotification(5 * time.Second) // Then s.NotNil(notification, "WatchKubeConfig did not notify") s.Equal("notifications/tools/list_changed", notification.Method, "WatchKubeConfig did not notify tools change") @@ -87,7 +67,7 @@ func (s *WatchKubeConfigSuite) TestClearsNoLongerAvailableTools() { // Reload Config without OpenShift s.mockServer.ResetHandlers() s.WriteKubeconfig() - s.WaitForNotification() + s.WaitForNotification(5 * time.Second) tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{}) s.Require().NoError(err, "call ListTools failed") @@ -101,3 +81,84 @@ func (s *WatchKubeConfigSuite) TestClearsNoLongerAvailableTools() { func TestWatchKubeConfig(t *testing.T) { suite.Run(t, new(WatchKubeConfigSuite)) } + +type WatchClusterStateSuite struct { + BaseMcpSuite + mockServer *test.MockServer + handler *test.DiscoveryClientHandler +} + +func (s *WatchClusterStateSuite) SetupTest() { + s.BaseMcpSuite.SetupTest() + // Configure fast polling for tests + s.T().Setenv("CLUSTER_STATE_POLL_INTERVAL_MS", "100") + s.T().Setenv("CLUSTER_STATE_DEBOUNCE_WINDOW_MS", "50") + s.mockServer = test.NewMockServer() + s.handler = &test.DiscoveryClientHandler{} + s.mockServer.Handle(s.handler) + s.Cfg.KubeConfig = s.mockServer.KubeconfigFile(s.T()) +} + +func (s *WatchClusterStateSuite) TearDownTest() { + s.BaseMcpSuite.TearDownTest() + if s.mockServer != nil { + s.mockServer.Close() + } +} + +func (s *WatchClusterStateSuite) AddAPIGroup(groupName string) { + s.handler.Groups = append(s.handler.Groups, groupName) +} + +func (s *WatchClusterStateSuite) TestNotifiesToolsChangeOnAPIGroupAddition() { + // Given - Initialize with basic API groups + s.InitMcpClient() + + // When - Add a new API group to simulate cluster state change + s.AddAPIGroup(`{"name":"custom.example.com","versions":[{"groupVersion":"custom.example.com/v1","version":"v1"}],"preferredVersion":{"groupVersion":"custom.example.com/v1","version":"v1"}}`) + + notification := s.WaitForNotification(10 * time.Second) + + // Then + s.NotNil(notification, "cluster state watcher did not notify") + s.Equal("notifications/tools/list_changed", notification.Method, "cluster state watcher did not notify tools change") +} + +func (s *WatchClusterStateSuite) TestDetectsOpenShiftClusterStateChange() { + s.InitMcpClient() + + s.Run("OpenShift tool is not available initially", func() { + tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{}) + s.Require().NoError(err, "call ListTools failed") + s.Require().NotNil(tools, "list tools failed") + for _, tool := range tools.Tools { + s.Require().Falsef(tool.Name == "projects_list", "expected OpenShift tool to not be available initially") + } + }) + + s.Run("OpenShift tool is added after cluster becomes OpenShift", func() { + // Simulate cluster becoming OpenShift by adding OpenShift API groups + s.mockServer.ResetHandlers() + s.mockServer.Handle(&test.InOpenShiftHandler{}) + + notification := s.WaitForNotification(10 * time.Second) + s.NotNil(notification, "cluster state watcher did not notify") + + tools, err := s.ListTools(s.T().Context(), mcp.ListToolsRequest{}) + s.Require().NoError(err, "call ListTools failed") + s.Require().NotNil(tools, "list tools failed") + + var found bool + for _, tool := range tools.Tools { + if tool.Name == "projects_list" { + found = true + break + } + } + s.Truef(found, "expected OpenShift tool to be available after cluster state change") + }) +} + +func TestWatchClusterState(t *testing.T) { + suite.Run(t, new(WatchClusterStateSuite)) +} diff --git a/pkg/openshift/openshift.go b/pkg/openshift/openshift.go new file mode 100644 index 00000000..b5743c37 --- /dev/null +++ b/pkg/openshift/openshift.go @@ -0,0 +1,14 @@ +package openshift + +import ( + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/discovery" +) + +func IsOpenshift(discoveryClient discovery.DiscoveryInterface) bool { + _, err := discoveryClient.ServerResourcesForGroupVersion(schema.GroupVersion{ + Group: "project.openshift.io", + Version: "v1", + }.String()) + return err == nil +}