Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/test/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import (
)

func McpInitRequest() mcp.InitializeRequest {
initRequest := mcp.InitializeRequest{}
initRequest := mcp.InitializeRequest{
Request: mcp.Request{Method: "initialize"},
}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.33.7"}
return initRequest
Expand Down
22 changes: 22 additions & 0 deletions internal/test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package test
import (
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -49,3 +50,24 @@ func WaitForServer(tcpAddr *net.TCPAddr) error {
}
return err
}

// WaitForHealthz waits for the /healthz endpoint to return a non-404 response
func WaitForHealthz(tcpAddr *net.TCPAddr) error {
url := fmt.Sprintf("http://%s/healthz", tcpAddr.String())
var resp *http.Response
var err error
for i := 0; i < 100; i++ {
resp, err = http.Get(url)
if err == nil {
_ = resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
return nil
}
}
time.Sleep(50 * time.Millisecond)
}
if err != nil {
return err
}
return fmt.Errorf("healthz endpoint returned 404 after retries")
}
13 changes: 11 additions & 2 deletions pkg/http/http_authorization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func (s *AuthorizationSuite) SetupTest() {
s.BaseHttpSuite.SetupTest()

// Capture logs
s.logBuffer.Reset()
s.klogState = klog.CaptureState()
flags := flag.NewFlagSet("test", flag.ContinueOnError)
klog.InitFlags(flags)
Expand Down Expand Up @@ -59,14 +60,14 @@ func (s *AuthorizationSuite) TearDownTest() {

func (s *AuthorizationSuite) StartClient(options ...transport.StreamableHTTPCOption) {
var err error
s.mcpClient, err = client.NewStreamableHttpClient(fmt.Sprintf("http://127.0.0.1:%d/mcp", s.TcpAddr.Port), options...)
s.mcpClient, err = client.NewStreamableHttpClient(fmt.Sprintf("http://127.0.0.1:%s/mcp", s.StaticConfig.Port), options...)
s.Require().NoError(err, "Expected no error creating Streamable HTTP MCP client")
err = s.mcpClient.Start(s.T().Context())
s.Require().NoError(err, "Expected no error starting Streamable HTTP MCP client")
}

func (s *AuthorizationSuite) HttpGet(authHeader string) *http.Response {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/mcp", s.TcpAddr.Port), nil)
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://127.0.0.1:%s/mcp", s.StaticConfig.Port), nil)
s.Require().NoError(err, "Failed to create request")
if authHeader != "" {
req.Header.Set("Authorization", authHeader)
Expand Down Expand Up @@ -339,6 +340,7 @@ func (s *AuthorizationSuite) TestAuthorizationRawToken() {
for _, c := range cases {
s.StaticConfig.OAuthAudience = c.audience
s.StaticConfig.ValidateToken = c.validateToken
s.logBuffer.Reset()
s.StartServer()
s.StartClient(transport.WithHTTPHeaders(map[string]string{
"Authorization": "Bearer " + tokenBasicNotExpired,
Expand All @@ -362,7 +364,9 @@ func (s *AuthorizationSuite) TestAuthorizationRawToken() {
})
})
_ = s.mcpClient.Close()
s.mcpClient = nil
s.StopServer()
s.Require().NoError(s.WaitForShutdown())
}
}

Expand Down Expand Up @@ -407,7 +411,9 @@ func (s *AuthorizationSuite) TestAuthorizationOidcToken() {
})
})
_ = s.mcpClient.Close()
s.mcpClient = nil
s.StopServer()
s.Require().NoError(s.WaitForShutdown())
}
}

Expand Down Expand Up @@ -440,6 +446,7 @@ func (s *AuthorizationSuite) TestAuthorizationOidcTokenExchange() {
s.StaticConfig.StsClientSecret = "test-sts-client-secret"
s.StaticConfig.StsAudience = "backend-audience"
s.StaticConfig.StsScopes = []string{"backend-scope"}
s.logBuffer.Reset()
s.StartServer()
s.StartClient(transport.WithHTTPHeaders(map[string]string{
"Authorization": "Bearer " + validOidcClientToken,
Expand All @@ -463,7 +470,9 @@ func (s *AuthorizationSuite) TestAuthorizationOidcTokenExchange() {
})
})
_ = s.mcpClient.Close()
s.mcpClient = nil
s.StopServer()
s.Require().NoError(s.WaitForShutdown())
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/http/http_mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (s *McpTransportSuite) TearDownTest() {
}

func (s *McpTransportSuite) TestSseTransport() {
sseClient, sseClientErr := client.NewSSEMCPClient(fmt.Sprintf("http://127.0.0.1:%d/sse", s.TcpAddr.Port))
sseClient, sseClientErr := client.NewSSEMCPClient(fmt.Sprintf("http://127.0.0.1:%s/sse", s.StaticConfig.Port))
s.Require().NoError(sseClientErr, "Expected no error creating SSE MCP client")
startErr := sseClient.Start(s.T().Context())
s.Require().NoError(startErr, "Expected no error starting SSE MCP client")
Expand All @@ -44,7 +44,7 @@ func (s *McpTransportSuite) TestSseTransport() {
}

func (s *McpTransportSuite) TestStreamableHttpTransport() {
httpClient, httpClientErr := client.NewStreamableHttpClient(fmt.Sprintf("http://127.0.0.1:%d/mcp", s.TcpAddr.Port), transport.WithContinuousListening())
httpClient, httpClientErr := client.NewStreamableHttpClient(fmt.Sprintf("http://127.0.0.1:%s/mcp", s.StaticConfig.Port), transport.WithContinuousListening())
s.Require().NoError(httpClientErr, "Expected no error creating Streamable HTTP MCP client")
startErr := httpClient.Start(s.T().Context())
s.Require().NoError(startErr, "Expected no error starting Streamable HTTP MCP client")
Expand Down
15 changes: 8 additions & 7 deletions pkg/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
type BaseHttpSuite struct {
suite.Suite
MockServer *test.MockServer
TcpAddr *net.TCPAddr
StaticConfig *config.StaticConfig
mcpServer *mcp.Server
OidcProvider *oidc.Provider
Expand All @@ -43,18 +42,19 @@ type BaseHttpSuite struct {
}

func (s *BaseHttpSuite) SetupTest() {
var err error
http.DefaultClient.Timeout = 10 * time.Second
s.MockServer = test.NewMockServer()
s.TcpAddr, err = test.RandomPortAddress()
s.Require().NoError(err, "Expected no error getting random port address")
s.MockServer.Handle(&test.DiscoveryClientHandler{})
s.StaticConfig = config.Default()
s.StaticConfig.KubeConfig = s.MockServer.KubeconfigFile(s.T())
s.StaticConfig.Port = strconv.Itoa(s.TcpAddr.Port)
}

func (s *BaseHttpSuite) StartServer() {
var err error

tcpAddr, err := test.RandomPortAddress()
s.Require().NoError(err, "Expected no error getting random port address")
s.StaticConfig.Port = strconv.Itoa(tcpAddr.Port)

s.mcpServer, err = mcp.NewServer(mcp.Configuration{StaticConfig: s.StaticConfig})
s.Require().NoError(err, "Expected no error creating MCP server")
s.Require().NotNil(s.mcpServer, "MCP server should not be nil")
Expand All @@ -64,7 +64,8 @@ func (s *BaseHttpSuite) StartServer() {
cancelCtx, s.StopServer = context.WithCancel(gc)
group.Go(func() error { return Serve(cancelCtx, s.mcpServer, s.StaticConfig, s.OidcProvider, nil) })
s.WaitForShutdown = group.Wait
s.Require().NoError(test.WaitForServer(s.TcpAddr), "HTTP server did not start in time")
s.Require().NoError(test.WaitForServer(tcpAddr), "HTTP server did not start in time")
s.Require().NoError(test.WaitForHealthz(tcpAddr), "HTTP server /healthz endpoint did not respond with non-404 in time")
}

func (s *BaseHttpSuite) TearDownTest() {
Expand Down