diff --git a/cmd/admin/handlers/post.go b/cmd/admin/handlers/post.go index e29c0f5a..6ef27cbb 100644 --- a/cmd/admin/handlers/post.go +++ b/cmd/admin/handlers/post.go @@ -10,6 +10,7 @@ import ( "net/http" "github.com/jmpsec/osctrl/cmd/admin/sessions" + "github.com/jmpsec/osctrl/pkg/handlers" "github.com/jmpsec/osctrl/pkg/nodes" "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" @@ -130,95 +131,25 @@ func (h *HandlersAdmin) QueryRunPOSTHandler(w http.ResponseWriter, r *http.Reque adminErrorResponse(w, "error creating query", http.StatusInternalServerError, err) return } - // List all the nodes that match the query - var expected []uint - targetNodesID := []uint{} - // TODO: Refactor this to use osctrl-api instead of direct DB queries - // Extract targets by environment - if len(q.Environments) > 0 { - expected = []uint{} - for _, e := range q.Environments { - // TODO: Check if user has permissions to query the environment - if (e != "") && h.Envs.Exists(e) { - nodes, err := h.Nodes.GetByEnv(e, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID)) - if err != nil { - adminErrorResponse(w, "error getting nodes by environment", http.StatusInternalServerError, err) - return - } - for _, n := range nodes { - expected = append(expected, n.ID) - } - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) + // Prepare data for the handler code + data := handlers.ProcessingQuery{ + Envs: q.Environments, + Platforms: q.Platforms, + UUIDs: q.UUIDs, + Hosts: q.Hosts, + Tags: q.Tags, + EnvID: env.ID, + InactiveHours: h.Settings.InactiveHours(settings.NoEnvironmentID), } - // Create platform target - if len(q.Platforms) > 0 { - expected = []uint{} - platforms, _ := h.Nodes.GetEnvIDPlatforms(env.ID) - for _, p := range q.Platforms { - if (p != "") && checkValidPlatform(platforms, p) { - nodes, err := h.Nodes.GetByPlatform(env.ID, p, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID)) - if err != nil { - adminErrorResponse(w, "error getting nodes by platform", http.StatusInternalServerError, err) - return - } - for _, n := range nodes { - expected = append(expected, n.ID) - } - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) + manager := handlers.Managers{ + Nodes: h.Nodes, + Envs: h.Envs, + Tags: h.Tags, } - // Create UUIDs target - if len(q.UUIDs) > 0 { - expected = []uint{} - for _, u := range q.UUIDs { - if u != "" { - node, err := h.Nodes.GetByUUID(u) - if err != nil { - log.Err(err).Msgf("error getting node %s and failed to create node query for it", u) - continue - } - expected = append(expected, node.ID) - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) - } - // Create hostnames target - if len(q.Hosts) > 0 { - expected = []uint{} - for _, _h := range q.Hosts { - if _h != "" { - node, err := h.Nodes.GetByIdentifier(_h) - if err != nil { - log.Err(err).Msgf("error getting node %s and failed to create node query for it", _h) - continue - } - expected = append(expected, node.ID) - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) - } - // Create tags target - if len(q.Tags) > 0 { - expected = []uint{} - for _, _t := range q.Tags { - if _t != "" { - exist, tag := h.Tags.ExistsGet(tags.GetStrTagName(_t), env.ID) - if exist { - tagged, err := h.Tags.GetTaggedNodes(tag) - if err != nil { - log.Err(err).Msgf("error getting tagged nodes for tag %s", _t) - continue - } - for _, tn := range tagged { - expected = append(expected, tn.NodeID) - } - } - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) + targetNodesID, err := handlers.CreateQueryCarve(data, manager, newQuery) + if err != nil { + adminErrorResponse(w, "error creating query", http.StatusInternalServerError, err) + return } // If the list is empty, we don't need to create node queries if len(targetNodesID) != 0 { @@ -271,7 +202,7 @@ func (h *HandlersAdmin) CarvesRunPOSTHandler(w http.ResponseWriter, r *http.Requ } // Parse request JSON body log.Debug().Msg("Decoding POST body") - var c DistributedCarveRequest + var c DistributedQueryRequest if err := json.NewDecoder(r.Body).Decode(&c); err != nil { adminErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err) return @@ -297,94 +228,25 @@ func (h *HandlersAdmin) CarvesRunPOSTHandler(w http.ResponseWriter, r *http.Requ adminErrorResponse(w, "error creating carve", http.StatusInternalServerError, err) return } - // List all the nodes that match the query - var expected []uint - targetNodesID := []uint{} - // Extract targets by environment - if len(c.Environments) > 0 { - expected = []uint{} - for _, e := range c.Environments { - // TODO: Check if user has permissions to query the environment - if (e != "") && h.Envs.Exists(e) { - nodes, err := h.Nodes.GetByEnv(e, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID)) - if err != nil { - adminErrorResponse(w, "error getting nodes by environment", http.StatusInternalServerError, err) - return - } - for _, n := range nodes { - expected = append(expected, n.ID) - } - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) + // Prepare data for the handler code + data := handlers.ProcessingQuery{ + Envs: c.Environments, + Platforms: c.Platforms, + UUIDs: c.UUIDs, + Hosts: c.Hosts, + Tags: c.Tags, + EnvID: env.ID, + InactiveHours: h.Settings.InactiveHours(settings.NoEnvironmentID), } - // Create platform target - if len(c.Platforms) > 0 { - expected = []uint{} - platforms, _ := h.Nodes.GetEnvIDPlatforms(env.ID) - for _, p := range c.Platforms { - if (p != "") && checkValidPlatform(platforms, p) { - nodes, err := h.Nodes.GetByPlatform(env.ID, p, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID)) - if err != nil { - adminErrorResponse(w, "error getting nodes by platform", http.StatusInternalServerError, err) - return - } - for _, n := range nodes { - expected = append(expected, n.ID) - } - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) + manager := handlers.Managers{ + Nodes: h.Nodes, + Envs: h.Envs, + Tags: h.Tags, } - // Create UUIDs target - if len(c.UUIDs) > 0 { - expected = []uint{} - for _, u := range c.UUIDs { - if u != "" { - node, err := h.Nodes.GetByUUID(u) - if err != nil { - log.Err(err).Msgf("error getting node %s and failed to create carve query for it", u) - continue - } - expected = append(expected, node.ID) - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) - } - // Create hostnames target - if len(c.Hosts) > 0 { - expected = []uint{} - for _, _h := range c.Hosts { - if _h != "" { - node, err := h.Nodes.GetByIdentifier(_h) - if err != nil { - log.Err(err).Msgf("error getting node %s and failed to create carve query for it", _h) - continue - } - expected = append(expected, node.ID) - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) - } - // Create tags target - if len(c.Tags) > 0 { - expected = []uint{} - for _, _t := range c.Tags { - if _t != "" { - exist, tag := h.Tags.ExistsGet(tags.GetStrTagName(_t), env.ID) - if exist { - tagged, err := h.Tags.GetTaggedNodes(tag) - if err != nil { - log.Err(err).Msgf("error getting tagged nodes for tag %s", _t) - continue - } - for _, tn := range tagged { - expected = append(expected, tn.NodeID) - } - } - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) + targetNodesID, err := handlers.CreateQueryCarve(data, manager, newQuery) + if err != nil { + adminErrorResponse(w, "error creating query", http.StatusInternalServerError, err) + return } // If the list is empty, we don't need to create node queries if len(targetNodesID) != 0 { diff --git a/cmd/admin/handlers/types-requests.go b/cmd/admin/handlers/types-requests.go index b87eb89b..e15f973c 100644 --- a/cmd/admin/handlers/types-requests.go +++ b/cmd/admin/handlers/types-requests.go @@ -22,17 +22,6 @@ type DistributedQueryRequest struct { Save bool `json:"save"` Name string `json:"name"` Query string `json:"query"` - ExpHours int `json:"exp_hours"` -} - -// DistributedCarveRequest to receive carve requests -type DistributedCarveRequest struct { - CSRFToken string `json:"csrftoken"` - Environments []string `json:"environment_list"` - Platforms []string `json:"platform_list"` - UUIDs []string `json:"uuid_list"` - Hosts []string `json:"host_list"` - Tags []string `json:"tag_list"` Path string `json:"path"` ExpHours int `json:"exp_hours"` } diff --git a/cmd/admin/handlers/utils.go b/cmd/admin/handlers/utils.go index 07e95993..0d7deb4e 100644 --- a/cmd/admin/handlers/utils.go +++ b/cmd/admin/handlers/utils.go @@ -37,16 +37,6 @@ func adminOKResponse(w http.ResponseWriter, msg string) { utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, AdminResponse{Message: msg}) } -// Helper to verify if a platform is valid -func checkValidPlatform(platforms []string, platform string) bool { - for _, p := range platforms { - if p == platform { - return true - } - } - return false -} - // Helper to check if the CSRF token is valid func checkCSRFToken(ctxToken, receivedToken string) bool { return (strings.TrimSpace(ctxToken) == strings.TrimSpace(receivedToken)) @@ -80,17 +70,12 @@ func generateCarveQuery(file string, glob bool) string { } // Helper to generate the file carve query -func newCarveReady(user, path string, exp time.Time, envid uint, req DistributedCarveRequest) queries.DistributedQuery { +func newCarveReady(user, path string, exp time.Time, envid uint, req DistributedQueryRequest) queries.DistributedQuery { return queries.DistributedQuery{ Query: generateCarveQuery(path, false), Name: generateCarveName(), Creator: user, - Expected: 0, - Executions: 0, Active: true, - Completed: false, - Deleted: false, - Expired: false, Expiration: exp, Type: queries.CarveQueryType, Path: path, @@ -102,25 +87,13 @@ func newCarveReady(user, path string, exp time.Time, envid uint, req Distributed // Helper to determine if a query may be a carve func newQueryReady(user, query string, exp time.Time, envid uint, req DistributedQueryRequest) queries.DistributedQuery { if strings.Contains(query, "carve") { - cReq := DistributedCarveRequest{ - Environments: req.Environments, - Platforms: req.Platforms, - UUIDs: req.UUIDs, - Hosts: req.Hosts, - Tags: req.Tags, - } - return newCarveReady(user, query, exp, envid, cReq) + return newCarveReady(user, query, exp, envid, req) } return queries.DistributedQuery{ Query: query, Name: generateQueryName(), Creator: user, - Expected: 0, - Executions: 0, Active: true, - Completed: false, - Deleted: false, - Expired: false, Expiration: exp, Type: queries.StandardQueryType, EnvironmentID: envid, diff --git a/cmd/api/handlers/carves.go b/cmd/api/handlers/carves.go index 8b2485ff..59bd7966 100644 --- a/cmd/api/handlers/carves.go +++ b/cmd/api/handlers/carves.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jmpsec/osctrl/pkg/carves" + "github.com/jmpsec/osctrl/pkg/handlers" "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/types" @@ -173,7 +174,7 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - var c types.ApiDistributedCarveRequest + var c types.ApiDistributedQueryRequest // Parse request JSON body if err := json.NewDecoder(r.Body).Decode(&c); err != nil { apiErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err) @@ -188,21 +189,13 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { if c.ExpHours == 0 { expTime = time.Time{} } - query := carves.GenCarveQuery(c.Path, false) // Prepare and create new carve - carveName := carves.GenCarveName() - newQuery := queries.DistributedQuery{ - Query: query, - Name: carveName, + Query: carves.GenCarveQuery(c.Path, false), + Name: carves.GenCarveName(), Creator: ctx[ctxUser], - Expected: 0, - Executions: 0, Active: true, - Expired: false, Expiration: expTime, - Completed: false, - Deleted: false, Type: queries.CarveQueryType, Path: c.Path, EnvironmentID: env.ID, @@ -211,15 +204,36 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error creating query", http.StatusInternalServerError, err) return } - // Create UUID target - if (c.UUID != "") && h.Nodes.CheckByUUID(c.UUID) { - if err := h.Queries.CreateTarget(carveName, queries.QueryTargetUUID, c.UUID); err != nil { - apiErrorResponse(w, "error creating carve UUID target", http.StatusInternalServerError, err) + // Prepare data for the handler code + data := handlers.ProcessingQuery{ + Envs: c.Environments, + Platforms: c.Platforms, + UUIDs: c.UUIDs, + Hosts: c.Hosts, + Tags: c.Tags, + EnvID: env.ID, + InactiveHours: h.Settings.InactiveHours(settings.NoEnvironmentID), + } + manager := handlers.Managers{ + Nodes: h.Nodes, + Envs: h.Envs, + Tags: h.Tags, + } + targetNodesID, err := handlers.CreateQueryCarve(data, manager, newQuery) + if err != nil { + apiErrorResponse(w, "error creating query", http.StatusInternalServerError, err) + return + } + // If the list is empty, we don't need to create node queries + if len(targetNodesID) != 0 { + if err := h.Queries.CreateNodeQueries(targetNodesID, newQuery.ID); err != nil { + log.Err(err).Msgf("error creating node queries for carve %s", newQuery.Name) + apiErrorResponse(w, "error creating node queries", http.StatusInternalServerError, err) return } } // Update value for expected - if err := h.Queries.SetExpected(carveName, 1, env.ID); err != nil { + if err := h.Queries.SetExpected(newQuery.Name, len(targetNodesID), env.ID); err != nil { apiErrorResponse(w, "error setting expected", http.StatusInternalServerError, err) return } diff --git a/cmd/api/handlers/queries.go b/cmd/api/handlers/queries.go index 6b6de107..56304b35 100644 --- a/cmd/api/handlers/queries.go +++ b/cmd/api/handlers/queries.go @@ -6,7 +6,7 @@ import ( "net/http" "time" - "github.com/jmpsec/osctrl/pkg/nodes" + "github.com/jmpsec/osctrl/pkg/handlers" "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/types" @@ -114,18 +114,12 @@ func (h *HandlersApi) QueriesRunHandler(w http.ResponseWriter, r *http.Request) expTime = time.Time{} } // Prepare and create new query - queryName := queries.GenQueryName() newQuery := queries.DistributedQuery{ Query: q.Query, - Name: queryName, + Name: queries.GenQueryName(), Creator: ctx[ctxUser], - Expected: 0, - Executions: 0, Active: true, - Expired: false, Expiration: expTime, - Completed: false, - Deleted: false, Hidden: q.Hidden, Type: queries.StandardQueryType, EnvironmentID: env.ID, @@ -134,87 +128,26 @@ func (h *HandlersApi) QueriesRunHandler(w http.ResponseWriter, r *http.Request) apiErrorResponse(w, "error creating query", http.StatusInternalServerError, err) return } - // Get the query id - newQuery, err = h.Queries.Get(queryName, env.ID) + // Prepare data for the handler code + data := handlers.ProcessingQuery{ + Envs: q.Environments, + Platforms: q.Platforms, + UUIDs: q.UUIDs, + Hosts: q.Hosts, + Tags: q.Tags, + EnvID: env.ID, + InactiveHours: h.Settings.InactiveHours(settings.NoEnvironmentID), + } + manager := handlers.Managers{ + Nodes: h.Nodes, + Envs: h.Envs, + Tags: h.Tags, + } + targetNodesID, err := handlers.CreateQueryCarve(data, manager, newQuery) if err != nil { apiErrorResponse(w, "error creating query", http.StatusInternalServerError, err) return } - - // List all the nodes that match the query - var expected []uint - - targetNodesID := []uint{} - // Current logic is to select nodes meeting all criteria in the query - // TODO: I believe we should only allow to list nodes in one environment in URL paths - // We will refactor this part to be tag based queries and add more options to the query - if len(q.Environments) > 0 { - expected = []uint{} - for _, e := range q.Environments { - if (e != "") && h.Envs.Exists(e) { - nodes, err := h.Nodes.GetByEnv(e, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID)) - if err != nil { - apiErrorResponse(w, "error getting nodes by environment", http.StatusInternalServerError, err) - return - } - for _, n := range nodes { - expected = append(expected, n.ID) - } - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) - } - // Create platform target - if len(q.Platforms) > 0 { - expected = []uint{} - platforms, _ := h.Nodes.GetAllPlatforms() - for _, p := range q.Platforms { - if (p != "") && checkValidPlatform(platforms, p) { - nodes, err := h.Nodes.GetByPlatform(env.ID, p, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID)) - if err != nil { - apiErrorResponse(w, "error getting nodes by platform", http.StatusInternalServerError, err) - return - } - for _, n := range nodes { - expected = append(expected, n.ID) - } - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) - } - // Create UUIDs target - if len(q.UUIDs) > 0 { - expected = []uint{} - for _, u := range q.UUIDs { - if u != "" { - node, err := h.Nodes.GetByUUID(u) - if err != nil { - log.Warn().Msgf("error getting node %s and failed to create node query for it", u) - continue - } - expected = append(expected, node.ID) - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) - } - // Create hostnames target - // Currently we are using the GetByIdentifier function and it need be more clear - // about the definition of the identifier - if len(q.Hosts) > 0 { - expected = []uint{} - for _, hostName := range q.Hosts { - if hostName != "" { - node, err := h.Nodes.GetByIdentifier(hostName) - if err != nil { - log.Warn().Msgf("error getting node %s and failed to create node query for it", hostName) - continue - } - expected = append(expected, node.ID) - } - } - targetNodesID = utils.Intersect(targetNodesID, expected) - } - // If the list is empty, we don't need to create node queries if len(targetNodesID) != 0 { if err := h.Queries.CreateNodeQueries(targetNodesID, newQuery.ID); err != nil { @@ -223,14 +156,13 @@ func (h *HandlersApi) QueriesRunHandler(w http.ResponseWriter, r *http.Request) return } } - // Update value for expected - if err := h.Queries.SetExpected(queryName, len(targetNodesID), env.ID); err != nil { + if err := h.Queries.SetExpected(newQuery.Name, len(targetNodesID), env.ID); err != nil { apiErrorResponse(w, "error setting expected", http.StatusInternalServerError, err) return } // Return query name as serialized response - log.Debug().Msgf("Created query %s", newQuery.Name) + log.Debug().Msgf("Created query %s with id %d", newQuery.Name, newQuery.ID) utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiQueriesResponse{Name: newQuery.Name}) } diff --git a/cmd/api/handlers/utils.go b/cmd/api/handlers/utils.go index 83fffd4d..42b9f982 100644 --- a/cmd/api/handlers/utils.go +++ b/cmd/api/handlers/utils.go @@ -43,13 +43,3 @@ func apiErrorResponse(w http.ResponseWriter, msg string, code int, err error) { log.Debug().Msgf("apiErrorResponse %s: %v", msg, err) utils.HTTPResponse(w, utils.JSONApplicationUTF8, code, types.ApiErrorResponse{Error: msg}) } - -// Helper to verify if a platform is valid -func checkValidPlatform(platforms []string, platform string) bool { - for _, p := range platforms { - if p == platform { - return true - } - } - return false -} diff --git a/cmd/cli/api-carve.go b/cmd/cli/api-carve.go index b8ccd8c8..ab10d732 100644 --- a/cmd/cli/api-carve.go +++ b/cmd/cli/api-carve.go @@ -98,11 +98,15 @@ func (api *OsctrlAPI) CompleteCarve(env, name string) (types.ApiGenericResponse, } // RunCarve to initiate a carve in osctrl -func (api *OsctrlAPI) RunCarve(env, uuid, fPath string, exp int) (types.ApiQueriesResponse, error) { - c := types.ApiDistributedCarveRequest{ - UUID: uuid, - Path: fPath, - ExpHours: exp, +func (api *OsctrlAPI) RunCarve(env, fPath string, uuids, hosts, platforms, tags []string, hidden bool, exp int) (types.ApiQueriesResponse, error) { + c := types.ApiDistributedQueryRequest{ + UUIDs: uuids, + Hosts: hosts, + Platforms: platforms, + Tags: tags, + Path: fPath, + Hidden: hidden, + ExpHours: exp, } var r types.ApiQueriesResponse reqURL := fmt.Sprintf("%s%s", api.Configuration.URL, path.Join(APIPath, APICarves, env)) diff --git a/cmd/cli/api-query.go b/cmd/cli/api-query.go index 15c52f3e..ada29971 100644 --- a/cmd/cli/api-query.go +++ b/cmd/cli/api-query.go @@ -82,9 +82,12 @@ func (api *OsctrlAPI) CompleteQuery(env, name string) (types.ApiGenericResponse, } // RunQuery to initiate a query in osctrl -func (api *OsctrlAPI) RunQuery(env, uuid, query string, hidden bool, exp int) (types.ApiQueriesResponse, error) { +func (api *OsctrlAPI) RunQuery(env, query string, uuids, hosts, platforms, tags []string, hidden bool, exp int) (types.ApiQueriesResponse, error) { q := types.ApiDistributedQueryRequest{ - UUIDs: []string{uuid}, + UUIDs: uuids, + Hosts: hosts, + Platforms: platforms, + Tags: tags, Query: query, Hidden: hidden, ExpHours: exp, diff --git a/cmd/cli/carve.go b/cmd/cli/carve.go index 6b915046..c74e255a 100644 --- a/cmd/cli/carve.go +++ b/cmd/cli/carve.go @@ -6,9 +6,13 @@ import ( "fmt" "os" "strconv" + "strings" + "time" "github.com/jmpsec/osctrl/pkg/carves" + "github.com/jmpsec/osctrl/pkg/handlers" "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/settings" "github.com/olekukonko/tablewriter" "github.com/urfave/cli/v2" ) @@ -303,30 +307,49 @@ func runCarve(c *cli.Context) error { fmt.Println("❌ environment is required") os.Exit(1) } - uuid := c.String("uuid") - if uuid == "" { + uuidStr := c.String("uuid") + if uuidStr == "" { fmt.Println("❌ UUID is required") os.Exit(1) } + uuidList := []string{uuidStr} + if strings.Contains(uuidStr, ",") { + uuidList = strings.Split(uuidStr, ",") + } + platformStr := c.String("platform") + platformList := []string{platformStr} + if strings.Contains(platformStr, ",") { + platformList = strings.Split(platformStr, ",") + } + hostStr := c.String("host") + hostList := []string{hostStr} + if strings.Contains(hostStr, ",") { + hostList = strings.Split(hostStr, ",") + } + tagStr := c.String("tag") + tagList := []string{tagStr} + if strings.Contains(tagStr, ",") { + tagList = strings.Split(tagStr, ",") + } expHours := c.Int("expiration") - var cName string + hidden := c.Bool("hidden") + cName := carves.GenCarveName() if dbFlag { e, err := envs.Get(env) if err != nil { return fmt.Errorf("❌ %w", err) } - carveName := carves.GenCarveName() + expTime := queries.QueryExpiration(expHours) + if expHours == 0 { + expTime = time.Time{} + } newQuery := queries.DistributedQuery{ Query: carves.GenCarveQuery(path, false), - Name: carveName, + Name: cName, Creator: appName, - Expected: 0, - Executions: 0, Active: true, - Expired: false, - Expiration: queries.QueryExpiration(expHours), - Completed: false, - Deleted: false, + Expiration: expTime, + Hidden: hidden, Type: queries.CarveQueryType, Path: path, EnvironmentID: e.ID, @@ -334,18 +357,36 @@ func runCarve(c *cli.Context) error { if err := queriesmgr.Create(&newQuery); err != nil { return fmt.Errorf("❌ %w", err) } - if (uuid != "") && nodesmgr.CheckByUUID(uuid) { - if err := queriesmgr.CreateTarget(carveName, queries.QueryTargetUUID, uuid); err != nil { - return fmt.Errorf("❌ error creating target - %w", err) + // Prepare data for the handler code + data := handlers.ProcessingQuery{ + Envs: []string{}, + Platforms: platformList, + UUIDs: uuidList, + Hosts: hostList, + Tags: tagList, + EnvID: e.ID, + InactiveHours: settingsmgr.InactiveHours(settings.NoEnvironmentID), + } + manager := handlers.Managers{ + Nodes: nodesmgr, + Envs: envs, + Tags: tagsmgr, + } + targetNodesID, err := handlers.CreateQueryCarve(data, manager, newQuery) + if err != nil { + return fmt.Errorf("❌ error creating query carve - %w", err) + } + // If the list is empty, we don't need to create node queries + if len(targetNodesID) != 0 { + if err := queriesmgr.CreateNodeQueries(targetNodesID, newQuery.ID); err != nil { + return fmt.Errorf("❌ error creating node queries - %w", err) } } - if err := queriesmgr.SetExpected(carveName, 1, e.ID); err != nil { + if err := queriesmgr.SetExpected(cName, len(targetNodesID), e.ID); err != nil { return fmt.Errorf("❌ error setting expected - %w", err) } - cName = carveName - return nil } else if apiFlag { - c, err := osctrlAPI.RunCarve(env, uuid, path, expHours) + c, err := osctrlAPI.RunCarve(env, path, uuidList, hostList, platformList, tagList, hidden, expHours) if err != nil { return fmt.Errorf("❌ error running carve - %w", err) } diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 868cf936..f366cd48 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -1348,7 +1348,22 @@ func init() { &cli.StringFlag{ Name: "uuid", Aliases: []string{"u"}, - Usage: "Node UUID to be used", + Usage: "Node UUID(s) to be used. Comma separated for multiple values", + }, + &cli.StringFlag{ + Name: "host", + Aliases: []string{"hostname", "H"}, + Usage: "Node hostname(s) to be used. Comma separated for multiple values", + }, + &cli.StringFlag{ + Name: "platform", + Aliases: []string{"p"}, + Usage: "Node platform(s) to be used. Comma separated for multiple values", + }, + &cli.StringFlag{ + Name: "tag", + Aliases: []string{"t"}, + Usage: "Tag(s) to be used. Comma separated for multiple values", }, &cli.BoolFlag{ Name: "hidden", @@ -1492,7 +1507,22 @@ func init() { &cli.StringFlag{ Name: "uuid", Aliases: []string{"u"}, - Usage: "Node UUID to be used", + Usage: "Node UUID(s) to be used. Comma separated for multiple values", + }, + &cli.StringFlag{ + Name: "host", + Aliases: []string{"hostname", "H"}, + Usage: "Node hostname(s) to be used. Comma separated for multiple values", + }, + &cli.StringFlag{ + Name: "platform", + Aliases: []string{"p"}, + Usage: "Node platform(s) to be used. Comma separated for multiple values", + }, + &cli.StringFlag{ + Name: "tag", + Aliases: []string{"t"}, + Usage: "Tag(s) to be used. Comma separated for multiple values", }, &cli.IntFlag{ Name: "expiration", diff --git a/cmd/cli/query.go b/cmd/cli/query.go index ad90e042..4043c86d 100644 --- a/cmd/cli/query.go +++ b/cmd/cli/query.go @@ -6,8 +6,12 @@ import ( "fmt" "os" "strconv" + "strings" + "time" + "github.com/jmpsec/osctrl/pkg/handlers" "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/settings" "github.com/olekukonko/tablewriter" "github.com/urfave/cli/v2" ) @@ -241,31 +245,48 @@ func runQuery(c *cli.Context) error { fmt.Println("❌ environment is required") os.Exit(1) } - uuid := c.String("uuid") - if uuid == "" { + uuidStr := c.String("uuid") + if uuidStr == "" { fmt.Println("❌ UUID is required") os.Exit(1) } + uuidList := []string{uuidStr} + if strings.Contains(uuidStr, ",") { + uuidList = strings.Split(uuidStr, ",") + } + platformStr := c.String("platform") + platformList := []string{platformStr} + if strings.Contains(platformStr, ",") { + platformList = strings.Split(platformStr, ",") + } + hostStr := c.String("host") + hostList := []string{hostStr} + if strings.Contains(hostStr, ",") { + hostList = strings.Split(hostStr, ",") + } + tagStr := c.String("tag") + tagList := []string{tagStr} + if strings.Contains(tagStr, ",") { + tagList = strings.Split(tagStr, ",") + } expHours := c.Int("expiration") hidden := c.Bool("hidden") - var queryName string + queryName := queries.GenQueryName() if dbFlag { e, err := envs.Get(env) if err != nil { return fmt.Errorf("❌ error env get - %w", err) } - queryName = queries.GenQueryName() + expTime := queries.QueryExpiration(expHours) + if expHours == 0 { + expTime = time.Time{} + } newQuery := queries.DistributedQuery{ Query: query, Name: queryName, Creator: appName, - Expected: 0, - Executions: 0, Active: true, - Expired: false, - Expiration: queries.QueryExpiration(expHours), - Completed: false, - Deleted: false, + Expiration: expTime, Hidden: hidden, Type: queries.StandardQueryType, EnvironmentID: e.ID, @@ -273,16 +294,36 @@ func runQuery(c *cli.Context) error { if err := queriesmgr.Create(&newQuery); err != nil { return fmt.Errorf("❌ error query create - %w", err) } - if (uuid != "") && nodesmgr.CheckByUUID(uuid) { - if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetUUID, uuid); err != nil { - return fmt.Errorf("❌ error create target - %w", err) + // Prepare data for the handler code + data := handlers.ProcessingQuery{ + Envs: []string{}, + Platforms: platformList, + UUIDs: uuidList, + Hosts: hostList, + Tags: tagList, + EnvID: e.ID, + InactiveHours: settingsmgr.InactiveHours(settings.NoEnvironmentID), + } + manager := handlers.Managers{ + Nodes: nodesmgr, + Envs: envs, + Tags: tagsmgr, + } + targetNodesID, err := handlers.CreateQueryCarve(data, manager, newQuery) + if err != nil { + return fmt.Errorf("❌ error creating query carve - %w", err) + } + // If the list is empty, we don't need to create node queries + if len(targetNodesID) != 0 { + if err := queriesmgr.CreateNodeQueries(targetNodesID, newQuery.ID); err != nil { + return fmt.Errorf("❌ error creating node queries - %w", err) } } - if err := queriesmgr.SetExpected(queryName, 1, e.ID); err != nil { + if err := queriesmgr.SetExpected(queryName, len(targetNodesID), e.ID); err != nil { return fmt.Errorf("❌ error set expected - %w", err) } } else if apiFlag { - q, err := osctrlAPI.RunQuery(env, uuid, query, hidden, expHours) + q, err := osctrlAPI.RunQuery(env, query, uuidList, hostList, platformList, tagList, hidden, expHours) if err != nil { return fmt.Errorf("❌ error run query - %w", err) } diff --git a/osctrl-api.yaml b/osctrl-api.yaml index 2a00b365..409d62e5 100644 --- a/osctrl-api.yaml +++ b/osctrl-api.yaml @@ -814,7 +814,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/ApiDistributedCarveRequest" + $ref: "#/components/schemas/ApiDistributedQueryRequest" responses: 200: description: successful operation diff --git a/pkg/handlers/handlers.go b/pkg/handlers/handlers.go new file mode 100644 index 00000000..010e3690 --- /dev/null +++ b/pkg/handlers/handlers.go @@ -0,0 +1,115 @@ +package handlers + +import ( + "fmt" + + "github.com/jmpsec/osctrl/pkg/environments" + "github.com/jmpsec/osctrl/pkg/nodes" + "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/tags" + "github.com/jmpsec/osctrl/pkg/utils" +) + +type ProcessingQuery struct { + Envs []string + Platforms []string + UUIDs []string + Hosts []string + Tags []string + EnvID uint + InactiveHours int64 +} + +type Managers struct { + Envs *environments.EnvManager + Nodes *nodes.NodeManager + Tags *tags.TagManager +} + +// CreateQueryCarve - Create On-demand Query or Carve, to be used in osctrl-admin or osctrl-api +func CreateQueryCarve(data ProcessingQuery, manager Managers, newQuery queries.DistributedQuery) ([]uint, error) { + var expected []uint + targetNodesID := []uint{} + // Environments target + if len(data.Envs) > 0 { + expected = []uint{} + for _, e := range data.Envs { + // TODO: Check if user has permissions to query the environment + if (e != "") && manager.Envs.Exists(e) { + nodes, err := manager.Nodes.GetByEnv(e, nodes.ActiveNodes, data.InactiveHours) + if err != nil { + return targetNodesID, fmt.Errorf("error getting nodes by environment: %w", err) + } + for _, n := range nodes { + expected = append(expected, n.ID) + } + } + } + targetNodesID = utils.Intersect(targetNodesID, expected) + } + // Platforms target + if len(data.Platforms) > 0 { + expected = []uint{} + platforms, _ := manager.Nodes.GetEnvIDPlatforms(data.EnvID) + for _, p := range data.Platforms { + if (p != "") && utils.Contains(platforms, p) { + nodes, err := manager.Nodes.GetByPlatform(data.EnvID, p, nodes.ActiveNodes, data.InactiveHours) + if err != nil { + return targetNodesID, fmt.Errorf("error getting nodes by platform: %w", err) + } + for _, n := range nodes { + expected = append(expected, n.ID) + } + } + } + targetNodesID = utils.Intersect(targetNodesID, expected) + } + // UUIDs target + if len(data.UUIDs) > 0 { + expected = []uint{} + for _, u := range data.UUIDs { + if u != "" { + node, err := manager.Nodes.GetByUUIDEnv(u, data.EnvID) + if err != nil { + return targetNodesID, fmt.Errorf("error getting node %s and failed to create node query for it: %w", u, err) + } + expected = append(expected, node.ID) + } + } + targetNodesID = utils.Intersect(targetNodesID, expected) + } + // Hostnames target + if len(data.Hosts) > 0 { + expected = []uint{} + for _, _h := range data.Hosts { + if _h != "" { + node, err := manager.Nodes.GetByIdentifierEnv(_h, data.EnvID) + if err != nil { + return targetNodesID, fmt.Errorf("error getting node %s and failed to create node query for it: %w", _h, err) + } + expected = append(expected, node.ID) + } + } + targetNodesID = utils.Intersect(targetNodesID, expected) + } + // Tags target + if len(data.Tags) > 0 { + expected = []uint{} + for _, _t := range data.Tags { + if _t != "" { + exist, tag := manager.Tags.ExistsGet(tags.GetStrTagName(_t), data.EnvID) + if exist { + tagged, err := manager.Tags.GetTaggedNodes(tag) + if err != nil { + return targetNodesID, fmt.Errorf("error getting tagged nodes for tag %s: %w", _t, err) + } + for _, tn := range tagged { + expected = append(expected, tn.NodeID) + } + } + } + } + targetNodesID = utils.Intersect(targetNodesID, expected) + } + return targetNodesID, nil +} diff --git a/pkg/nodes/nodes.go b/pkg/nodes/nodes.go index e1d98ca9..efdcf285 100644 --- a/pkg/nodes/nodes.go +++ b/pkg/nodes/nodes.go @@ -111,6 +111,22 @@ func (n *NodeManager) GetByIdentifier(identifier string) (OsqueryNode, error) { return node, nil } +// GetByIdentifierEnv to retrieve full node object from DB, by uuid or hostname or localname +// UUID is expected uppercase +func (n *NodeManager) GetByIdentifierEnv(identifier string, envid uint) (OsqueryNode, error) { + var node OsqueryNode + if err := n.DB.Where( + "(uuid = ? OR hostname = ? OR localname = ?) AND environment_id = ?", + strings.ToUpper(identifier), + identifier, + identifier, + envid, + ).First(&node).Error; err != nil { + return node, err + } + return node, nil +} + // GetByUUID to retrieve full node object from DB, by uuid // UUID is expected uppercase func (n *NodeManager) GetByUUID(uuid string) (OsqueryNode, error) { diff --git a/pkg/types/types.go b/pkg/types/types.go index a35446b4..7bfabb2d 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -51,17 +51,11 @@ type ApiDistributedQueryRequest struct { Hosts []string `json:"host_list"` Tags []string `json:"tag_list"` Query string `json:"query"` + Path string `json:"path"` Hidden bool `json:"hidden"` ExpHours int `json:"exp_hours"` } -// ApiDistributedCarveRequest to receive query requests -type ApiDistributedCarveRequest struct { - UUID string `json:"uuid"` - Path string `json:"path"` - ExpHours int `json:"exp_hours"` -} - // ApiNodeGenericRequest to receive generic node requests type ApiNodeGenericRequest struct { UUID string `json:"uuid"` diff --git a/pkg/utils/http-utils_test.go b/pkg/utils/http-utils_test.go index 1c2e4187..e7013c0e 100644 --- a/pkg/utils/http-utils_test.go +++ b/pkg/utils/http-utils_test.go @@ -1,11 +1,8 @@ package utils import ( - "bytes" - "log" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/assert" @@ -51,14 +48,6 @@ func testingMock(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("the test works")) } -func captureOutput(f func()) string { - var buf bytes.Buffer - log.SetOutput(&buf) - f() - log.SetOutput(os.Stderr) - return buf.String() -} - func TestSendRequest(t *testing.T) { server := serverMock() defer server.Close() diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 9127d61f..5c40d0dd 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -32,6 +32,7 @@ func RandomForNames() string { return hex.EncodeToString(hasher.Sum(nil)) } +// Intersect returns the intersection of two slices of uints func Intersect(slice1, slice2 []uint) []uint { if len(slice1) == 0 { return slice2 @@ -40,12 +41,10 @@ func Intersect(slice1, slice2 []uint) []uint { if len(slice2) == 0 { return slice1 } - set := make(map[uint]struct{}) for _, item := range slice1 { set[item] = struct{}{} // Add items from slice1 to the set } - intersection := []uint{} for _, item := range slice2 { if _, exists := set[item]; exists { @@ -53,6 +52,15 @@ func Intersect(slice1, slice2 []uint) []uint { delete(set, item) // Ensure uniqueness in the result } } - return intersection } + +// Contains checks if string is in the slice +func Contains(all []string, target string) bool { + for _, s := range all { + if s == target { + return true + } + } + return false +}