diff --git a/lxd/db/operations.go b/lxd/db/operations.go index 0597d5cd2c..6e7ccbe6ec 100644 --- a/lxd/db/operations.go +++ b/lxd/db/operations.go @@ -22,6 +22,12 @@ func (c *ClusterTx) OperationsUUIDs() ([]string, error) { return query.SelectStrings(c.tx, stmt, c.nodeID) } +// OperationNodes returns a list of nodes that have running operations +func (c *ClusterTx) OperationNodes() ([]string, error) { + stmt := "SELECT DISTINCT nodes.address FROM operations JOIN nodes ON nodes.id = node_id" + return query.SelectStrings(c.tx, stmt) +} + // OperationByUUID returns the operation with the given UUID. func (c *ClusterTx) OperationByUUID(uuid string) (Operation, error) { null := Operation{} diff --git a/lxd/operations.go b/lxd/operations.go index 299ab81072..a36cec4ec6 100644 --- a/lxd/operations.go +++ b/lxd/operations.go @@ -14,6 +14,7 @@ import ( "github.com/lxc/lxd/lxd/cluster" "github.com/lxc/lxd/lxd/db" + "github.com/lxc/lxd/lxd/node" "github.com/lxc/lxd/lxd/util" "github.com/lxc/lxd/shared" "github.com/lxc/lxd/shared/api" @@ -453,35 +454,41 @@ func operationAPIGet(d *Daemon, r *http.Request) Response { var body *api.Operation - // First check the local cache, then the cluster database table. + // First check if the query is for a local operation from this node op, err := operationGet(id) if err == nil { _, body, err = op.Render() if err != nil { return SmartError(err) } - } else { - var address string - err = d.cluster.Transaction(func(tx *db.ClusterTx) error { - operation, err := tx.OperationByUUID(id) - if err != nil { - return err - } - address = operation.NodeAddress - return nil - }) - if err != nil { - return SmartError(err) - } - cert := d.endpoints.NetworkCert() - client, err := cluster.Connect(address, cert, false) - if err != nil { - return SmartError(err) - } - body, _, err = client.GetOperation(id) + + return SyncResponse(true, body) + } + + // Then check if the query is from an operation on another node, and, if so, forward it + var address string + err = d.cluster.Transaction(func(tx *db.ClusterTx) error { + operation, err := tx.OperationByUUID(id) if err != nil { - return SmartError(err) + return err } + + address = operation.NodeAddress + return nil + }) + if err != nil { + return SmartError(err) + } + + cert := d.endpoints.NetworkCert() + client, err := cluster.Connect(address, cert, false) + if err != nil { + return SmartError(err) + } + + body, _, err = client.GetOperation(id) + if err != nil { + return SmartError(err) } return SyncResponse(true, body) @@ -490,14 +497,41 @@ func operationAPIGet(d *Daemon, r *http.Request) Response { func operationAPIDelete(d *Daemon, r *http.Request) Response { id := mux.Vars(r)["id"] + // First check if the query is for a local operation from this node op, err := operationGet(id) + if err == nil { + _, err = op.Cancel() + if err != nil { + return BadRequest(err) + } + + return EmptySyncResponse + } + + // Then check if the query is from an operation on another node, and, if so, forward it + var address string + err = d.cluster.Transaction(func(tx *db.ClusterTx) error { + operation, err := tx.OperationByUUID(id) + if err != nil { + return err + } + + address = operation.NodeAddress + return nil + }) if err != nil { - return NotFound(err) + return SmartError(err) } - _, err = op.Cancel() + cert := d.endpoints.NetworkCert() + client, err := cluster.Connect(address, cert, false) if err != nil { - return BadRequest(err) + return SmartError(err) + } + + err = client.DeleteOperation(id) + if err != nil { + return SmartError(err) } return EmptySyncResponse @@ -506,38 +540,165 @@ func operationAPIDelete(d *Daemon, r *http.Request) Response { var operationCmd = Command{name: "operations/{id}", get: operationAPIGet, delete: operationAPIDelete} func operationsAPIGet(d *Daemon, r *http.Request) Response { - var md shared.Jmap - recursion := util.IsRecursionRequest(r) - md = shared.Jmap{} + localOperationURLs := func() (shared.Jmap, error) { + // Get all the operations + operationsLock.Lock() + ops := operations + operationsLock.Unlock() - operationsLock.Lock() - ops := operations - operationsLock.Unlock() + // Build a list of URLs + body := shared.Jmap{} - for _, v := range ops { - status := strings.ToLower(v.status.String()) - _, ok := md[status] - if !ok { - if recursion { - md[status] = make([]*api.Operation, 0) - } else { - md[status] = make([]string, 0) + for _, v := range ops { + status := strings.ToLower(v.status.String()) + _, ok := body[status] + if !ok { + body[status] = make([]string, 0) } + + body[status] = append(body[status].([]string), v.url) } - if !recursion { - md[status] = append(md[status].([]string), v.url) - continue + return body, nil + } + + localOperations := func() (shared.Jmap, error) { + // Get all the operations + operationsLock.Lock() + ops := operations + operationsLock.Unlock() + + // Build a list of operations + body := shared.Jmap{} + + for _, v := range ops { + status := strings.ToLower(v.status.String()) + _, ok := body[status] + if !ok { + body[status] = make([]*api.Operation, 0) + } + + _, op, err := v.Render() + if err != nil { + return nil, err + } + + body[status] = append(body[status].([]*api.Operation), op) } - _, body, err := v.Render() + return body, nil + } + + // Check if called from a cluster node + if isClusterNotification(r) { + // Only return the local data + if recursion { + // Recursive queries + body, err := localOperations() + if err != nil { + return InternalError(err) + } + + return SyncResponse(true, body) + } + + // Normal queries + body, err := localOperationURLs() + if err != nil { + return InternalError(err) + } + + return SyncResponse(true, body) + } + + // Start with local operations + var md shared.Jmap + var err error + + if recursion { + md, err = localOperations() + if err != nil { + return InternalError(err) + } + } else { + md, err = localOperationURLs() + if err != nil { + return InternalError(err) + } + } + + // Check if clustered + clustered, err := cluster.Enabled(d.db) + if err != nil { + return InternalError(err) + } + + // Return now if not clustered + if !clustered { + return SyncResponse(true, md) + } + + // Get all nodes with running operations + var nodes []string + err = d.cluster.Transaction(func(tx *db.ClusterTx) error { + var err error + + nodes, err = tx.OperationNodes() if err != nil { + return err + } + + return nil + }) + if err != nil { + return SmartError(err) + } + + // Get local address + localAddress, err := node.HTTPSAddress(d.db) + if err != nil { + return InternalError(err) + } + + cert := d.endpoints.NetworkCert() + for _, node := range nodes { + if node == localAddress { continue } - md[status] = append(md[status].([]*api.Operation), body) + // Connect to the remote server + client, err := cluster.Connect(node, cert, true) + if err != nil { + return SmartError(err) + } + + // Get operation data + ops, err := client.GetOperations() + if err != nil { + return SmartError(err) + } + + // Merge with existing data + for _, op := range ops { + status := strings.ToLower(op.Status) + + _, ok := md[status] + if !ok { + if recursion { + md[status] = make([]*api.Operation, 0) + } else { + md[status] = make([]string, 0) + } + } + + if recursion { + md[status] = append(md[status].([]*api.Operation), &op) + } else { + md[status] = append(md[status].([]string), fmt.Sprintf("/1.0/operations/%s", op.ID)) + } + } } return SyncResponse(true, md) @@ -546,23 +707,51 @@ func operationsAPIGet(d *Daemon, r *http.Request) Response { var operationsCmd = Command{name: "operations", get: operationsAPIGet} func operationAPIWaitGet(d *Daemon, r *http.Request) Response { + id := mux.Vars(r)["id"] + timeout, err := shared.AtoiEmptyDefault(r.FormValue("timeout"), -1) if err != nil { return InternalError(err) } - id := mux.Vars(r)["id"] + // First check if the query is for a local operation from this node op, err := operationGet(id) + if err == nil { + _, err = op.WaitFinal(timeout) + if err != nil { + return InternalError(err) + } + + _, body, err := op.Render() + if err != nil { + return SmartError(err) + } + + return SyncResponse(true, body) + } + + // Then check if the query is from an operation on another node, and, if so, forward it + var address string + err = d.cluster.Transaction(func(tx *db.ClusterTx) error { + operation, err := tx.OperationByUUID(id) + if err != nil { + return err + } + + address = operation.NodeAddress + return nil + }) if err != nil { - return NotFound(err) + return SmartError(err) } - _, err = op.WaitFinal(timeout) + cert := d.endpoints.NetworkCert() + client, err := cluster.Connect(address, cert, false) if err != nil { - return InternalError(err) + return SmartError(err) } - _, body, err := op.Render() + _, body, err := client.GetOperationWait(id, timeout) if err != nil { return SmartError(err) } @@ -618,15 +807,13 @@ func (r *forwardedOperationWebSocket) String() string { func operationAPIWebsocketGet(d *Daemon, r *http.Request) Response { id := mux.Vars(r)["id"] - // First check if the websocket is for a local operation from this - // node. + // First check if the query is for a local operation from this node op, err := operationGet(id) if err == nil { return &operationWebSocket{r, op} } - // Secondly check if the websocket is from an operation on another - // node, and, if so, proxy it. + // Then check if the query is from an operation on another node, and, if so, forward it secret := r.FormValue("secret") if secret == "" { return BadRequest(fmt.Errorf("missing secret")) @@ -638,6 +825,7 @@ func operationAPIWebsocketGet(d *Daemon, r *http.Request) Response { if err != nil { return err } + address = operation.NodeAddress return nil }) @@ -656,6 +844,7 @@ func operationAPIWebsocketGet(d *Daemon, r *http.Request) Response { if err != nil { return SmartError(err) } + return &forwardedOperationWebSocket{req: r, id: id, source: source} }