diff --git a/commands/root.go b/commands/root.go index 9342d3a..1db52cd 100644 --- a/commands/root.go +++ b/commands/root.go @@ -4,12 +4,14 @@ import ( "fmt" "net/http" "os" + "strconv" + "os/user" "golang.org/x/net/context" "golang.org/x/oauth2/google" "google.golang.org/api/compute/v1" - "google.golang.org/cloud/compute/metadata" + "cloud.google.com/go/compute/metadata" "gopkg.in/inconshreveable/log15.v2" @@ -169,7 +171,9 @@ func (c *RootCommand) runVolumePlugin() error { } h := volume.NewHandler(d) - if err := h.ServeUnix("docker", "gce"); err != nil { + u, _ := user.Lookup("root") + gid, _ := strconv.Atoi(u.Gid) + if err := h.ServeUnix("gce", gid); err != nil { return fmt.Errorf("error starting volume driver server: %s", err) } diff --git a/plugin/volume.go b/plugin/volume.go index 31f2688..8b4606c 100644 --- a/plugin/volume.go +++ b/plugin/volume.go @@ -16,9 +16,10 @@ import ( var WaitStatusTimeout = 100 * time.Second type Volume struct { - Root string - p providers.DiskProvider - fs Filesystem + Root string + p providers.DiskProvider + fs Filesystem + mounts map[string][]string } func NewVolume(c *http.Client, project, zone, instance string) (*Volume, error) { @@ -31,33 +32,34 @@ func NewVolume(c *http.Client, project, zone, instance string) (*Volume, error) Root: "/mnt/", p: p, fs: NewFilesystem(), + mounts: map[string][]string{}, }, nil } -func (v *Volume) Create(r volume.Request) volume.Response { +func (v *Volume) Create(r *volume.CreateRequest) error { log15.Debug("create request received", "name", r.Name) start := time.Now() - config, err := v.createDiskConfig(r) + config, err := v.createDiskConfig(r.Name, r) if err != nil { - return buildReponseError(err) + return err } if err := v.p.Create(config); err != nil { - return buildReponseError(err) + return err } log15.Info("disk created", "disk", r.Name, "elapsed", time.Since(start)) - return volume.Response{} + return nil } -func (v *Volume) List(volume.Request) volume.Response { +func (v *Volume) List() (*volume.ListResponse, error) { log15.Debug("list request received") disks, err := v.p.List() if err != nil { - return buildReponseError(err) + return nil, err } - r := volume.Response{} + r := volume.ListResponse{} for _, d := range disks { if d.Status != "READY" { continue @@ -68,32 +70,32 @@ func (v *Volume) List(volume.Request) volume.Response { }) } - return r + return &r, nil } -func (v *Volume) Capabilities(volume.Request) volume.Response { +func (v *Volume) Capabilities() *volume.CapabilitiesResponse { log15.Debug("capabilities request received") - return volume.Response{ + return &volume.CapabilitiesResponse{ Capabilities: volume.Capability{Scope: "local"}, } } -func (v *Volume) Get(r volume.Request) volume.Response { +func (v *Volume) Get(r *volume.GetRequest) (*volume.GetResponse, error) { log15.Debug("get request received") disks, err := v.p.List() if err != nil { - return buildReponseError(err) + return nil, err } - resp := volume.Response{} + resp := volume.GetResponse{} for _, d := range disks { if d.Name != r.Name { continue } - config, err := v.createDiskConfig(r) + config, err := v.createDiskConfig(r.Name, nil) if err != nil { - return buildReponseError(err) + return nil, err } resp.Volume = &volume.Volume{ @@ -102,71 +104,79 @@ func (v *Volume) Get(r volume.Request) volume.Response { } } - return resp + return &resp, nil } -func (v *Volume) Remove(r volume.Request) volume.Response { +func (v *Volume) Remove(r *volume.RemoveRequest) error { log15.Debug("remove request received", "name", r.Name) start := time.Now() - config, err := v.createDiskConfig(r) + config, err := v.createDiskConfig(r.Name, nil) if err != nil { - return buildReponseError(err) + return err } if err := v.p.Delete(config); err != nil { - return buildReponseError(err) + return err } log15.Info("disk removed", "disk", r.Name, "elapsed", time.Since(start)) - return volume.Response{} + return nil } -func (v *Volume) Path(r volume.Request) volume.Response { - config, err := v.createDiskConfig(r) +func (v *Volume) Path(r *volume.PathRequest) (*volume.PathResponse, error) { + config, err := v.createDiskConfig(r.Name, nil) if err != nil { - return buildReponseError(err) + return nil, err } mnt := config.MountPoint(v.Root) log15.Debug("path request received", "name", r.Name, "mnt", mnt) if err := v.createMountPoint(config); err != nil { - return buildReponseError(err) + return nil, err } - return volume.Response{Mountpoint: mnt} + return &volume.PathResponse{Mountpoint: mnt}, nil } -func (v *Volume) Mount(r volume.Request) volume.Response { +func (v *Volume) Mount(r *volume.MountRequest) (*volume.MountResponse, error) { log15.Debug("mount request received", "name", r.Name) start := time.Now() - - config, err := v.createDiskConfig(r) + config, err := v.createDiskConfig(r.Name, nil) if err != nil { - return buildReponseError(err) + return nil, err } - if err := v.createMountPoint(config); err != nil { - return buildReponseError(err) - } + if v.trackMount(r.ID, r.Name) { + if err := v.createMountPoint(config); err != nil { + v.resetMount(r.Name) + return nil, err + } - if err := v.p.Attach(config); err != nil { - return buildReponseError(err) - } + if err := v.p.Attach(config); err != nil { + v.resetMount(r.Name) + return nil, err + } - if err := v.fs.Format(config.Dev()); err != nil { - return buildReponseError(err) - } + if err := v.fs.Format(config.Dev()); err != nil { + v.resetMount(r.Name) + return nil, err + } - if err := v.fs.Mount(config.Dev(), config.MountPoint(v.Root)); err != nil { - return buildReponseError(err) + if err := v.fs.Mount(config.Dev(), config.MountPoint(v.Root)); err != nil { + v.resetMount(r.Name) + return nil, err + } + + log15.Info("disk mounted", "disk", r.Name, "elapsed", time.Since(start)) + } else { + log15.Info("disk already mounted", "disk", r.Name, "elapsed", time.Since(start)) } - log15.Info("disk mounted", "disk", r.Name, "elapsed", time.Since(start)) - return volume.Response{ + return &volume.MountResponse{ Mountpoint: config.MountPoint(v.Root), - } + }, nil } func (v *Volume) createMountPoint(c *providers.DiskConfig) error { @@ -187,54 +197,91 @@ func (v *Volume) createMountPoint(c *providers.DiskConfig) error { return nil } -func (v *Volume) Unmount(r volume.Request) volume.Response { +func (v *Volume) Unmount(r *volume.UnmountRequest) error { log15.Debug("unmount request received", "name", r.Name) start := time.Now() - config, err := v.createDiskConfig(r) + config, err := v.createDiskConfig(r.Name, nil) if err != nil { - return buildReponseError(err) + return err } - if err := v.fs.Unmount(config.MountPoint(v.Root)); err != nil { - return buildReponseError(err) - } + if v.trackUnmount(r.ID, r.Name) { + if err := v.fs.Unmount(config.MountPoint(v.Root)); err != nil { + return err + } + + if err := v.p.Detach(config); err != nil { + return err + } - if err := v.p.Detach(config); err != nil { - return buildReponseError(err) + log15.Info("disk unmounted", "disk", r.Name, "elapsed", time.Since(start)) + } else { + log15.Info("other mounts active, not unmounting", "disk", r.Name, "elapsed", time.Since(start)) } - log15.Info("disk unmounted", "disk", r.Name, "elapsed", time.Since(start)) - return volume.Response{} + return nil } -func (v *Volume) createDiskConfig(r volume.Request) (*providers.DiskConfig, error) { - config := &providers.DiskConfig{Name: r.Name} - - for key, value := range r.Options { - switch key { - case "Name": - config.Name = value - case "Type": - config.Type = value - case "SizeGb": - var err error - config.SizeGb, err = strconv.ParseInt(value, 10, 64) - if err != nil { - return nil, err +func (v *Volume) createDiskConfig(name string, r *volume.CreateRequest) (*providers.DiskConfig, error) { + config := &providers.DiskConfig{Name: name} + + if r != nil { + for key, value := range r.Options { + switch key { + case "Name": + config.Name = value + case "Type": + config.Type = value + case "SizeGb": + var err error + config.SizeGb, err = strconv.ParseInt(value, 10, 64) + if err != nil { + return nil, err + } + case "SourceSnapshot": + config.SourceSnapshot = value + case "SourceImage": + config.SourceImage = value + default: + return nil, fmt.Errorf("unknown option %q", key) } - case "SourceSnapshot": - config.SourceSnapshot = value - case "SourceImage": - config.SourceImage = value - default: - return nil, fmt.Errorf("unknown option %q", key) } } return config, config.Validate() } -func buildReponseError(err error) volume.Response { - log15.Error("request failed", "error", err.Error()) - return volume.Response{Err: err.Error()} +// Return if mount action needed +func (v *Volume) trackMount(id string, name string) bool { + existing, ok := v.mounts[name] + if !ok || len(existing) == 0 { + v.mounts[name] = []string{id} + return true + } else { + v.mounts[name] = append(existing, id) + return false + } +} + +// Return if unmount action needed +func (v *Volume) trackUnmount(id string, name string) bool { + existing, ok := v.mounts[name] + if !ok { + // In case service restarted and state has been lost, most likely unmount means unmount + return true + } + + for i, oldId := range existing { + if oldId == id { + v.mounts[name] = append(v.mounts[name][:i], v.mounts[name][i+1:]...) + return len(v.mounts[name]) == 0 + } + } + + return false +} + +// Reset tracked mount in case of errors such as disks already being attached +func (v *Volume) resetMount(name string) { + delete(v.mounts, name) }