diff --git a/internal/api/handler/agents.go b/internal/api/handler/agents.go new file mode 100644 index 00000000..ee07b492 --- /dev/null +++ b/internal/api/handler/agents.go @@ -0,0 +1,456 @@ +package handler + +import ( + "crypto/rand" + "encoding/base64" + "errors" + "net/http" + "strings" + "time" + + "github.com/compliance-framework/api/internal/api" + "github.com/compliance-framework/api/internal/service/relational" + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "go.uber.org/zap" + "gorm.io/gorm" +) + +type AgentHandler struct { + sugar *zap.SugaredLogger + db *gorm.DB +} + +type agentResponse struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created-at"` + UpdatedAt time.Time `json:"updated-at"` + Name string `json:"name"` + Description *string `json:"description,omitempty"` + IsActive bool `json:"is-active"` + LastAuthenticatedAt *time.Time `json:"last-authenticated-at,omitempty"` + ServiceAccountKeys int64 `json:"service-account-key-count"` +} + +type agentKeyResponse struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created-at"` + UpdatedAt time.Time `json:"updated-at"` + Name *string `json:"name,omitempty"` + ClientID string `json:"client-id"` + LastUsedAt *time.Time `json:"last-used-at,omitempty"` + ExpiresAt *time.Time `json:"expires-at,omitempty"` + NeverExpires bool `json:"never-expires"` + RevokedAt *time.Time `json:"revoked-at,omitempty"` +} + +type agentKeyCreateResponse struct { + agentKeyResponse + ClientSecret string `json:"client-secret"` +} + +type createAgentRequest struct { + Name string `json:"name" validate:"required"` + Description *string `json:"description"` + IsActive *bool `json:"is-active"` +} + +type updateAgentRequest struct { + Name *string `json:"name"` + Description *string `json:"description"` + IsActive *bool `json:"is-active"` +} + +type createAgentKeyRequest struct { + Name *string `json:"name"` + ExpiresAt *time.Time `json:"expires-at,omitempty"` + NeverExpires bool `json:"never-expires,omitempty"` +} + +func NewAgentHandler(sugar *zap.SugaredLogger, db *gorm.DB) *AgentHandler { + return &AgentHandler{sugar: sugar, db: db} +} + +func (h *AgentHandler) Register(api *echo.Group) { + api.GET("", h.ListAgents) + api.POST("", h.CreateAgent) + api.GET("/:id", h.GetAgent) + api.PUT("/:id", h.UpdateAgent) + api.DELETE("/:id", h.DeleteAgent) + api.POST("/:id/keys", h.CreateAgentKey) + api.GET("/:id/keys", h.ListAgentKeys) + api.GET("/:id/keys/:keyId", h.GetAgentKey) + api.DELETE("/:id/keys/:keyId", h.DeleteAgentKey) +} + +func (h *AgentHandler) ListAgents(ctx echo.Context) error { + var agents []relational.Agent + if err := h.db.Order("created_at asc").Find(&agents).Error; err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + keyCounts, err := h.listAgentKeyCounts(agents) + if err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + resp := make([]agentResponse, 0, len(agents)) + for _, agent := range agents { + item := buildAgentResponse(&agent, keyCounts[agent.ID.String()]) + resp = append(resp, item) + } + + return ctx.JSON(http.StatusOK, GenericDataListResponse[agentResponse]{Data: resp}) +} + +func (h *AgentHandler) GetAgent(ctx echo.Context) error { + agent, err := h.getAgentByParam(ctx.Param("id")) + if err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + if agent == nil { + return ctx.JSON(http.StatusNotFound, api.NewError(gorm.ErrRecordNotFound)) + } + + resp, err := h.buildAgentResponse(agent) + if err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + return ctx.JSON(http.StatusOK, GenericDataResponse[agentResponse]{Data: resp}) +} + +func (h *AgentHandler) CreateAgent(ctx echo.Context) error { + var req createAgentRequest + if err := ctx.Bind(&req); err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + if err := ctx.Validate(&req); err != nil { + return ctx.JSON(http.StatusBadRequest, api.Validator(err)) + } + + agent := &relational.Agent{ + Name: req.Name, + Description: req.Description, + IsActive: true, + } + if req.IsActive != nil { + agent.IsActive = *req.IsActive + } + + if err := h.db.Create(agent).Error; err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + resp, err := h.buildAgentResponse(agent) + if err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + return ctx.JSON(http.StatusCreated, GenericDataResponse[agentResponse]{Data: resp}) +} + +func (h *AgentHandler) UpdateAgent(ctx echo.Context) error { + agent, err := h.getAgentByParam(ctx.Param("id")) + if err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + if agent == nil { + return ctx.JSON(http.StatusNotFound, api.NewError(gorm.ErrRecordNotFound)) + } + + var req updateAgentRequest + if err := ctx.Bind(&req); err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + + if req.Name != nil { + agent.Name = *req.Name + } + if req.Description != nil { + agent.Description = req.Description + } + if req.IsActive != nil { + agent.IsActive = *req.IsActive + } + + if err := h.db.Save(agent).Error; err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + resp, err := h.buildAgentResponse(agent) + if err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + return ctx.JSON(http.StatusOK, GenericDataResponse[agentResponse]{Data: resp}) +} + +func (h *AgentHandler) DeleteAgent(ctx echo.Context) error { + agent, err := h.getAgentByParam(ctx.Param("id")) + if err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + if agent == nil { + return ctx.JSON(http.StatusNotFound, api.NewError(gorm.ErrRecordNotFound)) + } + if err := h.db.Transaction(func(tx *gorm.DB) error { + now := time.Now().UTC() + + if err := tx.Model(agent). + Update("is_active", false).Error; err != nil { + return err + } + + if err := tx.Model(&relational.AgentServiceAccountKey{}). + Where("agent_id = ? AND revoked_at IS NULL", *agent.ID). + Update("revoked_at", now).Error; err != nil { + return err + } + + if err := tx.Delete(agent).Error; err != nil { + return err + } + + return nil + }); err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + return ctx.NoContent(http.StatusNoContent) +} + +func (h *AgentHandler) CreateAgentKey(ctx echo.Context) error { + agent, err := h.getAgentByParam(ctx.Param("id")) + if err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + if agent == nil { + return ctx.JSON(http.StatusNotFound, api.NewError(gorm.ErrRecordNotFound)) + } + + var req createAgentKeyRequest + if err := ctx.Bind(&req); err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + + expiresAt, err := normalizeAgentKeyExpiry(req.ExpiresAt, req.NeverExpires) + if err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + + clientSecret, err := generateAgentSecret() + if err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + key := &relational.AgentServiceAccountKey{ + AgentID: agent.ID, + Name: normalizeOptionalString(req.Name), + ClientID: uuid.NewString(), + ExpiresAt: expiresAt, + } + if err := key.SetSecret(clientSecret); err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + if err := h.db.Create(key).Error; err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + return ctx.JSON(http.StatusCreated, GenericDataResponse[agentKeyCreateResponse]{ + Data: agentKeyCreateResponse{ + agentKeyResponse: buildAgentKeyResponse(key), + ClientSecret: clientSecret, + }, + }) +} + +func (h *AgentHandler) ListAgentKeys(ctx echo.Context) error { + agent, err := h.getAgentByParam(ctx.Param("id")) + if err != nil { + return ctx.JSON(http.StatusBadRequest, api.NewError(err)) + } + if agent == nil { + return ctx.JSON(http.StatusNotFound, api.NewError(gorm.ErrRecordNotFound)) + } + + var keys []relational.AgentServiceAccountKey + if err := h.db.Where("agent_id = ?", *agent.ID).Order("created_at asc").Find(&keys).Error; err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + resp := make([]agentKeyResponse, 0, len(keys)) + for _, key := range keys { + resp = append(resp, buildAgentKeyResponse(&key)) + } + return ctx.JSON(http.StatusOK, GenericDataListResponse[agentKeyResponse]{Data: resp}) +} + +func (h *AgentHandler) GetAgentKey(ctx echo.Context) error { + key, status, err := h.getAgentKey(ctx.Param("id"), ctx.Param("keyId")) + if err != nil { + return ctx.JSON(status, api.NewError(err)) + } + return ctx.JSON(http.StatusOK, GenericDataResponse[agentKeyResponse]{Data: buildAgentKeyResponse(key)}) +} + +func (h *AgentHandler) DeleteAgentKey(ctx echo.Context) error { + key, status, err := h.getAgentKey(ctx.Param("id"), ctx.Param("keyId")) + if err != nil { + return ctx.JSON(status, api.NewError(err)) + } + now := time.Now().UTC() + key.RevokedAt = &now + if err := h.db.Save(key).Error; err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + return ctx.NoContent(http.StatusNoContent) +} + +func (h *AgentHandler) getAgentByParam(agentID string) (*relational.Agent, error) { + agentUUID, err := uuid.Parse(agentID) + if err != nil { + return nil, err + } + var agent relational.Agent + if err := h.db.First(&agent, agentUUID).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return &agent, nil +} + +func (h *AgentHandler) getAgentKey(agentID, keyID string) (*relational.AgentServiceAccountKey, int, error) { + agent, err := h.getAgentByParam(agentID) + if err != nil { + return nil, http.StatusBadRequest, err + } + if agent == nil { + return nil, http.StatusNotFound, gorm.ErrRecordNotFound + } + keyUUID, err := uuid.Parse(keyID) + if err != nil { + return nil, http.StatusBadRequest, err + } + var key relational.AgentServiceAccountKey + if err := h.db.Where("agent_id = ?", *agent.ID).First(&key, keyUUID).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, http.StatusNotFound, err + } + return nil, http.StatusInternalServerError, err + } + return &key, http.StatusOK, nil +} + +func (h *AgentHandler) buildAgentResponse(agent *relational.Agent) (agentResponse, error) { + keyCounts, err := h.listAgentKeyCounts([]relational.Agent{*agent}) + if err != nil { + return agentResponse{}, err + } + + return buildAgentResponse(agent, keyCounts[agent.ID.String()]), nil +} + +func buildAgentResponse(agent *relational.Agent, keyCount int64) agentResponse { + return agentResponse{ + ID: agent.ID.String(), + CreatedAt: agent.CreatedAt, + UpdatedAt: agent.UpdatedAt, + Name: agent.Name, + Description: agent.Description, + IsActive: agent.IsActive, + LastAuthenticatedAt: agent.LastAuthenticatedAt, + ServiceAccountKeys: keyCount, + } +} + +func (h *AgentHandler) listAgentKeyCounts(agents []relational.Agent) (map[string]int64, error) { + counts := make(map[string]int64, len(agents)) + if len(agents) == 0 { + return counts, nil + } + + agentIDs := make([]uuid.UUID, 0, len(agents)) + for _, agent := range agents { + if agent.ID == nil { + continue + } + agentIDs = append(agentIDs, *agent.ID) + counts[agent.ID.String()] = 0 + } + if len(agentIDs) == 0 { + return counts, nil + } + + type keyCountResult struct { + AgentID uuid.UUID + KeyCount int64 + } + + var results []keyCountResult + if err := h.db.Model(&relational.AgentServiceAccountKey{}). + Select("agent_id, COUNT(*) AS key_count"). + Where("agent_id IN ? AND revoked_at IS NULL", agentIDs). + Group("agent_id"). + Scan(&results).Error; err != nil { + return nil, err + } + + for _, result := range results { + counts[result.AgentID.String()] = result.KeyCount + } + + return counts, nil +} + +func buildAgentKeyResponse(key *relational.AgentServiceAccountKey) agentKeyResponse { + return agentKeyResponse{ + ID: key.ID.String(), + CreatedAt: key.CreatedAt, + UpdatedAt: key.UpdatedAt, + Name: key.Name, + ClientID: key.ClientID, + LastUsedAt: key.LastUsedAt, + ExpiresAt: key.ExpiresAt, + NeverExpires: key.ExpiresAt == nil, + RevokedAt: key.RevokedAt, + } +} + +func generateAgentSecret() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func normalizeOptionalString(value *string) *string { + if value == nil { + return nil + } + + trimmed := strings.TrimSpace(*value) + if trimmed == "" { + return nil + } + + return &trimmed +} + +func normalizeAgentKeyExpiry(expiresAt *time.Time, neverExpires bool) (*time.Time, error) { + if neverExpires && expiresAt != nil { + return nil, errors.New("expires-at cannot be combined with never-expires") + } + if expiresAt == nil { + if neverExpires { + return nil, nil + } + return nil, errors.New("expires-at is required unless never-expires is true") + } + + normalized := expiresAt.UTC() + if !normalized.After(time.Now().UTC()) { + return nil, errors.New("expires-at must be in the future") + } + + return &normalized, nil +} diff --git a/internal/api/handler/agents_integration_test.go b/internal/api/handler/agents_integration_test.go new file mode 100644 index 00000000..c5fd1e1b --- /dev/null +++ b/internal/api/handler/agents_integration_test.go @@ -0,0 +1,216 @@ +//go:build integration + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/compliance-framework/api/internal/api" + "github.com/compliance-framework/api/internal/service/relational" + "github.com/compliance-framework/api/internal/tests" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" +) + +type AgentAPIIntegrationSuite struct { + tests.IntegrationTestSuite + server *api.Server +} + +func TestAgentAPI(t *testing.T) { + suite.Run(t, new(AgentAPIIntegrationSuite)) +} + +func (suite *AgentAPIIntegrationSuite) SetupTest() { + err := suite.Migrator.Refresh() + suite.Require().NoError(err) + + logger, _ := zap.NewDevelopment() + metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) + suite.server = api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) + RegisterHandlers(suite.server, logger.Sugar(), suite.DB, suite.Config, &APIServices{}) +} + +func (suite *AgentAPIIntegrationSuite) authedRequest(method, path string, body any) (*httptest.ResponseRecorder, *http.Request) { + token, err := suite.GetAuthToken() + suite.Require().NoError(err) + + payload := []byte{} + if body != nil { + data, marshalErr := json.Marshal(body) + suite.Require().NoError(marshalErr) + payload = data + } + rec := httptest.NewRecorder() + req := httptest.NewRequest(method, path, bytes.NewReader(payload)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", *token)) + return rec, req +} + +func (suite *AgentAPIIntegrationSuite) TestAgentCRUDAndKeys() { + createRec, createReq := suite.authedRequest(http.MethodPost, "/api/admin/agents", map[string]any{ + "name": "agent-one", + "description": "integration agent", + }) + suite.server.E().ServeHTTP(createRec, createReq) + require.Equal(suite.T(), http.StatusCreated, createRec.Code) + + var created GenericDataResponse[agentResponse] + require.NoError(suite.T(), json.Unmarshal(createRec.Body.Bytes(), &created)) + require.Equal(suite.T(), "agent-one", created.Data.Name) + require.Equal(suite.T(), int64(0), created.Data.ServiceAccountKeys) + + listRec, listReq := suite.authedRequest(http.MethodGet, "/api/admin/agents", nil) + suite.server.E().ServeHTTP(listRec, listReq) + require.Equal(suite.T(), http.StatusOK, listRec.Code) + + var listed GenericDataListResponse[agentResponse] + require.NoError(suite.T(), json.Unmarshal(listRec.Body.Bytes(), &listed)) + require.Len(suite.T(), listed.Data, 1) + + getRec, getReq := suite.authedRequest(http.MethodGet, fmt.Sprintf("/api/admin/agents/%s", created.Data.ID), nil) + suite.server.E().ServeHTTP(getRec, getReq) + require.Equal(suite.T(), http.StatusOK, getRec.Code) + + updateRec, updateReq := suite.authedRequest(http.MethodPut, fmt.Sprintf("/api/admin/agents/%s", created.Data.ID), map[string]any{ + "name": "agent-one-updated", + "is-active": false, + }) + suite.server.E().ServeHTTP(updateRec, updateReq) + require.Equal(suite.T(), http.StatusOK, updateRec.Code) + + var updated GenericDataResponse[agentResponse] + require.NoError(suite.T(), json.Unmarshal(updateRec.Body.Bytes(), &updated)) + require.Equal(suite.T(), "agent-one-updated", updated.Data.Name) + require.False(suite.T(), updated.Data.IsActive) + + keyCreateRec, keyCreateReq := suite.authedRequest(http.MethodPost, fmt.Sprintf("/api/admin/agents/%s/keys", created.Data.ID), map[string]any{ + "name": "primary", + "never-expires": true, + }) + suite.server.E().ServeHTTP(keyCreateRec, keyCreateReq) + require.Equal(suite.T(), http.StatusCreated, keyCreateRec.Code) + + var keyCreated GenericDataResponse[agentKeyCreateResponse] + require.NoError(suite.T(), json.Unmarshal(keyCreateRec.Body.Bytes(), &keyCreated)) + require.NotEmpty(suite.T(), keyCreated.Data.ClientID) + require.NotEmpty(suite.T(), keyCreated.Data.ClientSecret) + require.True(suite.T(), keyCreated.Data.NeverExpires) + + keyListRec, keyListReq := suite.authedRequest(http.MethodGet, fmt.Sprintf("/api/admin/agents/%s/keys", created.Data.ID), nil) + suite.server.E().ServeHTTP(keyListRec, keyListReq) + require.Equal(suite.T(), http.StatusOK, keyListRec.Code) + + var keyList GenericDataListResponse[agentKeyResponse] + require.NoError(suite.T(), json.Unmarshal(keyListRec.Body.Bytes(), &keyList)) + require.Len(suite.T(), keyList.Data, 1) + require.Equal(suite.T(), keyCreated.Data.ClientID, keyList.Data[0].ClientID) + require.True(suite.T(), keyList.Data[0].NeverExpires) + + keyGetRec, keyGetReq := suite.authedRequest(http.MethodGet, fmt.Sprintf("/api/admin/agents/%s/keys/%s", created.Data.ID, keyCreated.Data.ID), nil) + suite.server.E().ServeHTTP(keyGetRec, keyGetReq) + require.Equal(suite.T(), http.StatusOK, keyGetRec.Code) + + keyDeleteRec, keyDeleteReq := suite.authedRequest(http.MethodDelete, fmt.Sprintf("/api/admin/agents/%s/keys/%s", created.Data.ID, keyCreated.Data.ID), nil) + suite.server.E().ServeHTTP(keyDeleteRec, keyDeleteReq) + require.Equal(suite.T(), http.StatusNoContent, keyDeleteRec.Code) + + deleteRec, deleteReq := suite.authedRequest(http.MethodDelete, fmt.Sprintf("/api/admin/agents/%s", created.Data.ID), nil) + suite.server.E().ServeHTTP(deleteRec, deleteReq) + require.Equal(suite.T(), http.StatusNoContent, deleteRec.Code) +} + +func (suite *AgentAPIIntegrationSuite) TestCreateAgentKeyWithExpiry() { + err := suite.Migrator.Refresh() + suite.Require().NoError(err) + + createRec, createReq := suite.authedRequest(http.MethodPost, "/api/admin/agents", map[string]any{ + "name": "agent-two", + }) + suite.server.E().ServeHTTP(createRec, createReq) + require.Equal(suite.T(), http.StatusCreated, createRec.Code) + + var created GenericDataResponse[agentResponse] + require.NoError(suite.T(), json.Unmarshal(createRec.Body.Bytes(), &created)) + + expiresAt := time.Now().UTC().Add(2 * time.Hour).Format(time.RFC3339) + keyCreateRec, keyCreateReq := suite.authedRequest(http.MethodPost, fmt.Sprintf("/api/admin/agents/%s/keys", created.Data.ID), map[string]any{ + "name": "expiring", + "expires-at": expiresAt, + }) + suite.server.E().ServeHTTP(keyCreateRec, keyCreateReq) + require.Equal(suite.T(), http.StatusCreated, keyCreateRec.Code) + + var keyCreated GenericDataResponse[agentKeyCreateResponse] + require.NoError(suite.T(), json.Unmarshal(keyCreateRec.Body.Bytes(), &keyCreated)) + require.False(suite.T(), keyCreated.Data.NeverExpires) + require.NotNil(suite.T(), keyCreated.Data.ExpiresAt) +} + +func (suite *AgentAPIIntegrationSuite) TestCreateAgentKeyRequiresExplicitExpiryDecision() { + err := suite.Migrator.Refresh() + suite.Require().NoError(err) + + createRec, createReq := suite.authedRequest(http.MethodPost, "/api/admin/agents", map[string]any{ + "name": "agent-three", + }) + suite.server.E().ServeHTTP(createRec, createReq) + require.Equal(suite.T(), http.StatusCreated, createRec.Code) + + var created GenericDataResponse[agentResponse] + require.NoError(suite.T(), json.Unmarshal(createRec.Body.Bytes(), &created)) + + keyCreateRec, keyCreateReq := suite.authedRequest(http.MethodPost, fmt.Sprintf("/api/admin/agents/%s/keys", created.Data.ID), map[string]any{ + "name": "missing-expiry-choice", + }) + suite.server.E().ServeHTTP(keyCreateRec, keyCreateReq) + require.Equal(suite.T(), http.StatusBadRequest, keyCreateRec.Code) + require.Contains(suite.T(), keyCreateRec.Body.String(), "expires-at is required unless never-expires is true") +} + +func (suite *AgentAPIIntegrationSuite) TestDeleteAgentRevokesKeysAndDeactivatesAgent() { + createRec, createReq := suite.authedRequest(http.MethodPost, "/api/admin/agents", map[string]any{ + "name": "agent-delete-test", + }) + suite.server.E().ServeHTTP(createRec, createReq) + require.Equal(suite.T(), http.StatusCreated, createRec.Code) + + var created GenericDataResponse[agentResponse] + require.NoError(suite.T(), json.Unmarshal(createRec.Body.Bytes(), &created)) + + keyCreateRec, keyCreateReq := suite.authedRequest(http.MethodPost, fmt.Sprintf("/api/admin/agents/%s/keys", created.Data.ID), map[string]any{ + "name": "primary", + "never-expires": true, + }) + suite.server.E().ServeHTTP(keyCreateRec, keyCreateReq) + require.Equal(suite.T(), http.StatusCreated, keyCreateRec.Code) + + var keyCreated GenericDataResponse[agentKeyCreateResponse] + require.NoError(suite.T(), json.Unmarshal(keyCreateRec.Body.Bytes(), &keyCreated)) + + deleteRec, deleteReq := suite.authedRequest(http.MethodDelete, fmt.Sprintf("/api/admin/agents/%s", created.Data.ID), nil) + suite.server.E().ServeHTTP(deleteRec, deleteReq) + require.Equal(suite.T(), http.StatusNoContent, deleteRec.Code) + + var agent relational.Agent + err := suite.DB.Unscoped().First(&agent, "id = ?", created.Data.ID).Error + require.NoError(suite.T(), err) + require.False(suite.T(), agent.IsActive) + require.NotNil(suite.T(), agent.DeletedAt) + require.True(suite.T(), agent.DeletedAt.Valid) + + var key relational.AgentServiceAccountKey + err = suite.DB.First(&key, "id = ?", keyCreated.Data.ID).Error + require.NoError(suite.T(), err) + require.NotNil(suite.T(), key.RevokedAt) +} diff --git a/internal/api/handler/api.go b/internal/api/handler/api.go index afbc7d21..78fee679 100644 --- a/internal/api/handler/api.go +++ b/internal/api/handler/api.go @@ -45,10 +45,15 @@ func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB filterHandler.Register(server.API().Group("/filters")) heartbeatHandler := NewHeartbeatHandler(logger, db) - heartbeatHandler.Register(server.API().Group("/agent/heartbeat")) + agentIngestMiddleware := middleware.AgentJWTOrPublicMiddleware(db, config.JWTPublicKey, !config.StrictDisablePublicAgentEndpoints) + heartbeatHandler.RegisterCreate(server.API().Group("/agent/heartbeat"), agentIngestMiddleware) + // Keep the legacy operator-facing metrics route stable while protecting it with user auth. + heartbeatHandler.RegisterOverTime(server.API().Group("/agent/heartbeat"), middleware.JWTMiddleware(config.JWTPublicKey)) evidenceHandler := NewEvidenceHandler(logger, services.EvidenceService) - evidenceHandler.Register(server.API().Group("/evidence")) + evidenceGroup := server.API().Group("/evidence") + evidenceHandler.RegisterCreate(evidenceGroup, agentIngestMiddleware) + evidenceHandler.RegisterReadRoutes(evidenceGroup) poamService := poamsvc.NewPoamService(db) riskService := riskrel.NewRiskService(db) @@ -78,8 +83,7 @@ func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB riskTemplateHandler.Register(riskTemplateGroup) agentRiskTemplateGroup := server.API().Group("/agent/risk-templates") - agentRiskTemplateGroup.Use(middleware.AgentJWTMiddleware(config.JWTPublicKey)) - riskTemplateHandler.RegisterAgent(agentRiskTemplateGroup) + riskTemplateHandler.RegisterAgent(agentRiskTemplateGroup, agentIngestMiddleware) subjectTemplateHandler := templatehandlers.NewSubjectTemplateHandler(logger, db) subjectTemplateGroup := server.API().Group("/admin/subject-templates") @@ -88,8 +92,13 @@ func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB subjectTemplateHandler.Register(subjectTemplateGroup) agentSubjectTemplateGroup := server.API().Group("/agent/subject-templates") - agentSubjectTemplateGroup.Use(middleware.AgentJWTMiddleware(config.JWTPublicKey)) - subjectTemplateHandler.RegisterAgent(agentSubjectTemplateGroup) + subjectTemplateHandler.RegisterAgent(agentSubjectTemplateGroup, agentIngestMiddleware) + + agentHandler := NewAgentHandler(logger, db) + agentsGroup := server.API().Group("/admin/agents") + agentsGroup.Use(middleware.JWTMiddleware(config.JWTPublicKey)) + agentsGroup.Use(middleware.RequireAdminGroups(db, config, logger)) + agentHandler.Register(agentsGroup) userHandler := NewUserHandler(logger, db) diff --git a/internal/api/handler/auth/agent.go b/internal/api/handler/auth/agent.go new file mode 100644 index 00000000..0fe25fde --- /dev/null +++ b/internal/api/handler/auth/agent.go @@ -0,0 +1,144 @@ +package auth + +import ( + "errors" + "net/http" + "strings" + "time" + + "github.com/compliance-framework/api/internal/api" + "github.com/compliance-framework/api/internal/authn" + "github.com/compliance-framework/api/internal/service/relational" + "github.com/labstack/echo/v4" + "gorm.io/gorm" +) + +type agentTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +func (h *AuthHandler) GetAgentToken(ctx echo.Context) error { + clientID, clientSecret, err := getAgentCredentials(ctx) + if err != nil { + return ctx.JSON(http.StatusUnauthorized, api.NewError(err)) + } + + var key relational.AgentServiceAccountKey + if err := h.db.Where("client_id = ?", clientID).First(&key).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + h.logAgentAuthEvent(nil, nil, clientID, relational.AgentAuthEventOutcomeFailure, "unknown_client_id", ctx) + return ctx.JSON(http.StatusUnauthorized, api.NewError(errors.New("invalid client credentials"))) + } + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + if key.AgentID == nil { + h.logAgentAuthEvent(nil, &key, clientID, relational.AgentAuthEventOutcomeFailure, "missing_agent_id", ctx) + return ctx.JSON(http.StatusUnauthorized, api.NewError(errors.New("invalid client credentials"))) + } + + var agent relational.Agent + if err := h.db.Where("id = ?", *key.AgentID).First(&agent).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + h.logAgentAuthEvent(nil, &key, clientID, relational.AgentAuthEventOutcomeFailure, "agent_not_found", ctx) + return ctx.JSON(http.StatusUnauthorized, api.NewError(errors.New("invalid client credentials"))) + } + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + now := time.Now().UTC() + if key.IsRevoked(now) { + h.logAgentAuthEvent(&agent, &key, clientID, relational.AgentAuthEventOutcomeFailure, "key_revoked", ctx) + return ctx.JSON(http.StatusForbidden, api.NewError(errors.New("agent key is revoked"))) + } + if key.IsExpired(now) { + h.logAgentAuthEvent(&agent, &key, clientID, relational.AgentAuthEventOutcomeFailure, "key_expired", ctx) + return ctx.JSON(http.StatusForbidden, api.NewError(errors.New("agent key is expired"))) + } + if !agent.IsActive { + h.logAgentAuthEvent(&agent, &key, clientID, relational.AgentAuthEventOutcomeFailure, "agent_inactive", ctx) + return ctx.JSON(http.StatusForbidden, api.NewError(errors.New("agent is inactive"))) + } + if !key.CheckSecret(clientSecret) { + h.logAgentAuthEvent(&agent, &key, clientID, relational.AgentAuthEventOutcomeFailure, "invalid_secret", ctx) + return ctx.JSON(http.StatusUnauthorized, api.NewError(errors.New("invalid client credentials"))) + } + + token, err := authn.GenerateAgentJWTToken(&agent, &key, h.config.JWTPrivateKey) + if err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + if err := h.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&relational.AgentServiceAccountKey{}). + Where("id = ?", key.ID.String()). + Update("last_used_at", now).Error; err != nil { + return err + } + if err := tx.Model(&relational.Agent{}). + Where("id = ?", agent.ID.String()). + Update("last_authenticated_at", now).Error; err != nil { + return err + } + event := h.newAgentAuthEvent(&agent, &key, clientID, relational.AgentAuthEventOutcomeSuccess, nil, ctx) + return tx.Create(event).Error + }); err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + return ctx.JSON(http.StatusOK, &agentTokenResponse{ + AccessToken: *token, + TokenType: "Bearer", + ExpiresIn: 86400, + }) +} + +func getAgentCredentials(ctx echo.Context) (string, string, error) { + if clientID, clientSecret, ok := ctx.Request().BasicAuth(); ok { + clientID = strings.TrimSpace(clientID) + clientSecret = strings.TrimSpace(clientSecret) + if clientID != "" && clientSecret != "" { + return clientID, clientSecret, nil + } + } + + clientID := strings.TrimSpace(ctx.FormValue("client_id")) + clientSecret := strings.TrimSpace(ctx.FormValue("client_secret")) + if clientID == "" || clientSecret == "" { + return "", "", errors.New("missing client credentials") + } + return clientID, clientSecret, nil +} + +func (h *AuthHandler) logAgentAuthEvent(agent *relational.Agent, key *relational.AgentServiceAccountKey, principal string, outcome string, reason string, ctx echo.Context) { + event := h.newAgentAuthEvent(agent, key, principal, outcome, &reason, ctx) + if err := h.db.Create(event).Error; err != nil { + h.sugar.Warnw("Failed to log agent auth event", "error", err) + } +} + +func (h *AuthHandler) newAgentAuthEvent(agent *relational.Agent, key *relational.AgentServiceAccountKey, principal string, outcome string, reason *string, ctx echo.Context) *relational.AgentAuthEvent { + event := &relational.AgentAuthEvent{ + AuthMethod: relational.AgentAuthMethodServiceAccount, + Outcome: outcome, + Principal: &principal, + Reason: reason, + } + if agent != nil && agent.ID != nil { + agentID := *agent.ID + event.AgentID = &agentID + } + if key != nil && key.ID != nil { + keyID := *key.ID + event.CredentialID = &keyID + } + if remoteAddr := strings.TrimSpace(ctx.RealIP()); remoteAddr != "" { + event.RemoteAddr = &remoteAddr + } + if userAgent := strings.TrimSpace(ctx.Request().UserAgent()); userAgent != "" { + event.UserAgent = &userAgent + } + return event +} diff --git a/internal/api/handler/auth/auth.go b/internal/api/handler/auth/auth.go index 5f074b87..0051a095 100644 --- a/internal/api/handler/auth/auth.go +++ b/internal/api/handler/auth/auth.go @@ -43,6 +43,7 @@ func NewAuthHandler(logger *zap.SugaredLogger, db *gorm.DB, config *config.Confi func (h *AuthHandler) Register(api *echo.Group) { api.POST("/login", h.LoginUser) api.POST("/token", h.GetOAuth2Token) + api.POST("/agent/token", h.GetAgentToken) api.GET("/publickey.pub", h.GetPublicKeyPEM) api.GET("/publickey", h.GetJWK) diff --git a/internal/api/handler/auth/auth_integration_test.go b/internal/api/handler/auth/auth_integration_test.go index c344c857..2d8c0744 100644 --- a/internal/api/handler/auth/auth_integration_test.go +++ b/internal/api/handler/auth/auth_integration_test.go @@ -10,9 +10,12 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "testing" + "time" "github.com/compliance-framework/api/internal/api" + "github.com/compliance-framework/api/internal/service/relational" "github.com/compliance-framework/api/internal/tests" "github.com/stretchr/testify/suite" "go.uber.org/zap" @@ -113,3 +116,135 @@ func (suite *AuthAPIIntegrationSuite) TestPublicKeyEndpoint() { respKey, _ := pem.Decode(rec.Body.Bytes()) suite.Require().NotNil(respKey, "Expected PEM-encoded public key in response") } + +func (suite *AuthAPIIntegrationSuite) TestAgentTokenWithBasicAuth() { + err := suite.IntegrationTestSuite.Migrator.Refresh() + suite.Require().NoError(err) + + agent, err := suite.CreateAgent("auth-agent") + suite.Require().NoError(err) + key, secret, err := suite.CreateAgentKey(agent, "auth-key") + suite.Require().NoError(err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/agent/token", nil) + req.SetBasicAuth(key.ClientID, secret) + suite.server.E().ServeHTTP(rec, req) + suite.Equal(http.StatusOK, rec.Code) + suite.Contains(rec.Body.String(), "access_token") +} + +func (suite *AuthAPIIntegrationSuite) TestAgentTokenWithFormCredentials() { + err := suite.IntegrationTestSuite.Migrator.Refresh() + suite.Require().NoError(err) + + agent, err := suite.CreateAgent("form-agent") + suite.Require().NoError(err) + key, secret, err := suite.CreateAgentKey(agent, "form-key") + suite.Require().NoError(err) + + form := url.Values{} + form.Set("client_id", key.ClientID) + form.Set("client_secret", secret) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/agent/token", bytes.NewBufferString(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + suite.server.E().ServeHTTP(rec, req) + suite.Equal(http.StatusOK, rec.Code) +} + +func (suite *AuthAPIIntegrationSuite) TestAgentTokenRejectsBadSecret() { + err := suite.IntegrationTestSuite.Migrator.Refresh() + suite.Require().NoError(err) + + agent, err := suite.CreateAgent("bad-secret-agent") + suite.Require().NoError(err) + key, _, err := suite.CreateAgentKey(agent, "bad-secret-key") + suite.Require().NoError(err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/agent/token", nil) + req.SetBasicAuth(key.ClientID, "wrong-secret") + suite.server.E().ServeHTTP(rec, req) + suite.Equal(http.StatusUnauthorized, rec.Code) +} + +func (suite *AuthAPIIntegrationSuite) TestAgentTokenRejectsUnknownClientID() { + err := suite.IntegrationTestSuite.Migrator.Refresh() + suite.Require().NoError(err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/agent/token", nil) + req.SetBasicAuth("missing-client-id", "wrong-secret") + suite.server.E().ServeHTTP(rec, req) + suite.Equal(http.StatusUnauthorized, rec.Code) + + var events []relational.AgentAuthEvent + suite.Require().NoError(suite.DB.Order("created_at asc").Find(&events).Error) + suite.Require().Len(events, 1) + suite.Equal(relational.AgentAuthEventOutcomeFailure, events[0].Outcome) + suite.Equal(relational.AgentAuthMethodServiceAccount, events[0].AuthMethod) + suite.NotNil(events[0].Principal) + suite.Equal("missing-client-id", *events[0].Principal) + suite.NotNil(events[0].Reason) + suite.Equal("unknown_client_id", *events[0].Reason) + suite.Nil(events[0].AgentID) + suite.Nil(events[0].CredentialID) +} + +func (suite *AuthAPIIntegrationSuite) TestAgentTokenRejectsRevokedKey() { + err := suite.IntegrationTestSuite.Migrator.Refresh() + suite.Require().NoError(err) + + agent, err := suite.CreateAgent("revoked-agent") + suite.Require().NoError(err) + key, secret, err := suite.CreateAgentKey(agent, "revoked-key") + suite.Require().NoError(err) + now := time.Now().UTC() + key.RevokedAt = &now + suite.Require().NoError(suite.DB.Save(key).Error) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/agent/token", nil) + req.SetBasicAuth(key.ClientID, secret) + suite.server.E().ServeHTTP(rec, req) + suite.Equal(http.StatusForbidden, rec.Code) +} + +func (suite *AuthAPIIntegrationSuite) TestAgentTokenRejectsInactiveAgent() { + err := suite.IntegrationTestSuite.Migrator.Refresh() + suite.Require().NoError(err) + + agent, err := suite.CreateAgent("inactive-agent") + suite.Require().NoError(err) + agent.IsActive = false + suite.Require().NoError(suite.DB.Save(agent).Error) + key, secret, err := suite.CreateAgentKey(agent, "inactive-key") + suite.Require().NoError(err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/agent/token", nil) + req.SetBasicAuth(key.ClientID, secret) + suite.server.E().ServeHTTP(rec, req) + suite.Equal(http.StatusForbidden, rec.Code) +} + +func (suite *AuthAPIIntegrationSuite) TestAgentTokenRejectsExpiredKey() { + err := suite.IntegrationTestSuite.Migrator.Refresh() + suite.Require().NoError(err) + + agent, err := suite.CreateAgent("expired-agent") + suite.Require().NoError(err) + key, secret, err := suite.CreateAgentKey(agent, "expired-key") + suite.Require().NoError(err) + expiresAt := time.Now().UTC().Add(-time.Minute) + key.ExpiresAt = &expiresAt + suite.Require().NoError(suite.DB.Save(key).Error) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/auth/agent/token", nil) + req.SetBasicAuth(key.ClientID, secret) + suite.server.E().ServeHTTP(rec, req) + suite.Equal(http.StatusForbidden, rec.Code) +} diff --git a/internal/api/handler/evidence.go b/internal/api/handler/evidence.go index 19407aa4..67911241 100644 --- a/internal/api/handler/evidence.go +++ b/internal/api/handler/evidence.go @@ -43,6 +43,22 @@ func (h *EvidenceHandler) Register(api *echo.Group) { api.GET("/compliance-by-filter/:id", h.ComplianceByFilter) } +func (h *EvidenceHandler) RegisterCreate(api *echo.Group, middlewares ...echo.MiddlewareFunc) { + api.POST("", h.Create, middlewares...) +} + +func (h *EvidenceHandler) RegisterReadRoutes(api *echo.Group) { + api.GET("/:id", h.Get) + api.GET("/history/:id", h.History) + api.GET("/latest/:id", h.Latest) + api.POST("/search", h.Search) + api.GET("/for-control/:id", h.ForControl) + api.GET("/status-over-time/:id", h.StatusOverTimeByUUID) + api.POST("/status-over-time", h.StatusOverTime) + api.GET("/compliance-by-control/:id", h.ComplianceByControl) + api.GET("/compliance-by-filter/:id", h.ComplianceByFilter) +} + type EvidenceActivityStep struct { UUID uuid.UUID Title string diff --git a/internal/api/handler/evidence_integration_test.go b/internal/api/handler/evidence_integration_test.go index d5dcdc0f..70be30a4 100644 --- a/internal/api/handler/evidence_integration_test.go +++ b/internal/api/handler/evidence_integration_test.go @@ -36,9 +36,21 @@ type EvidenceApiIntegrationSuite struct { tests.IntegrationTestSuite } +func (suite *EvidenceApiIntegrationSuite) setupServer() *api.Server { + logger, _ := zap.NewDevelopment() + metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) + server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) + services := &APIServices{} + evidenceSvc := evidencesvc.NewEvidenceService(suite.DB, logger.Sugar(), suite.Config, nil) + services.EvidenceService = evidenceSvc + RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, services) + return server +} + func (suite *EvidenceApiIntegrationSuite) TestCreate() { err := suite.Migrator.Refresh() suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = false // Create two catalogs with the same group ID structure evidence := EvidenceCreateRequest{ @@ -156,13 +168,7 @@ func (suite *EvidenceApiIntegrationSuite) TestCreate() { }, } - logger, _ := zap.NewDevelopment() - metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) - server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) - services := &APIServices{} - evidenceSvc := evidencesvc.NewEvidenceService(suite.DB, logger.Sugar(), suite.Config, nil) - services.EvidenceService = evidenceSvc - RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, services) + server := suite.setupServer() rec := httptest.NewRecorder() reqBody, _ := json.Marshal(evidence) req := httptest.NewRequest(http.MethodPost, "/api/evidence", bytes.NewReader(reqBody)) @@ -176,6 +182,82 @@ func (suite *EvidenceApiIntegrationSuite) TestCreate() { suite.Equal(int64(1), count) } +func (suite *EvidenceApiIntegrationSuite) TestCreateRequiresAgentAuthWhenUnsafeDisabled() { + err := suite.Migrator.Refresh() + suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = true + + server := suite.setupServer() + rec := httptest.NewRecorder() + reqBody, _ := json.Marshal(EvidenceCreateRequest{ + UUID: uuid.New(), + Title: "Evidence", + Start: time.Now().Add(-time.Hour), + End: time.Now().Add(-time.Minute), + }) + req := httptest.NewRequest(http.MethodPost, "/api/evidence", bytes.NewReader(reqBody)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + server.E().ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusUnauthorized, rec.Code) +} + +func (suite *EvidenceApiIntegrationSuite) TestCreateWithAgentTokenWhenUnsafeDisabled() { + err := suite.Migrator.Refresh() + suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = true + + server := suite.setupServer() + agent, err := suite.CreateAgent("evidence-agent") + suite.Require().NoError(err) + key, _, err := suite.CreateAgentKey(agent, "evidence-key") + suite.Require().NoError(err) + token, err := suite.GetAgentToken(agent, key) + suite.Require().NoError(err) + + rec := httptest.NewRecorder() + reqBody, _ := json.Marshal(EvidenceCreateRequest{ + UUID: uuid.New(), + Title: "Evidence", + Start: time.Now().Add(-time.Hour), + End: time.Now().Add(-time.Minute), + }) + req := httptest.NewRequest(http.MethodPost, "/api/evidence", bytes.NewReader(reqBody)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", *token)) + server.E().ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusCreated, rec.Code) +} + +func (suite *EvidenceApiIntegrationSuite) TestCreateRejectsExpiredAgentKeyWhenUnsafeDisabled() { + err := suite.Migrator.Refresh() + suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = true + + server := suite.setupServer() + agent, err := suite.CreateAgent("expired-evidence-agent") + suite.Require().NoError(err) + key, _, err := suite.CreateAgentKey(agent, "expired-evidence-key") + suite.Require().NoError(err) + expiresAt := time.Now().UTC().Add(-time.Minute) + key.ExpiresAt = &expiresAt + suite.Require().NoError(suite.DB.Save(key).Error) + token, err := suite.GetAgentToken(agent, key) + suite.Require().NoError(err) + + rec := httptest.NewRecorder() + reqBody, _ := json.Marshal(EvidenceCreateRequest{ + UUID: uuid.New(), + Title: "Evidence", + Start: time.Now().Add(-time.Hour), + End: time.Now().Add(-time.Minute), + }) + req := httptest.NewRequest(http.MethodPost, "/api/evidence", bytes.NewReader(reqBody)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", *token)) + server.E().ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusForbidden, rec.Code) +} + func (suite *EvidenceApiIntegrationSuite) TestSearch() { suite.Run("Returns the single latest evidence for a stream", func() { err := suite.Migrator.Refresh() diff --git a/internal/api/handler/heartbeat.go b/internal/api/handler/heartbeat.go index 7a437b8d..7ddb5a9d 100644 --- a/internal/api/handler/heartbeat.go +++ b/internal/api/handler/heartbeat.go @@ -29,6 +29,14 @@ func (h *HeartbeatHandler) Register(api *echo.Group) { api.GET("/over-time", h.OverTime) } +func (h *HeartbeatHandler) RegisterCreate(api *echo.Group, middlewares ...echo.MiddlewareFunc) { + api.POST("", h.Create, middlewares...) +} + +func (h *HeartbeatHandler) RegisterOverTime(api *echo.Group, middlewares ...echo.MiddlewareFunc) { + api.GET("/over-time", h.OverTime, middlewares...) +} + type HeartbeatCreateRequest struct { UUID uuid.UUID `json:"uuid,omitempty" validate:"required"` CreatedAt time.Time `json:"created_at,omitempty" validate:"required"` diff --git a/internal/api/handler/heartbeat_integration_test.go b/internal/api/handler/heartbeat_integration_test.go index 86c2b638..5dbf962a 100644 --- a/internal/api/handler/heartbeat_integration_test.go +++ b/internal/api/handler/heartbeat_integration_test.go @@ -31,17 +31,23 @@ type HeartbeatApiIntegrationSuite struct { tests.IntegrationTestSuite } +func (suite *HeartbeatApiIntegrationSuite) setupServer() *api.Server { + logger, _ := zap.NewDevelopment() + metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) + server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) + services := &APIServices{} + RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, services) + return server +} + func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatCreateValidation() { err := suite.Migrator.Refresh() suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = false // Create two catalogs with the same group ID structure heartbeat := HeartbeatCreateRequest{} - logger, _ := zap.NewDevelopment() - metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) - server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) - services := &APIServices{} - RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, services) + server := suite.setupServer() rec := httptest.NewRecorder() reqBody, _ := json.Marshal(heartbeat) req := httptest.NewRequest(http.MethodPost, "/api/agent/heartbeat", bytes.NewReader(reqBody)) @@ -53,6 +59,7 @@ func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatCreateValidation() { func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatCreate() { err := suite.Migrator.Refresh() suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = false // Create two catalogs with the same group ID structure heartbeat := HeartbeatCreateRequest{ @@ -60,11 +67,7 @@ func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatCreate() { CreatedAt: time.Now(), } - logger, _ := zap.NewDevelopment() - metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) - server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) - services := &APIServices{} - RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, services) + server := suite.setupServer() rec := httptest.NewRecorder() reqBody, _ := json.Marshal(heartbeat) req := httptest.NewRequest(http.MethodPost, "/api/agent/heartbeat", bytes.NewReader(reqBody)) @@ -81,6 +84,7 @@ func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatCreate() { func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatOverTime() { err := suite.Migrator.Refresh() suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = true // Seed some heartbeats for range 3 { @@ -94,14 +98,13 @@ func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatOverTime() { } // Create two catalogs with the same group ID structure - logger, _ := zap.NewDevelopment() - metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) - server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) - services := &APIServices{} - RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, services) + server := suite.setupServer() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/agent/heartbeat/over-time/", nil) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + token, err := suite.GetAuthToken() + suite.Require().NoError(err) + req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", *token)) server.E().ServeHTTP(rec, req) assert.Equal(suite.T(), http.StatusOK, rec.Code) @@ -125,3 +128,17 @@ func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatOverTime() { suite.Equal(response.Data[0].Interval.Sub(response.Data[1].Interval).Abs(), 2*time.Minute) suite.Equal(response.Data[1].Interval.Sub(response.Data[2].Interval).Abs(), 2*time.Minute) } + +func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatCreateRequiresAgentAuthWhenUnsafeDisabled() { + err := suite.Migrator.Refresh() + suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = true + + server := suite.setupServer() + rec := httptest.NewRecorder() + reqBody, _ := json.Marshal(HeartbeatCreateRequest{UUID: uuid.New(), CreatedAt: time.Now()}) + req := httptest.NewRequest(http.MethodPost, "/api/agent/heartbeat", bytes.NewReader(reqBody)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + server.E().ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusUnauthorized, rec.Code) +} diff --git a/internal/api/handler/templates/risk_template.go b/internal/api/handler/templates/risk_template.go index ceab988c..a32a34f0 100644 --- a/internal/api/handler/templates/risk_template.go +++ b/internal/api/handler/templates/risk_template.go @@ -38,8 +38,8 @@ func (h *RiskTemplateHandler) Register(apiGroup *echo.Group) { apiGroup.DELETE("/:id", h.Delete) } -func (h *RiskTemplateHandler) RegisterAgent(apiGroup *echo.Group) { - apiGroup.POST("/batch", h.BatchUpsert) +func (h *RiskTemplateHandler) RegisterAgent(apiGroup *echo.Group, middlewares ...echo.MiddlewareFunc) { + apiGroup.POST("/batch", h.BatchUpsert, middlewares...) } type threatIDRequest struct { diff --git a/internal/api/handler/templates/risk_template_integration_test.go b/internal/api/handler/templates/risk_template_integration_test.go index 4b3591d3..b680ef77 100644 --- a/internal/api/handler/templates/risk_template_integration_test.go +++ b/internal/api/handler/templates/risk_template_integration_test.go @@ -69,6 +69,11 @@ func (suite *RiskTemplateApiIntegrationSuite) SetupTest() { err := suite.Migrator.Refresh() suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = true + suite.setupServer() +} + +func (suite *RiskTemplateApiIntegrationSuite) setupServer() { logger, _ := zap.NewDevelopment() metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) suite.server = api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) @@ -106,6 +111,27 @@ func (suite *RiskTemplateApiIntegrationSuite) unauthenticatedRequest(method, pat return rec, req } +func (suite *RiskTemplateApiIntegrationSuite) agentRequest(method, path string, body any) (*httptest.ResponseRecorder, *http.Request) { + agent, err := suite.CreateAgent("risk-template-agent") + suite.Require().NoError(err) + key, _, err := suite.CreateAgentKey(agent, "risk-template-key") + suite.Require().NoError(err) + token, err := suite.GetAgentToken(agent, key) + suite.Require().NoError(err) + + payload := []byte{} + if body != nil { + data, marshalErr := json.Marshal(body) + suite.Require().NoError(marshalErr) + payload = data + } + rec := httptest.NewRecorder() + req := httptest.NewRequest(method, path, bytes.NewReader(payload)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", *token)) + return rec, req +} + func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateCRUD() { createReq := map[string]any{ "plugin-id": "github-repositories", @@ -646,7 +672,7 @@ func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertCreateA }, } - rec, req := suite.authedRequest(http.MethodPost, "/api/agent/risk-templates/batch", batchReq) + rec, req := suite.agentRequest(http.MethodPost, "/api/agent/risk-templates/batch", batchReq) suite.server.E().ServeHTTP(rec, req) require.Equal(suite.T(), http.StatusOK, rec.Code) @@ -699,7 +725,7 @@ func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertCreateA }, } - rec2, req2 := suite.authedRequest(http.MethodPost, "/api/agent/risk-templates/batch", batchReq2) + rec2, req2 := suite.agentRequest(http.MethodPost, "/api/agent/risk-templates/batch", batchReq2) suite.server.E().ServeHTTP(rec2, req2) require.Equal(suite.T(), http.StatusOK, rec2.Code) @@ -736,7 +762,7 @@ func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertEmptyPa } // Send empty template list — both should be deleted. - rec, req := suite.authedRequest(http.MethodPost, "/api/agent/risk-templates/batch", map[string]any{ + rec, req := suite.agentRequest(http.MethodPost, "/api/agent/risk-templates/batch", map[string]any{ "plugin-id": "batch-delete-plugin", "policy-package": "compliance_framework.delete_test", "templates": []map[string]any{}, @@ -759,7 +785,7 @@ func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertEmptyPa } func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertMissingIDReturns400() { - rec, req := suite.authedRequest(http.MethodPost, "/api/agent/risk-templates/batch", map[string]any{ + rec, req := suite.agentRequest(http.MethodPost, "/api/agent/risk-templates/batch", map[string]any{ "plugin-id": "batch-plugin", "policy-package": "compliance_framework.batch_test", "templates": []map[string]any{ @@ -776,7 +802,7 @@ func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertMissing } func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertValidationError() { - rec, req := suite.authedRequest(http.MethodPost, "/api/agent/risk-templates/batch", map[string]any{ + rec, req := suite.agentRequest(http.MethodPost, "/api/agent/risk-templates/batch", map[string]any{ "plugin-id": "batch-plugin", "policy-package": "compliance_framework.batch_test", "templates": []map[string]any{ @@ -792,7 +818,10 @@ func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertValidat require.Equal(suite.T(), http.StatusBadRequest, rec.Code) } -func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertIsPublic() { +func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertIsPublicWhenUnsafeFlagEnabled() { + suite.Config.StrictDisablePublicAgentEndpoints = false + suite.setupServer() + rec, req := suite.unauthenticatedRequest(http.MethodPost, "/api/agent/risk-templates/batch", map[string]any{ "plugin-id": "batch-plugin", "policy-package": "compliance_framework.batch_test", @@ -801,3 +830,13 @@ func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertIsPubli suite.server.E().ServeHTTP(rec, req) require.Equal(suite.T(), http.StatusOK, rec.Code) } + +func (suite *RiskTemplateApiIntegrationSuite) TestRiskTemplateBatchUpsertRequiresAgentAuthWhenUnsafeDisabled() { + rec, req := suite.unauthenticatedRequest(http.MethodPost, "/api/agent/risk-templates/batch", map[string]any{ + "plugin-id": "batch-plugin", + "policy-package": "compliance_framework.batch_test", + "templates": []map[string]any{}, + }) + suite.server.E().ServeHTTP(rec, req) + require.Equal(suite.T(), http.StatusUnauthorized, rec.Code) +} diff --git a/internal/api/handler/templates/subject_template.go b/internal/api/handler/templates/subject_template.go index e7ae01fd..844c5be4 100644 --- a/internal/api/handler/templates/subject_template.go +++ b/internal/api/handler/templates/subject_template.go @@ -36,8 +36,8 @@ func (h *SubjectTemplateHandler) Register(apiGroup *echo.Group) { apiGroup.PUT("/:id", h.Update) } -func (h *SubjectTemplateHandler) RegisterAgent(apiGroup *echo.Group) { - apiGroup.POST("/batch", h.BatchUpsert) +func (h *SubjectTemplateHandler) RegisterAgent(apiGroup *echo.Group, middlewares ...echo.MiddlewareFunc) { + apiGroup.POST("/batch", h.BatchUpsert, middlewares...) } type subjectTemplateSelectorLabelRequest struct { diff --git a/internal/api/handler/templates/subject_template_integration_test.go b/internal/api/handler/templates/subject_template_integration_test.go index 807bcae8..65d501dd 100644 --- a/internal/api/handler/templates/subject_template_integration_test.go +++ b/internal/api/handler/templates/subject_template_integration_test.go @@ -61,6 +61,11 @@ func (suite *SubjectTemplateApiIntegrationSuite) SetupTest() { err := suite.Migrator.Refresh() suite.Require().NoError(err) + suite.Config.StrictDisablePublicAgentEndpoints = true + suite.setupServer() +} + +func (suite *SubjectTemplateApiIntegrationSuite) setupServer() { logger, _ := zap.NewDevelopment() metrics := api.NewMetricsHandler(context.Background(), logger.Sugar()) suite.server = api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics) @@ -100,6 +105,28 @@ func (suite *SubjectTemplateApiIntegrationSuite) unauthenticatedRequest(method, return rec, req } +func (suite *SubjectTemplateApiIntegrationSuite) agentRequest(method, path string, body any) (*httptest.ResponseRecorder, *http.Request) { + agent, err := suite.CreateAgent("subject-template-agent") + suite.Require().NoError(err) + key, _, err := suite.CreateAgentKey(agent, "subject-template-key") + suite.Require().NoError(err) + token, err := suite.GetAgentToken(agent, key) + suite.Require().NoError(err) + + payload := []byte{} + if body != nil { + data, marshalErr := json.Marshal(body) + suite.Require().NoError(marshalErr) + payload = data + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(method, path, bytes.NewReader(payload)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", *token)) + return rec, req +} + func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateCRUD() { createRec, createCall := suite.authedRequest(http.MethodPost, "/api/admin/subject-templates", map[string]any{ "name": "Runtime component identity", @@ -265,7 +292,7 @@ func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertC }, } - rec, req := suite.authedRequest(http.MethodPost, "/api/agent/subject-templates/batch", batchReq) + rec, req := suite.agentRequest(http.MethodPost, "/api/agent/subject-templates/batch", batchReq) suite.server.E().ServeHTTP(rec, req) require.Equal(suite.T(), http.StatusOK, rec.Code) @@ -323,7 +350,7 @@ func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertC }, } - rec2, req2 := suite.authedRequest(http.MethodPost, "/api/agent/subject-templates/batch", batchReq2) + rec2, req2 := suite.agentRequest(http.MethodPost, "/api/agent/subject-templates/batch", batchReq2) suite.server.E().ServeHTTP(rec2, req2) require.Equal(suite.T(), http.StatusOK, rec2.Code) @@ -370,7 +397,7 @@ func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertE } // Send empty template list — both should be deleted. - rec, req := suite.authedRequest(http.MethodPost, "/api/agent/subject-templates/batch", map[string]any{ + rec, req := suite.agentRequest(http.MethodPost, "/api/agent/subject-templates/batch", map[string]any{ "plugin-id": "delete-plugin", "templates": []map[string]any{}, }) @@ -385,7 +412,7 @@ func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertE } func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertMissingIDReturns400() { - rec, req := suite.authedRequest(http.MethodPost, "/api/agent/subject-templates/batch", map[string]any{ + rec, req := suite.agentRequest(http.MethodPost, "/api/agent/subject-templates/batch", map[string]any{ "plugin-id": "batch-plugin", "templates": []map[string]any{ { @@ -404,7 +431,7 @@ func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertM } func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertValidationError() { - rec, req := suite.authedRequest(http.MethodPost, "/api/agent/subject-templates/batch", map[string]any{ + rec, req := suite.agentRequest(http.MethodPost, "/api/agent/subject-templates/batch", map[string]any{ "plugin-id": "batch-plugin", "templates": []map[string]any{ { @@ -422,7 +449,10 @@ func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertV require.Equal(suite.T(), http.StatusBadRequest, rec.Code) } -func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertIsPublic() { +func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertIsPublicWhenUnsafeFlagEnabled() { + suite.Config.StrictDisablePublicAgentEndpoints = false + suite.setupServer() + rec, req := suite.unauthenticatedRequest(http.MethodPost, "/api/agent/subject-templates/batch", map[string]any{ "plugin-id": "batch-plugin", "templates": []map[string]any{}, @@ -430,3 +460,12 @@ func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertI suite.server.E().ServeHTTP(rec, req) require.Equal(suite.T(), http.StatusOK, rec.Code) } + +func (suite *SubjectTemplateApiIntegrationSuite) TestSubjectTemplateBatchUpsertRequiresAgentAuthWhenUnsafeDisabled() { + rec, req := suite.unauthenticatedRequest(http.MethodPost, "/api/agent/subject-templates/batch", map[string]any{ + "plugin-id": "batch-plugin", + "templates": []map[string]any{}, + }) + suite.server.E().ServeHTTP(rec, req) + require.Equal(suite.T(), http.StatusUnauthorized, rec.Code) +} diff --git a/internal/api/middleware/agent_auth.go b/internal/api/middleware/agent_auth.go deleted file mode 100644 index 047cc5c1..00000000 --- a/internal/api/middleware/agent_auth.go +++ /dev/null @@ -1,17 +0,0 @@ -package middleware - -import ( - "crypto/rsa" - - "github.com/labstack/echo/v4" -) - -// AgentJWTMiddleware returns an Echo middleware function that verifies JWT tokens using the provided RSA public key. -// TODO[gusfcarvalho]: this method is a simple noop for now -func AgentJWTMiddleware(publicKey *rsa.PublicKey) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - return next(c) - } - } -} diff --git a/internal/api/middleware/agent_ingest.go b/internal/api/middleware/agent_ingest.go new file mode 100644 index 00000000..b9beb6af --- /dev/null +++ b/internal/api/middleware/agent_ingest.go @@ -0,0 +1,96 @@ +package middleware + +import ( + "crypto/rsa" + "errors" + "net/http" + "time" + + "github.com/compliance-framework/api/internal/authn" + "github.com/compliance-framework/api/internal/service/relational" + "github.com/labstack/echo/v4" + "gorm.io/gorm" +) + +type AgentAuthContext struct { + Claims *authn.AgentClaims + Agent *relational.Agent + Key *relational.AgentServiceAccountKey +} + +func AgentJWTMiddleware(db *gorm.DB, publicKey *rsa.PublicKey) echo.MiddlewareFunc { + return AgentJWTOrPublicMiddleware(db, publicKey, false) +} + +func AgentJWTOrPublicMiddleware(db *gorm.DB, publicKey *rsa.PublicKey, allowPublic bool) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if c.Request().Method == http.MethodOptions { + return next(c) + } + + authHeader := c.Request().Header.Get(echo.HeaderAuthorization) + if authHeader == "" { + if allowPublic { + return next(c) + } + return echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed authorization header") + } + + tokenString, err := getTokenFromHeader(authHeader) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, err) + } + + claims, agent, key, err := verifyAgentRequest(db, tokenString, publicKey, c) + if err != nil { + return err + } + + c.Set("agent_claims", claims) + c.Set("agent_auth", &AgentAuthContext{ + Claims: claims, + Agent: agent, + Key: key, + }) + return next(c) + } + } +} + +func verifyAgentRequest(db *gorm.DB, tokenString string, publicKey *rsa.PublicKey, c echo.Context) (*authn.AgentClaims, *relational.Agent, *relational.AgentServiceAccountKey, error) { + claims, err := authn.VerifyAgentJWTToken(tokenString, publicKey) + if err != nil { + return nil, nil, nil, echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired token") + } + + var agent relational.Agent + if err := db.Where("id = ?", claims.AgentID).First(&agent).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil, nil, echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired token") + } + c.Logger().Errorf("failed to load agent for authenticated agent request: %v (agent_id=%v)", err, claims.AgentID) + return nil, nil, nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to load agent") + } + if !agent.IsActive { + return nil, nil, nil, echo.NewHTTPError(http.StatusForbidden, "agent is inactive") + } + + var key relational.AgentServiceAccountKey + if err := db.Where("agent_id = ? AND id = ?", *agent.ID, claims.CredentialID).First(&key).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil, nil, echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired token") + } + c.Logger().Errorf("failed to load agent key for authenticated agent request: %v (agent_id=%v credential_id=%v)", err, agent.ID, claims.CredentialID) + return nil, nil, nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to load agent key") + } + now := time.Now().UTC() + if key.IsRevoked(now) { + return nil, nil, nil, echo.NewHTTPError(http.StatusForbidden, "agent key is revoked") + } + if key.IsExpired(now) { + return nil, nil, nil, echo.NewHTTPError(http.StatusForbidden, "agent key is expired") + } + + return claims, &agent, &key, nil +} diff --git a/internal/api/middleware/auth.go b/internal/api/middleware/auth.go index 3551233b..e8e4a793 100644 --- a/internal/api/middleware/auth.go +++ b/internal/api/middleware/auth.go @@ -45,7 +45,7 @@ func JWTMiddleware(publicKey *rsa.PublicKey) echo.MiddlewareFunc { func getTokenFromHeader(authHeader string) (string, error) { parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 || parts[0] != "Bearer" { + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { return "", errors.New("missing or malformed authorization header") } return parts[1], nil diff --git a/internal/api/middleware/auth_test.go b/internal/api/middleware/auth_test.go new file mode 100644 index 00000000..d22ef854 --- /dev/null +++ b/internal/api/middleware/auth_test.go @@ -0,0 +1,13 @@ +package middleware + +import "testing" + +func TestGetTokenFromHeaderAcceptsCaseInsensitiveBearer(t *testing.T) { + token, err := getTokenFromHeader("bearer token-value") + if err != nil { + t.Fatalf("expected lowercase bearer scheme to be accepted, got %v", err) + } + if token != "token-value" { + t.Fatalf("expected token %q, got %q", "token-value", token) + } +} diff --git a/internal/authn/jwt.go b/internal/authn/jwt.go index 8ecbe82f..de1185ed 100644 --- a/internal/authn/jwt.go +++ b/internal/authn/jwt.go @@ -2,14 +2,21 @@ package authn import ( "crypto/rsa" + "errors" "time" "github.com/compliance-framework/api/internal/service/relational" "github.com/golang-jwt/jwt/v5" ) +const ( + TokenKindUser = "user" + TokenKindAgent = "agent" +) + type UserClaims struct { jwt.RegisteredClaims + TokenKind string `json:"token_kind"` GivenName string `json:"given_name"` FamilyName string `json:"family_name"` } @@ -19,6 +26,14 @@ type PasswordResetClaims struct { Email string `json:"email"` } +type AgentClaims struct { + jwt.RegisteredClaims + TokenKind string `json:"token_kind"` + AgentID string `json:"agent_id"` + CredentialID string `json:"credential_id"` + AuthMethod string `json:"auth_method"` +} + func GenerateJWTToken(user *relational.User, privateKey *rsa.PrivateKey) (*string, error) { now := time.Now() claims := UserClaims{ @@ -29,6 +44,7 @@ func GenerateJWTToken(user *relational.User, privateKey *rsa.PrivateKey) (*strin ExpiresAt: jwt.NewNumericDate(now.Add(24 * time.Hour)), NotBefore: jwt.NewNumericDate(now), }, + TokenKind: TokenKindUser, GivenName: user.FirstName, FamilyName: user.LastName, } @@ -40,6 +56,29 @@ func GenerateJWTToken(user *relational.User, privateKey *rsa.PrivateKey) (*strin return &tokenString, nil } +func GenerateAgentJWTToken(agent *relational.Agent, key *relational.AgentServiceAccountKey, privateKey *rsa.PrivateKey) (*string, error) { + now := time.Now() + claims := AgentClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "compliance-framework", + Subject: key.ClientID, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(24 * time.Hour)), + NotBefore: jwt.NewNumericDate(now), + }, + TokenKind: TokenKindAgent, + AgentID: agent.ID.String(), + CredentialID: key.ID.String(), + AuthMethod: relational.AgentAuthMethodServiceAccount, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(privateKey) + if err != nil { + return nil, err + } + return &tokenString, nil +} + func GeneratePasswordResetToken(email string, privateKey *rsa.PrivateKey) (*string, error) { now := time.Now() claims := PasswordResetClaims{ @@ -71,6 +110,28 @@ func VerifyJWTToken(tokenString string, publicKey *rsa.PublicKey) (*UserClaims, return nil, err } if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { + if claims.TokenKind != "" && claims.TokenKind != TokenKindUser { + return nil, errors.New("unexpected token kind") + } + return claims, nil + } + return nil, jwt.ErrTokenMalformed +} + +func VerifyAgentJWTToken(tokenString string, publicKey *rsa.PublicKey) (*AgentClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &AgentClaims{}, func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, jwt.ErrSignatureInvalid + } + return publicKey, nil + }) + if err != nil { + return nil, err + } + if claims, ok := token.Claims.(*AgentClaims); ok && token.Valid { + if claims.TokenKind != TokenKindAgent { + return nil, errors.New("unexpected token kind") + } return claims, nil } return nil, jwt.ErrTokenMalformed diff --git a/internal/authn/jwt_test.go b/internal/authn/jwt_test.go new file mode 100644 index 00000000..ef3e97b0 --- /dev/null +++ b/internal/authn/jwt_test.go @@ -0,0 +1,47 @@ +package authn + +import ( + "testing" + + "github.com/compliance-framework/api/internal/config" + "github.com/compliance-framework/api/internal/service/relational" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestUserAndAgentTokensAreSeparated(t *testing.T) { + privateKey, publicKey, err := config.GenerateKeyPair(2048) + require.NoError(t, err) + + user := &relational.User{ + Email: "dummy@example.com", + FirstName: "Dummy", + LastName: "User", + } + userToken, err := GenerateJWTToken(user, privateKey) + require.NoError(t, err) + + agentID := uuid.New() + keyID := uuid.New() + agent := &relational.Agent{UUIDModel: relational.UUIDModel{ID: &agentID}} + key := &relational.AgentServiceAccountKey{ + UUIDModel: relational.UUIDModel{ID: &keyID}, + ClientID: "client-id", + } + agentToken, err := GenerateAgentJWTToken(agent, key, privateKey) + require.NoError(t, err) + + userClaims, err := VerifyJWTToken(*userToken, publicKey) + require.NoError(t, err) + require.Equal(t, TokenKindUser, userClaims.TokenKind) + + _, err = VerifyJWTToken(*agentToken, publicKey) + require.Error(t, err) + + agentClaims, err := VerifyAgentJWTToken(*agentToken, publicKey) + require.NoError(t, err) + require.Equal(t, TokenKindAgent, agentClaims.TokenKind) + + _, err = VerifyAgentJWTToken(*userToken, publicKey) + require.Error(t, err) +} diff --git a/internal/config/config.go b/internal/config/config.go index 9de85287..8f3f09bb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,28 +19,29 @@ var ( ) type Config struct { - AppPort string - Environment string - DBDriver string - DBConnectionString string - DBDebug bool - JWTSecret string - JWTPrivateKey *rsa.PrivateKey - JWTPublicKey *rsa.PublicKey - APIAllowedOrigins []string - MetricsEnabled bool - MetricsPort string - WebBaseURL string - SSO *SSOConfig - Email *EmailConfig - Worker *WorkerConfig - EvidenceDefaultExpiryMonths int // Default expiration in months for evidence without explicit expiry - DigestEnabled bool // Enable or disable the digest scheduler - DigestSchedule string // Cron schedule for digest emails - Workflow *WorkflowConfig - Risk *RiskConfig - PprofEnabled bool // Enable or disable pprof debugging server - PprofPort string // Port for pprof debugging server + AppPort string + Environment string + DBDriver string + DBConnectionString string + DBDebug bool + JWTSecret string + JWTPrivateKey *rsa.PrivateKey + JWTPublicKey *rsa.PublicKey + APIAllowedOrigins []string + MetricsEnabled bool + MetricsPort string + WebBaseURL string + SSO *SSOConfig + Email *EmailConfig + Worker *WorkerConfig + EvidenceDefaultExpiryMonths int // Default expiration in months for evidence without explicit expiry + DigestEnabled bool // Enable or disable the digest scheduler + DigestSchedule string // Cron schedule for digest emails + Workflow *WorkflowConfig + Risk *RiskConfig + PprofEnabled bool // Enable or disable pprof debugging server + PprofPort string // Port for pprof debugging server + StrictDisablePublicAgentEndpoints bool } func NewConfig(logger *zap.SugaredLogger) *Config { @@ -211,28 +212,29 @@ func NewConfig(logger *zap.SugaredLogger) *Config { } return &Config{ - AppPort: appPort, - Environment: environment, - DBDriver: dbDriver, - DBConnectionString: stripQuotes(viper.GetString("db_connection")), - DBDebug: viper.GetBool("db_debug"), - JWTSecret: stripQuotes(viper.GetString("jwt_secret")), - JWTPrivateKey: jwtPrivateKey, - JWTPublicKey: jwtPublicKey, - APIAllowedOrigins: allowedOrigins, - MetricsEnabled: metricsEnabled, - MetricsPort: metricsPort, - WebBaseURL: webBaseURL, - SSO: ssoConfig, - Email: emailConfig, - Worker: workerConfig, - EvidenceDefaultExpiryMonths: evidenceDefaultExpiryMonths, - DigestEnabled: digestEnabled, - DigestSchedule: digestSchedule, - Workflow: workflowConfig, - Risk: riskConfig, - PprofEnabled: pprofEnabled, - PprofPort: pprofPort, + AppPort: appPort, + Environment: environment, + DBDriver: dbDriver, + DBConnectionString: stripQuotes(viper.GetString("db_connection")), + DBDebug: viper.GetBool("db_debug"), + JWTSecret: stripQuotes(viper.GetString("jwt_secret")), + JWTPrivateKey: jwtPrivateKey, + JWTPublicKey: jwtPublicKey, + APIAllowedOrigins: allowedOrigins, + MetricsEnabled: metricsEnabled, + MetricsPort: metricsPort, + WebBaseURL: webBaseURL, + SSO: ssoConfig, + Email: emailConfig, + Worker: workerConfig, + EvidenceDefaultExpiryMonths: evidenceDefaultExpiryMonths, + DigestEnabled: digestEnabled, + DigestSchedule: digestSchedule, + Workflow: workflowConfig, + Risk: riskConfig, + PprofEnabled: pprofEnabled, + PprofPort: pprofPort, + StrictDisablePublicAgentEndpoints: viper.GetBool("strict_disable_public_agent_endpoints"), } } diff --git a/internal/service/migrator.go b/internal/service/migrator.go index 1d006a36..ec527306 100644 --- a/internal/service/migrator.go +++ b/internal/service/migrator.go @@ -142,6 +142,9 @@ func MigrateUp(db *gorm.DB) error { &poamrel.PoamItemControlLink{}, &poamrel.PoamItemFindingLink{}, &relational.User{}, + &relational.Agent{}, + &relational.AgentServiceAccountKey{}, + &relational.AgentAuthEvent{}, &Heartbeat{}, &relational.Evidence{}, &relational.Labels{}, @@ -354,6 +357,9 @@ func MigrateDown(db *gorm.DB) error { &poamrel.PoamItemMilestone{}, &poamrel.PoamItem{}, + &relational.AgentAuthEvent{}, + &relational.AgentServiceAccountKey{}, + &relational.Agent{}, &relational.User{}, &Heartbeat{}, diff --git a/internal/service/relational/agents.go b/internal/service/relational/agents.go new file mode 100644 index 00000000..d3ac16eb --- /dev/null +++ b/internal/service/relational/agents.go @@ -0,0 +1,125 @@ +package relational + +import ( + "errors" + "strings" + "time" + + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +const ( + AgentAuthMethodServiceAccount = "service_account" + AgentAuthEventOutcomeSuccess = "success" + AgentAuthEventOutcomeFailure = "failure" +) + +var ErrAgentSecretRequired = errors.New("agent secret is required") + +type Agent struct { + UUIDModel + + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + DeletedAt gorm.DeletedAt `json:"deletedAt" gorm:"index"` + + Name string `json:"name" gorm:"not null"` + Description *string `json:"description,omitempty"` + IsActive bool `json:"isActive" gorm:"default:true"` + LastAuthenticatedAt *time.Time `json:"lastAuthenticatedAt,omitempty"` + + ServiceAccountKeys []AgentServiceAccountKey `json:"serviceAccountKeys,omitempty"` +} + +func (Agent) TableName() string { + return "ccf_agents" +} + +type AgentServiceAccountKey struct { + UUIDModel + + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + DeletedAt gorm.DeletedAt `json:"deletedAt" gorm:"index"` + + AgentID *uuid.UUID `json:"agentId" gorm:"type:uuid;not null;index"` + Agent Agent `json:"-" gorm:"foreignKey:AgentID;references:ID"` + + Name *string `json:"name,omitempty"` + ClientID string `json:"clientId" gorm:"uniqueIndex;not null"` + SecretHash string `json:"-"` + LastUsedAt *time.Time `json:"lastUsedAt,omitempty"` + ExpiresAt *time.Time `json:"expiresAt,omitempty"` + RevokedAt *time.Time `json:"revokedAt,omitempty"` +} + +func (AgentServiceAccountKey) TableName() string { + return "ccf_agent_service_account_keys" +} + +func (k *AgentServiceAccountKey) SetSecret(secret string) error { + if strings.TrimSpace(secret) == "" { + return ErrAgentSecretRequired + } + + hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost) + if err != nil { + return err + } + + k.SecretHash = string(hash) + return nil +} + +func (k *AgentServiceAccountKey) CheckSecret(secret string) bool { + if strings.TrimSpace(k.SecretHash) == "" || strings.TrimSpace(secret) == "" { + return false + } + + return bcrypt.CompareHashAndPassword([]byte(k.SecretHash), []byte(secret)) == nil +} + +func (k *AgentServiceAccountKey) IsExpired(at time.Time) bool { + if k.ExpiresAt == nil { + return false + } + + return !k.ExpiresAt.After(at.UTC()) +} + +func (k *AgentServiceAccountKey) IsRevoked(at time.Time) bool { + if k.RevokedAt == nil { + return false + } + + return !k.RevokedAt.After(at.UTC()) +} + +type AgentAuthEvent struct { + UUIDModel + + CreatedAt time.Time `json:"createdAt"` + + AgentID *uuid.UUID `json:"agentId,omitempty" gorm:"type:uuid;index"` + CredentialID *uuid.UUID `json:"credentialId,omitempty" gorm:"type:uuid;index"` + AuthMethod string `json:"authMethod" gorm:"type:varchar(64);not null;index"` + Outcome string `json:"outcome" gorm:"type:varchar(32);not null;index"` + Principal *string `json:"principal,omitempty"` + Reason *string `json:"reason,omitempty"` + RemoteAddr *string `json:"remoteAddr,omitempty"` + UserAgent *string `json:"userAgent,omitempty"` +} + +func (AgentAuthEvent) TableName() string { + return "ccf_agent_auth_events" +} + +func (e *AgentAuthEvent) BeforeUpdate(_ *gorm.DB) error { + return errors.New("agent auth events are append-only") +} + +func (e *AgentAuthEvent) BeforeDelete(_ *gorm.DB) error { + return errors.New("agent auth events are append-only") +} diff --git a/internal/tests/integration.go b/internal/tests/integration.go index 6dec83ec..d142f252 100644 --- a/internal/tests/integration.go +++ b/internal/tests/integration.go @@ -4,6 +4,8 @@ package tests import ( "context" + "fmt" + "time" "github.com/compliance-framework/api/internal/authn" "github.com/compliance-framework/api/internal/config" @@ -91,3 +93,31 @@ func (suite *IntegrationTestSuite) GetAuthToken() (*string, error) { return authn.GenerateJWTToken(&dummyUser, suite.Config.JWTPrivateKey) } + +func (suite *IntegrationTestSuite) CreateAgent(name string) (*relational.Agent, error) { + agent := &relational.Agent{ + Name: name, + IsActive: true, + } + return agent, suite.DB.Create(agent).Error +} + +func (suite *IntegrationTestSuite) CreateAgentKey(agent *relational.Agent, name string) (*relational.AgentServiceAccountKey, string, error) { + secret := fmt.Sprintf("secret-%d", time.Now().UnixNano()) + key := &relational.AgentServiceAccountKey{ + AgentID: agent.ID, + Name: &name, + ClientID: fmt.Sprintf("client-%d", time.Now().UnixNano()), + } + if err := key.SetSecret(secret); err != nil { + return nil, "", err + } + if err := suite.DB.Create(key).Error; err != nil { + return nil, "", err + } + return key, secret, nil +} + +func (suite *IntegrationTestSuite) GetAgentToken(agent *relational.Agent, key *relational.AgentServiceAccountKey) (*string, error) { + return authn.GenerateAgentJWTToken(agent, key, suite.Config.JWTPrivateKey) +} diff --git a/internal/tests/migrate.go b/internal/tests/migrate.go index 7dbad2b1..aadbde5a 100644 --- a/internal/tests/migrate.go +++ b/internal/tests/migrate.go @@ -162,6 +162,9 @@ func (t *TestMigrator) Up() error { &relational.AssessmentLog{}, &relational.AssessmentLogEntry{}, &relational.Attestation{}, + &relational.Agent{}, + &relational.AgentServiceAccountKey{}, + &relational.AgentAuthEvent{}, &relational.User{}, &service.Heartbeat{}, @@ -357,6 +360,9 @@ func (t *TestMigrator) Down() error { "poam_findings", "poam_risks", + &relational.AgentAuthEvent{}, + &relational.AgentServiceAccountKey{}, + &relational.Agent{}, &relational.User{}, &service.Heartbeat{}, diff --git a/sdk/client.go b/sdk/client.go index 6cd81637..8816eda9 100644 --- a/sdk/client.go +++ b/sdk/client.go @@ -1,18 +1,31 @@ package sdk import ( + "bytes" "context" + "encoding/json" "fmt" "io" "net/http" "strings" + "sync" + "time" "go.uber.org/zap" ) +const agentTokenExpirySkew = time.Minute + +type AgentAuthConfig struct { + ClientID string + ClientSecret string +} + type Config struct { BaseURL string Logger *zap.SugaredLogger + + AgentAuth *AgentAuthConfig } type Client struct { @@ -20,39 +33,314 @@ type Client struct { config *Config + tokenMu sync.Mutex + tokenRefreshCh chan struct{} + cachedAccessToken string + cachedTokenType string + cachedTokenExpiresAt time.Time + Evidence *evidenceClient RiskTemplate *riskTemplateClient SubjectTemplate *subjectTemplateClient + + Heartbeat *heartbeatClient } func NewClient(client *http.Client, config *Config) *Client { - return &Client{ + if client == nil { + client = http.DefaultClient + } + if config == nil { + config = &Config{} + } + + c := &Client{ httpClient: client, config: config, - Evidence: &evidenceClient{ - httpClient: client, - config: config, - }, - RiskTemplate: &riskTemplateClient{ - httpClient: client, - config: config, - }, - SubjectTemplate: &subjectTemplateClient{ - httpClient: client, - config: config, - }, } + + c.Evidence = &evidenceClient{client: c} + c.RiskTemplate = &riskTemplateClient{client: c} + c.SubjectTemplate = &subjectTemplateClient{client: c} + c.Heartbeat = &heartbeatClient{client: c} + + return c } func (c *Client) NewRequest(ctx context.Context, method string, path string, reader io.Reader) (*http.Response, error) { + if !c.hasAgentAuth() { + return c.executeStreamingRequest(ctx, method, path, reader, "") + } + + return c.doStreamingRequest(ctx, method, path, reader) +} + +func (c *Client) doJSONRequest(ctx context.Context, method string, path string, payload any) (*http.Response, error) { + var body []byte + if payload != nil { + var err error + body, err = json.Marshal(payload) + if err != nil { + return nil, err + } + } + + return c.doRequest(ctx, method, path, body) +} + +func (c *Client) doRequest(ctx context.Context, method string, path string, body []byte) (*http.Response, error) { + if !c.hasAgentAuth() { + return c.executeRequest(ctx, method, path, body, "") + } + + tokenType, accessToken, err := c.getAgentAccessToken(ctx) + if err != nil { + return nil, err + } + + resp, err := c.executeRequest(ctx, method, path, body, formatAuthorizationHeader(tokenType, accessToken)) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + + closeResponseBody(resp, c.config.Logger) + c.invalidateAgentAccessToken() + + tokenType, accessToken, err = c.getAgentAccessToken(ctx) + if err != nil { + return nil, err + } + + return c.executeRequest(ctx, method, path, body, formatAuthorizationHeader(tokenType, accessToken)) +} + +func (c *Client) doStreamingRequest(ctx context.Context, method string, path string, reader io.Reader) (*http.Response, error) { + bodyFactory, canRetry := makeReplayableRequestBody(reader) + + tokenType, accessToken, err := c.getAgentAccessToken(ctx) + if err != nil { + return nil, err + } + + reqBody, err := bodyFactory() + if err != nil { + return nil, err + } + resp, err := c.executeStreamingRequest(ctx, method, path, reqBody, formatAuthorizationHeader(tokenType, accessToken)) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized || !canRetry { + return resp, nil + } + + closeResponseBody(resp, c.config.Logger) + c.invalidateAgentAccessToken() + + tokenType, accessToken, err = c.getAgentAccessToken(ctx) + if err != nil { + return nil, err + } + + reqBody, err = bodyFactory() + if err != nil { + return nil, err + } + return c.executeStreamingRequest(ctx, method, path, reqBody, formatAuthorizationHeader(tokenType, accessToken)) +} + +func (c *Client) executeRequest(ctx context.Context, method string, path string, body []byte, authorization string) (*http.Response, error) { + return c.executeStreamingRequest(ctx, method, path, bytes.NewReader(body), authorization) +} + +func (c *Client) executeStreamingRequest(ctx context.Context, method string, path string, body io.Reader, authorization string) (*http.Response, error) { path = strings.TrimPrefix(path, "/") url := strings.TrimSuffix(c.config.BaseURL, "/") - req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s/%s", url, path), reader) + req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s/%s", url, path), body) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") + if authorization != "" { + req.Header.Set("Authorization", authorization) + } return c.httpClient.Do(req) } + +func (c *Client) hasAgentAuth() bool { + return c.config != nil && + c.config.AgentAuth != nil && + strings.TrimSpace(c.config.AgentAuth.ClientID) != "" && + strings.TrimSpace(c.config.AgentAuth.ClientSecret) != "" +} + +func (c *Client) getAgentAccessToken(ctx context.Context) (string, string, error) { + for { + tokenType, accessToken, refreshCh, shouldFetch := c.getAgentAccessTokenState() + if accessToken != "" { + return tokenType, accessToken, nil + } + if shouldFetch { + return c.fetchAndCacheAgentAccessToken(ctx, refreshCh) + } + + select { + case <-refreshCh: + case <-ctx.Done(): + return "", "", ctx.Err() + } + } +} + +func (c *Client) invalidateAgentAccessToken() { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + + c.cachedAccessToken = "" + c.cachedTokenType = "" + c.cachedTokenExpiresAt = time.Time{} +} + +func (c *Client) getAgentAccessTokenState() (string, string, chan struct{}, bool) { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + + if c.hasFreshAgentAccessTokenLocked() { + return c.cachedTokenType, c.cachedAccessToken, nil, false + } + if c.tokenRefreshCh != nil { + return "", "", c.tokenRefreshCh, false + } + + c.tokenRefreshCh = make(chan struct{}) + return "", "", c.tokenRefreshCh, true +} + +func (c *Client) hasFreshAgentAccessTokenLocked() bool { + return c.cachedAccessToken != "" && time.Now().UTC().Add(agentTokenExpirySkew).Before(c.cachedTokenExpiresAt) +} + +func (c *Client) fetchAndCacheAgentAccessToken(ctx context.Context, refreshCh chan struct{}) (string, string, error) { + tokenType, accessToken, expiresAt, err := c.fetchAgentAccessToken(ctx) + + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + + if err == nil { + c.cachedAccessToken = accessToken + c.cachedTokenType = tokenType + c.cachedTokenExpiresAt = expiresAt + } + if c.tokenRefreshCh == refreshCh { + close(c.tokenRefreshCh) + c.tokenRefreshCh = nil + } + if err != nil { + return "", "", err + } + + return c.cachedTokenType, c.cachedAccessToken, nil +} + +func (c *Client) fetchAgentAccessToken(ctx context.Context) (string, string, time.Time, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/api/auth/agent/token", strings.TrimSuffix(c.config.BaseURL, "/")), nil) + if err != nil { + return "", "", time.Time{}, err + } + req.SetBasicAuth(strings.TrimSpace(c.config.AgentAuth.ClientID), strings.TrimSpace(c.config.AgentAuth.ClientSecret)) + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", "", time.Time{}, err + } + defer closeResponseBody(resp, c.config.Logger) + + if resp.StatusCode != http.StatusOK { + return "", "", time.Time{}, fmt.Errorf("agent auth failed with status code: %d", resp.StatusCode) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", "", time.Time{}, err + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return "", "", time.Time{}, fmt.Errorf("agent auth response missing access_token") + } + + return tokenResp.TokenType, tokenResp.AccessToken, time.Now().UTC().Add(time.Duration(tokenResp.ExpiresIn) * time.Second), nil +} + +func formatAuthorizationHeader(tokenType, accessToken string) string { + normalizedTokenType := strings.TrimSpace(tokenType) + if normalizedTokenType == "" { + normalizedTokenType = "Bearer" + } else if strings.EqualFold(normalizedTokenType, "bearer") { + normalizedTokenType = "Bearer" + } + + return fmt.Sprintf("%s %s", normalizedTokenType, accessToken) +} + +func readRequestBody(reader io.Reader) ([]byte, error) { + if reader == nil { + return nil, nil + } + + return io.ReadAll(reader) +} + +func makeReplayableRequestBody(reader io.Reader) (func() (io.Reader, error), bool) { + if reader == nil { + return func() (io.Reader, error) { return nil, nil }, true + } + + switch r := reader.(type) { + case *bytes.Reader: + payload := make([]byte, r.Len()) + snapshot := *r + if _, err := snapshot.Read(payload); err != nil { + return func() (io.Reader, error) { return nil, err }, false + } + return func() (io.Reader, error) { return bytes.NewReader(payload), nil }, true + case *strings.Reader: + payload := make([]byte, r.Len()) + snapshot := *r + if _, err := snapshot.Read(payload); err != nil { + return func() (io.Reader, error) { return nil, err }, false + } + return func() (io.Reader, error) { return strings.NewReader(string(payload)), nil }, true + case *bytes.Buffer: + payload := append([]byte(nil), r.Bytes()...) + return func() (io.Reader, error) { return bytes.NewReader(payload), nil }, true + case io.ReadSeeker: + payload, err := readRequestBody(r) + if err != nil { + return func() (io.Reader, error) { return nil, err }, false + } + if _, err := r.Seek(0, io.SeekStart); err != nil { + return func() (io.Reader, error) { return nil, err }, false + } + return func() (io.Reader, error) { return bytes.NewReader(payload), nil }, true + default: + return func() (io.Reader, error) { return reader, nil }, false + } +} + +func closeResponseBody(resp *http.Response, logger *zap.SugaredLogger) { + if resp == nil || resp.Body == nil { + return + } + + if err := resp.Body.Close(); err != nil && logger != nil { + logger.Errorw("failed to close response body", "err", err) + } +} diff --git a/sdk/client_auth_test.go b/sdk/client_auth_test.go new file mode 100644 index 00000000..22f9d652 --- /dev/null +++ b/sdk/client_auth_test.go @@ -0,0 +1,621 @@ +package sdk + +import ( + "context" + "encoding/base64" + "errors" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/compliance-framework/api/sdk/types" + "github.com/google/uuid" +) + +func newAuthenticatedTestClient(handler roundTripFunc) *Client { + return NewClient(&http.Client{Transport: handler}, &Config{ + BaseURL: "http://example.test", + AgentAuth: &AgentAuthConfig{ + ClientID: "client-id", + ClientSecret: "client-secret", + }, + }) +} + +type trackingReader struct { + data []byte + reads int +} + +func (r *trackingReader) Read(p []byte) (int, error) { + if len(r.data) == 0 { + return 0, io.EOF + } + r.reads++ + n := copy(p, r.data) + r.data = r.data[n:] + return n, nil +} + +func TestClientAgentAuthUsesTokenEndpointWithBasicAuth(t *testing.T) { + var ( + mu sync.Mutex + tokenMethod string + tokenPath string + tokenAuthorization string + protectedAuthorization string + ) + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + mu.Lock() + defer mu.Unlock() + + switch r.URL.Path { + case "/api/auth/agent/token": + tokenMethod = r.Method + tokenPath = r.URL.Path + tokenAuthorization = r.Header.Get("Authorization") + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"token-1","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + case "/api/test": + protectedAuthorization = r.Header.Get("Authorization") + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + + if tokenMethod != http.MethodPost { + t.Fatalf("expected token method %q, got %q", http.MethodPost, tokenMethod) + } + if tokenPath != "/api/auth/agent/token" { + t.Fatalf("expected token path %q, got %q", "/api/auth/agent/token", tokenPath) + } + + expectedBasic := "Basic " + base64.StdEncoding.EncodeToString([]byte("client-id:client-secret")) + if tokenAuthorization != expectedBasic { + t.Fatalf("expected basic auth %q, got %q", expectedBasic, tokenAuthorization) + } + if protectedAuthorization != "Bearer token-1" { + t.Fatalf("expected protected request auth %q, got %q", "Bearer token-1", protectedAuthorization) + } +} + +func TestClientNewRequestLeavesRequestUnauthenticatedWithoutAgentAuth(t *testing.T) { + var authHeader string + + client := NewClient(&http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + authHeader = r.Header.Get("Authorization") + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + })}, &Config{BaseURL: "http://example.test"}) + + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + + if authHeader != "" { + t.Fatalf("expected no authorization header, got %q", authHeader) + } +} + +func TestClientNewRequestDoesNotPrebufferBodyWithoutAgentAuth(t *testing.T) { + reader := &trackingReader{data: []byte(`{"hello":"world"}`)} + readCountBeforeRoundTrip := -1 + + client := NewClient(&http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + readCountBeforeRoundTrip = reader.reads + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + if string(body) != `{"hello":"world"}` { + t.Fatalf("unexpected body %q", string(body)) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + })}, &Config{BaseURL: "http://example.test"}) + + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", reader) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + + if readCountBeforeRoundTrip != 0 { + t.Fatalf("expected request body to remain unread before round trip, got %d reads", readCountBeforeRoundTrip) + } +} + +func TestClientNewRequestDoesNotPrebufferNonReplayableBodyWithAgentAuth(t *testing.T) { + reader := &trackingReader{data: []byte(`{"hello":"world"}`)} + readCountBeforeRoundTrip := -1 + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"token-1","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + case "/api/test": + readCountBeforeRoundTrip = reader.reads + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + if string(body) != `{"hello":"world"}` { + t.Fatalf("unexpected body %q", string(body)) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", reader) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + + if readCountBeforeRoundTrip != 0 { + t.Fatalf("expected non-replayable request body to remain unread before round trip, got %d reads", readCountBeforeRoundTrip) + } +} + +func TestClientReusesCachedAgentTokenUntilNearExpiry(t *testing.T) { + var ( + tokenCalls int + protectedCalls int + ) + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + tokenCalls++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"cached-token","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + case "/api/test": + protectedCalls++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + for range 2 { + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + } + + if tokenCalls != 1 { + t.Fatalf("expected one token request, got %d", tokenCalls) + } + if protectedCalls != 2 { + t.Fatalf("expected two protected requests, got %d", protectedCalls) + } +} + +func TestClientTreatsNearlyExpiredTokenAsStale(t *testing.T) { + var tokenCalls int + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + tokenCalls++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"short-lived","token_type":"bearer","expires_in":30}`)), + Header: make(http.Header), + }, nil + case "/api/test": + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + for range 2 { + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + } + + if tokenCalls != 2 { + t.Fatalf("expected near-expiry token to be refreshed, got %d token requests", tokenCalls) + } +} + +func TestClientRetriesOnceOnProtectedCall401(t *testing.T) { + var ( + tokenCalls int + protectedAuthHeaders []string + ) + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + tokenCalls++ + token := "token-1" + if tokenCalls == 2 { + token = "token-2" + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"` + token + `","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + case "/api/test": + protectedAuthHeaders = append(protectedAuthHeaders, r.Header.Get("Authorization")) + status := http.StatusUnauthorized + if len(protectedAuthHeaders) == 2 { + status = http.StatusOK + } + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + + if tokenCalls != 2 { + t.Fatalf("expected two token requests, got %d", tokenCalls) + } + if len(protectedAuthHeaders) != 2 { + t.Fatalf("expected two protected requests, got %d", len(protectedAuthHeaders)) + } + if protectedAuthHeaders[0] != "Bearer token-1" || protectedAuthHeaders[1] != "Bearer token-2" { + t.Fatalf("unexpected protected auth headers: %#v", protectedAuthHeaders) + } +} + +func TestClientDoesNotRetryOnProtectedCall403(t *testing.T) { + var ( + tokenCalls int + protectedCalls int + ) + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + tokenCalls++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"token-1","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + case "/api/test": + protectedCalls++ + return &http.Response{ + StatusCode: http.StatusForbidden, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + + if tokenCalls != 1 { + t.Fatalf("expected one token request, got %d", tokenCalls) + } + if protectedCalls != 1 { + t.Fatalf("expected one protected request, got %d", protectedCalls) + } +} + +func TestClientDoesNotLoopOnSecond401(t *testing.T) { + var ( + tokenCalls int + protectedCalls int + ) + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + tokenCalls++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"token","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + case "/api/test": + protectedCalls++ + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + + if tokenCalls != 2 { + t.Fatalf("expected two token requests, got %d", tokenCalls) + } + if protectedCalls != 2 { + t.Fatalf("expected two protected requests, got %d", protectedCalls) + } +} + +func TestClientDoesNotRetry401ForNonReplayableRequestBody(t *testing.T) { + var ( + tokenCalls int + protectedCalls int + ) + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + tokenCalls++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"token","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + case "/api/test": + protectedCalls++ + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + resp, err := client.NewRequest(context.Background(), http.MethodPost, "/api/test", &trackingReader{data: []byte(`{}`)}) + if err != nil { + t.Fatalf("new request: %v", err) + } + closeResponseBody(resp, nil) + + if tokenCalls != 1 { + t.Fatalf("expected one token request, got %d", tokenCalls) + } + if protectedCalls != 1 { + t.Fatalf("expected one protected request, got %d", protectedCalls) + } +} + +func TestClientGetAgentAccessTokenWaitersRespectContextDuringRefresh(t *testing.T) { + var tokenCalls atomic.Int32 + fetchStarted := make(chan struct{}) + releaseFetch := make(chan struct{}) + firstFetchDone := make(chan error, 1) + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + if tokenCalls.Add(1) == 1 { + close(fetchStarted) + } + <-releaseFetch + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"token-1","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + go func() { + _, _, err := client.getAgentAccessToken(context.Background()) + firstFetchDone <- err + }() + + select { + case <-fetchStarted: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for token fetch to start") + } + + waitCtx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + + _, _, err := client.getAgentAccessToken(waitCtx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected waiter context deadline exceeded, got %v", err) + } + if tokenCalls.Load() != 1 { + t.Fatalf("expected one in-flight token request, got %d", tokenCalls.Load()) + } + + close(releaseFetch) + + select { + case err := <-firstFetchDone: + if err != nil { + t.Fatalf("expected first token fetch to succeed, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first token fetch to finish") + } +} + +func TestAuthenticatedSDKMethodsAttachBearerToken(t *testing.T) { + type call struct { + name string + run func(context.Context, *Client) error + path string + } + + cases := []call{ + { + name: "evidence", + path: "/api/evidence", + run: func(ctx context.Context, client *Client) error { + return client.Evidence.Create(ctx, types.Evidence{ + UUID: uuid.New(), + Title: "evidence", + Start: time.Now().Add(-time.Hour), + End: time.Now().Add(-time.Minute), + Status: types.ObjectiveStatus{ + State: "satisfied", + }, + }) + }, + }, + { + name: "risk-template", + path: "/api/agent/risk-templates/batch", + run: func(ctx context.Context, client *Client) error { + return client.RiskTemplate.Upsert(ctx, "plugin-a", "package-a", types.RiskTemplate{ + ID: uuid.NewString(), + Name: "template-a", + Title: "Template A", + Statement: "Template statement", + ViolationIds: []string{"violation-a"}, + }) + }, + }, + { + name: "subject-template", + path: "/api/agent/subject-templates/batch", + run: func(ctx context.Context, client *Client) error { + return client.SubjectTemplate.Upsert(ctx, "plugin-a", types.SubjectTemplate{ + ID: uuid.NewString(), + Name: "template-a", + Type: "component", + IdentityLabelKeys: []string{"asset_id"}, + SourceMode: "runtime-derived", + SelectorLabels: []types.SubjectTemplateSelectorLabel{ + {Key: "_plugin", Value: "plugin-a"}, + }, + LabelSchema: []types.SubjectTemplateLabelSchema{ + {Key: "asset_id"}, + }, + }) + }, + }, + { + name: "heartbeat", + path: "/api/agent/heartbeat", + run: func(ctx context.Context, client *Client) error { + return client.Heartbeat.Create(ctx, types.Heartbeat{ + UUID: uuid.New(), + CreatedAt: time.Now().UTC(), + }) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var authHeader string + + client := newAuthenticatedTestClient(func(r *http.Request) (*http.Response, error) { + switch r.URL.Path { + case "/api/auth/agent/token": + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"token-1","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + case tc.path: + authHeader = r.Header.Get("Authorization") + status := http.StatusCreated + if tc.path != "/api/evidence" && tc.path != "/api/agent/heartbeat" { + status = http.StatusOK + } + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + default: + t.Fatalf("unexpected path %q", r.URL.Path) + return nil, nil + } + }) + + if err := tc.run(context.Background(), client); err != nil { + t.Fatalf("%s call failed: %v", tc.name, err) + } + if authHeader != "Bearer token-1" { + t.Fatalf("expected bearer auth header, got %q", authHeader) + } + }) + } +} diff --git a/sdk/evidence.go b/sdk/evidence.go index c7522547..32b8a122 100644 --- a/sdk/evidence.go +++ b/sdk/evidence.go @@ -1,9 +1,7 @@ package sdk import ( - "bytes" "context" - "encoding/json" "fmt" "net/http" @@ -11,30 +9,16 @@ import ( ) type evidenceClient struct { - httpClient *http.Client - config *Config + client *Client } func (r *evidenceClient) Create(ctx context.Context, evidence ...types.Evidence) error { for _, evid := range evidence { - reqBody, _ := json.Marshal(evid) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/api/evidence", r.config.BaseURL), bytes.NewReader(reqBody)) + response, err := r.client.doJSONRequest(ctx, http.MethodPost, "/api/evidence", evid) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - response, err := r.httpClient.Do(req) - if err != nil { - return err - } - defer func() { - err := response.Body.Close() - if err != nil { - if r.config.Logger != nil { - r.config.Logger.Error("failed to close response body", "err", err) - } - } - }() + closeResponseBody(response, r.client.config.Logger) if response.StatusCode != http.StatusCreated { return fmt.Errorf("unexpected api response status code: %d", response.StatusCode) diff --git a/sdk/evidence_test.go b/sdk/evidence_test.go index 165de87f..61bf0849 100644 --- a/sdk/evidence_test.go +++ b/sdk/evidence_test.go @@ -4,13 +4,13 @@ package sdk_test import ( "context" - "fmt" + "testing" + "time" + "github.com/compliance-framework/api/internal" "github.com/compliance-framework/api/sdk/types" "github.com/google/uuid" "github.com/stretchr/testify/suite" - "testing" - "time" ) func TestEvidenceSDK(t *testing.T) { @@ -23,8 +23,9 @@ type EvidenceSDKIntegrationSuite struct { func (suite *EvidenceSDKIntegrationSuite) TestCreate() { suite.Run("Evidence can be created through the SDK", func() { - client := suite.GetSDKTestClient() - fmt.Println(client) + suite.Require().NoError(suite.Migrator.Refresh()) + client, err := suite.GetAuthenticatedSDKTestClient() + suite.Require().NoError(err) // Create two catalogs with the same group ID structure evidence := types.Evidence{ UUID: uuid.New(), @@ -139,7 +140,7 @@ func (suite *EvidenceSDKIntegrationSuite) TestCreate() { State: "not-satisfied", // "satisfied" | "not-satisfied" }, } - err := client.Evidence.Create(context.TODO(), evidence) + err = client.Evidence.Create(context.TODO(), evidence) suite.NoError(err) }) } diff --git a/sdk/heartbeat.go b/sdk/heartbeat.go new file mode 100644 index 00000000..f8ebe22e --- /dev/null +++ b/sdk/heartbeat.go @@ -0,0 +1,27 @@ +package sdk + +import ( + "context" + "fmt" + "net/http" + + "github.com/compliance-framework/api/sdk/types" +) + +type heartbeatClient struct { + client *Client +} + +func (h *heartbeatClient) Create(ctx context.Context, heartbeat types.Heartbeat) error { + response, err := h.client.doJSONRequest(ctx, http.MethodPost, "/api/agent/heartbeat", heartbeat) + if err != nil { + return err + } + closeResponseBody(response, h.client.config.Logger) + + if response.StatusCode != http.StatusCreated { + return fmt.Errorf("unexpected api response status code: %d", response.StatusCode) + } + + return nil +} diff --git a/sdk/heartbeat_integration_test.go b/sdk/heartbeat_integration_test.go new file mode 100644 index 00000000..f4883612 --- /dev/null +++ b/sdk/heartbeat_integration_test.go @@ -0,0 +1,39 @@ +//go:build integration + +package sdk_test + +import ( + "context" + "testing" + "time" + + "github.com/compliance-framework/api/internal/service" + "github.com/compliance-framework/api/sdk/types" + "github.com/google/uuid" + "github.com/stretchr/testify/suite" +) + +func TestHeartbeatSDK(t *testing.T) { + suite.Run(t, new(HeartbeatSDKIntegrationSuite)) +} + +type HeartbeatSDKIntegrationSuite struct { + IntegrationBaseTestSuite +} + +func (suite *HeartbeatSDKIntegrationSuite) TestCreateWithAgentAuth() { + suite.Require().NoError(suite.Migrator.Refresh()) + + client, err := suite.GetAuthenticatedSDKTestClient() + suite.Require().NoError(err) + + err = client.Heartbeat.Create(context.Background(), types.Heartbeat{ + UUID: uuid.New(), + CreatedAt: time.Now().UTC(), + }) + suite.Require().NoError(err) + + var count int64 + suite.Require().NoError(suite.DB.Model(&service.Heartbeat{}).Count(&count).Error) + suite.Equal(int64(1), count) +} diff --git a/sdk/heartbeat_test.go b/sdk/heartbeat_test.go new file mode 100644 index 00000000..4eda222b --- /dev/null +++ b/sdk/heartbeat_test.go @@ -0,0 +1,85 @@ +package sdk + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/compliance-framework/api/sdk/types" + "github.com/google/uuid" +) + +func TestHeartbeatCreatePostsPayload(t *testing.T) { + var ( + gotMethod string + gotPath string + gotContentType string + gotBody string + ) + + client := NewClient(&http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + gotMethod = r.Method + gotPath = r.URL.Path + gotContentType = r.Header.Get("Content-Type") + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + gotBody = string(body) + + return &http.Response{ + StatusCode: http.StatusCreated, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + })}, &Config{BaseURL: "http://example.test"}) + + err := client.Heartbeat.Create(context.Background(), types.Heartbeat{ + UUID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), + CreatedAt: time.Date(2026, time.April, 7, 12, 0, 0, 0, time.UTC), + }) + if err != nil { + t.Fatalf("create heartbeat: %v", err) + } + + if gotMethod != http.MethodPost { + t.Fatalf("expected method %q, got %q", http.MethodPost, gotMethod) + } + if gotPath != "/api/agent/heartbeat" { + t.Fatalf("expected path %q, got %q", "/api/agent/heartbeat", gotPath) + } + if gotContentType != "application/json" { + t.Fatalf("expected content type %q, got %q", "application/json", gotContentType) + } + if !strings.Contains(gotBody, "\"uuid\":\"11111111-1111-1111-1111-111111111111\"") { + t.Fatalf("expected uuid in heartbeat payload, got %q", gotBody) + } + if !strings.Contains(gotBody, "\"created_at\":\"2026-04-07T12:00:00Z\"") { + t.Fatalf("expected created_at in heartbeat payload, got %q", gotBody) + } +} + +func TestHeartbeatCreateReturnsErrorOnUnexpectedStatus(t *testing.T) { + client := NewClient(&http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusTeapot, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + })}, &Config{BaseURL: "http://example.test"}) + + err := client.Heartbeat.Create(context.Background(), types.Heartbeat{ + UUID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), + CreatedAt: time.Now().UTC(), + }) + if err == nil { + t.Fatal("expected error for unexpected status code") + } + if !strings.Contains(err.Error(), "418") { + t.Fatalf("expected error to mention status code 418, got %q", err.Error()) + } +} diff --git a/sdk/integration_base_test.go b/sdk/integration_base_test.go index c15859cb..a2fb68fa 100644 --- a/sdk/integration_base_test.go +++ b/sdk/integration_base_test.go @@ -4,6 +4,7 @@ package sdk_test import ( "context" + "fmt" "net" "net/http" "strings" @@ -11,6 +12,7 @@ import ( "github.com/compliance-framework/api/internal/api" "github.com/compliance-framework/api/internal/api/handler" + authhandler "github.com/compliance-framework/api/internal/api/handler/auth" "github.com/compliance-framework/api/internal/authn" "github.com/compliance-framework/api/internal/config" "github.com/compliance-framework/api/internal/service/relational" @@ -18,6 +20,7 @@ import ( "github.com/compliance-framework/api/internal/tests" "github.com/compliance-framework/api/sdk" + "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/stretchr/testify/suite" "github.com/testcontainers/testcontainers-go" @@ -45,6 +48,36 @@ func (suite *IntegrationBaseTestSuite) GetSDKTestClient() *sdk.Client { return sdk.NewClient(http.DefaultClient, config) } +func (suite *IntegrationBaseTestSuite) GetAuthenticatedSDKTestClient() (*sdk.Client, error) { + agent := &relational.Agent{ + Name: fmt.Sprintf("sdk-agent-%d", time.Now().UnixNano()), + IsActive: true, + } + if err := suite.DB.Create(agent).Error; err != nil { + return nil, err + } + + clientSecret := fmt.Sprintf("sdk-secret-%d", time.Now().UnixNano()) + key := &relational.AgentServiceAccountKey{ + AgentID: agent.ID, + ClientID: uuid.NewString(), + } + if err := key.SetSecret(clientSecret); err != nil { + return nil, err + } + if err := suite.DB.Create(key).Error; err != nil { + return nil, err + } + + return sdk.NewClient(http.DefaultClient, &sdk.Config{ + BaseURL: "http://" + suite.Server.E().ListenerAddr().String(), + AgentAuth: &sdk.AgentAuthConfig{ + ClientID: key.ClientID, + ClientSecret: clientSecret, + }, + }), nil +} + func (suite *IntegrationBaseTestSuite) SetupSuite() { ctx := context.Background() @@ -55,6 +88,7 @@ func (suite *IntegrationBaseTestSuite) SetupSuite() { cfg.JWTPrivateKey = privKey cfg.JWTPublicKey = pubKey + cfg.StrictDisablePublicAgentEndpoints = true suite.Config = cfg postgresContainer, err := postgresContainers.Run(ctx, @@ -113,6 +147,7 @@ func (suite *IntegrationBaseTestSuite) SetupSuite() { } handler.RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, services) + authhandler.RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, metrics, nil, nil) suite.Server = server diff --git a/sdk/risk_template.go b/sdk/risk_template.go index 8db41c50..bea6220d 100644 --- a/sdk/risk_template.go +++ b/sdk/risk_template.go @@ -1,9 +1,7 @@ package sdk import ( - "bytes" "context" - "encoding/json" "fmt" "net/http" @@ -11,8 +9,7 @@ import ( ) type riskTemplateClient struct { - httpClient *http.Client - config *Config + client *Client } type upsertRiskTemplatesRequest struct { @@ -31,27 +28,11 @@ func (r *riskTemplateClient) Upsert(ctx context.Context, pluginID string, policy PolicyPackage: policyPackage, Templates: riskTemplates, } - reqBody, err := json.Marshal(reqData) + response, err := r.client.doJSONRequest(ctx, http.MethodPost, "/api/agent/risk-templates/batch", reqData) if err != nil { return err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/api/agent/risk-templates/batch", r.config.BaseURL), bytes.NewReader(reqBody)) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - response, err := r.httpClient.Do(req) - if err != nil { - return err - } - defer func() { - err := response.Body.Close() - if err != nil { - if r.config.Logger != nil { - r.config.Logger.Error("failed to close response body", "err", err) - } - } - }() + closeResponseBody(response, r.client.config.Logger) if response.StatusCode != http.StatusCreated && response.StatusCode != http.StatusOK { return fmt.Errorf("unexpected api response status code: %d", response.StatusCode) diff --git a/sdk/risk_template_integration_test.go b/sdk/risk_template_integration_test.go new file mode 100644 index 00000000..f6cae57b --- /dev/null +++ b/sdk/risk_template_integration_test.go @@ -0,0 +1,42 @@ +//go:build integration + +package sdk_test + +import ( + "context" + "testing" + + templaterel "github.com/compliance-framework/api/internal/service/relational/templates" + "github.com/compliance-framework/api/sdk/types" + "github.com/google/uuid" + "github.com/stretchr/testify/suite" +) + +func TestRiskTemplateSDK(t *testing.T) { + suite.Run(t, new(RiskTemplateSDKIntegrationSuite)) +} + +type RiskTemplateSDKIntegrationSuite struct { + IntegrationBaseTestSuite +} + +func (suite *RiskTemplateSDKIntegrationSuite) TestUpsertWithAgentAuth() { + suite.Require().NoError(suite.Migrator.Refresh()) + + client, err := suite.GetAuthenticatedSDKTestClient() + suite.Require().NoError(err) + + templateID := uuid.NewString() + err = client.RiskTemplate.Upsert(context.Background(), "plugin-a", "package-a", types.RiskTemplate{ + ID: templateID, + Name: "Template A", + Title: "Template A", + Statement: "Template statement", + ViolationIds: []string{"violation-a"}, + }) + suite.Require().NoError(err) + + var count int64 + suite.Require().NoError(suite.DB.Model(&templaterel.RiskTemplate{}).Where("id = ?", templateID).Count(&count).Error) + suite.Equal(int64(1), count) +} diff --git a/sdk/subject_template.go b/sdk/subject_template.go index bbec3b3e..29b4414e 100644 --- a/sdk/subject_template.go +++ b/sdk/subject_template.go @@ -1,9 +1,7 @@ package sdk import ( - "bytes" "context" - "encoding/json" "fmt" "net/http" @@ -11,8 +9,7 @@ import ( ) type subjectTemplateClient struct { - httpClient *http.Client - config *Config + client *Client } type upsertSubjectTemplatesRequest struct { @@ -30,27 +27,11 @@ func (r *subjectTemplateClient) Upsert(ctx context.Context, pluginID string, sub Templates: subjectTemplates, } - reqBody, err := json.Marshal(reqData) + response, err := r.client.doJSONRequest(ctx, http.MethodPost, "/api/agent/subject-templates/batch", reqData) if err != nil { return err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/api/agent/subject-templates/batch", r.config.BaseURL), bytes.NewReader(reqBody)) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - response, err := r.httpClient.Do(req) - if err != nil { - return err - } - defer func() { - err := response.Body.Close() - if err != nil { - if r.config.Logger != nil { - r.config.Logger.Error("failed to close response body", "err", err) - } - } - }() + closeResponseBody(response, r.client.config.Logger) if response.StatusCode != http.StatusCreated && response.StatusCode != http.StatusOK { return fmt.Errorf("unexpected api response status code: %d", response.StatusCode) diff --git a/sdk/subject_template_integration_test.go b/sdk/subject_template_integration_test.go new file mode 100644 index 00000000..1239407a --- /dev/null +++ b/sdk/subject_template_integration_test.go @@ -0,0 +1,48 @@ +//go:build integration + +package sdk_test + +import ( + "context" + "testing" + + templaterel "github.com/compliance-framework/api/internal/service/relational/templates" + "github.com/compliance-framework/api/sdk/types" + "github.com/google/uuid" + "github.com/stretchr/testify/suite" +) + +func TestSubjectTemplateSDK(t *testing.T) { + suite.Run(t, new(SubjectTemplateSDKIntegrationSuite)) +} + +type SubjectTemplateSDKIntegrationSuite struct { + IntegrationBaseTestSuite +} + +func (suite *SubjectTemplateSDKIntegrationSuite) TestUpsertWithAgentAuth() { + suite.Require().NoError(suite.Migrator.Refresh()) + + client, err := suite.GetAuthenticatedSDKTestClient() + suite.Require().NoError(err) + + templateID := uuid.NewString() + err = client.SubjectTemplate.Upsert(context.Background(), "plugin-a", types.SubjectTemplate{ + ID: templateID, + Name: "Template A", + Type: "component", + IdentityLabelKeys: []string{"asset_id"}, + SourceMode: "runtime-derived", + SelectorLabels: []types.SubjectTemplateSelectorLabel{ + {Key: "_plugin", Value: "plugin-a"}, + }, + LabelSchema: []types.SubjectTemplateLabelSchema{ + {Key: "asset_id"}, + }, + }) + suite.Require().NoError(err) + + var count int64 + suite.Require().NoError(suite.DB.Model(&templaterel.SubjectTemplate{}).Where("id = ?", templateID).Count(&count).Error) + suite.Equal(int64(1), count) +} diff --git a/sdk/types/types.go b/sdk/types/types.go index 47e873ea..61729946 100644 --- a/sdk/types/types.go +++ b/sdk/types/types.go @@ -230,3 +230,8 @@ type SubjectTemplate struct { SelectorLabels []SubjectTemplateSelectorLabel `json:"selector-labels"` LabelSchema []SubjectTemplateLabelSchema `json:"label-schema"` } + +type Heartbeat struct { + UUID uuid.UUID `json:"uuid"` + CreatedAt time.Time `json:"created_at"` +}