Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 36 additions & 174 deletions cmd/admin/handlers/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"

"github.com/jmpsec/osctrl/cmd/admin/sessions"
"github.com/jmpsec/osctrl/pkg/handlers"
"github.com/jmpsec/osctrl/pkg/nodes"
"github.com/jmpsec/osctrl/pkg/queries"
"github.com/jmpsec/osctrl/pkg/settings"
Expand Down Expand Up @@ -130,95 +131,25 @@ func (h *HandlersAdmin) QueryRunPOSTHandler(w http.ResponseWriter, r *http.Reque
adminErrorResponse(w, "error creating query", http.StatusInternalServerError, err)
return
}
// List all the nodes that match the query
var expected []uint
targetNodesID := []uint{}
// TODO: Refactor this to use osctrl-api instead of direct DB queries
// Extract targets by environment
if len(q.Environments) > 0 {
expected = []uint{}
for _, e := range q.Environments {
// TODO: Check if user has permissions to query the environment
if (e != "") && h.Envs.Exists(e) {
nodes, err := h.Nodes.GetByEnv(e, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID))
if err != nil {
adminErrorResponse(w, "error getting nodes by environment", http.StatusInternalServerError, err)
return
}
for _, n := range nodes {
expected = append(expected, n.ID)
}
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
// Prepare data for the handler code
data := handlers.ProcessingQuery{
Envs: q.Environments,
Platforms: q.Platforms,
UUIDs: q.UUIDs,
Hosts: q.Hosts,
Tags: q.Tags,
EnvID: env.ID,
InactiveHours: h.Settings.InactiveHours(settings.NoEnvironmentID),
}
// Create platform target
if len(q.Platforms) > 0 {
expected = []uint{}
platforms, _ := h.Nodes.GetEnvIDPlatforms(env.ID)
for _, p := range q.Platforms {
if (p != "") && checkValidPlatform(platforms, p) {
nodes, err := h.Nodes.GetByPlatform(env.ID, p, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID))
if err != nil {
adminErrorResponse(w, "error getting nodes by platform", http.StatusInternalServerError, err)
return
}
for _, n := range nodes {
expected = append(expected, n.ID)
}
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
manager := handlers.Managers{
Nodes: h.Nodes,
Envs: h.Envs,
Tags: h.Tags,
}
// Create UUIDs target
if len(q.UUIDs) > 0 {
expected = []uint{}
for _, u := range q.UUIDs {
if u != "" {
node, err := h.Nodes.GetByUUID(u)
if err != nil {
log.Err(err).Msgf("error getting node %s and failed to create node query for it", u)
continue
}
expected = append(expected, node.ID)
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
}
// Create hostnames target
if len(q.Hosts) > 0 {
expected = []uint{}
for _, _h := range q.Hosts {
if _h != "" {
node, err := h.Nodes.GetByIdentifier(_h)
if err != nil {
log.Err(err).Msgf("error getting node %s and failed to create node query for it", _h)
continue
}
expected = append(expected, node.ID)
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
}
// Create tags target
if len(q.Tags) > 0 {
expected = []uint{}
for _, _t := range q.Tags {
if _t != "" {
exist, tag := h.Tags.ExistsGet(tags.GetStrTagName(_t), env.ID)
if exist {
tagged, err := h.Tags.GetTaggedNodes(tag)
if err != nil {
log.Err(err).Msgf("error getting tagged nodes for tag %s", _t)
continue
}
for _, tn := range tagged {
expected = append(expected, tn.NodeID)
}
}
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
targetNodesID, err := handlers.CreateQueryCarve(data, manager, newQuery)
if err != nil {
adminErrorResponse(w, "error creating query", http.StatusInternalServerError, err)
return
}
// If the list is empty, we don't need to create node queries
if len(targetNodesID) != 0 {
Expand Down Expand Up @@ -271,7 +202,7 @@ func (h *HandlersAdmin) CarvesRunPOSTHandler(w http.ResponseWriter, r *http.Requ
}
// Parse request JSON body
log.Debug().Msg("Decoding POST body")
var c DistributedCarveRequest
var c DistributedQueryRequest
if err := json.NewDecoder(r.Body).Decode(&c); err != nil {
adminErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err)
return
Expand All @@ -297,94 +228,25 @@ func (h *HandlersAdmin) CarvesRunPOSTHandler(w http.ResponseWriter, r *http.Requ
adminErrorResponse(w, "error creating carve", http.StatusInternalServerError, err)
return
}
// List all the nodes that match the query
var expected []uint
targetNodesID := []uint{}
// Extract targets by environment
if len(c.Environments) > 0 {
expected = []uint{}
for _, e := range c.Environments {
// TODO: Check if user has permissions to query the environment
if (e != "") && h.Envs.Exists(e) {
nodes, err := h.Nodes.GetByEnv(e, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID))
if err != nil {
adminErrorResponse(w, "error getting nodes by environment", http.StatusInternalServerError, err)
return
}
for _, n := range nodes {
expected = append(expected, n.ID)
}
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
// Prepare data for the handler code
data := handlers.ProcessingQuery{
Envs: c.Environments,
Platforms: c.Platforms,
UUIDs: c.UUIDs,
Hosts: c.Hosts,
Tags: c.Tags,
EnvID: env.ID,
InactiveHours: h.Settings.InactiveHours(settings.NoEnvironmentID),
}
// Create platform target
if len(c.Platforms) > 0 {
expected = []uint{}
platforms, _ := h.Nodes.GetEnvIDPlatforms(env.ID)
for _, p := range c.Platforms {
if (p != "") && checkValidPlatform(platforms, p) {
nodes, err := h.Nodes.GetByPlatform(env.ID, p, nodes.ActiveNodes, h.Settings.InactiveHours(settings.NoEnvironmentID))
if err != nil {
adminErrorResponse(w, "error getting nodes by platform", http.StatusInternalServerError, err)
return
}
for _, n := range nodes {
expected = append(expected, n.ID)
}
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
manager := handlers.Managers{
Nodes: h.Nodes,
Envs: h.Envs,
Tags: h.Tags,
}
// Create UUIDs target
if len(c.UUIDs) > 0 {
expected = []uint{}
for _, u := range c.UUIDs {
if u != "" {
node, err := h.Nodes.GetByUUID(u)
if err != nil {
log.Err(err).Msgf("error getting node %s and failed to create carve query for it", u)
continue
}
expected = append(expected, node.ID)
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
}
// Create hostnames target
if len(c.Hosts) > 0 {
expected = []uint{}
for _, _h := range c.Hosts {
if _h != "" {
node, err := h.Nodes.GetByIdentifier(_h)
if err != nil {
log.Err(err).Msgf("error getting node %s and failed to create carve query for it", _h)
continue
}
expected = append(expected, node.ID)
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
}
// Create tags target
if len(c.Tags) > 0 {
expected = []uint{}
for _, _t := range c.Tags {
if _t != "" {
exist, tag := h.Tags.ExistsGet(tags.GetStrTagName(_t), env.ID)
if exist {
tagged, err := h.Tags.GetTaggedNodes(tag)
if err != nil {
log.Err(err).Msgf("error getting tagged nodes for tag %s", _t)
continue
}
for _, tn := range tagged {
expected = append(expected, tn.NodeID)
}
}
}
}
targetNodesID = utils.Intersect(targetNodesID, expected)
targetNodesID, err := handlers.CreateQueryCarve(data, manager, newQuery)
if err != nil {
adminErrorResponse(w, "error creating query", http.StatusInternalServerError, err)
return
}
// If the list is empty, we don't need to create node queries
if len(targetNodesID) != 0 {
Expand Down
11 changes: 0 additions & 11 deletions cmd/admin/handlers/types-requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,6 @@ type DistributedQueryRequest struct {
Save bool `json:"save"`
Name string `json:"name"`
Query string `json:"query"`
ExpHours int `json:"exp_hours"`
}

// DistributedCarveRequest to receive carve requests
type DistributedCarveRequest struct {
CSRFToken string `json:"csrftoken"`
Environments []string `json:"environment_list"`
Platforms []string `json:"platform_list"`
UUIDs []string `json:"uuid_list"`
Hosts []string `json:"host_list"`
Tags []string `json:"tag_list"`
Path string `json:"path"`
ExpHours int `json:"exp_hours"`
}
Expand Down
31 changes: 2 additions & 29 deletions cmd/admin/handlers/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,6 @@ func adminOKResponse(w http.ResponseWriter, msg string) {
utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, AdminResponse{Message: msg})
}

// Helper to verify if a platform is valid
func checkValidPlatform(platforms []string, platform string) bool {
for _, p := range platforms {
if p == platform {
return true
}
}
return false
}

// Helper to check if the CSRF token is valid
func checkCSRFToken(ctxToken, receivedToken string) bool {
return (strings.TrimSpace(ctxToken) == strings.TrimSpace(receivedToken))
Expand Down Expand Up @@ -80,17 +70,12 @@ func generateCarveQuery(file string, glob bool) string {
}

// Helper to generate the file carve query
func newCarveReady(user, path string, exp time.Time, envid uint, req DistributedCarveRequest) queries.DistributedQuery {
func newCarveReady(user, path string, exp time.Time, envid uint, req DistributedQueryRequest) queries.DistributedQuery {
return queries.DistributedQuery{
Query: generateCarveQuery(path, false),
Name: generateCarveName(),
Creator: user,
Expected: 0,
Executions: 0,
Active: true,
Completed: false,
Deleted: false,
Expired: false,
Expiration: exp,
Type: queries.CarveQueryType,
Path: path,
Expand All @@ -102,25 +87,13 @@ func newCarveReady(user, path string, exp time.Time, envid uint, req Distributed
// Helper to determine if a query may be a carve
func newQueryReady(user, query string, exp time.Time, envid uint, req DistributedQueryRequest) queries.DistributedQuery {
if strings.Contains(query, "carve") {
cReq := DistributedCarveRequest{
Environments: req.Environments,
Platforms: req.Platforms,
UUIDs: req.UUIDs,
Hosts: req.Hosts,
Tags: req.Tags,
}
return newCarveReady(user, query, exp, envid, cReq)
return newCarveReady(user, query, exp, envid, req)
}
return queries.DistributedQuery{
Query: query,
Name: generateQueryName(),
Creator: user,
Expected: 0,
Executions: 0,
Active: true,
Completed: false,
Deleted: false,
Expired: false,
Expiration: exp,
Type: queries.StandardQueryType,
EnvironmentID: envid,
Expand Down
Loading
Loading