From 2a4d6d563234e2bc9b1670a93b6e1b078005c055 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <209825114+claude[bot]@users.noreply.github.com> Date: Thu, 28 Aug 2025 03:19:53 +0000 Subject: [PATCH] feat(a2a): Add functions to list and fetch A2A agents - Add ListAgents() function to list all available A2A agents - Add GetAgent() function to fetch specific agent by ID - Download latest OpenAPI spec with A2A endpoints - Generate updated types with A2AAgentCard and ListAgentsResponse - Add comprehensive tests with table-driven scenarios - Include error handling for 403, 401, 404, 500 status codes - Add timeout and cancellation support - Follow existing SDK patterns and conventions - Include detailed documentation with usage examples Resolves #24 Co-authored-by: Eden Reich --- sdk.go | 102 +++++++++++++ sdk_test.go | 430 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 532 insertions(+) diff --git a/sdk.go b/sdk.go index 54d9a9a..ee11094 100644 --- a/sdk.go +++ b/sdk.go @@ -28,6 +28,8 @@ type Client interface { ListModels(ctx context.Context) (*ListModelsResponse, error) ListProviderModels(ctx context.Context, provider Provider) (*ListModelsResponse, error) ListTools(ctx context.Context) (*ListToolsResponse, error) + ListAgents(ctx context.Context) (*ListAgentsResponse, error) + GetAgent(ctx context.Context, id string) (*A2AAgentCard, error) GenerateContent(ctx context.Context, provider Provider, model string, messages []Message) (*CreateChatCompletionResponse, error) GenerateContentStream(ctx context.Context, provider Provider, model string, messages []Message) (<-chan SSEvent, error) HealthCheck(ctx context.Context) error @@ -549,6 +551,106 @@ func (c *clientImpl) ListTools(ctx context.Context) (*ListToolsResponse, error) return result, nil } +// ListAgents returns all available A2A agents. +// Only accessible when EXPOSE_A2A is enabled on the server. +// +// Example: +// +// client := sdk.NewClient(&sdk.ClientOptions{ +// BaseURL: "http://localhost:8080/v1", +// APIKey: "your-api-key", +// }) +// ctx := context.Background() +// agents, err := client.ListAgents(ctx) +// if err != nil { +// log.Fatalf("Error listing agents: %v", err) +// } +// fmt.Printf("Available agents: %+v\n", agents.Data) +func (c *clientImpl) ListAgents(ctx context.Context) (*ListAgentsResponse, error) { + resp, err := c.executeWithRetry(ctx, func() (*resty.Response, error) { + return c.http.R(). + SetContext(ctx). + SetResult(&ListAgentsResponse{}). + Get(fmt.Sprintf("%s/a2a/agents", c.baseURL)) + }) + + if err != nil { + return nil, err + } + + if resp.IsError() { + var errorResp Error + if err := json.Unmarshal(resp.Body(), &errorResp); err == nil && errorResp.Error != nil { + return nil, fmt.Errorf("API error: %s (status code: %d)", *errorResp.Error, resp.StatusCode()) + } + + errMsg := fmt.Sprintf("failed to list A2A agents, status code: %d", resp.StatusCode()) + + if len(resp.Body()) > 0 { + errMsg = fmt.Sprintf("%s, response body: %s", errMsg, string(resp.Body())) + } + + return nil, fmt.Errorf("%s", errMsg) + } + + result, ok := resp.Result().(*ListAgentsResponse) + if !ok || result == nil { + return nil, fmt.Errorf("failed to parse response") + } + + return result, nil +} + +// GetAgent returns a specific A2A agent by its unique identifier. +// Only accessible when EXPOSE_A2A is enabled on the server. +// +// Example: +// +// client := sdk.NewClient(&sdk.ClientOptions{ +// BaseURL: "http://localhost:8080/v1", +// APIKey: "your-api-key", +// }) +// ctx := context.Background() +// agent, err := client.GetAgent(ctx, "agent-id-123") +// if err != nil { +// log.Fatalf("Error getting agent: %v", err) +// } +// fmt.Printf("Agent details: %+v\n", agent) +func (c *clientImpl) GetAgent(ctx context.Context, id string) (*A2AAgentCard, error) { + resp, err := c.executeWithRetry(ctx, func() (*resty.Response, error) { + return c.http.R(). + SetContext(ctx). + SetResult(&A2AAgentCard{}). + Get(fmt.Sprintf("%s/a2a/agents/%s", c.baseURL, id)) + }) + + if err != nil { + return nil, err + } + + if resp.IsError() { + var errorResp Error + if err := json.Unmarshal(resp.Body(), &errorResp); err == nil && errorResp.Error != nil { + return nil, fmt.Errorf("API error: %s (status code: %d)", *errorResp.Error, resp.StatusCode()) + } + + errMsg := fmt.Sprintf("failed to get A2A agent, status code: %d", resp.StatusCode()) + + if len(resp.Body()) > 0 { + errMsg = fmt.Sprintf("%s, response body: %s", errMsg, string(resp.Body())) + } + + return nil, fmt.Errorf("%s", errMsg) + } + + result, ok := resp.Result().(*A2AAgentCard) + if !ok || result == nil { + return nil, fmt.Errorf("failed to parse response") + } + + return result, nil +} + // GenerateContent generates content using the specified provider and model. // // Example: diff --git a/sdk_test.go b/sdk_test.go index e7291a6..c842738 100644 --- a/sdk_test.go +++ b/sdk_test.go @@ -1966,3 +1966,433 @@ func TestRetryConfigWithNilCallback(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 2, callCount) } + +func TestListAgents(t *testing.T) { + tests := []struct { + name string + mockResponse ListAgentsResponse + expectedAgents int + }{ + { + name: "successful agents listing", + mockResponse: ListAgentsResponse{ + Object: "list", + Data: []A2AAgentCard{ + { + Id: "agent1", + Name: "Test Agent 1", + Description: "A test A2A agent for development", + Version: "1.0.0", + Url: "https://agent1.example.com", + Capabilities: map[string]interface{}{"chat": true, "reasoning": true}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text", "json"}, + Skills: []map[string]interface{}{{"name": "chat", "type": "conversation"}}, + }, + { + Id: "agent2", + Name: "Test Agent 2", + Description: "Another test A2A agent", + Version: "2.1.0", + Url: "https://agent2.example.com", + Capabilities: map[string]interface{}{"analysis": true, "reporting": true}, + DefaultInputModes: []string{"text", "image"}, + DefaultOutputModes: []string{"text", "pdf"}, + Skills: []map[string]interface{}{{"name": "analysis", "type": "data_processing"}}, + }, + }, + }, + expectedAgents: 2, + }, + { + name: "empty agents list", + mockResponse: ListAgentsResponse{ + Object: "list", + Data: []A2AAgentCard{}, + }, + expectedAgents: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/a2a/agents", r.URL.Path, "Path should be /v1/a2a/agents") + assert.Equal(t, http.MethodGet, r.Method, "Method should be GET") + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(tt.mockResponse) + assert.NoError(t, err) + })) + defer server.Close() + + baseURL := server.URL + "/v1" + client := NewClient(&ClientOptions{ + BaseURL: baseURL, + }) + + ctx := context.Background() + agents, err := client.ListAgents(ctx) + + assert.NoError(t, err) + assert.NotNil(t, agents) + assert.Equal(t, "list", agents.Object) + assert.Len(t, agents.Data, tt.expectedAgents) + + if tt.expectedAgents > 0 { + assert.Equal(t, "agent1", agents.Data[0].Id) + assert.Equal(t, "Test Agent 1", agents.Data[0].Name) + assert.Equal(t, "A test A2A agent for development", agents.Data[0].Description) + assert.Equal(t, "1.0.0", agents.Data[0].Version) + assert.Equal(t, "https://agent1.example.com", agents.Data[0].Url) + assert.NotNil(t, agents.Data[0].Capabilities) + assert.Contains(t, agents.Data[0].DefaultInputModes, "text") + assert.Contains(t, agents.Data[0].DefaultOutputModes, "text") + } + }) + } +} + +func TestListAgents_APIError(t *testing.T) { + tests := []struct { + name string + statusCode int + responseBody map[string]interface{} + expectedError string + }{ + { + name: "A2A not exposed", + statusCode: http.StatusForbidden, + responseBody: map[string]interface{}{ + "error": "A2A agents endpoint is not exposed. Set EXPOSE_A2A=true to enable.", + }, + expectedError: "API error", + }, + { + name: "unauthorized", + statusCode: http.StatusUnauthorized, + responseBody: map[string]interface{}{ + "error": "Unauthorized access", + }, + expectedError: "API error", + }, + { + name: "internal server error", + statusCode: http.StatusInternalServerError, + responseBody: map[string]interface{}{ + "error": "Internal server error", + }, + expectedError: "HTTP 500", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + err := json.NewEncoder(w).Encode(tt.responseBody) + assert.NoError(t, err) + })) + defer server.Close() + + baseURL := server.URL + "/v1" + client := NewClient(&ClientOptions{ + BaseURL: baseURL, + }) + + ctx := context.Background() + agents, err := client.ListAgents(ctx) + + assert.Error(t, err) + assert.Nil(t, agents) + assert.Contains(t, err.Error(), tt.expectedError) + }) + } +} + +func TestGetAgent(t *testing.T) { + tests := []struct { + name string + agentID string + mockResponse A2AAgentCard + }{ + { + name: "successful agent retrieval", + agentID: "test-agent-id", + mockResponse: A2AAgentCard{ + Id: "test-agent-id", + Name: "Detailed Test Agent", + Description: "A comprehensive test agent with full details", + Version: "3.2.1", + Url: "https://detailed-agent.example.com", + IconUrl: stringPtr("https://detailed-agent.example.com/icon.png"), + DocumentationUrl: stringPtr("https://detailed-agent.example.com/docs"), + Capabilities: map[string]interface{}{ + "chat": true, + "reasoning": true, + "analysis": true, + "vision": true, + }, + DefaultInputModes: []string{"text", "image", "audio"}, + DefaultOutputModes: []string{"text", "json", "image"}, + Skills: []map[string]interface{}{ + {"name": "conversation", "type": "chat", "enabled": true}, + {"name": "document_analysis", "type": "analysis", "enabled": true}, + {"name": "image_processing", "type": "vision", "enabled": true}, + }, + Provider: &map[string]interface{}{ + "name": "Test Provider", + "version": "1.0", + "url": "https://provider.example.com", + }, + Security: &[]map[string]interface{}{ + {"type": "bearer", "scheme": "JWT"}, + }, + SecuritySchemes: &map[string]interface{}{ + "bearerAuth": map[string]interface{}{ + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + }, + }, + SupportsAuthenticatedExtendedCard: boolPtr(true), + }, + }, + { + name: "minimal agent data", + agentID: "minimal-agent", + mockResponse: A2AAgentCard{ + Id: "minimal-agent", + Name: "Minimal Agent", + Description: "Basic agent with minimal configuration", + Version: "1.0.0", + Url: "https://minimal.example.com", + Capabilities: map[string]interface{}{"basic": true}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, + Skills: []map[string]interface{}{{"name": "basic", "type": "simple"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := fmt.Sprintf("/v1/a2a/agents/%s", tt.agentID) + assert.Equal(t, expectedPath, r.URL.Path, "Path should match agent ID") + assert.Equal(t, http.MethodGet, r.Method, "Method should be GET") + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(tt.mockResponse) + assert.NoError(t, err) + })) + defer server.Close() + + baseURL := server.URL + "/v1" + client := NewClient(&ClientOptions{ + BaseURL: baseURL, + }) + + ctx := context.Background() + agent, err := client.GetAgent(ctx, tt.agentID) + + assert.NoError(t, err) + assert.NotNil(t, agent) + assert.Equal(t, tt.mockResponse.Id, agent.Id) + assert.Equal(t, tt.mockResponse.Name, agent.Name) + assert.Equal(t, tt.mockResponse.Description, agent.Description) + assert.Equal(t, tt.mockResponse.Version, agent.Version) + assert.Equal(t, tt.mockResponse.Url, agent.Url) + assert.Equal(t, tt.mockResponse.Capabilities, agent.Capabilities) + assert.Equal(t, tt.mockResponse.DefaultInputModes, agent.DefaultInputModes) + assert.Equal(t, tt.mockResponse.DefaultOutputModes, agent.DefaultOutputModes) + + if tt.name == "successful agent retrieval" { + assert.Equal(t, tt.mockResponse.IconUrl, agent.IconUrl) + assert.Equal(t, tt.mockResponse.DocumentationUrl, agent.DocumentationUrl) + assert.Equal(t, tt.mockResponse.Provider, agent.Provider) + assert.Equal(t, tt.mockResponse.Security, agent.Security) + assert.Equal(t, tt.mockResponse.SecuritySchemes, agent.SecuritySchemes) + assert.Equal(t, tt.mockResponse.SupportsAuthenticatedExtendedCard, agent.SupportsAuthenticatedExtendedCard) + } + }) + } +} + +func TestGetAgent_APIError(t *testing.T) { + tests := []struct { + name string + agentID string + statusCode int + responseBody map[string]interface{} + expectedError string + }{ + { + name: "agent not found", + agentID: "nonexistent-agent", + statusCode: http.StatusNotFound, + responseBody: map[string]interface{}{ + "error": "Agent not found", + }, + expectedError: "API error", + }, + { + name: "A2A not exposed", + agentID: "test-agent", + statusCode: http.StatusForbidden, + responseBody: map[string]interface{}{ + "error": "A2A agents endpoint is not exposed. Set EXPOSE_A2A=true to enable.", + }, + expectedError: "API error", + }, + { + name: "unauthorized", + agentID: "test-agent", + statusCode: http.StatusUnauthorized, + responseBody: map[string]interface{}{ + "error": "Unauthorized access", + }, + expectedError: "API error", + }, + { + name: "internal server error", + agentID: "test-agent", + statusCode: http.StatusInternalServerError, + responseBody: map[string]interface{}{ + "error": "Internal server error", + }, + expectedError: "HTTP 500", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + err := json.NewEncoder(w).Encode(tt.responseBody) + assert.NoError(t, err) + })) + defer server.Close() + + baseURL := server.URL + "/v1" + client := NewClient(&ClientOptions{ + BaseURL: baseURL, + }) + + ctx := context.Background() + agent, err := client.GetAgent(ctx, tt.agentID) + + assert.Error(t, err) + assert.Nil(t, agent) + assert.Contains(t, err.Error(), tt.expectedError) + }) + } +} + +func TestGetAgent_InvalidID(t *testing.T) { + tests := []struct { + name string + agentID string + expected string + }{ + { + name: "empty agent ID", + agentID: "", + expected: "a2a/agents/", + }, + { + name: "agent ID with special characters", + agentID: "agent@123test", + expected: "a2a/agents/agent@123test", + }, + { + name: "agent ID with spaces", + agentID: "agent with spaces", + expected: "a2a/agents/agent with spaces", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.Path, tt.expected, "Path should contain expected agent ID") + assert.Equal(t, http.MethodGet, r.Method, "Method should be GET") + + agent := A2AAgentCard{ + Id: tt.agentID, + Name: "Test Agent", + Description: "Test agent description", + Version: "1.0.0", + Url: "https://test.example.com", + Capabilities: map[string]interface{}{"test": true}, + DefaultInputModes: []string{"text"}, + DefaultOutputModes: []string{"text"}, + Skills: []map[string]interface{}{{"name": "test", "type": "basic"}}, + } + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(agent) + assert.NoError(t, err) + })) + defer server.Close() + + baseURL := server.URL + "/v1" + client := NewClient(&ClientOptions{ + BaseURL: baseURL, + }) + + ctx := context.Background() + agent, err := client.GetAgent(ctx, tt.agentID) + + assert.NoError(t, err) + assert.NotNil(t, agent) + assert.Equal(t, tt.agentID, agent.Id) + }) + } +} + +func TestA2AWithTimeout(t *testing.T) { + tests := []struct { + name string + function string + }{ + { + name: "ListAgents with timeout", + function: "ListAgents", + }, + { + name: "GetAgent with timeout", + function: "GetAgent", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + baseURL := server.URL + "/v1" + client := NewClient(&ClientOptions{ + BaseURL: baseURL, + Timeout: 100 * time.Millisecond, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + var err error + if tt.function == "ListAgents" { + _, err = client.ListAgents(ctx) + } else { + _, err = client.GetAgent(ctx, "test-agent") + } + + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") + }) + } +} +