Skip to content

Commit

Permalink
chore: refactor Mount() for clarity (#96)
Browse files Browse the repository at this point in the history
It looks like the Mount() method once used goroutines, and while this
was removed, the complexity was not. This change does not change any
behavior, but removes overly-complex structure.

At the same time, VolumeDriver.volumes was changes to a map of value
types (rather than pointers) to make it explict when updates were
happening, as anyone with access to the pointer could make changes
without the protection of a lock, but when a value, it's necessary to do
an explicit Put() to update the map, so access to data is more
controlled.

There are still some potential data races in this code, for instance if
two Mount() calls happen at the same time, then the mount counter might
be incremented only once. The risk from this is low, but they should be
fixed in the future.
  • Loading branch information
blgm committed Jul 6, 2023
1 parent 39b2964 commit fcc5ef9
Showing 1 changed file with 56 additions and 89 deletions.
145 changes: 56 additions & 89 deletions volume_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"os"
"path/filepath"
"sync"
"time"

"code.cloudfoundry.org/dockerdriver"
Expand All @@ -23,17 +22,15 @@ import (

type NfsVolumeInfo struct {
Opts map[string]interface{} `json:"-"` // don't store opts
wg sync.WaitGroup
mountError string
dockerdriver.VolumeInfo // see dockerdriver.resources.go
dockerdriver.VolumeInfo // see dockerdriver.resources.go
}

type OsHelper interface {
Umask(mask int) (oldmask int)
}

type VolumeDriver struct {
volumes *syncmap.SyncMap[*NfsVolumeInfo]
volumes *syncmap.SyncMap[NfsVolumeInfo]
os osshim.Os
filepath filepathshim.Filepath
ioutil ioutilshim.Ioutil
Expand All @@ -46,7 +43,7 @@ type VolumeDriver struct {

func NewVolumeDriver(logger lager.Logger, os osshim.Os, filepath filepathshim.Filepath, ioutil ioutilshim.Ioutil, time timeshim.Time, mountChecker mountchecker.MountChecker, mountPathRoot string, mounter Mounter, oshelper OsHelper) *VolumeDriver {
d := &VolumeDriver{
volumes: syncmap.New[*NfsVolumeInfo](),
volumes: syncmap.New[NfsVolumeInfo](),
os: os,
filepath: filepath,
ioutil: ioutil,
Expand Down Expand Up @@ -97,7 +94,7 @@ func (d *VolumeDriver) Create(env dockerdriver.Env, createRequest dockerdriver.C
Opts: createRequest.Opts,
}

d.volumes.Put(createRequest.Name, &volInfo)
d.volumes.Put(createRequest.Name, volInfo)
} else {
existing.Opts = createRequest.Opts

Expand Down Expand Up @@ -134,103 +131,60 @@ func (d *VolumeDriver) Mount(env dockerdriver.Env, mountRequest dockerdriver.Mou
return dockerdriver.MountResponse{Err: "Missing mandatory 'volume_name'"}
}

var doMount bool
var opts map[string]interface{}
var mountPath string
var wg *sync.WaitGroup

ret := func() dockerdriver.MountResponse {

volume, ok := d.volumes.Get(mountRequest.Name)
if !ok {
return dockerdriver.MountResponse{Err: fmt.Sprintf("Volume '%s' must be created before being mounted", mountRequest.Name)}
}

mountPath = d.mountPath(driverhttp.EnvWithLogger(logger, env), volume.Name)

logger.Info("mounting-volume", lager.Data{"id": volume.Name, "mountpoint": mountPath})
logger.Info("mount-source", lager.Data{"source": volume.Opts["source"].(string)})

if volume.MountCount < 1 {
doMount = true
volume.wg.Add(1)
opts = map[string]interface{}{}
for k, v := range volume.Opts {
opts[k] = v
}
}

volume.Mountpoint = mountPath
volume.MountCount++
volume, ok := d.volumes.Get(mountRequest.Name)
if !ok {
return dockerdriver.MountResponse{Err: fmt.Sprintf("Volume '%s' must be created before being mounted", mountRequest.Name)}
}

logger.Info("volume-ref-count-incremented", lager.Data{"name": volume.Name, "count": volume.MountCount})
mountPath := d.mountPath(driverhttp.EnvWithLogger(logger, env), volume.Name)
volume.Mountpoint = mountPath
logger.Info("mounting-volume", lager.Data{"id": volume.Name, "mountpoint": mountPath})
logger.Info("mount-source", lager.Data{"source": volume.Opts["source"].(string)})

if err := d.persistState(driverhttp.EnvWithLogger(logger, env)); err != nil {
logger.Error("persist-state-failed", err)
return dockerdriver.MountResponse{Err: fmt.Sprintf("persist state failed when mounting: %s", err.Error())}
}
doMount := volume.MountCount < 1
volume.MountCount++
logger.Info("volume-ref-count-incremented", lager.Data{"name": volume.Name, "count": volume.MountCount})

wg = &volume.wg
return dockerdriver.MountResponse{Mountpoint: volume.Mountpoint}
}()

if ret.Err != "" {
return ret
d.volumes.Put(mountRequest.Name, volume)
if err := d.persistState(driverhttp.EnvWithLogger(logger, env)); err != nil {
logger.Error("persist-state-failed", err)
return dockerdriver.MountResponse{Err: fmt.Sprintf("persist state failed when mounting: %s", err.Error())}
}

if doMount {
mountStartTime := d.time.Now()

err := d.mount(driverhttp.EnvWithLogger(logger, env), opts, mountPath)
err := d.mount(driverhttp.EnvWithLogger(logger, env), copyOpts(volume.Opts), mountPath)

mountEndTime := d.time.Now()
mountDuration := mountEndTime.Sub(mountStartTime)
if mountDuration > 8*time.Second {
logger.Error("mount-duration-too-high", nil, lager.Data{"mount-duration-in-second": mountDuration / time.Second, "warning": "This may result in container creation failure!"})
}

func() {
volume, ok := d.volumes.Get(mountRequest.Name)
if !ok {
ret = dockerdriver.MountResponse{Err: fmt.Sprintf("Volume '%s' not found", mountRequest.Name)}
} else if err != nil {
if _, ok := err.(dockerdriver.SafeError); ok {
errBytes, m_err := json.Marshal(err)
if m_err != nil {
logger.Error("failed-to-marshal-safeerror", m_err)
volume.mountError = err.Error()
}
volume.mountError = string(errBytes)
} else {
volume.mountError = err.Error()
}
switch err.(type) {
case nil:
return dockerdriver.MountResponse{Mountpoint: volume.Mountpoint}
case dockerdriver.SafeError:
errBytes, mErr := json.Marshal(err)
if mErr != nil {
logger.Error("failed-to-marshal-safeerror", mErr)
return dockerdriver.MountResponse{Err: err.Error()}
}
}()

wg.Done()
}

wg.Wait()

return func() dockerdriver.MountResponse {
volume, ok := d.volumes.Get(mountRequest.Name)
if !ok {
return dockerdriver.MountResponse{Err: fmt.Sprintf("Volume '%s' not found", mountRequest.Name)}
} else if volume.mountError != "" {
return dockerdriver.MountResponse{Err: volume.mountError}
} else {
// Check the volume to make sure it's still mounted before handing it out again.
if !doMount && !d.mounter.Check(driverhttp.EnvWithLogger(logger, env), volume.Name, volume.Mountpoint) {
wg.Add(1)
defer wg.Done()
if err := d.mount(driverhttp.EnvWithLogger(logger, env), volume.Opts, mountPath); err != nil {
logger.Error("remount-volume-failed", err)
return dockerdriver.MountResponse{Err: fmt.Sprintf("Error remounting volume: %s", err.Error())}
}
return dockerdriver.MountResponse{Err: string(errBytes)}
default:
return dockerdriver.MountResponse{Err: err.Error()}
}
} else {
// Check the volume to make sure it's still mounted before handing it out again.
if !d.mounter.Check(driverhttp.EnvWithLogger(logger, env), volume.Name, volume.Mountpoint) {
if err := d.mount(driverhttp.EnvWithLogger(logger, env), volume.Opts, mountPath); err != nil {
logger.Error("remount-volume-failed", err)
return dockerdriver.MountResponse{Err: fmt.Sprintf("Error remounting volume: %s", err.Error())}
}
return dockerdriver.MountResponse{Mountpoint: volume.Mountpoint}
}
}()
return dockerdriver.MountResponse{Mountpoint: volume.Mountpoint}
}
}

func (d *VolumeDriver) Path(env dockerdriver.Env, pathRequest dockerdriver.PathRequest) dockerdriver.PathResponse {
Expand Down Expand Up @@ -258,6 +212,8 @@ func (d *VolumeDriver) Path(env dockerdriver.Env, pathRequest dockerdriver.PathR

func (d *VolumeDriver) Unmount(env dockerdriver.Env, unmountRequest dockerdriver.UnmountRequest) dockerdriver.ErrorResponse {
logger := env.Logger().Session("unmount", lager.Data{"volume": unmountRequest.Name})
logger.Info("start")
defer logger.Info("end")

if unmountRequest.Name == "" {
return dockerdriver.ErrorResponse{Err: "Missing mandatory 'volume_name'"}
Expand Down Expand Up @@ -285,8 +241,11 @@ func (d *VolumeDriver) Unmount(env dockerdriver.Env, unmountRequest dockerdriver
volume.MountCount--
logger.Info("volume-ref-count-decremented", lager.Data{"name": volume.Name, "count": volume.MountCount})

if volume.MountCount < 1 {
switch volume.MountCount {
case 0:
d.volumes.Delete(unmountRequest.Name)
default:
d.volumes.Put(unmountRequest.Name, volume)
}

if err := d.persistState(driverhttp.EnvWithLogger(logger, env)); err != nil {
Expand Down Expand Up @@ -343,15 +302,15 @@ func (d *VolumeDriver) Get(env dockerdriver.Env, getRequest dockerdriver.GetRequ
}
}

func (d *VolumeDriver) getVolume(env dockerdriver.Env, volumeName string) (*NfsVolumeInfo, error) {
func (d *VolumeDriver) getVolume(env dockerdriver.Env, volumeName string) (NfsVolumeInfo, error) {
logger := env.Logger().Session("get-volume")

if vol, ok := d.volumes.Get(volumeName); ok {
logger.Info("getting-volume", lager.Data{"name": volumeName})
return vol, nil
}

return &NfsVolumeInfo{}, errors.New("Volume not found")
return NfsVolumeInfo{}, errors.New("Volume not found")
}

func (d *VolumeDriver) Capabilities(env dockerdriver.Env) dockerdriver.CapabilitiesResponse {
Expand Down Expand Up @@ -544,3 +503,11 @@ func (d *VolumeDriver) Drain(env dockerdriver.Env) error {

return nil
}

func copyOpts(input map[string]any) map[string]any {
output := make(map[string]any)
for k, v := range input {
output[k] = v
}
return output
}

0 comments on commit fcc5ef9

Please sign in to comment.